您的位置:首页 > 编程语言 > Java开发

算法导论学习笔记—Strassen算法的Java实现

2017-11-29 16:27 731 查看
Strassen算法

    Strassen算法的核心思想是令递归树稍微不那么茂盛,相比于简单的“分而治之”的矩阵递归计算,其递归的分支由8条减少到7条。其时间复杂度为O(n的lg7次方)。虽然,它的算法中需要新增10个(n/2 * n/2)的中间矩阵S1-S10。每次子矩阵的加减运算会增加O(n平方/4)的时间消耗,所以代码在执行S1-S10时,这部分的时间复杂度为10*O(n平方/4),但是相比之下,Strassen算法减少了一次递归,所以时间复杂度上会减少。

   下面来看简单“分而治之”矩阵乘法和Strassen算法的对比:

(1) 简单“分而治之”矩阵乘法

package com.oracle.ThirdCharpter;

/**

 * 写一个简单的“分而治之的矩阵乘法”,即A = [A11 A12    B=[B11 B12   C=[A11*B11+A12*B21  A11*B12+A12*B22

 *                                        A21 A22]      B21 B22]     A21*B11+A22*B21  A21*B12+A22*B22]

 * 经过算法分析发现,其时间复杂度依然还是O(n的3次方)

 * @author zhegao

 *

 */

public class Practice1_4 {

    public int[][] matrix_multiply(int[][] a,int[][] b) {        

        if(a.length==1) {

            return new int[][] {{a[0][0]*b[0][0]}};

        }else {

            int[][] A11 = partition(a,1);

            int[][] B11 = partition(b,1);

            int[][] A12 = partition(a,2);

            int[][] B12 = partition(b,2);

            int[][] A21 = partition(a,3);

            int[][] B21 = partition(b,3);

            int[][] A22 = partition(a,4);

            int[][] B22 = partition(b,4);

            //进行加法运算

            int[][] C11 = matrixAdd(matrix_multiply(A11,B11),matrix_multiply(A12,B21));

            int[][] C12 = matrixAdd(matrix_multiply(A11,B12),matrix_multiply(A12,B22));

            int[][] C21 = matrixAdd(matrix_multiply(A21,B11),matrix_multiply(A22,B21));

            int[][] C22 = matrixAdd(matrix_multiply(A21,B12),matrix_multiply(A22,B22));

            int[][] C = merge(C11,C12,C21,C22);

            return C;

        }

        

    }

    //拆分矩阵,得到四个子矩阵,把不同位置的子矩阵标记成1,2,3,4。1——左上;2——右上;3——左下;4——右下

    public int[][] partition(int[][] arr,int index) {

        int len = arr.length;

        int[][] result = new int[len/2][len/2];

        switch(index) {

        case 1:

            for(int i=0;i<len/2;i++) {

                for(int j=0;j<len/2;j++) {                    

                    result[i][j]=arr[i][j];        

                    //System.out.println(result[i][j]);

                }

            };

            break;

        case 2:

            for(int i=0;i<len/2;i++) {

                for(int j=len/2;j<len;j++) {                

                    result[i][j-len/2]=arr[i][j];                        

                }

            };

            break;

        case 3:

            for(int i=len/2;i<len;i++) {                

                for(int j=0;j<len/2;j++) {                                    

                    result[i-len/2][j] = arr[i][j];

                }

            };

            break;

        case 4:

            for(int i=len/2;i<len;i++) {

                for(int j=len/2;j<len;j++) {                    

                    result[i-len/2][j-len/2]=arr[i][j];

                }

            };

        }

        return result;

    }

    //矩阵的加运算

    public int[][] matrixAdd(int[][] a,int[][] b){

        int[][] result = new int[a.length][a.length];

        for(int i=0;i<a.length;i++) {

            for(int j=0;j<a.length;j++) {

                result[i][j]=a[i][j]+b[i][j];

            }

        }

        return result;

    }

    public  void display(int[][] arr) {

        System.out.print("[");

        for(int i=0;i<arr.length;i++) {

            for(int j=0;j<arr.length;j++) {

                if(j==arr.length-1) {

                    System.out.print(arr[i][j]);

                }else {

                    System.out.print(arr[i][j]+ " ");

                }

            }

            System.out.print("]");

            System.out.print("\n");

        }    

    }

    //将四个子矩阵合并成一个整体的大矩阵

    public int[][] merge(int[][] a1,int[][] a2,int[][] a3,int[][] a4){

        int len = a1.length;

        int[][] result = new int[len*2][len*2];

        for(int i=0;i<result.length;i++) {

            if(i<len) {

                for(int j=0;j<result.length;j++) {

                    if(j<len) {

                        result[i][j]=a1[i][j];

                    }else {

                        result[i][j] = a2[i][j-len];

                    }

                }

            }else {

                for(int j=0;j<result.length;j++) {

                    if(j<len) {

                        result[i][j]=a3[i-len][j];

                    }else {

                        result[i][j] = a4[i-len][j-len];

                    }

                }

            }

        }

        return result;

    }

    public static void main(String[] args) {

        int[][] arr = new int[][] {{1,2,3,4},{5,6,7,8},{9,10,11,12},{13,14,15,16}};

        Practice1_4 prac = new Practice1_4();

        int[][] result1 = prac.partition(arr, 1);

        int[][] result2 = prac.partition(arr, 2);

        int[][] result3 = prac.partition(arr, 3);

        int[][] result4 = prac.partition(arr, 4);

      //测试矩阵分离方法

        prac.display(result1);

        prac.display(result2);

        prac.display(result3);

        prac.display(result4);

   

        //测试矩阵的合并方法

        int[][] merge = prac.merge(result1, result2, result3, result4);

        prac.display(merge);

        

        //测试分而治之的矩阵乘法

        int[][] a1 = new int[][] {{1,2,3,4},{4,3,2,1},{0,1,2,3},{3,2,1,0}};

        int[][] a2 = new int[][] {{1,2,3,4},{4,3,2,1},{0,1,2,3},{3,2,1,0}};

        prac.display(prac.matrix_multiply(a1, a2));

    }

}

(2)Strassen算法

package com.oracle.ThirdCharpter;

/**

 * 使用Strassen算法进行矩阵乘法

 * 相比于“分而治之”的矩阵乘法,Stassen算法的递归分支只有7条,所以其时间复杂度为O(log2 7)

 *

 * 分析:Strassen算法相比传统”分而治之“的算法,它的递归分支只有7条,P1-P7
4000


 * @author zhegao

 *

 */

public class Practice1_5 {

    public int[][] matrix_multiply(int[][] a,int[][] b) {        

        if(a.length==1) {

            return new int[][] {{a[0][0]*b[0][0]}};

        }else {

            int[][] A11 = partition(a,1);

            int[][] B11 = partition(b,1);

            int[][] A12 = partition(a,2);

            int[][] B12 = partition(b,2);

            int[][] A21 = partition(a,3);

            int[][] B21 = partition(b,3);

            int[][] A22 = partition(a,4);

            int[][] B22 = partition(b,4);

            

            //计算S1-S10的中间矩阵

            int[][] S1 = matrixSubstract(B12,B22);

            int[][] S2 = matrixAdd(A11,A12);

            int[][] S3 = matrixAdd(A21,A22);

            int[][] S4 = matrixSubstract(B21,B11);

            int[][] S5 = matrixAdd(A11,A22);

            int[][] S6 = matrixAdd(B11,B22);

            int[][] S7 = matrixSubstract(A12,A22);

            int[][] S8 = matrixAdd(B21,B22);

            int[][] S9 = matrixSubstract(A11,A21);

            int[][] S10 = matrixAdd(B11,B12);

            

            //计算P1-P7的几个递归矩阵

            int[][] P1 = matrix_multiply(A11,S1);

            int[][] P2 = matrix_multiply(S2,B22);

            int[][] P3 = matrix_multiply(S3,B11);

            int[][] P4 = matrix_multiply(A22,S4);

            int[][] P5 = matrix_multiply(S5,S6);

            int[][] P6 = matrix_multiply(S7,S8);

            int[][] P7 = matrix_multiply(S9,S10);

            

            //进行加减运算

            int[][] C11 = matrixAdd(matrixSubstract(matrixAdd(P5,P4),P2),P6);

            int[][] C12 = matrixAdd(P1,P2);

            int[][] C21 = matrixAdd(P3,P4);

            int[][] C22 = matrixSubstract(matrixSubstract(matrixAdd(P5,P1),P3),P7);

            

            //合并各个子矩阵

            int[][] C = merge(C11,C12,C21,C22);

            return C;

        }

    }

    //拆分矩阵,得到四个子矩阵,把不同位置的子矩阵标记成1,2,3,4。1——左上;2——右上;3——左下;4——右下

    public int[][] partition(int[][] arr,int index) {

        int len = arr.length;

        int[][] result = new int[len/2][len/2];

        switch(index) {

        case 1:

            for(int i=0;i<len/2;i++) {

                for(int j=0;j<len/2;j++) {                    

                    result[i][j]=arr[i][j];        

                    //System.out.println(result[i][j]);

                }

            };

            break;

        case 2:

            for(int i=0;i<len/2;i++) {

                for(int j=len/2;j<len;j++) {                

                    result[i][j-len/2]=arr[i][j];                        

                }

            };

            break;

        case 3:

            for(int i=len/2;i<len;i++) {                

                for(int j=0;j<len/2;j++) {                                    

                    result[i-len/2][j] = arr[i][j];

                }

            };

            break;

        case 4:

            for(int i=len/2;i<len;i++) {

                for(int j=len/2;j<len;j++) {                    

                    result[i-len/2][j-len/2]=arr[i][j];

                }

            };

        }

        return result;

    }

    

    //矩阵的加运算

    public int[][] matrixAdd(int[][] a,int[][] b){

        int[][] result = new int[a.length][a.length];

        for(int i=0;i<a.length;i++) {

            for(int j=0;j<a.length;j++) {

                result[i][j]=a[i][j]+b[i][j];

            }

        }

        return result;

    }

    

    //矩阵的减运算

    public int[][] matrixSubstract(int[][] a,int[][] b){

        int[][] result = new int[a.length][a.length];

        for(int i=0;i<a.length;i++) {

            for(int j=0;j<a.length;j++) {

                result[i][j]=a[i][j]-b[i][j];

            }

        }

        return result;

    }

    public  void display(int[][] arr) {

        System.out.print("[");

        for(int i=0;i<arr.length;i++) {

            for(int j=0;j<arr.length;j++) {

                if(j==arr.length-1) {

                    System.out.print(arr[i][j]);

                }else {

                    System.out.print(arr[i][j]+ " ");

                }

            }

            System.out.print("]");

            System.out.print("\n");

        }    

    }

    

    //将四个子矩阵合并成一个整体的大矩阵

        public int[][] merge(int[][] a1,int[][] a2,int[][] a3,int[][] a4){

            int len = a1.length;

            int[][] result = new int[len*2][len*2];

            for(int i=0;i<result.length;i++) {

                if(i<len) {

                    for(int j=0;j<result.length;j++) {

                        if(j<len) {

                            result[i][j]=a1[i][j];

                        }else {

                            result[i][j] = a2[i][j-len];

                        }

                    }

                }else {

                    for(int j=0;j<result.length;j++) {

                        if(j<len) {

                            result[i][j]=a3[i-len][j];

                        }else {

                            result[i][j] = a4[i-len][j-len];

                        }

                    }

                }

            }

            return result;

        }

    public static void main(String[] args) {

        //测试分而治之的矩阵乘法

        Practice1_5 prac = new Practice1_5();

        int[][] a1 = new int[][] {{1,2,3,4},{4,3,2,1},{0,1,2,3},{3,2,1,0}};

        int[][] a2 = new int[][] {{1,2,3,4},{4,3,2,1},{0,1,2,3},{3,2,1,0}};

        prac.display(prac.matrix_multiply(a1, a2));

    }

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: