您的位置:首页 > 其它

矩阵相乘nxn block的计算过程

2013-11-25 19:15 477 查看
/* Create macros so that the matrices are stored in column-major order */

#define A(i,j) a[ (j)*lda + (i) ]
#define B(i,j) b[ (j)*ldb + (i) ]
#define C(i,j) c[ (j)*ldc + (i) ]

/* Routine for computing C = A * B + C */

void AddDot4x4( int, float *, int, float *, int, float *, int );

void MY_MMult( int m, int n, int k, float *a, int lda,
float *b, int ldb,
float *c, int ldc )
{
int i, j;

for ( j=0; j<n; j+=4 ){        /* Loop over the columns of C, unrolled by 4 */
for ( i=0; i<m; i+=4 ){        /* Loop over the rows of C */
/* Update C( i,j ), C( i,j+1 ), C( i,j+2 ), and C( i,j+3 ) in
one routine (four inner products) */

AddDot4x4( k, &A( i,0 ), lda, &B( 0,j ), ldb, &C( i,j ), ldc );
}
}
}

void AddDot4x4( int k, float *a, int lda,  float *b, int ldb, float *c, int ldc )
{
/* So, this routine computes a 4x4 block of matrix A

C( 0, 0 ), C( 0, 1 ), C( 0, 2 ), C( 0, 3 ).
C( 1, 0 ), C( 1, 1 ), C( 1, 2 ), C( 1, 3 ).
C( 2, 0 ), C( 2, 1 ), C( 2, 2 ), C( 2, 3 ).
C( 3, 0 ), C( 3, 1 ), C( 3, 2 ), C( 3, 3 ).

Notice that this routine is called with c = C( i, j ) in the
previous routine, so these are actually the elements

C( i  , j ), C( i  , j+1 ), C( i  , j+2 ), C( i  , j+3 )
C( i+1, j ), C( i+1, j+1 ), C( i+1, j+2 ), C( i+1, j+3 )
C( i+2, j ), C( i+2, j+1 ), C( i+2, j+2 ), C( i+2, j+3 )
C( i+3, j ), C( i+3, j+1 ), C( i+3, j+2 ), C( i+3, j+3 )

in the original matrix C

A simple rearrangement to prepare for the use of vector registers */

int p;
register float
/* hold contributions to
C( 0, 0 ), C( 0, 1 ), C( 0, 2 ), C( 0, 3 )
C( 1, 0 ), C( 1, 1 ), C( 1, 2 ), C( 1, 3 )
C( 2, 0 ), C( 2, 1 ), C( 2, 2 ), C( 2, 3 )
C( 3, 0 ), C( 3, 1 ), C( 3, 2 ), C( 3, 3 )   */
c_00_reg,   c_01_reg,   c_02_reg,   c_03_reg,
c_10_reg,   c_11_reg,   c_12_reg,   c_13_reg,
c_20_reg,   c_21_reg,   c_22_reg,   c_23_reg,
c_30_reg,   c_31_reg,   c_32_reg,   c_33_reg,
/* hold
A( 0, p )
A( 1, p )
A( 2, p )
A( 3, p ) */
a_0p_reg,
a_1p_reg,
a_2p_reg,
a_3p_reg,
b_p0_reg,
b_p1_reg,
b_p2_reg,
b_p3_reg;

float
/* Point to the current elements in the four columns of B */
*b_p0_pntr, *b_p1_pntr, *b_p2_pntr, *b_p3_pntr;

b_p0_pntr = &B( 0, 0 );
b_p1_pntr = &B( 0, 1 );
b_p2_pntr = &B( 0, 2 );
b_p3_pntr = &B( 0, 3 );

c_00_reg = 0.0;   c_01_reg = 0.0;   c_02_reg = 0.0;   c_03_reg = 0.0;
c_10_reg = 0.0;   c_11_reg = 0.0;   c_12_reg = 0.0;   c_13_reg = 0.0;
c_20_reg = 0.0;   c_21_reg = 0.0;   c_22_reg = 0.0;   c_23_reg = 0.0;
c_30_reg = 0.0;   c_31_reg = 0.0;   c_32_reg = 0.0;   c_33_reg = 0.0;

/*
TODO:
process of calculation

c00 = a00xb00
c10 = a10xb00
c01 = a00xb01
c11 = a10xb01

c00 = a00xb00 + a01xb10
c10 = a10xb00 + a11xb10
c01 = a00xb01 + a01xb11
c11 = a10xb01 + a11xb11

c00 = a00xb00 + a01xb10 + a02xb20
c10 = a10xb00 + a11xb10 + a12xb20
c01 = a00xb01 + a01xb11 + a02xb21
c11 = a10xb01 + a11xb11 + a12xb21

c00 = a00xb00 + a01xb10 + a02xb20 + a03xb30
c10 = a10xb00 + a11xb10 + a12xb20 + a13xb30
c01 = a00xb01 + a01xb11 + a02xb21 + a03xb31
c11 = a10xb01 + a11xb11 + a12xb21 + a13xb31

...
...
...
k = k -1

c00 = a00xb00 + a01xb10 + a02xb20 + a03xb30 + ... + a0k * bk0
c10 = a10xb00 + a11xb10 + a12xb20 + a13xb30 + ... + a1k * bk0
c01 = a00xb01 + a01xb11 + a02xb21 + a03xb31 + ... + a0k * bk1
c11 = a10xb01 + a11xb11 + a12xb21 + a13xb31 + ... + a1k * bk1

*/
for ( p=0; p<k; p++ ){
a_0p_reg = A( 0, p );
a_1p_reg = A( 1, p );
a_2p_reg = A( 2, p );
a_3p_reg = A( 3, p );

b_p0_reg = *b_p0_pntr++;
b_p1_reg = *b_p1_pntr++;
b_p2_reg = *b_p2_pntr++;
b_p3_reg = *b_p3_pntr++;

/* First row and second rows */
c_00_reg += a_0p_reg * b_p0_reg;
c_10_reg += a_1p_reg * b_p0_reg;

c_01_reg += a_0p_reg * b_p1_reg;
c_11_reg += a_1p_reg * b_p1_reg;

c_02_reg += a_0p_reg * b_p2_reg;
c_12_reg += a_1p_reg * b_p2_reg;

c_03_reg += a_0p_reg * b_p3_reg;
c_13_reg += a_1p_reg * b_p3_reg;

/* Third and fourth rows */
c_20_reg += a_2p_reg * b_p0_reg;
c_30_reg += a_3p_reg * b_p0_reg;

c_21_reg += a_2p_reg * b_p1_reg;
c_31_reg += a_3p_reg * b_p1_reg;

c_22_reg += a_2p_reg * b_p2_reg;
c_32_reg += a_3p_reg * b_p2_reg;

c_23_reg += a_2p_reg * b_p3_reg;
c_33_reg += a_3p_reg * b_p3_reg;
}

C( 0, 0 ) += c_00_reg;   C( 0, 1 ) += c_01_reg;   C( 0, 2 ) += c_02_reg;   C( 0, 3 ) += c_03_reg;
C( 1, 0 ) += c_10_reg;   C( 1, 1 ) += c_11_reg;   C( 1, 2 ) += c_12_reg;   C( 1, 3 ) += c_13_reg;
C( 2, 0 ) += c_20_reg;   C( 2, 1 ) += c_21_reg;   C( 2, 2 ) += c_22_reg;   C( 2, 3 ) += c_23_reg;
C( 3, 0 ) += c_30_reg;   C( 3, 1 ) += c_31_reg;   C( 3, 2 ) += c_32_reg;   C( 3, 3 ) += c_33_reg;
}


注意其中的计算过程,迭代过程。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: