DynamicProgramming——矩阵连乘

来源:互联网 发布:gta5fps优化 编辑:程序博客网 时间:2024/05/18 01:29

矩阵连乘问题:

一、思考

对于N个矩阵相乘A1,A2,A3,A4,A5,A6,A7,A8,A9

由于两个矩阵相乘(以A*B为例)的充要条件是矩阵A的列数与矩阵B的行数相等,

而对于N个不同的矩阵连乘时,想邻两个矩阵总是有如下关系:即左矩阵的列数与右矩阵的行数相等。

二、矩阵相乘之间的关系

假设两个矩阵A=n*m,B=m*l;
  我们对于两个相乘的矩阵,他们需要的乘积次数count=n*m*l;而在N个数组相乘的过程中需要的总次数根据你相乘的方式不同,所需要的总次数是不同的。以三个数组A=10*20,B=20*30,C=30*40为例,如果以(AB)C方式相乘需要18000次,而以A(BC)方式相乘则需要32000次,显示以(AB)C方式相乘能够使得相乘次数最少,从而能够降低矩阵相乘的运行时间。
而对于多个矩阵相乘,我们该如何找到最优的相乘方式呢?
考虑如下,对于A1.....An矩阵序列,我们要求A[1:N]的最优次数,可以假设在A[1:N]矩阵序列中,存在一个k值,使得对于A[1:k]+A[k+1:n]+sum的结果最小,其中,k将A[1:N]划分成两个矩阵序列A[1:K],A[K+1:N],sum是两个矩阵相乘之后的两个数组乘积需要的次数。在这里,我们将A[1:N]的问题转化为求解A[1:K],A[K+1:N]两个子问题的最优化,这样递归的求解到最基本的两个矩阵相乘的最优化,最后将最优子问题合并来求解原问题的最优方法。
递归求解A[i:j]:
if(i==j)
A[i:j]=0;
else{
A[i:j]=min{A[i:k]+A[k+1:j]+sum};// k从i+1循环遍历,在i到j之间找到最优
}

三、子问题重复

在这个问题中,还存在着一个子问题重复求解问题,
例如要求解A[1:5],可以分为求解
1.A[1:2]+A[3:5]+sum,
2.A[1:3]+A[4:5]+sum,
在这里,我们需要求解A[1:3],而我们在求解A[1:4]时已经求解过A[1:3]了,这样就存在着子问题重复求解,增加了算法的计算量,使得时间大大增加。
我们采用建立一个数组来记录已经求过的子问题,即,建立一个数组m[i][j],用来存储计算过的A[i:j],这样在遇到子问题时,我们查找备忘录就可以得到了,而不用重复计算。
同时也可以采用s[i][j],用来记录在A[i:j]之间找到的最优的k值。

四、JAVA代码实现

package DynamicProgramming;

import java.util.Arrays;
/**
* 矩阵连乘:
* 是利用动态规划的思想
* 求取N个矩阵相乘需要的计算次数最少的乘法方式
* 这个方式我们用括号来体现,利用动态规划,将
* 大问题分解成小问题,大问题的最优解依赖于
* 小问题的最优解,而在动态规划方面,很多小问题被重复计算了,
* 所以我们需要一个备忘录来记录已经求得的子问题的最优解
*/
public class MatrixMultiply {
private int[] p; //用来存储N个矩阵的行数与列数
private int[][] m; //备忘录,用来记录已经求得的子问题的最优解
private int[][] s; //用来记录断点K值
private int n; //用来存储有多少个矩阵相乘
MatrixMultiply(int n,int[] p){
this.n=n+1;
this.p=new int[this.n];
this.p=Arrays.copyOfRange(p, 0, p.length);
this.m=new int[this.n][this.n];
this.s=new int[this.n][this.n];
}
public void matrixSubWays(){
for(int i=0;i<this.n;i++)
m[i][i]=0;
for(int r=2;r<this.n;r++){ //递归的求解最少的相乘方式
for(int i=1;i<this.n-r+1;i++){
int j=i+r-1;
m[i][j]=m[i][i]+m[i+1][j]+p[i-1]*p[i]*p[j];
s[i][j]=i;
for(int k=i+1;k<j;k++){
int small=m[i][k]+m[k+1][j]+p[i-1]*p[k]*p[j]; //以k为中断点,M[i][j]的最少次数等于M[i][k]
if(small<m[i][j]){ //+M[K+1][j]+p[i-1]*p[i]*p[j]
m[i][j]=small; //在i到j中循环知道找到最小k,并把k存入S[i][j]
s[i][j]=k;
}
}
}
}
}
public void traceBack(int i,int j,int[][] s){ //通过递归给数组加上括号
if(i==j){
System.out.print("A"+i);
return;
}
System.out.print("(");
traceBack(i,s[i][j],s);
traceBack(s[i][j]+1,j,s);
System.out.print(")");
}
public void printMatrix(){
int i=1;
int j=this.n-1;
traceBack(i,j,this.s);
}
public void getNumber(){
int count=0;
for(int i=1;i<this.n;i++){
for(int j=1;j<this.n;j++){
System.out.print(m[i][j]);
count++;
if(count%(this.n-1)==0){
System.out.println();
}
else
System.out.print("\t");
}
}
}
public void getK(){
int count=0;
for(int i=1;i<this.n;i++){
for(int j=1;j<this.n;j++){
System.out.print(s[i][j]);
count++;
if(count%(this.n-1)==0){
System.out.println();
}
else
System.out.print("\t");
}
}
}
public static void main(String[] args){
int[] p={30,35,15,5,10,20,25};
int n=p.length-1;
MatrixMultiply mm=new MatrixMultiply(n,p);
mm.matrixSubWays();
mm.printMatrix();
}
}

0 0