您的位置:首页 > 其它

Strassen’s 矩阵乘法—分治法实现

2017-12-06 23:40 435 查看
内容会持续更新,有错误的地方欢迎指正,谢谢!

前言:博主最近正在学习《算法》这门专业课程,这是该课程的第二次上机题目,我把自己的解题方法分享给大家,欢迎讨论!

题目:

1.比较数学定义的矩阵乘法算法和Strassen’s 矩阵乘法算法的效率;

2.自主生成两个16*16的矩阵,输出Strassen’s 矩阵乘法算法结果。

数学定义的矩阵乘法算法:利用三个for循环来解决,时间复杂度为O(n^3)。

数学定义的矩阵乘法算法的核心代码如下:

//公理:两个矩阵相乘A*B,A的列数必等于B的行数。
int a[2][3] = {1, 1, 1, 1, 1, 1};
int b[3][1] = {1, 1, 1};
for (int i = 0; i < 2; ++i)
{
for (int j = 0; j < 1; ++j)
{
c[i][j] = 0;
for (int k = 0; k < 3; ++k)
c[i][j] += a[i][k] * b[k][j];
}
}


一般算法需要八次乘法:



试试Strassen’s 矩阵乘法算法:



我们可以推出:



上面只有7次乘法和多次加减法,Strassen’s 矩阵乘法算法将其变成7次乘法。大家都知道乘法比加减法消耗更多的性能!所以,该算法能将时间复杂度降低到O( n^lg7 ) = O( n^2.81 )。

代码实现如下:(其中N必须为2的幂,这里N=16)

#include <iostream>
using namespace std;
#define N 16

//矩阵相加
void Plus(int a[N / 2][N / 2], int b[N / 2][N / 2], int c[N / 2][N / 2])
{
int i, j;
for (i = 0; i < N / 2; i++)
{
for (j = 0; j < N / 2; j++)
{
a[i][j] = b[i][j] + c[i][j];
}
}
}

//矩阵相减
void Minus(int a[N / 2][N / 2], int b[N / 2][N / 2], int c[N / 2][N / 2])
{
int i, j;
for (i = 0; i < N / 2; i++)
{
for (j = 0; j < N / 2; j++)
{
a[i][j] = b[i][j] - c[i][j];
}
}
}

//矩阵相乘
void Multiply(int a[N / 2][N / 2], int b[N / 2][N / 2], int c[N / 2][N / 2])
{
int i, j, k;
for (i = 0; i < N / 2; i++)
{
for (j = 0; j < N / 2; j++)
{
a[i][j] = 0;
for (k = 0; k < N / 2; k++)
{
a[i][j] += b[i][k] * c[k][j];
}
}
}
}

int main()
{
int i, j, k;
int m1

;
int m2

;
for (i = 0; i < N; ++i)//初始化要相乘的这两个16*16的矩阵
{
for (j = 0; j < N; ++j)
{
m1[i][j] = 1;
m2[i][j] = 1;
}
}

int I[N / 2][N / 2], J[N / 2][N / 2], K[N / 2][N / 2], L[N / 2][N / 2];
int A[N / 2][N / 2], B[N / 2][N / 2], C[N / 2][N / 2], D[N / 2][N / 2];
int E[N / 2][N / 2], F[N / 2][N / 2], G[N / 2][N / 2], H[N / 2][N / 2];
int S1[N / 2][N / 2], S2[N / 2][N / 2], S3[N / 2][N / 2], S4[N / 2][N / 2];
int S5[N / 2][N / 2], S6[N / 2][N / 2], S7[N / 2][N / 2];
int t1[N / 2][N / 2], t2[N / 2][N / 2];

//将原矩阵m1、m2拆分为A B C D E F G H矩阵
for (i = 0; i < N / 2; i++)
{
for (j = 0; j < N / 2; j++)
{
A[i][j] = m1[i][j];
B[i][j] = m1[i][j + N / 2];
C[i][j] = m1[i + N / 2][j];
D[i][j] = m1[i + N / 2][j + N / 2];

E[i][j] = m2[i][j];
F[i][j] = m2[i][j + N / 2];
G[i][j] = m2[i + N / 2][j];
H[i][j] = m2[i + N / 2][j + N / 2];
}
}

//S1
Minus(I, F, H);
Multiply(S1, A, I);

//S2
Plus(I, A, B);
Multiply(S2, I, H);

//S3
Plus(I, C, D);
Multiply(S3, I, E);

//S4
Minus(I, G, E);
Multiply(S4, D, I);

//S5
Plus(I, A, D);
Plus(J, E, F);
Multiply(S5, I, J);

//S6
Minus(I, B, D);
Plus(J, G, H);
Multiply(S6, I, J);

//S7
Minus(I, A, C);
Plus(J, E, F);
Multiply(S7, I, J);

//计算I J K L矩阵
//I = S5 + S4 - S2 + S6
Plus(t1, S5, S4);
Minus(t2, t1, S2);
Plus(I, t2, S6);

//J = S1 + S2
Plus(J, S1, S2);

//K = S3 + S4
Plus(K, S3, S4);

//L = S5 + S1 - S3 - S7 = S5 + S1 - ( S3 + S7 )
Plus(t1, S5, S1);
Plus(t2, S3, S7);
Minus(L, t1, t2);

//将得到的I J K L矩阵合并到最终结果result矩阵中
int result[N][N] = { 0 };
for (int i = 0; i < N / 2; i++)
{
for (int j = 0; j < N / 2; j++)
{
result[i][j] = I[i][j];
result[i][j + N / 2] = J[i][j];
result[i + N / 2][j] = K[i][j];
result[i + N / 2][j + N / 2] = L[i][j];
}
}

//输出最终的矩阵
for (i = 0; i < N; ++i)
{
k = 0;
for (j = 0; j < N; ++j)
{
cout << result[i][j] << "  ";
++k;
if (k == N)
cout << endl;
}
}
cout << endl;

getchar();
return 0;
}


备注:由于博主时间问题,本代码并未实现递归,也就是并未利用分治法拆分到最小单元再计算再合并,只是阐述了分治法解决该问题的思路,若要实现完整版代码,我指明方法:

新建一个递归函数,需将main()里的部分代码移到递归函数里,并需修改递归函数里的所有二维数组的定义,例如:

int MatrixA[N / 2][N / 2];
int MatrixB[N / 2][N / 2];
int MatrixC[N / 2][N / 2];


应该被修改为如下形式:

//n为递归函数传入的参数
int** MatrixA = new int*
;
int** MatrixB = new int*
;
int** MatrixC = new int*
;
for (int i = 0; i < n; i++)
{
MatrixA[i] = new int
;
MatrixB[i] = new int
;
MatrixC[i] = new int
;
}


用完new的二维数组之后还要记得释放内存,不然,在递归中,很容易产生内存泄漏:

for (int i = 0; i < n; i++)
{
delete[] A[i];
delete[] B[i];
}
delete[] A;
delete[] B;


递归函数的参数有n,MatrixA,MatrixB,MatrixC

n用于传递矩阵维数。

MatrixA矩阵就是上方代码的m1矩阵。该题是求m1乘以m2矩阵,你就知道m1是什么了。

MatrixB矩阵就是上方代码的m2矩阵。该题是求m1乘以m2矩阵,你就知道m2是什么了。

MatrixC矩阵用于记录结果,最后输出MatrixC即是最终结果。

分治法实现的完整代码,能输出最终结果和每一次递归的S1~S7:

http://download.csdn.net/download/billcyj/10157466
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息