您的位置:首页 > 编程语言 > C语言/C++

基于Strassen算法采用分治的矩阵乘法cpp实现

2017-10-26 20:35 453 查看
直接上代码。

注意:只支持维度为2的幂次的方阵相乘。

#include <cstdio>
#define maxn 50
struct matrix
{
int con [maxn][maxn];
int size = 0 ; //规定一定是n * n矩阵
} m1, m2;

matrix add(matrix A, matrix B, int len ) {
matrix res;

for (int i = 0; i < len; i++)
{
for (int j = 0; j < len; j++)
{
res.con[i][j] = A.con[i][j] + B.con[i][j];
}
}
return res;
}
matrix sub(matrix A, matrix B, int len ) {
matrix res;

for (int i = 0; i < len; i++)
{
for (int j = 0; j < len; j++)
{
res.con[i][j] = A.con[i][j] - B.con[i][j];
// printf("%d\n",res.con[i][j] );
}
}
return res;
}
void print_it(matrix a, int n) {
for (int i = 0; i < n ; i++)
{
for (int j = 0; j < n ; j++)
{
printf("%d ", a.con[i][j]);
}
printf("\n");
}
}

matrix create(matrix input, int r1, int r2, int c1, int c2) {
int ii = 0, jj = 0;
matrix res;
for (int i = r1; i <= r2 && ii < r2 - r1; i++)
{
for (int j = c1; j < c2 && jj < c2 - c1; j++)
{
res.con[ii][jj] = input.con[i][j];

jj++;
}
jj = 0;
ii++;
}
return res;
}
matrix multi(matrix A, matrix B, int r1, int c1, int len) {
// 0 0
if (len == 1)
{
matrix ender ;
ender. con[0][0] = A.con[0][0] * B.con[0][0];
return ender;
} else {
matrix a, b, c, d, e, f, g, h;
int ii = 0, jj = 0;
a = create(A, r1, r1 + len / 2, c1, c1 + len / 2);
e = create(B, r1, r1 + len / 2, c1, c1 + len / 2);
b = create(A, r1, r1 + len / 2, c1 + len / 2, len);
f = create(B, r1, r1 + len / 2, c1 + len / 2, len);
c = create(A, r1 + len / 2, len, c1, c1 + len / 2);
g = create(B, r1 + len / 2, len, c1, c1 + len / 2);
d = create(A, r1 + len / 2 , len , c1 + len / 2, len);
h = create(B, r1 + len / 2 , len , c1 + len / 2 , len);
matrix p1, p2, p3, p4, p5, p6, p7;
p1 = multi(a, sub(f, h, len / 2), 0, 0, len / 2);
p2 = multi(add(a, b, len / 2), h, 0, 0, len / 2);
p3 = multi(add(c, d, len / 2), e, 0, 0, len / 2);
p4 = multi(d, sub(g, e, len / 2), 0, 0, len / 2);
p5 = multi(add(a, d, len / 2), add(e, h, len / 2), 0, 0, len / 2);
p6 = multi(sub(b, d, len / 2), add(g, h, len / 2), 0, 0, len / 2);
p7 = multi(sub(a, c, len / 2), add(e, f, len / 2), 0, 0, len / 2);
matrix r , s, t, u;
r = sub(add(add(p5, p4, len / 2), p6, len / 2), p2, len / 2);
s = add(p1, p2, len / 2);
t = add(p3, p4, len / 2);
u = sub(add(p5, p1, len / 2), add(p3, p7, len / 2), len / 2);
matrix rr;
// printf("--\n");
// print_it(r, len / 2);
// printf("--\n");
// print_it(s, len / 2);
// printf("--\n");
// print_it(t, len / 2);
// printf("--\n");
// print_it(u, len / 2);
for (int j = 0 ; j < len / 2; j++) {
for (int jj = 0 ; jj < len / 2; jj++) {
rr.con[j][jj] = r.con[j][jj];
}
}
for (int j = 0 ; j < len / 2; j++) {
for (int jj = 0 ; jj < len / 2; jj++) {
rr.con[j][jj + len / 2] = s.con[j][jj];
}
}
for (int j = 0 ; j < len / 2; j++) {
for (int jj = 0 ; jj < len / 2; jj++) {
rr.con[j + len / 2][jj] = t.con[j][jj];
}
}
for (int j = 0 ; j < len / 2; j++) {
for (int jj = 0 ; jj < len / 2; jj++) {
rr.con[j + len / 2][jj + len / 2] = u.con[j][jj];
}
}
return rr;
}

}

int main(int argc, char const *argv[])
{
int n ;
printf("输入矩阵的维数:\n");
scanf("%d", &n);
printf("第一个矩阵:\n");
for (int i = 0; i < n ; i++)
{
for (int j = 0; j < n ; j++)
{
scanf("%d", &m1.con[i][j]);
}
}
printf("第二个矩阵:\n");
for (int i = 0; i < n ; i++)
{
for (int j = 0; j < n ; j++)
{
scanf("%d", &m2.con[i][j]);
}
}
printf("计算结果:\n");
print_it(multi(m1, m2, 0, 0, n), n);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: