Strassen's algorithm to compute matrix multiplication
2016-05-25 03:54
597 查看
// // main.cpp // Strassen // // Created by Longxiang Lyu on 5/24/16. // Copyright (c) 2016 Longxiang Lyu. All rights reserved. // #include <iostream> #include <vector> #include <string> #include <stdexcept> #include <math.h> using namespace std; void printMatrix(const vector<vector<int>> &matrix) { for (auto row : matrix) { for (auto elem : row) cout << elem << " "; cout << endl; } } void zeroPadding(vector<vector<int>> &matrix) { size_t sz = pow(2, (int)(sqrt(max(matrix.size(), matrix[0].size())) + 1)); matrix.resize(sz); for (size_t i = 0; i < sz; ++i) { if (!matrix[i].empty()) matrix[i].resize(sz); else matrix[i] = vector<int>(sz, 0); } } void sum(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret) { size_t sz = A.size(); for (int i = 0; i < sz; ++i) for (int j = 0; j < sz; ++j) ret[i][j] = (A[i][j] + B[i][j]); } void subtract(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret) { size_t sz = A.size(); // ret.clear(); // ret.resize(sz); for (int i = 0; i < sz; ++i) for (int j = 0; j < sz; ++j) ret[i][j] = (A[i][j] - B[i][j]); } void strassenHelper(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret) { if (A.size() == 1) { ret[0][0] = A[0][0] * B[0][0]; return; } size_t sz = A.size(); size_t new_sz = sz / 2; ret = vector<vector<int>>(sz, vector<int>(sz)); vector<vector<int>> a11(new_sz), a12(new_sz), a21(new_sz), a22(new_sz), b11(new_sz), b12(new_sz), b21(new_sz), b22(new_sz); for (int i = 0; i < new_sz; ++i) { for (int j = 0; j < new_sz; ++j) { a11[i].push_back(A[i][j]); a12[i].push_back(A[i][j + new_sz]); a21[i].push_back(A[i + new_sz][j]); a22[i].push_back(A[i + new_sz][j + new_sz]); b11[i].push_back(B[i][j]); b12[i].push_back(B[i][j + new_sz]); b21[i].push_back(B[i + new_sz][j]); b22[i].push_back(B[i + new_sz][j + new_sz]); } } vector<vector<int>> result1(new_sz, vector<int>(new_sz, 0)), result2(new_sz, vector<int>(new_sz, 0)); // p1 vector<vector<int>> p1(new_sz, vector<int>(new_sz, 0)); sum(a11, a22, result1); sum(b11, b22, result2); strassenHelper(result1, result2, p1); // p2 vector<vector<int>> p2(new_sz, vector<int>(new_sz, 0)); sum(a21, a22, result1); strassenHelper(result1, b11, p2); // p3 vector<vector<int>> p3(new_sz, vector<int>(new_sz, 0)); subtract(b12, b22, result2); strassenHelper(a11, result2, p3); // p4 vector<vector<int>> p4(new_sz, vector<int>(new_sz, 0)); subtract(b21, b11, result2); strassenHelper(a22, result2, p4); // p5 vector<vector<int>> p5(new_sz, vector<int>(new_sz, 0)); sum(a11, a12, result1); strassenHelper(result1, b22, p5); // p6 vector<vector<int>> p6(new_sz, vector<int>(new_sz, 0)); subtract(a21, a11, result1); sum(b11, b12, result2); strassenHelper(result1, result2, p6); // p7 vector<vector<int>> p7(new_sz, vector<int>(new_sz, 0)); subtract(a12, a22, result1); sum(b21, b22, result2); strassenHelper(result1, result2, p7); vector<vector<int>> c11(new_sz, vector<int>(new_sz, 0)); vector<vector<int>> c12(new_sz, vector<int>(new_sz, 0)); vector<vector<int>> c21(new_sz, vector<int>(new_sz, 0)); vector<vector<int>> c22(new_sz, vector<int>(new_sz, 0)); sum(p3, p5, c12); sum(p2, p4, c21); sum(p1, p4, result1); sum(result1, p7, result2); subtract(result2, p5, c11); sum(p1, p3, result1); sum(result1, p6, result2); subtract(result2, p2, c22); for (int i = 0; i < new_sz; ++i) { for (int j = 0; j < new_sz; ++j) { ret[i][j] = c11[i][j]; ret[i][j + new_sz] = c12[i][j]; ret[i + new_sz][j] = c21[i][j]; ret[i + new_sz][j + new_sz] = c22[i][j]; } } } void strassen(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret) { if (A.empty() || B.empty()) throw runtime_error("empty matrices"); if (A[0].size() != B.size()) throw runtime_error("A's col not equal B's row"); zeroPadding(A); zeroPadding(B); strassenHelper(A, B, ret); } int main(int argc, const char * argv[]) { vector<vector<int>> A{{1, 2, 0}, {1, 2, 3}, {1, 2, 3}}; vector<vector<int>> B{{1, 0, 1}, {1, 1, 1}, {2, 1, 1}}; vector<vector<int>> ret(2, vector<int>(2)); strassen(A, B, ret); printMatrix(ret); return 0; }
Reference:
https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/
相关文章推荐
- Gof 设计模式 创建型
- 根据经纬度,计算两点之间的距离
- beego模板语法 go语言模版语法
- Category 中属性的使用
- dajngo 权限机制
- algorithm 题集三 (16.05.24)
- go学习
- HDU 4722 Good Numbers 数位dp或找规律枚举 数位dp感悟
- GO 语言圣经 -在线阅读
- django 模板中url的处理
- go学习
- Google 2016 i/o 大会
- golang的SHA1withRSA的实现
- django 2
- Google Eventbus优缺点
- Golang游戏服务器
- ViewGroup setVisibility 为GONE 子View依然占用地方,其中的子EditText会出现点击焦点占用
- 【Algothrim】动态规划法实例1
- go语言模板引擎应用以及读取io流
- hdu5512Pagodas