您的位置:首页 > 其它

算法导论 矩阵相乘(Strassen方法)

2017-03-22 20:17 375 查看
#include <stdio.h>
#include <stdlib.h>
#define LEN 4

typedef struct
{
int** S;
int n;
}Square;

void printSquare(Square A)
{
printf("length is %d\n",A.n);
for(int i=0;i<A.n;i++)
{
for(int j=0;j<A.n;j++)
{
printf("%4d",A.S[i][j]);
}
printf("\n");
}
}

void freeSquare(Square A)
{
for(int i=0;i<A.n;i++)
{
free(*(A.S+i));
}
free(A.S);
}

Square makeSquare(int n)
{
int** S=(int **)malloc(sizeof(int *)*n);
for(int i=0;i<n;i++)
{
*(S+i)=(int*)malloc(sizeof(int)*n);
}
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
S[i][j]=0;
}
}
Square sq={S,n};
return sq;
}

Square divideSquare(Square A,int rowstart,int rowend,int colstart,int colend)
{
int n=rowend-rowstart+1;
Square sq=makeSquare(n);
for(int i=rowstart,ii=0;i<=rowend;i++,ii++)
{
for(int j=colstart,jj=0;j<=colend;j++,jj++)
{
sq.S[ii][jj]=A.S[i][j];
}
}
return sq;
}

Square squareCalc(Square A,Square B,char op)
{
int n=A.n;
Square sq=makeSquare(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
if(op=='+')
sq.S[i][j]=A.S[i][j]+B.S[i][j];
else if(op=='-')
sq.S[i][j]=A.S[i][j]-B.S[i][j];
}
}
return sq;
}

void StrassenMul(Square C,Square A,Square B)
{
if(A.n==1)
{
C.S[0][0]=A.S[0][0]*B.S[0][0];
return;
}

Square A11=divideSquare(A,0,A.n/2-1,0,A.n/2-1);
Square A12=divideSquare(A,0,A.n/2-1,A.n/2,A.n-1);
Square A21=divideSquare(A,A.n/2,A.n-1,0,A.n/2-1);
Square A22=divideSquare(A,A.n/2,A.n-1,A.n/2,A.n-1);
Square B11=divideSquare(B,0,B.n/2-1,0,B.n/2-1);
Square B12=divideSquare(B,0,B.n/2-1,B.n/2,B.n-1);
Square B21=divideSquare(B,B.n/2,B.n-1,0,B.n/2-1);
Square B22=divideSquare(B,B.n/2,B.n-1,B.n/2,B.n-1);

Square S1=squareCalc(B12,B22,'-');
Square S2=squareCalc(A11,A12,'+');
Square S3=squareCalc(A21,A22,'+');
Square S4=squareCalc(B21,B11,'-');
Square S5=squareCalc(A11,A22,'+');
Square S6=squareCalc(B11,B22,'+');
Square S7=squareCalc(A12,A22,'-');
Square S8=squareCalc(B21,B22,'+');
Square S9=squareCalc(A11,A21,'-');
Square S10=squareCalc(B11,B12,'+');

int n=A11.n;
Square P1=makeSquare(n);
Square P2=makeSquare(n);
Square P3=makeSquare(n);
Square P4=makeSquare(n);
Square P5=makeSquare(n);
Square P6=makeSquare(n);
Square P7=makeSquare(n);

StrassenMul(P1,A11,S1);
StrassenMul(P2,S2,B22);
StrassenMul(P3,S3,B11);
StrassenMul(P4,A22,S4);
StrassenMul(P5,S5,S6);
StrassenMul(P6,S7,S8);
StrassenMul(P7,S9,S10);

Square C111=squareCalc(P5,P4,'+');
Square C112=squareCalc(C111,P2,'-');
Square C11=squareCalc(C112,P6,'+');
Square C12=squareCalc(P1,P2,'+');
Square C21=squareCalc(P3,P4,'+');
Square C221=squareCalc(P5,P1,'+');
Square C222=squareCalc(C221,P3,'-');
Square C22=squareCalc(C222,P7,'-');

int i,j,ii,jj;
for(i=0,ii=0;i<=C.n/2-1;i++,ii++)
{
for(j=0,jj=0;j<=C.n/2-1;j++,jj++)
{
C.S[i][j]=C11.S[ii][jj];
}
}
for(i=0,ii=0;i<=C.n/2-1;i++,ii++)
{
for(j=C.n/2,jj=0;j<=C.n-1;j++,jj++)
{
C.S[i][j]=C12.S[ii][jj];
}
}
for(i=C.n/2,ii=0;i<=C.n-1;i++,ii++)
{
for(j=0,jj=0;j<=C.n/2-1;j++,jj++)
{
C.S[i][j]=C21.S[ii][jj];
}
}
for(i=C.n/2,ii=0;i<=C.n-1;i++,ii++)
{
for(j=C.n/2,jj=0;j<=C.n-1;j++,jj++)
{
C.S[i][j]=C22.S[ii][jj];
}
}
freeSquare(A11);
freeSquare(A12);
freeSquare(A21);
freeSquare(A22);
freeSquare(B11);
freeSquare(B12);
freeSquare(B21);
freeSquare(B22);

freeSquare(S1);
freeSquare(S2);
freeSquare(S3);
freeSquare(S4);
freeSquare(S5);
freeSquare(S6);
freeSquare(S7);
freeSquare(S8);
freeSquare(S9);
freeSquare(S10);

freeSquare(P1);
freeSquare(P2);
freeSquare(P3);
freeSquare(P4);
freeSquare(P5);
freeSquare(P6);
freeSquare(P7);

freeSquare(C111);
freeSquare(C112);
freeSquare(C11);
freeSquare(C12);
freeSquare(C21);
freeSquare(C221);
freeSquare(C222);
freeSquare(C22);
}

int main()
{
int AS[LEN][LEN]={{1,2,3,4},{3,4,5,6},{5,6,7,8},{7,8,9,10}};
int BS[LEN][LEN]={{5,6,7,8},{7,8,9,10},{11,12,13,14},{15,16,17,18}};
Square A=makeSquare(LEN);
for(int i=0;i<LEN;i++)
{
for(int j=0;j<LEN;j++)
{
A.S[i][j]=AS[i][j];
}
}
Square B=makeSquare(LEN);
for(int i=0;i<LEN;i++)
{
for(int j=0;j<LEN;j++)
{
B.S[i][j]=BS[i][j];
}
}
Square C=makeSquare(LEN);
StrassenMul(C,A,B);
printSquare(C);
getchar();
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: