(原创)一般矩阵 Matrix类

来源:互联网 发布:重置sql密码用户名 编辑:程序博客网 时间:2024/06/05 19:39

讨论矩阵的两种表示方法,一种用一维数组来存储矩阵元素,另一种用二维数组来存储矩阵元素。然后比较两种方法,并测试它们的性能,做出总结。

1.一维数组形式
主要代码:

public class Matrix implements CloneableObject{    int rows,cols;    Object [] element;    public Matrix(int theRows,int theCols){        rows=theRows;        cols=theCols;        element=new Object[rows*cols];    }    public Object clone(){        Matrix x=new Matrix(rows,cols);        for(int i=0;i<rows*cols;i++){            x.element[i]=((CloneableObject)element[i]).clone();        }        return x;    }    public void copy(Matrix m){        if(this!=m){            rows=m.rows;            cols=m.cols;            element=new Object[rows*cols];            for(int i=0;i<rows*cols;i++){                element[i]=((CloneableObject)m.element[i]).clone();            }        }    }    public Object get(int i,int j){        checkIndex(i,j);        return element[(i-1)*cols+(j-1)];    }    private void checkIndex(int i, int j) {        // TODO Auto-generated method stub        if(i<1||j<1||i>rows||j>cols){            throw new IndexOutOfBoundsException("");        }    }    public void set(int i,int j,Object newValue){        checkIndex(i,j);        element[(i-1)*cols+(j-1)]=((CloneableObject)newValue).clone();    }    public Matrix add(Matrix m){        if(rows!=m.rows||cols!=m.cols){            throw new IllegalArgumentException("can not add");        }        Matrix w=new Matrix(rows,cols);        int n=rows*cols;        for(int i=0;i<n;i++){            w.element[i]=((Computable)element[i]).add(m.element[i]);        }        return w;    }    public Matrix sub(Matrix m){        if(rows!=m.rows||cols!=m.cols){            throw new IllegalArgumentException("can not subtract");        }        Matrix w=new Matrix(rows,cols);        int n=rows*cols;        for(int i=0;i<n;i++){            w.element[i]=((Computable)element[i]).subtract(m.element[i]);        }        return w;    }    public Matrix multiply(Matrix m){        if(cols!=m.rows){            throw new IllegalArgumentException("can not multiply");        }        Matrix w=new Matrix(rows,m.cols);        int ct=0,cm=0,cw=0;        for(int i=1;i<=rows;i++){            for(int j=1;j<=m.cols;j++){                Computable sum=(Computable) (((Computable)element[ct]).multiply(m.element[cm]));                for(int k=2;k<=cols;k++){                    ct++;                    cm+=m.cols;                    sum.increment(((Computable)element[ct]).multiply(m.element[cm]));                }                w.element[cw++]=sum;                ct-=cols-1;                cm=j;            }            ct+=cols;            cm=0;        }        return w;    }    public Matrix transpose(){        Matrix w=new Matrix(cols,rows);        for(int i=1;i<=rows;i++){            for(int j=1;j<=cols;j++){                w.element[(j-1)*cols+i-1]=element[(i-1)*rows+j-1];            }        }        return w;    }    public Matrix decrement(Object x){        Matrix w=new Matrix(rows,cols);        int n=rows*cols;        for(int i=0;i<n;i++){            w.element[i]=((Computable)element[i]).decrement(x);        }        return w;    }    public Matrix increment(Object x){        Matrix w=new Matrix(rows,cols);        int n=rows*cols;        for(int i=0;i<n;i++){            w.element[i]=((Computable)element[i]).increment(x);        }        return w;    }    public Matrix multiplyByConstant(Object x){        Matrix w=new Matrix(rows,cols);        int n=rows*cols;        for(int i=0;i<n;i++){            w.element[i]=((Computable)element[i]).multiply(x);        }        return w;    }    public Matrix dividedByConstant(Object x){        Matrix w=new Matrix(rows,cols);        int n=rows*cols;        for(int i=0;i<n;i++){            w.element[i]=((Computable)element[i]).divide(x);        }        return w;    }    public String toString(){        StringBuilder s=new StringBuilder();        int n=rows*cols;        for(int i=0;i<n;i++){            s=s.append("\t"+element[i].toString()+" ");            if((i+1)%cols==0){                s.append("\n");            }        }        return s.toString();    }}

这里的接口CloneableObject只有一个clone方法
public interface CloneableObject extends Cloneable
{public Object clone();}

Computable接口代码:

public interface Computable{   /** @return this + x */   public Object add(Object x);   /** @return this - x */   public Object subtract(Object x);   /** @return this * x */   public Object multiply(Object x);   /** @return quotient of this / x */   public Object divide(Object x);   /** @return remainder of this / x */   public Object mod(Object x);   /** @return this incremented by x */   public Object increment(Object x);   /** @return this decremented by x */   public Object decrement(Object x);   /** @return the additive zero element */   public Object zero();   /** @return the multiplicative identity element */   public Object identity();}

这里的Matrix类有三个数据成员rows,cols,element[],分别表示矩阵的行数,列数和矩阵的内容。并定义了构造方法public Matrix(int theRows,int theCols),矩阵构造函数的复杂度是O(rows*cols),如果我们假设复制一个矩阵项,两个矩阵项想家以及将一个矩阵项转换为字符串的时间为θ(1),那么方法clone,copy,add,toString的渐进复杂度也都是O(rows*cols)。矩阵乘法的复杂度是O(rows*cols*m.cols)。

2.二维数组形式
主要代码

public class MatrixAs2DArray implements CloneableObject{    Object [][] element;    int rows,cols;    public MatrixAs2DArray(int rows,int cols){        element=new Object[rows][cols];        this.rows=rows;        this.cols=cols;    }    public Object clone(){        MatrixAs2DArray w=new MatrixAs2DArray(rows,cols);        for(int i=0;i<rows;i++){            for(int j=0;j<cols;j++){                w.element[i][j]=((CloneableObject)element[i][j]).clone();            }        }        return w;    }    public void copy(MatrixAs2DArray m){        rows=m.rows;        cols=m.cols;        element=new Object[rows][cols];        for(int i=0;i<rows;i++){            for(int j=0;j<cols;j++){                element[i][j]=((CloneableObject)m.element[i][j]).clone();            }        }    }    public Object get(int i,int j){        checkIndex(i,j);        return element[i-1][j-1];       }    private void checkIndex(int i, int j) {        // TODO Auto-generated method stub        if(i<1||j<1||i>rows||j>cols){            throw new IndexOutOfBoundsException("");        }    }    public void set(int i,int j,Object newValue){        checkIndex(i,j);        element[i-1][j-1]=((CloneableObject)newValue).clone();    }    public MatrixAs2DArray add(MatrixAs2DArray m){        if(rows!=m.rows||cols!=m.cols){            throw new IllegalArgumentException("can not add");        }        MatrixAs2DArray w=new MatrixAs2DArray(rows,cols);        for(int i=0;i<rows;i++){            for(int j=0;j<cols;j++){                w.element[i][j]=((Computable)element[i][j]).add(m.element[i][j]);            }        }        return w;    }    public MatrixAs2DArray subtract(MatrixAs2DArray m){        if(rows!=m.rows||cols!=m.cols){            throw new IllegalArgumentException("can not add");        }        MatrixAs2DArray w=new MatrixAs2DArray(rows,cols);        for(int i=0;i<rows;i++){            for(int j=0;j<cols;j++){                w.element[i][j]=((Computable)element[i][j]).subtract(m.element[i][j]);            }        }        return w;    }    public MatrixAs2DArray multiply(MatrixAs2DArray m){        if(cols!=m.rows){            throw new IllegalArgumentException("can not multiply");        }        MatrixAs2DArray w=new MatrixAs2DArray(rows,m.cols);         for (int i = 0; i < rows; i++)             for (int j = 0; j < w.cols; j++)             {// compute [i][j] term of result                // compute first term of w(i,j)                Computable sum =  (Computable) ((Computable)element[i][0])                                   .multiply(m.element[0][j]);                // add in remaining terms                for (int k = 1; k < cols; k++)                   sum.increment(((Computable) element[i][k]).multiply                                  (m.element[k][j]));                w.element[i][j] = sum;             }        return w;    }    public String toString(){        StringBuilder s=new StringBuilder();        for(int i=0;i<rows;i++){            for(int j=0;j<cols;j++){                s.append("\t"+element[i][j].toString()+" ");            }            s.append("\n");        }        return s.toString();    }}

这里为了以示区别类名称用的是MatrixAs2DArray,依然是三个数据成员rows,cols,element[][],只是这里element变成二维数组了。其各个方法的复杂度与Matrix类是一样的。

3.比较
设要表示的矩阵大小为m*n的,假设元素都是int类型的,如果用一维数组(x[mn])存储要占用4mn+4个字节,其中4mn个字节用来存储数据,4个字节用来存储数组长度。如果用二维数组(x[m][n])来存储要占用4mn+8m+4个字节,其中4m个字节用来存储x[]指针,4个字节用来存储x[]指针长度,每一个x[]数组要4n+4个字节来存储,共有m个。从内存角度上来说,一维数组形式所占内存较少更有优势。
接下来对矩阵加法操作和乘法操作进行试验来比较性能。
代码:

public class MatrixPerformanceTest {    public static void main(String args[]){        for(int n=30;n<1000;n=n*2){            Matrix m1=new Matrix(n,n);            Matrix m2=new Matrix(n,n);            MatrixAs2DArray ma1=new MatrixAs2DArray(n,n);            MatrixAs2DArray ma2=new MatrixAs2DArray(n,n);            for (int i = 1; i <= n; i++)                for (int j = 1; j <= n; j++)                {                   MyInteger q1 = new MyInteger(2 * i + j);                   MyInteger q2 = new MyInteger(2 *j+3*i);                   m1.set(i, j, q1);                   m2.set(i, j, q2);                   ma1.set(i, j, q1);                   ma2.set(i, j, q2);                }            long startTime1=System.currentTimeMillis();            int count1=0;            do{                m1.add(m2);                count1++;            }while((System.currentTimeMillis()-startTime1)<1000);            long elapsedTime1=(System.currentTimeMillis()-startTime1)/count1;            System.out.print("n="+n+"时:1方法加法耗时:"+elapsedTime1+"  ");            long startTime2=System.currentTimeMillis();            int count2=0;            do{                ma1.add(ma2);                count2++;            }while((System.currentTimeMillis()-startTime2)<1000);            long elapsedTime2=(System.currentTimeMillis()-startTime2)/count2;            System.out.println("n="+n+"时:2方法加法耗时:"+elapsedTime2);            long startTime3=System.currentTimeMillis();            int count3=0;            do{                m1.multiply(m2);                count3++;            }while((System.currentTimeMillis()-startTime3)<1000);            long elapsedTime3=(System.currentTimeMillis()-startTime3)/count3;            System.out.print("n="+n+"时:1方法乘法耗时:"+elapsedTime3+"  ");            long startTime4=System.currentTimeMillis();            int count4=0;            do{                ma1.multiply(ma2);                count4++;            }while((System.currentTimeMillis()-startTime4)<1000);            long elapsedTime4=(System.currentTimeMillis()-startTime4)/count4;            System.out.println("n="+n+"时:2方法乘法耗时:"+elapsedTime4);        }        System.out.println("测试结束");    }}

结果:
n=30时:1方法加法耗时:0 n=30时:2方法加法耗时:0
n=30时:1方法乘法耗时:0 n=30时:2方法乘法耗时:0
n=60时:1方法加法耗时:0 n=60时:2方法加法耗时:0
n=60时:1方法乘法耗时:4 n=60时:2方法乘法耗时:3
n=120时:1方法加法耗时:0 n=120时:2方法加法耗时:0
n=120时:1方法乘法耗时:35 n=120时:2方法乘法耗时:35
n=240时:1方法加法耗时:1 n=240时:2方法加法耗时:2
n=240时:1方法乘法耗时:362 n=240时:2方法乘法耗时:378
n=480时:1方法加法耗时:5 n=480时:2方法加法耗时:8
n=480时:1方法乘法耗时:3752 n=480时:2方法乘法耗时:3952
n=960时:1方法加法耗时:44 n=960时:2方法加法耗时:34
n=960时:1方法乘法耗时:37277 n=960时:2方法乘法耗时:39904
测试结束

从结果可以看到对于加法,二维数组形式表现更好,但差别不大,对于乘法一维数组形式表现更好。

4.总结
实际上以上两种方法矩阵的乘法还可以进一步改进。这里以对二维形式表示的矩阵类的乘法为例进行改进(一维数组形式可以类似进行修改)
代码:

public MatrixAs2DArray multiply(MatrixAs2DArray m){        if(cols!=m.rows){            throw new IllegalArgumentException("can not multiply");        }        MatrixAs2DArray w=new MatrixAs2DArray(rows,m.cols);        for(int i=0;i<rows;i++){            for(int j=0;j<m.cols;j++){                w.element[i][j]=((Computable)element[i][0]).multiply(m.element[0][j]);            }        }        for(int i=0;i<rows;i++){            for(int k=1;k<cols;k++){                for(int j=0;j<m.cols;j++){                    Object temp=((Computable)element[i][k]).multiply(m.element[k][j]);                    w.element[i][j]=((Computable) w.element[i][j]).add(temp);                }            }        }}

实际上这里只是对乘法里面的三个嵌套的for循环顺序做了修改,但是由于将相乘的两个矩阵都行优先进行读取计算,使得缓存遗漏减少,增加了运算效率。下面是用改进后的二维数组形式进行的与之前一样的测试。
结果:
n=30时:1方法加法耗时:0 n=30时:2方法加法耗时:0
n=30时:1方法乘法耗时:0 n=30时:2方法乘法耗时:1
n=60时:1方法加法耗时:0 n=60时:2方法加法耗时:0
n=60时:1方法乘法耗时:3 n=60时:2方法乘法耗时:8
n=120时:1方法加法耗时:0 n=120时:2方法加法耗时:0
n=120时:1方法乘法耗时:34 n=120时:2方法乘法耗时:71
n=240时:1方法加法耗时:1 n=240时:2方法加法耗时:1
n=240时:1方法乘法耗时:331 n=240时:2方法乘法耗时:568
n=480时:1方法加法耗时:5 n=480时:2方法加法耗时:8
n=480时:1方法乘法耗时:3827 n=480时:2方法乘法耗时:4051
n=960时:1方法加法耗时:31 n=960时:2方法加法耗时:20
n=960时:1方法乘法耗时:36809 n=960时:2方法乘法耗时:31258
测试结束

可以看到乘法明显比修改之前更快,甚至超过了一维数组形式的乘法。不过如果对一维数组形式的乘法也进行改进,其乘法运算速度还是会超过二维数组形式的。

1 0
原创粉丝点击