您的位置:首页 > 编程语言 > Go语言

Strassen's algorithm to compute matrix multiplication

2016-05-25 03:54 597 查看
//
//  main.cpp
//  Strassen
//
//  Created by Longxiang Lyu on 5/24/16.
//  Copyright (c) 2016 Longxiang Lyu. All rights reserved.
//

#include <iostream>
#include <vector>
#include <string>
#include <stdexcept>
#include <math.h>

using namespace std;

void printMatrix(const vector<vector<int>> &matrix)
{
for (auto row : matrix)
{
for (auto elem : row)
cout << elem << " ";
cout << endl;
}
}

void zeroPadding(vector<vector<int>> &matrix)
{
size_t sz = pow(2, (int)(sqrt(max(matrix.size(), matrix[0].size())) + 1));
matrix.resize(sz);
for (size_t i = 0; i < sz; ++i)
{
if (!matrix[i].empty())
matrix[i].resize(sz);
else
matrix[i] = vector<int>(sz, 0);
}
}

void sum(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
size_t sz = A.size();

for (int i = 0; i < sz; ++i)
for (int j = 0; j < sz; ++j)
ret[i][j] = (A[i][j] + B[i][j]);
}

void subtract(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
size_t sz = A.size();
// ret.clear();
// ret.resize(sz);
for (int i = 0; i < sz; ++i)
for (int j = 0; j < sz; ++j)
ret[i][j] = (A[i][j] - B[i][j]);
}

void strassenHelper(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
if (A.size() == 1)
{
ret[0][0] = A[0][0] * B[0][0];
return;
}
size_t sz = A.size();
size_t new_sz = sz / 2;
ret = vector<vector<int>>(sz, vector<int>(sz));

vector<vector<int>> a11(new_sz), a12(new_sz), a21(new_sz), a22(new_sz), b11(new_sz), b12(new_sz), b21(new_sz), b22(new_sz);
for (int i = 0; i < new_sz; ++i)
{
for (int j = 0; j < new_sz; ++j)
{
a11[i].push_back(A[i][j]);
a12[i].push_back(A[i][j + new_sz]);
a21[i].push_back(A[i + new_sz][j]);
a22[i].push_back(A[i + new_sz][j + new_sz]);

b11[i].push_back(B[i][j]);
b12[i].push_back(B[i][j + new_sz]);
b21[i].push_back(B[i + new_sz][j]);
b22[i].push_back(B[i + new_sz][j + new_sz]);
}
}
vector<vector<int>> result1(new_sz, vector<int>(new_sz, 0)), result2(new_sz, vector<int>(new_sz, 0));

// p1
vector<vector<int>> p1(new_sz, vector<int>(new_sz, 0));
sum(a11, a22, result1);
sum(b11, b22, result2);
strassenHelper(result1, result2, p1);

// p2
vector<vector<int>> p2(new_sz, vector<int>(new_sz, 0));
sum(a21, a22, result1);
strassenHelper(result1, b11, p2);

// p3
vector<vector<int>> p3(new_sz, vector<int>(new_sz, 0));
subtract(b12, b22, result2);
strassenHelper(a11, result2, p3);

// p4
vector<vector<int>> p4(new_sz, vector<int>(new_sz, 0));
subtract(b21, b11, result2);
strassenHelper(a22, result2, p4);

// p5
vector<vector<int>> p5(new_sz, vector<int>(new_sz, 0));
sum(a11, a12, result1);
strassenHelper(result1, b22, p5);

// p6
vector<vector<int>> p6(new_sz, vector<int>(new_sz, 0));
subtract(a21, a11, result1);
sum(b11, b12, result2);
strassenHelper(result1, result2, p6);

// p7
vector<vector<int>> p7(new_sz, vector<int>(new_sz, 0));
subtract(a12, a22, result1);
sum(b21, b22, result2);
strassenHelper(result1, result2, p7);

vector<vector<int>> c11(new_sz, vector<int>(new_sz, 0));
vector<vector<int>> c12(new_sz, vector<int>(new_sz, 0));
vector<vector<int>> c21(new_sz, vector<int>(new_sz, 0));
vector<vector<int>> c22(new_sz, vector<int>(new_sz, 0));

sum(p3, p5, c12);
sum(p2, p4, c21);

sum(p1, p4, result1);
sum(result1, p7, result2);
subtract(result2, p5, c11);

sum(p1, p3, result1);
sum(result1, p6, result2);
subtract(result2, p2, c22);

for (int i = 0; i < new_sz; ++i)
{
for (int j = 0; j < new_sz; ++j)
{
ret[i][j] = c11[i][j];
ret[i][j + new_sz] = c12[i][j];
ret[i + new_sz][j] = c21[i][j];
ret[i + new_sz][j + new_sz] = c22[i][j];
}
}

}

void strassen(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret)
{
if (A.empty() || B.empty())
throw runtime_error("empty matrices");
if (A[0].size() != B.size())
throw runtime_error("A's col not equal B's row");
zeroPadding(A);
zeroPadding(B);
strassenHelper(A, B, ret);
}

int main(int argc, const char * argv[]) {
vector<vector<int>> A{{1, 2, 0}, {1, 2, 3}, {1, 2, 3}};
vector<vector<int>> B{{1, 0, 1}, {1, 1, 1}, {2, 1, 1}};
vector<vector<int>> ret(2, vector<int>(2));
strassen(A, B, ret);
printMatrix(ret);
return 0;
}

Reference:
https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: