稀疏矩阵的存储以及转置、加法、乘法操作实现

来源:互联网 发布:ubuntu vsftpd安装 编辑:程序博客网 时间:2024/06/01 09:20

1、稀疏矩阵的存储与表示

只存储稀疏矩阵中极少数的非零元素,采用一个三元组<row,column,value>来唯一确定一个矩阵元素;因此,稀疏矩阵可用一个三元组数组来表示。另外,还需要存储原矩阵的行数、列数和非零元素个数。

2、代码实现

快速转置和矩阵乘法函数中都用到了两个辅助数组

+ rowSize[] 存储矩阵每一行非零元素的个数,下标表示行号+ rowStart[] 存储矩阵每一行非零元素在三元组数组存储的起始位置,下标表示行号

因为矩阵的存储是按照从左到右从上到下元素的顺序存储的,利用以上两个辅助数组可以避免多次重复访问三元组数组。
+ SparseMatrix.h

/* * @author:Curya * @time:2017-09-08 * @theme:稀疏矩阵————包括矩阵转置、矩阵乘法、矩阵加法 * */#ifndef SPARSEMATRIX_H#define SPARSEMATRIX_H#include <iostream>#include <cstdlib>using namespace std;template<class T>struct Trituple {    int row;    //行号    int col;    //列号    T value;    //非零元素的值    Trituple<T>& operator = (const Trituple<T>& x)    {        row = x.row;        col = x.col;        value = x.value;        return *this;    }};template<class T>class SparseMatrix{    //输入输出运算符重载    template<class U>    friend ostream& operator << (ostream& out, const SparseMatrix<U>& myMatrix);    template<class U>    friend istream& operator >> (istream& in, SparseMatrix<U>& myMatrix);private:    int rowCnts;    int colCnts;    int valueCnts;    Trituple<T> *sparseMatrix;    int maxCnts;public:    SparseMatrix(int sz = 100);                         //构造函数    SparseMatrix(SparseMatrix<T>& x);                   //复制构造函数    ~SparseMatrix();                                   //析构函数    SparseMatrix<T>& operator = (const SparseMatrix<T>& x); //赋值运算符重载(不加const报错)    SparseMatrix<T> transpose();                        //矩阵转置    SparseMatrix<T> fastTranspose();                    //快速矩阵转置    SparseMatrix<T> addMatrix(SparseMatrix<T>& b);      //矩阵加法    SparseMatrix<T> mulMatrix(SparseMatrix<T>& b);      //矩阵乘法};template<class T>SparseMatrix<T>::SparseMatrix(int sz){    maxCnts = sz;    if(maxCnts < 1) {        cerr << "matrix init error!" << endl;        exit(1);    }    sparseMatrix = new Trituple<T>[maxCnts];    if(sparseMatrix == NULL) {        cerr << "memory alloc error!" << endl;        exit(1);    }    rowCnts = 0;    colCnts = 0;    valueCnts = 0;}template<class T>SparseMatrix<T>::SparseMatrix(SparseMatrix<T>& x){    rowCnts = x.rowCnts;    colCnts = x.colCnts;    valueCnts = x.valueCnts;    maxCnts = x.maxCnts;    sparseMatrix = new Trituple<T>[maxCnts];    if(sparseMatrix == NULL) {        cerr << "memory alloc error!" << endl;        exit(1);    }    for(int i = 0; i < valueCnts; ++i) {        sparseMatrix[i] = x.sparseMatrix[i];    }}template<class T>SparseMatrix<T>::~SparseMatrix(){    delete []sparseMatrix;}template<class T>SparseMatrix<T>& SparseMatrix<T>::operator=(const SparseMatrix<T>& x){    if(this == &x)        return *this;    rowCnts = x.rowCnts;    colCnts = x.colCnts;    valueCnts = x.valueCnts;    maxCnts = x.maxCnts;    if(sparseMatrix == NULL) {        cerr << "memory alloc error!" << endl;        exit(1);    }    for(int i = 0; i < valueCnts; ++i) {        sparseMatrix[i] = x.sparseMatrix[i];    }    return *this;}//效率低template<class T>SparseMatrix<T> SparseMatrix<T>::transpose(){    SparseMatrix<T> trans(100);    trans.rowCnts = this->colCnts;    trans.colCnts = this->rowCnts;    trans.valueCnts = this->valueCnts;    if(this->valueCnts > 0) {        int k, i, currentPtr = 0;        for(k = 0; k < this->colCnts; ++k) {            for(i = 0; i < this->valueCnts; ++i) {                if(this->sparseMatrix[i].col == k) {                    trans.sparseMatrix[currentPtr].row = k;                    trans.sparseMatrix[currentPtr].col = this->sparseMatrix[i].row;                    trans.sparseMatrix[currentPtr].value = this->sparseMatrix[i].value;                    currentPtr++;                }            }        }    }    return trans;}template<class T>SparseMatrix<T> SparseMatrix<T>::fastTranspose(){    //两个辅助数组    int *rowSize = new int[colCnts];        //统计转置后各行元素的个数(转置前各列元素个数)    int *rowStart = new int[colCnts];       //计算转置后,每行第一个非零元素存放的位置    //Tips:    //(1)总的元素个数已知    //(2)转置之后各行元素已知    //==>可以计算转之后每行的第一个非零元素在总的矩阵(一维数组)存放位置    //eg.假设一共有6个元素,    //转置之前存储顺序为(1,2)、(2,1)、(3,0)、(3,2)、(4,2)、(4,3)    //转置之后的存储顺序(0,3)、(1,2)、(1,4)、(2,1)、(2,3)、(3,4)    //可以计算,第一行非零元素个数c0=1;第二行c1=2;第三行c2=2;第四行c3=1    //第一行第1个非零元素存放位置为0;第二行(0+c0)=1;第三行(1+c1)=3;第四行(3+c2)=5    SparseMatrix<T> trans(100);             //存放矩阵转置后的结果    trans.colCnts = rowCnts;    trans.rowCnts = colCnts;    trans.valueCnts = valueCnts;    if(valueCnts > 0) {        int i, j;        for(i = 0; i < colCnts; ++i)    rowSize[i] = 0;        for(i = 0; i < valueCnts; ++i)  rowSize[sparseMatrix[i].col]++;        rowStart[0] = 0;        for(i = 1; i < colCnts; ++i)            rowStart[i] = rowStart[i - 1] + rowSize[i - 1];        for(i = 0; i < valueCnts; ++i) {            //查询该元素(转置后行号:sparseMatrix[i].col)应该存放的位置j            j = rowStart[sparseMatrix[i].col];            trans.sparseMatrix[j].row = sparseMatrix[i].col;            trans.sparseMatrix[j].col = sparseMatrix[i].row;            trans.sparseMatrix[j].value = sparseMatrix[i].value;            //原来应该存放的位置已占用,将其+1            rowStart[sparseMatrix[i].col]++;        }    }    delete[] rowSize;    delete[] rowStart;    return trans;}template<class T>SparseMatrix<T> SparseMatrix<T>::addMatrix(SparseMatrix<T>& b){    SparseMatrix<T> addResult;    if(this->rowCnts != b.rowCnts || this->colCnts != b.colCnts)        return addResult;    addResult.colCnts = colCnts;    addResult.rowCnts = rowCnts;    addResult.valueCnts = 0;    int i = 0, j = 0;    //思路与一元多项式计算一致(数据有序存储)    int index_a, index_b;           //根据三元组元素计算的该元素在数组中的索引位置    while(i < this->valueCnts && j < this->valueCnts) {        index_a = this->sparseMatrix[i].row * colCnts + this->sparseMatrix[i].col;        index_b = b.sparseMatrix[j].row * colCnts + b.sparseMatrix[j].col;        if(index_a < index_b) {            addResult.sparseMatrix[addResult.valueCnts] = this->sparseMatrix[i];            i++;        } else if(index_a > index_b) {            addResult.sparseMatrix[addResult.valueCnts] = b.sparseMatrix[j];            j++;        } else {            addResult.sparseMatrix[addResult.valueCnts] = this->sparseMatrix[i];            addResult.sparseMatrix[addResult.valueCnts].value = this->sparseMatrix[i].value + b.sparseMatrix[j].value;            i++;            j++;        }        addResult.valueCnts++;    }    //复制剩下的元素    for(; i < valueCnts; ++i) {        addResult.sparseMatrix[addResult.valueCnts] = this->sparseMatrix[i];        addResult.valueCnts++;    }    for(; j < valueCnts; ++j) {        addResult.sparseMatrix[addResult.valueCnts] = b.sparseMatrix[j];        addResult.valueCnts++;    }    return addResult;}//Tips:A[i][k]*B[k][j]//对A矩阵进行遍历,取得一个A[i][k],根据列号k,到矩阵B抽取所有行号为k的元素//在B矩阵的存储中,同行号的相邻存储,因此只需要知道该行元素存储的起始位置以及元素的个数//建立两个辅助矩阵://rowSize[]===>存储每一行的元素个数//rowStart[]===>存储每一行元素存储的起始位置template<class T>SparseMatrix<T> SparseMatrix<T>::mulMatrix(SparseMatrix<T>& b){    SparseMatrix<T> mulResult;    if(this->colCnts != b.rowCnts) {        cerr << "cannot multiply!" << endl;        return mulResult;    }    if(this->valueCnts == 100 || b.valueCnts == 100) {        cerr << "space needed in a or b" << endl;        return mulResult;    }    mulResult.rowCnts = this->rowCnts;    mulResult.colCnts = b.colCnts;    //辅助矩阵    int *rowSize = new int[b.rowCnts];          //B矩阵每一行所含元素个数    int *rowStart = new int[b.rowCnts + 1];     //B矩阵每一行第一个元素开始位置    T *tmp = new T[b.colCnts];    //辅助矩阵数据初始化    for(int i = 0; i < b.rowCnts; ++i)        rowSize[i] = 0;    for(int i = 0; i < b.valueCnts; ++i)        rowSize[b.sparseMatrix[i].row]++;    rowStart[0] = 0;    for(int i = 1; i <= b.rowCnts; ++i)        rowStart[i] = rowStart[i - 1] + rowSize[i - 1];    //对矩阵A进行遍历    int current = 0;                        //遍历指针    int lastInResult = -1;    int rowA, colA, colB;                   //当前处理元素的行号、列号信息    while(current < this->valueCnts) {        rowA = this->sparseMatrix[current].row;        //tmp用来暂存当前处理行的结果,初始化为0        for(int i = 0; i < b.colCnts; ++i)            tmp[i] = 0;        //对A矩阵第rowA行进行处理,该循环每次处理A矩阵一行数据与B矩阵一列数据对应的相乘        while(current < this->valueCnts && this->sparseMatrix[current].row == rowA) {            //获取当前A矩阵元素的列号,用于获取对应操作的B矩阵元素            colA = this->sparseMatrix[current].col;            for(int i =  rowStart[colA]; i < rowStart[colA + 1]; ++i) {                //获取当前B矩阵元素列号,将结果存储到以当前B矩阵元素列号为下标的tmp中                colB = b.sparseMatrix[i].col;                tmp[colB] += this->sparseMatrix[current].value * b.sparseMatrix[i].value;            }            current++;        }        //上一个while循环结束==>一行处理结束,将临时存储数据保存到结果中        //将临时存储的tmp转储到矩阵相乘结果矩阵(mulResult)中        for(int i = 0; i < b.colCnts; ++i) {            if(tmp[i] != 0) {                lastInResult++;                mulResult.sparseMatrix[lastInResult].row = rowA;                mulResult.sparseMatrix[lastInResult].col = i;                mulResult.sparseMatrix[lastInResult].value = tmp[i];            }        }    }    mulResult.valueCnts = lastInResult + 1;    delete []rowSize;    delete []rowStart;    delete []tmp;    return mulResult;}template<class T>ostream& operator << (ostream & out, const SparseMatrix<T>& myMatrix){    out << "rowCnts:" << myMatrix.rowCnts << " "        << "colCnts:" << myMatrix.colCnts << " "        << "nonZero valueCnts:" << myMatrix.valueCnts << endl;    for(int i = 0; i < myMatrix.valueCnts; ++i) {        out << "M[" << myMatrix.sparseMatrix[i].row            << "][" << myMatrix.sparseMatrix[i].col            << "] = " << myMatrix.sparseMatrix[i].value << endl;    }    return out;}template<class T>istream& operator >> (istream & in, SparseMatrix<T>& myMatrix){    cout << "input numbers of rowCnts, columnCnts and valueCnts" << endl;    in >> myMatrix.rowCnts >> myMatrix.colCnts >> myMatrix.valueCnts;    if(myMatrix.valueCnts > 100) {        cerr << "number of terms overflow!" << endl;        exit(1);    }    for(int i = 0; i < myMatrix.valueCnts; ++i) {        cout << "input row, column and value of term " << i << ": ";        in >> myMatrix.sparseMatrix[i].row           >> myMatrix.sparseMatrix[i].col           >> myMatrix.sparseMatrix[i].value;    }    return in;}#endif // SPARSEMATRIX_H
  • test.cpp
#include <iostream>#include "SparseMatrix.h"int main(int argc, char **argv){//  SparseMatrix<int> myMatrix;//  SparseMatrix<int> trans;//  cin >> myMatrix;//  cout << myMatrix;//  trans = myMatrix.transpose();       //SparseMatrix类的赋值运算符重载函数不添加const,该语句报错//  cout << trans;//  trans = trans.fastTranspose();//  cout << trans;//  trans = trans.addMatrix(trans);//  cout << trans;    SparseMatrix<int> a, b;    cin >> a;    cin >> b;    cout << a;    cout << b;    cout << a.mulMatrix(b);    return 0;}//tips:问题解释参见下链接//应该是因为,transpose函数返回值是一个临时对象,//赋值运算符重载函数形参是引用类型(引用传递),即引用形参是它对应实参的别名//https://stackoverflow.com/questions/20247525/about-c-conversion-no-known-conversion-for-argument-1-from-some-class-to/*Add function10 10 90 2 21 0 31 3 -112 3 -63 5 -174 1 94 4 195 3 -85 6 -52Multiply function3 4 70 0 100 2 50 3 71 0 21 1 12 0 32 2 44 2 60 0 21 0 41 1 82 1 143 0 33 1 5*/

测试结果1

测试结果2

原创粉丝点击