矩阵乘法 之 strassen 算法
2015-04-11 19:41
369 查看
一般情况下矩阵乘法需要三个for循环,时间复杂度为O(n^3),现在我们将矩阵分块
一般算法需要八次乘法
r = a * e + b * g ;
s = a * f + b * h ;
t = c * e + d * g;
u = c * f + d * h;
strassen将其变成7次乘法,因为大家都知道乘法比加减法消耗更多,所有时间复杂更高!
strassen的处理是:
令:
p1 = a * ( f - h )
p2 = ( a + b ) * h
p3 = ( c +d ) * e
p4 = d * ( g - e )
p5 = ( a + d ) * ( e + h )
p6 = ( b - d ) * ( g + h )
p7 = ( a - c ) * ( e + f )
那么我们可以知道:
r = p5 + p4 + p6 - p2
s = p1 + p2
t = p3 + p4
u = p5 + p1 - p3 - p7
我们可以看到上面只有7次乘法和多次加减法,最终达到降低复杂度为O( n^lg7 ) ~= O( n^2.81 );
代码实现如下:
现在最好的计算矩阵乘法的复杂度是O( n^2.376 ),不过只是理论上的结果。此处仅仅做参考~
一般算法需要八次乘法
r = a * e + b * g ;
s = a * f + b * h ;
t = c * e + d * g;
u = c * f + d * h;
strassen将其变成7次乘法,因为大家都知道乘法比加减法消耗更多,所有时间复杂更高!
strassen的处理是:
令:
p1 = a * ( f - h )
p2 = ( a + b ) * h
p3 = ( c +d ) * e
p4 = d * ( g - e )
p5 = ( a + d ) * ( e + h )
p6 = ( b - d ) * ( g + h )
p7 = ( a - c ) * ( e + f )
那么我们可以知道:
r = p5 + p4 + p6 - p2
s = p1 + p2
t = p3 + p4
u = p5 + p1 - p3 - p7
我们可以看到上面只有7次乘法和多次加减法,最终达到降低复杂度为O( n^lg7 ) ~= O( n^2.81 );
代码实现如下:
// strassen 算法:将矩阵相乘的复杂度降到O(n^lg7) ~= O(n^2.81) // 原理是将8次乘法减少到7次的处理 // 现在理论上的最好的算法是O(n^2,367),仅仅是理论上的而已 // // // 下面的代码仅仅是简单的实例而已,不必较真哦,呵呵~ // 下面的空间可以优化的,此处就不麻烦了~ #include <stdio.h> #define N 10 //matrix + matrix void plus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] ) { int i, j; for( i = 0; i < N / 2; i++ ) { for( j = 0; j < N / 2; j++ ) { t[i][j] = r[i][j] + s[i][j]; } } } //matrix - matrix void minus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] ) { int i, j; for( i = 0; i < N / 2; i++ ) { for( j = 0; j < N / 2; j++ ) { t[i][j] = r[i][j] - s[i][j]; } } } //matrix * matrix void mul( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] ) { int i, j, k; for( i = 0; i < N / 2; i++ ) { for( j = 0; j < N / 2; j++ ) { t[i][j] = 0; for( k = 0; k < N / 2; k++ ) { t[i][j] += r[i][k] * s[k][j]; } } } } int main() { int i, j, k; int mat ; int m1 ; int m2 ; int a[N/2][N/2],b[N/2][N/2],c[N/2][N/2],d[N/2][N/2]; int e[N/2][N/2],f[N/2][N/2],g[N/2][N/2],h[N/2][N/2]; int p1[N/2][N/2],p2[N/2][N/2],p3[N/2][N/2],p4[N/2][N/2]; int p5[N/2][N/2],p6[N/2][N/2],p7[N/2][N/2]; int r[N/2][N/2], s[N/2][N/2], t[N/2][N/2], u[N/2][N/2], t1[N/2][N/2], t2[N/2][N/2]; printf("\nInput the first matrix...:\n"); for( i = 0; i < N; i++ ) { for( j = 0; j < N; j++ ) { scanf("%d", &m1[i][j]); } } printf("\nInput the second matrix...:\n"); for( i = 0; i < N; i++ ) { for( j = 0; j < N; j++ ) { scanf("%d", &m2[i][j]); } } // a b c d e f g h for( i = 0; i < N / 2; i++ ) { for( j = 0; j < N / 2; j++ ) { a[i][j] = m1[i][j]; b[i][j] = m1[i][j + N / 2]; c[i][j] = m1[i + N / 2][j]; d[i][j] = m1[i + N / 2][j + N / 2]; e[i][j] = m2[i][j]; f[i][j] = m2[i][j + N / 2]; g[i][j] = m2[i + N / 2][j]; h[i][j] = m2[i + N / 2][j + N / 2]; } } //p1 minus( r, f, h ); mul( p1, a, r ); //p2 plus( r, a, b ); mul( p2, r, h ); //p3 plus( r, c, d ); mul( p3, r, e ); //p4 minus( r, g, e ); mul( p4, d, r ); //p5 plus( r, a, d ); plus( s, e, f ); mul( p5, r, s ); //p6 minus( r, b, d ); plus( s, g, h ); mul( p6, r, s ); //p7 minus( r, a, c ); plus( s, e, f ); mul( p7, r, s ); //r = p5 + p4 - p2 + p6 plus( t1, p5, p4 ); minus( t2, t1, p2 ); plus( r, t2, p6 ); //s = p1 + p2 plus( s, p1, p2 ); //t = p3 + p4 plus( t, p3, p4 ); //u = p5 + p1 - p3 - p7 = p5 + p1 - ( p3 + p7 ) plus( t1, p5, p1 ); plus( t2, p3, p7 ); minus( u, t1, t2 ); for( i = 0; i < N / 2; i++ ) { for( j = 0; j < N / 2; j++ ) { mat[i][j] = r[i][j]; mat[i][j + N / 2] = s[i][j]; mat[i + N / 2][j] = t[i][j]; mat[i + N / 2][j + N / 2] = u[i][j]; } } printf("\n下面是strassen算法处理结果:\n"); for( i = 0; i < N; i++ ) { for( j = 0; j < N; j++ ) { printf("%d ", mat[i][j]); } printf("\n"); } //下面是朴素算法处理 printf("\n下面是朴素算法处理结果:\n"); for( i = 0; i < N; i++ ) { for( j = 0; j < N; j++ ) { mat[i][j] = 0; for( k = 0; k < N; k++ ) { mat[i][j] += m1[i][j] * m2[i][j]; } } } for( i = 0; i < N; i++ ) { for( j = 0; j < N; j++ ) { printf("%d ", mat[i][j]); } printf("\n"); } return 0; }
现在最好的计算矩阵乘法的复杂度是O( n^2.376 ),不过只是理论上的结果。此处仅仅做参考~
相关文章推荐
- 纯C语言矩阵乘法的Strassen算法,包含非2次幂的情况
- 五大常用算法(一) 分治算法(3) Strassen矩阵乘法
- 快速矩阵乘法:Strassen 演算法
- Strassen矩阵乘法 + 快速计算乘方的算法 + 矩阵的次幂
- Strassen算法之矩阵乘法
- Strassen矩阵乘法 + 快速计算乘方的算法 + 矩阵的次幂
- 算法提高 矩阵乘法 蓝桥杯
- 基础数论算法(八) 矩阵乘法与线性齐次递推公式的快速求值
- Strassen矩阵乘法
- 算法训练 矩阵乘法
- 算法训练 矩阵乘法
- 算法训练 矩阵乘法
- 稀疏矩阵用三元数组表示后的矩阵乘法算法 集合!
- 【算法】_013_矩阵乘法
- 蓝桥杯算法训练 矩阵乘法
- 蓝桥杯 算法提高 矩阵乘法 【经典区间dp】
- 算法提高 矩阵乘法
- 算法训练 矩阵乘法
- 矩阵乘法算法
- 矩阵的乘法算法