MMX pmulhw and pmullw ?

From: Bernie (spamtrap_at_crayne.org)
Date: 09/23/04


Date: Thu, 23 Sep 2004 09:04:52 +0000 (UTC)

hi i have an C source code and i would like optimize it.
In order to do this task i have begin by write SIMD code with the intel C
intrisics.
But i have the following problem, i don't know how rewrite my code after the
1st mul operation

In the following i have comment C code and use only the pmullw instruction,
it's ok with small number but with larger some computation failed....I have
write in the code the line at which (search "WARNING") i don't know how
rewrite and manage the 16 higher bit of the interne 32bits product
result....

any help or advice will be welcome
ab

#include <mmintrin.h>
#define _MM_4W(vector,element) (*((short*)&##vector + ##element))
#define C1_OFFSET 0
#define S1_OFFSET 4
#define C3_OFFSET 8
#define S3_OFFSET 12
#define R2C6_OFFSET 16
#define R2S6_OFFSET 20
#define R2_OFFSET 24
#define NS1C1_OFFSET 28
#define S1C1_OFFSET 32
#define NS3C3_OFFSET 36
#define S3C3_OFFSET 40
#define NR2S6C6_OFFSET 44
#define R2S6C6_OFFSET 48

#define SHIFT_ACC 3//2 ou 2

static const short ps_coef_trans[52] =
 {
  1004, // c1=cos(pi/16) << 10 (+0)
  1004,
  1004,
  1004,
  200, // s1=sin(pi/16) (+4)
  200,
  200,
  200,
  851, // c3=cos(3pi/16) << 10 (+8)
  851,
  851,
  851,
  569, // s3=sin(3pi/16) << 10 (+12)
  569,
  569,
  569,
  554, // r2c6=sqrt(2)*cos(6pi/16) << 10 (+16)
  554,
  554,
  554,
  1337, // r2s6=sqrt(2)*sin(6pi/16) << 10 (+20)
  1337,
  1337,
  1337,
  181, // r2=sqrt(2) << 7 (+24)
  181,
  181,
  181,
  -1204,//ns1c1=-s1-c1 (+28)
  -1204,
  -1204,
  -1204,
  -804,//s1c1=s1-c1 (+32)
  -804,
  -804,
  -804,
  -1420,//ns3c3=-s3-c3 (+36)
  -1420,
  -1420,
  -1420,
  -282,//s3c3=s3-c3 (+40)
  -282,
  -282,
  -282,
  -1891,//nr2s6c6=-r2s6-r2c6 (+44)
  -1891,
  -1891,
  -1891,
  783,//r2s6c6=r2s6-r2c6 (+48)
  783,
  783,
  783
 };

void dct_aan_col_4_mmx(short *pre_dct_col_data, short *dct)
{
 /*static const int c1=1004; // cos(pi/16) << 10
 static const int s1=200; // sin(pi/16)
 static const int c3=851; // cos(3pi/16) << 10
 static const int s3=569; // sin(3pi/16) << 10
 static const int r2c6=554; // sqrt(2)*cos(6pi/16) << 10
 static const int r2s6=1337; // sqrt(2)*sin(6pi/16) << 10
 static const int r2=181; // sqrt(2) << 7

 static const int ns1c1=-s1-c1;
 static const int s1c1=s1-c1;
 static const int ns3c3=-s3-c3;
 static const int s3c3=s3-c3;

 static const int nr2s6c6=-r2s6-r2c6;
 static const int r2s6c6=r2s6-r2c6;

 int x0,x1,x2,x3,x4,x5,x6,x7,x8;
 x8=0;
 //*/

 __m64 c1,ns1c1,s1c1,c3,ns3c3,s3c3,r2c6,nr2s6c6,r2s6c6,r2;

 c1 =
_mm_set_pi16(ps_coef_trans[C1_OFFSET+3],ps_coef_trans[C1_OFFSET+2],ps_coef_t
rans[C1_OFFSET+1],ps_coef_trans[C1_OFFSET]);
 ns1c1 =
_mm_set_pi16(ps_coef_trans[NS1C1_OFFSET+3],ps_coef_trans[NS1C1_OFFSET+2],ps_
coef_trans[NS1C1_OFFSET+1],ps_coef_trans[NS1C1_OFFSET]);
 s1c1 =
_mm_set_pi16(ps_coef_trans[S1C1_OFFSET+3],ps_coef_trans[S1C1_OFFSET+2],ps_co
ef_trans[S1C1_OFFSET+1],ps_coef_trans[S1C1_OFFSET]);
 c3 =
_mm_set_pi16(ps_coef_trans[C3_OFFSET+3],ps_coef_trans[C3_OFFSET+2],ps_coef_t
rans[C3_OFFSET+1],ps_coef_trans[C3_OFFSET]);
 ns3c3 =
_mm_set_pi16(ps_coef_trans[NS3C3_OFFSET+3],ps_coef_trans[NS3C3_OFFSET+2],ps_
coef_trans[NS3C3_OFFSET+1],ps_coef_trans[NS3C3_OFFSET]);
 s3c3 =
_mm_set_pi16(ps_coef_trans[S3C3_OFFSET+3],ps_coef_trans[S3C3_OFFSET+2],ps_co
ef_trans[S3C3_OFFSET+1],ps_coef_trans[S3C3_OFFSET]);
 r2c6 =
_mm_set_pi16(ps_coef_trans[R2C6_OFFSET+3],ps_coef_trans[R2C6_OFFSET+2],ps_co
ef_trans[R2C6_OFFSET+1],ps_coef_trans[R2C6_OFFSET]);
 nr2s6c6 =
_mm_set_pi16(ps_coef_trans[NR2S6C6_OFFSET+3],ps_coef_trans[NR2S6C6_OFFSET+2]
,ps_coef_trans[NR2S6C6_OFFSET+1],ps_coef_trans[NR2S6C6_OFFSET]);
 r2s6c6 =
_mm_set_pi16(ps_coef_trans[R2S6C6_OFFSET+3],ps_coef_trans[R2S6C6_OFFSET+2],p
s_coef_trans[R2S6C6_OFFSET+1],ps_coef_trans[R2S6C6_OFFSET]);
 r2 =
_mm_set_pi16(ps_coef_trans[R2_OFFSET+3],ps_coef_trans[R2_OFFSET+2],ps_coef_t
rans[R2_OFFSET+1],ps_coef_trans[R2_OFFSET]);

 __m64 vect16; vect16 = _mm_set_pi16(16,16,16,16);
 __m64 vect16384; vect16384 = _mm_set_pi16(16384,16384,16384,16384);
 __m64 vect8192; vect8192 = _mm_set_pi16(8192,8192,8192,8192);

 __m64 x0,x1,x2,x3,x4,x5,x6,x7,x8;

 x0 =
_mm_set_pi16(pre_dct_col_data[3],pre_dct_col_data[2],pre_dct_col_data[1],pre
_dct_col_data[0]);
 x1 =
_mm_set_pi16(pre_dct_col_data[11],pre_dct_col_data[10],pre_dct_col_data[9],p
re_dct_col_data[8]);
 x2 =
_mm_set_pi16(pre_dct_col_data[19],pre_dct_col_data[18],pre_dct_col_data[17],
pre_dct_col_data[16]);
 x3 =
_mm_set_pi16(pre_dct_col_data[27],pre_dct_col_data[26],pre_dct_col_data[25],
pre_dct_col_data[24]);
 x4 =
_mm_set_pi16(pre_dct_col_data[35],pre_dct_col_data[34],pre_dct_col_data[33],
pre_dct_col_data[32]);
 x5 =
_mm_set_pi16(pre_dct_col_data[43],pre_dct_col_data[42],pre_dct_col_data[41],
pre_dct_col_data[40]);
 x6 =
_mm_set_pi16(pre_dct_col_data[51],pre_dct_col_data[50],pre_dct_col_data[49],
pre_dct_col_data[48]);
 x7 =
_mm_set_pi16(pre_dct_col_data[59],pre_dct_col_data[58],pre_dct_col_data[57],
pre_dct_col_data[56]);

 // Stage 1
 //x8=x7+x0;
 x8=_m_paddsw(x7,x0);
 //x0-=x7;
 x0= _m_psubsw(x0,x7);
 //x7=x1+x6;
 x7=_m_paddsw(x1,x6);
 //x1-=x6;
 x1=_m_psubsw(x1,x6);
 //x6=x2+x5;
 x6=_m_paddsw(x2,x5);
 //x2-=x5;
 x2=_m_psubsw(x2,x5);
 //x5=x3+x4;
 x5=_m_paddsw(x3,x4);
 //x3-=x4;
 x3=_m_psubsw(x3,x4);
 //*/

 // Stage 2
 //x4=x8+x5;
 x4=_m_paddsw(x8,x5);
 //x8-=x5;
 x8=_m_psubsw(x8,x5);
 //x5=x7+x6;
 x5=_m_paddsw(x7,x6);
 //x7-=x6;
 x7=_m_psubsw(x7,x6);

 //x6=c1*(x1+x2); //WARNING : Here begin my trouble
 __m64 x6_mull=_m_pmullw(c1,_m_paddsw(x1,x2));
//If i do __m64 x6_mulh=_m_pmulhw(c1,_m_paddsw(x1,x2));

 //x2=ns1c1*x2+x6; //How do this operation, manage carry between lower and
higher 16 bits.
 __m64 x2_mull=_m_paddsw(_m_pmullw(ns1c1,x2),x6_mull);
//maybe __m64 x2_mulh=_m_paddsw(_m_pmulhw(ns1c1,x2),x6_mulh);

 //x1=s1c1*x1+x6;
 __m64 x1_mull=_m_paddsw(_m_pmullw(s1c1,x1),x6_mull);

 //x6=c3*(x0+x3););
 x6_mull=_m_pmullw(c3,_m_paddsw(x0,x3));

 //x3=ns3c3*x3+x6;
 __m64 x3_mull=_m_paddsw(_m_pmullw(ns3c3,x3),x6_mull);

 //x0=s3c3*x0+x6;
 __m64 x0_mull=_m_paddsw(_m_pmullw(s3c3,x0),x6_mull);
 //*/

 // Stage 3
 //x6=x4+x5;
 x6=_m_paddsw(x4,x5);

 //x5=r2c6*(x7+x8);
 __m64 x5_mull=_m_pmullw(r2c6,_m_paddsw(x7,x8));

 //x8=r2s6c6*x8+x5;
 __m64 x8_mull=_m_paddsw(_m_pmullw(r2s6c6,x8),x5_mull);

 //x5=x0+x2;
 x5_mull=_m_paddsw(x0_mull,x2_mull);

 //x2=x3+x1;
 x2_mull=_m_paddsw(x3_mull,x1_mull);

 //x3-=x1;
 x3_mull=_m_psubsw(x3_mull,x1_mull);
 //*/

 // Stage 4 and output

 //dct[0]=(short)((x6+16)>>3);
 x6=_m_paddsw(x6,vect16);
 x6=_mm_srai_pi16(x6,3);
 dct[0] = _MM_4W(x6,0);dct[1] = _MM_4W(x6,1);dct[2] = _MM_4W(x6,2);dct[3] =
_MM_4W(x6,3);

 //dct[16]=(short)((x8+16384)>>13);
 x8_mull=_m_paddsw(x8_mull,vect16384);
 x8_mull=_mm_srai_pi16(x8_mull,13);
 dct[0+16] = _MM_4W(x8_mull,0);dct[1+16] = _MM_4W(x8_mull,1);dct[2+16] =
_MM_4W(x8_mull,2);dct[3+16] = _MM_4W(x8_mull,3);

 //dct[8]=(short)((x2+x5+16384)>>13);
 x2_mull=_m_paddsw(_m_paddsw(x2_mull,x5_mull),vect16384);
 x2_mull=_mm_srai_pi16(x2_mull,13);
 dct[0+8] = _MM_4W(x2_mull,0);dct[1+8] = _MM_4W(x2_mull,1);dct[2+8] =
_MM_4W(x2_mull,2);dct[3+8] = _MM_4W(x2_mull,3);

 //dct[24]=(short)(((x3>>8)*r2+8192)>>12);
 x3_mull =
_mm_srai_pi16(_m_paddsw(_m_pmullw(_mm_srai_pi16(x3_mull,8),r2),vect8192),12)
;
 dct[0+24] = _MM_4W(x3_mull,0);dct[1+24] = _MM_4W(x3_mull,1);dct[2+24] =
_MM_4W(x3_mull,2);dct[3+24] = _MM_4W(x3_mull,3);
}