Implementation of Strassen’s Algorithm for Matrix Multiplication
2013-03-21 06:25
357 查看
Strassen’s algorithm is not the most efficient algorithm for matrix multiplication, but it was the first algorithm that was theoretically faster than the naive algorithm. There is very good explanation and implementation of Strassen’s algorithm on Wikipedia.
However, the implementation of Strassen’s algorithm cannot be used directly, because it just sets the base case of the divide-and-conquer to be 1×1 matrix, which would consume huge time cost for iteration. If set the base case to 2×2 matrix, which means 2×2
matrix and 1×1 matrix will be multiplied by naive algorithm, then the Strassen’s algorithm will be more efficient for matrices larger than 512×512.
When set base case to 2×2 matrix, then the Strassen’s algorithm will surpass naive algorithm for matrices larger than 512×512.
When set base case to 6×6 matrix, then the Strassen’s algorithm will surpass naive algorithm for matrices larger than 128×128.
However, the implementation of Strassen’s algorithm cannot be used directly, because it just sets the base case of the divide-and-conquer to be 1×1 matrix, which would consume huge time cost for iteration. If set the base case to 2×2 matrix, which means 2×2
matrix and 1×1 matrix will be multiplied by naive algorithm, then the Strassen’s algorithm will be more efficient for matrices larger than 512×512.
When set base case to 2×2 matrix, then the Strassen’s algorithm will surpass naive algorithm for matrices larger than 512×512.
When set base case to 6×6 matrix, then the Strassen’s algorithm will surpass naive algorithm for matrices larger than 128×128.
/*------------------------------------------------------------------------------*/ // matrix_mult.cc -- Implementation of matrix multiplication with // Strassen's algorithm. // // Compile this file with gcc command: // g++ -Wall -o matrix_mult matrix_mult.cc #include <stdio.h> #include <stdlib.h> #include <time.h> #include <ctype.h> #include <unistd.h> #include <iostream> #include <fstream> #include <cmath> #include <cstring> using namespace std; // This function allocates the matrix inline double** allocate_matrix(int n) { double** mat=new double* ; for(int i=0;i<n;++i) { mat[i]=new double ; memset(mat[i],0,sizeof(double)*n); } return (mat); // returns the pointer to the vector. } /*------------------------------------------------------------------------------*/ // This function unallocates the matrix (frees memory) inline void free_matrix(double **M, int n) { for (int i = 0; i < n; i++) { delete [] M[i]; } delete [] M; // frees the pointer / M = NULL; } /*------------------------------------------------------------------------------*/ // function to sum two matrices inline void sum(double **a, double **b, double **result, int tam) { int i, j; for (i = 0; i < tam; i++) { for (j = 0; j < tam; j++) { result[i][j] = a[i][j] + b[i][j]; } } } /*------------------------------------------------------------------------------*/ // function to subtract two matrices inline void subtract(double **a, double **b, double **result, int tam) { int i, j; for (i = 0; i < tam; i++) { for (j = 0; j < tam; j++) { result[i][j] = a[i][j] - b[i][j]; } } } /*------------------------------------------------------------------------------*/ // naive method void naive(double** A, double** B,double** C, int n) { for (int i=0;i<n;i++) for (int j=0;j<n;j++) for(int k=0;k<n;k++) C[i][j] += A[i][k]*B[k][j]; } /*------------------------------------------------------------------------------*/ // Strassen's method void strassen(double **a, double **b, double **c, int tam) { // Key observation: call naive method for matrices smaller than 2 x 2 if(tam <= 4) { naive(a,b,c,tam); return; } // other cases are treated here: int newTam = tam/2; double **a11, **a12, **a21, **a22; double **b11, **b12, **b21, **b22; double **c11, **c12, **c21, **c22; double **p1, **p2, **p3, **p4, **p5, **p6, **p7; // memory allocation: a11 = allocate_matrix(newTam); a12 = allocate_matrix(newTam); a21 = allocate_matrix(newTam); a22 = allocate_matrix(newTam); b11 = allocate_matrix(newTam); b12 = allocate_matrix(newTam); b21 = allocate_matrix(newTam); b22 = allocate_matrix(newTam); c11 = allocate_matrix(newTam); c12 = allocate_matrix(newTam); c21 = allocate_matrix(newTam); c22 = allocate_matrix(newTam); p1 = allocate_matrix(newTam); p2 = allocate_matrix(newTam); p3 = allocate_matrix(newTam); p4 = allocate_matrix(newTam); p5 = allocate_matrix(newTam); p6 = allocate_matrix(newTam); p7 = allocate_matrix(newTam); double **aResult = allocate_matrix(newTam); double **bResult = allocate_matrix(newTam); //dividing the matrices in 4 sub-matrices: for (int i = 0; i < newTam; i++) { for (int j = 0; j < newTam; j++) { a11[i][j] = a[i][j]; a12[i][j] = a[i][j + newTam]; a21[i][j] = a[i + newTam][j]; a22[i][j] = a[i + newTam][j + newTam]; b11[i][j] = b[i][j]; b12[i][j] = b[i][j + newTam]; b21[i][j] = b[i + newTam][j]; b22[i][j] = b[i + newTam][j + newTam]; } } // Calculating p1 to p7: sum(a11, a22, aResult, newTam); // a11 + a22 sum(b11, b22, bResult, newTam); // b11 + b22 strassen(aResult, bResult, p1, newTam); // p1 = (a11+a22) * (b11+b22) sum(a21, a22, aResult, newTam); // a21 + a22 strassen(aResult, b11, p2, newTam); // p2 = (a21+a22) * (b11) subtract(b12, b22, bResult, newTam); // b12 - b22 strassen(a11, bResult, p3, newTam); // p3 = (a11) * (b12 - b22) subtract(b21, b11, bResult, newTam); // b21 - b11 strassen(a22, bResult, p4, newTam); // p4 = (a22) * (b21 - b11) sum(a11, a12, aResult, newTam); // a11 + a12 strassen(aResult, b22, p5, newTam); // p5 = (a11+a12) * (b22) subtract(a21, a11, aResult, newTam); // a21 - a11 sum(b11, b12, bResult, newTam); // b11 + b12 strassen(aResult, bResult, p6, newTam); // p6 = (a21-a11) * (b11+b12) subtract(a12, a22, aResult, newTam); // a12 - a22 sum(b21, b22, bResult, newTam); // b21 + b22 strassen(aResult, bResult, p7, newTam); // p7 = (a12-a22) * (b21+b22) // calculating c21, c21, c11 e c22: sum(p3, p5, c12, newTam); // c12 = p3 + p5 sum(p2, p4, c21, newTam); // c21 = p2 + p4 sum(p1, p4, aResult, newTam); // p1 + p4 sum(aResult, p7, bResult, newTam); // p1 + p4 + p7 subtract(bResult, p5, c11, newTam); // c11 = p1 + p4 - p5 + p7 sum(p1, p3, aResult, newTam); // p1 + p3 sum(aResult, p6, bResult, newTam); // p1 + p3 + p6 subtract(bResult, p2, c22, newTam); // c22 = p1 + p3 - p2 + p6 // Grouping the results obtained in a single matrix: for (int i = 0; i < newTam ; i++) { for (int j = 0 ; j < newTam ; j++) { c[i][j] = c11[i][j]; c[i][j + newTam] = c12[i][j]; c[i + newTam][j] = c21[i][j]; c[i + newTam][j + newTam] = c22[i][j]; } } // deallocating memory (free): free_matrix(a11, newTam); free_matrix(a12, newTam); free_matrix(a21, newTam); free_matrix(a22, newTam); free_matrix(b11, newTam); free_matrix(b12, newTam); free_matrix(b21, newTam); free_matrix(b22, newTam); free_matrix(c11, newTam); free_matrix(c12, newTam); free_matrix(c21, newTam); free_matrix(c22, newTam); free_matrix(p1, newTam); free_matrix(p2, newTam); free_matrix(p3, newTam); free_matrix(p4, newTam); free_matrix(p5, newTam); free_matrix(p6, newTam); free_matrix(p7, newTam); free_matrix(aResult, newTam); free_matrix(bResult, newTam); } // end of Strassen function /*------------------------------------------------------------------------------*/ // Generate random matrices void gen_matrix(double** M,int n) { for(int i=0;i<n;++i) { for(int j=0;j<n;++j) { M[i][j]=rand()%100; //M[i][j]=1; } } } /*------------------------------------------------------------------------------*/ // print matrix M using specied fstream void print_matrix(fstream& fs, double** M, int n) { for(int i=0;i<n;++i) { for(int j=0;j<n;++j) { fs<<M[i][j]<<" "; } fs<<endl; } fs<<endl; } /*------------------------------------------------------------------------------*/ // record the generated matrix and the final product void mat_mult_log(double** A, double** B,double** C,int n,char* file) { fstream fs; fs.open(file,fstream::out); fs<<"Random Matrix A:"<<endl; print_matrix(fs,A,n); fs<<"Random Matrix B:"<<endl; print_matrix(fs,B,n); fs<<"C=A * B"<<endl; print_matrix(fs,C,n); fs.close(); } /*------------------------------------------------------------------------------*/ int main(int argc, char** argv) { srand(time(NULL)); int mdim=2; // matrix dimension char* output=NULL; bool is_strassen=false; int c; while ((c = getopt (argc, argv, "sn:o:")) != -1) { switch (c) { case 's': is_strassen=true; break; case 'n': mdim = pow((int)2,atoi(optarg)); // 2^n dimensions break; case 'o': output = optarg; // 2^n dimensions break; case '?': if (optopt == 'n') fprintf (stderr, "Option -%c requires an argument.\n", optopt); else if (isprint (optopt)) fprintf (stderr, "Unknown option `-%c'.\n", optopt); else fprintf (stderr, "Unknown option character `\\x%x'.\n", optopt); return 1; default: abort (); } } // create new matrices double** A=allocate_matrix(mdim); double** B=allocate_matrix(mdim); double** C=allocate_matrix(mdim); gen_matrix(A,mdim); gen_matrix(B,mdim); // matrices multiplication if(is_strassen) strassen(A,B,C,mdim); else naive(A,B,C,mdim); if(output!=NULL) mat_mult_log(A,B,C,mdim,output); free_matrix(A,mdim); free_matrix(B,mdim); free_matrix(C,mdim); return 0; }
相关文章推荐
- Strassen’s algorithm for matrix multiplication
- Strassen's Subcubic Matrix Multiplication Algorithm
- 论文笔记:Research and Implementation of a Multi-label Learning Algorithm for Chinese Text Classification
- Strassen's algorithm to compute matrix multiplication
- Implementation of Hierarchical Attention Networks for Document Classification的讲解与Tensorflow实现
- Search a 2D Matrix IIWrite an efficient algorithm that searches for a value in an m x n matrix. This
- Supervised Learning 002: k-Nearest Neighbor - Classify Algorithm for Matrix
- MonoGame Cross Platform Implementation of XNA for iOS, Android, Mac , Linux, Windows, Windows8, OUYA
- 《A Sub-Pixel Edge Detector: an Implementation of the Canny/Devernay Algorithm》
- Linear-time algorithm for the maximum-subarray problem Java implementation
- Rob Hess's C implementation of SIFT algorithm
- algorithm 1.4.19 Local minimum of a matrix
- libcxx——algorithm,functionalConcept is a term that describes a named set of requirements for a type.
- Proof of a PD covariance matrix for conditional Gaussian
- Implementation codes of Data Structures and Algorithm Analysis in C (1)
- A Fast Priority Queue Implementation of the Dijkstra Shortest Path Algorithm
- implementation of General Sort Algorithm - mark
- 研读:Design and Implementation of Efficient Integrity Protection for Open Mobile Platforms
- Collection of algorithm for sorting. 常见排序算法集(一)
- Matrix Chain Multiplication-geeksforgeeks