您的位置:首页 > 其它

矩阵乘法的三个版本实现

2015-02-24 15:35 253 查看
version1

#include<stdio.h>
//Matrix multiplication O(n^3) 按照定义算
const int N=200;
void Multiply(int A[]
,int B[]
,int C[]
,int n)
{//C=A*B
int i,j,k;
for(i=0;i<n;i++)
for(j=0;j<n;j++){
C[i][j]=0;
for(k=0;k<n;k++) C[i][j]+=A[i][k]*B[k][j];
}
}
int main(void)
{
//freopen("1.txt","r",stdin);
//freopen("2.txt","w",stdout);
int A

,B

,C

,n;
scanf("%d",&n);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
scanf("%d",*(A+i)+j);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
scanf("%d",*(B+i)+j);
Multiply(A,B,C,n);//计算C=A*B
for(int i=0;i<n;i++){
for(int j=0;j<n;j++)
printf("%d ",C[i][j]);
putchar('\n');
}
return 0;
}


version2

#include<stdio.h>
//Matrix multiplication O(n^3) T(n)=8T(n/2)+O(n^2)
//虽然时间复杂度仍旧一样,但是其思想为strassen算法作铺垫
const int N=200;
void Add(int A[]
,int B[]
,int C[]
,int n)
{//C=A+B
int i,j;
for(i=0;i<n;i++)
for(j=0;j<n;j++)
C[i][j]=A[i][j]+B[i][j];
}

void Multiply(int A[]
,int B[]
,int C[]
,int n)
{//C=A*B
if(n==1) C[0][0]=A[0][0]*B[0][0];
else{
int halfn=n/2,i,j;
int A11[halfn]
,A12[halfn]
,A21[halfn]
,A22[halfn]
;//其实N应该是halfn,但考虑到函数参数传递,二维数组第二维要对应
int B11[halfn]
,B12[halfn]
,B21[halfn]
,B22[halfn]
;
//A、B分解成4个小矩阵
for(i=0;i<halfn;i++)
for(j=0;j<halfn;j++){
A11[i][j]=A[i][j];A12[i][j]=A[i][j+halfn];
A21[i][j]=A[i+halfn][j];A22[i][j]=A[i+halfn][j+halfn];
B11[i][j]=B[i][j];B12[i][j]=B[i][j+halfn];
B21[i][j]=B[i+halfn][j];B22[i][j]=B[i+halfn][j+halfn];
}

//递归计算8次矩阵乘法,同时做加法
int temp1[halfn]
,temp2[halfn]
;
Multiply(A11,B11,temp1,halfn);
Multiply(A12,B21,temp2,halfn);
int C11[halfn]
;
Add(temp1,temp2,C11,halfn);

Multiply(A11,B12,temp1,halfn);
Multiply(A12,B22,temp2,halfn);
int C12[halfn]
;
Add(temp1,temp2,C12,halfn);

Multiply(A21,B11,temp1,halfn);
Multiply(A22,B21,temp2,halfn);
int C21[halfn]
;
Add(temp1,temp2,C21,halfn);

Multiply(A21,B12,temp1,halfn);
Multiply(A22,B22,temp2,halfn);
int C22[halfn]
;
Add(temp1,temp2,C22,halfn);

//将C11,C12,C21,C22合并起来
for(i=0;i<halfn;i++)
for(j=0;j<halfn;j++){
C[i][j]=C11[i][j];C[i][j+halfn]=C12[i][j];
C[i+halfn][j]=C21[i][j];C[i+halfn][j+halfn]=C22[i][j];
}
}
}
int main(void)
{
//freopen("1.txt","r",stdin);
//freopen("2.txt","w",stdout);
int A

,B

,C

,n;
scanf("%d",&n);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
scanf("%d",*(A+i)+j);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
scanf("%d",*(B+i)+j);

Multiply(A,B,C,n);//计算C=A*B

for(int i=0;i<n;i++){
for(int j=0;j<n;j++)
printf("%d ",C[i][j]);
putchar('\n');
}
return 0;
}


version3  strassen算法。书上说,可以用下标运算来增加速度,还可以大量减少空间使用,不过我暂时没想到。
#include<stdio.h>
//Matrix multiplication O(n^lg7)
const int N=200;//n=128时,就stackoverflow了。所以都感觉不到n^lg7与n^3的区别了
void Add(int A[]
,int B[]
,int C[]
,int n)
{//C=A+B
int i,j;
for(i=0;i<n;i++)
for(j=0;j<n;j++)
C[i][j]=A[i][j]+B[i][j];
}
void Add1(int A[]
,int B[]
,int C[]
,int D[]
,int E[]
,int n)
{//E=A+B-C+D
int i,j;
for(i=0;i<n;i++)
for(j=0;j<n;j++)
E[i][j]=A[i][j]+B[i][j]-C[i][j]+D[i][j];
}
void Add2(int A[]
,int B[]
,int C[]
,int D[]
,int E[]
,int n)
{//E=A+B-C-D
int i,j;
for(i=0;i<n;i++)
for(j=0;j<n;j++)
E[i][j]=A[i][j]+B[i][j]-C[i][j]-D[i][j];
}
void Minus(int A[]
,int B[]
,int C[]
,int n)
{//C=A-B
int i,j;
for(i=0;i<n;i++)
for(j=0;j<n;j++)
C[i][j]=A[i][j]-B[i][j];
}

void Multiply(int A[]
,int B[]
,int C[]
,int n)
{//C=A*B
if(n==1) C[0][0]=A[0][0]*B[0][0];
else{
int halfn=n/2,i,j;
int A11[halfn]
,A12[halfn]
,A21[halfn]
,A22[halfn]
;//其实N应该是halfn,但考虑到函数参数传递,二维数组第二维要对应
int B11[halfn]
,B12[halfn]
,B21[halfn]
,B22[halfn]
;
//A、B分解成4个小矩阵
for(i=0;i<halfn;i++)
for(j=0;j<halfn;j++){
A11[i][j]=A[i][j];A12[i][j]=A[i][j+halfn];
A21[i][j]=A[i+halfn][j];A22[i][j]=A[i+halfn][j+halfn];
B11[i][j]=B[i][j];B12[i][j]=B[i][j+halfn];
B21[i][j]=B[i+halfn][j];B22[i][j]=B[i+halfn][j+halfn];
}

//10次加减法
int S1[halfn]
,S2[halfn]
,S3[halfn]
,S4[halfn]
,S5[halfn]
,S6[halfn]
,S7[halfn]
,S8[halfn]
,S9[halfn]
,S10[halfn]
;
Minus(B12,B22,S1,halfn);//S1=B12-B22
Add(A11,A12,S2,halfn);//S2=A11+A12
Add(A21,A22,S3,halfn);//S3=A21+A22
Minus(B21,B11,S4,halfn);//S4=B21-B11
Add(A11,A22,S5,halfn);//S5=A11+A22
Add(B11,B22,S6,halfn);//S6=B11+B22
Minus(A12,A22,S7,halfn);//S7=A12-A22
Add(B21,B22,S8,halfn);//S8=B21+B22
Minus(A11,A21,S9,halfn);//S9=A11-A21
Add(B11,B12,S10,halfn);//S10=B11+B12

//递归计算7次矩阵乘法
int P1[halfn]
,P2[halfn]
,P3[halfn]
,P4[halfn]
,P5[halfn]
,P6[halfn]
,P7[halfn]
;
Multiply(A11,S1,P1,halfn);//P1=A11*S1
Multiply(S2,B22,P2,halfn);//P2=S2*B22
Multiply(S3,B11,P3,halfn);//P3=S3*B11
Multiply(A22,S4,P4,halfn);//P4=A22*S4
Multiply(S5,S6,P5,halfn);//P5=S5*S6
Multiply(S7,S8,P6,halfn);//P6=S7*S8
Multiply(S9,S10,P7,halfn);//P7=S9*S10

int C11[halfn]
,C12[halfn]
,C21[halfn]
,C22[halfn]
;
Add1(P5,P4,P2,P6,C11,halfn);//C11=P5+P4-P2+P6
Add(P1,P2,C12,halfn);//C12=P1+P2
Add(P3,P4,C21,halfn);//C21=P3+P4
Add2(P5,P1,P3,P7,C22,halfn);//C22=P5+P1-P3-P7

//将C11,C12,C21,C22合并起来
for(i=0;i<halfn;i++)
for(j=0;j<halfn;j++){
C[i][j]=C11[i][j];C[i][j+halfn]=C12[i][j];
C[i+halfn][j]=C21[i][j];C[i+halfn][j+halfn]=C22[i][j];
}
}
}
int main(void)
{
//freopen("1.txt","r",stdin);
//freopen("2.txt","w",stdout);
int A

,B

,C

,n;
scanf("%d",&n);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
scanf("%d",*(A+i)+j);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
scanf("%d",*(B+i)+j);

Multiply(A,B,C,n);//计算C=A*B

for(int i=0;i<n;i++){
for(int j=0;j<n;j++)
printf("%d ",C[i][j]);
putchar('\n');
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: