#include <stdio.h>
#include <stdint.h>
#include <sys/time.h>

typedef uint8_t data;
typedef uint16_t data2;

#define INT_MULT(a,b,t)  ((t) = (a) * (b) + 0x80, ((((t) >> 8) + (t)) >> 8))

#if 0
#define GEN_combineVectors(suffix, data)			\
void combineVectors##suffix(                                    \
                    data c1,					\
		    data c2,					\
		    int n,					\
		    data * restrict dest,			\
		    data * restrict const s1,				\
		    data * restrict const s2) {			\
  for(int i=0; i<n; i++) {					\
    dest[i] = (c1 * s1[i]) + (c2 * s2[i]);			\
  }								\
}

GEN_combineVectors(F, float)
GEN_combineVectors(D, double)
#endif

void combineVectorsInt8(
                    uint8_t c1,
		    uint8_t c2,
		    int n,
		    uint8_t * restrict dest,
		    uint8_t * restrict const s1,
		    uint8_t * restrict const s2) {
  for(int i=0; i<n; i++) {
    uint16_t t = (c1 * s1[i]) + (c2 * s2[i]) + 0x80;
    dest[i] = ((t >> 8) + t) >> 8;
  }
}

typedef union {
  uint8_t t8[8];
  uint16_t t16[4];
  uint8_t v8 __attribute__ ((vector_size (8)));
  uint16_t v16 __attribute__ ((vector_size (8)));
} v64;

typedef union {
  uint8_t t8[16];
  uint16_t t16[8];
  uint8_t v8 __attribute__ ((vector_size (16)));
  uint16_t v16 __attribute__ ((vector_size (16)));
} v128;

void combineVectorsInt8SSE_v8(
                    uint8_t c1,
		    uint8_t c2,
		    int n,
		    v128 * restrict dest,
		    v128 * restrict const s1,
		    v128 * restrict const s2) {
  v128 zero, c1v, c2v, shift;
  zero.v8 ^= zero.v8;
  for(int i=0; i<sizeof(v128)/2; i++) {
    c1v.t16[i]=c1;
    c2v.t16[i]=c2;
    shift.t16[i]=0x80;
  }

  for(int i=0; i<n/sizeof(v128); i++) {
    v128 dh, dl;

    {
      v128 s1h, s2h, dht;
      s1h.v16 = __builtin_ia32_punpckhbw128 (s1[i].v16, zero.v16);
      s2h.v16 = __builtin_ia32_punpckhbw128 (s2[i].v16, zero.v16);
      dht.v16 = (c1v.v16 * s1h.v16) + (c2v.v16 * s2h.v16) + shift.v16;
      dh.v16 = __builtin_ia32_psrlwi128(dht.v16, 8);
      dh.v16 += dht.v16;
      dh.v16 = __builtin_ia32_psrlwi128(dh.v16, 8);
    }

    {
      v128 s1l, s2l, dlt;
      s1l.v16 = __builtin_ia32_punpcklbw128 (s1[i].v16, zero.v16);
      s2l.v16 = __builtin_ia32_punpcklbw128 (s2[i].v16, zero.v16);
      dlt.v16 = (c1v.v16 * s1l.v16) + (c2v.v16 * s2l.v16) + shift.v16;
      dl.v16 = __builtin_ia32_psrlwi128(dlt.v16, 8);
      dl.v16 += dlt.v16;
      dl.v16 = __builtin_ia32_psrlwi128(dl.v16, 8);
    }

    dest[i].v8 = __builtin_ia32_packuswb128(dl.v16, dh.v16);
  }
}

int main(int argc, char **argv) {
  const int nr = 65536;
  struct timeval chrono1, chrono2;

  uint8_t dest_n[nr], dest_v[nr], s1[nr], s2[nr];
  for(int i=0; i<nr; i++) {
    s1[i] = i;
  }
  {
    uint8_t z = 43;
    for(int i=0; i<nr; i++) {
      s2[i] = z;
      z = z * 87 + 138;
    }
  }
  uint8_t c1 = 100, c2 = 155;

  gettimeofday(&chrono1, NULL);
  for(int count=0; count<1000; count++) {
    combineVectorsInt8(c1, c2, nr, dest_n, s1, s2);
  }
  gettimeofday(&chrono2, NULL);
  printf("%ld\n", (chrono2.tv_sec - chrono1.tv_sec)*1000 +
	 (chrono2.tv_usec - chrono1.tv_usec)/1000);

  gettimeofday(&chrono1, NULL);
  for(int count=0; count<1000; count++) {
    combineVectorsInt8SSE_v8(c1, c2, nr, dest_v, s1, s2);
  }
  gettimeofday(&chrono2, NULL);

  printf("%ld\n", (chrono2.tv_sec - chrono1.tv_sec)*1000 +
	 (chrono2.tv_usec - chrono1.tv_usec)/1000);

  for(int i=0; i<nr; i++) {
    if(dest_n[i] != dest_v[i]) {
      printf("error at %d\n", i);
    }
  }
}
