矩阵类

来源:互联网 发布:醍醐灌顶的一句话 知乎 编辑:程序博客网 时间:2024/06/14 21:41

此次的矩阵类可以使用双下标,并且带有越界检查能力

用例:

 jks::CMatrix<int> m(3,4);
 int i,j;
 for (i=0;i<m.getHeight();i++)
 {
  for (j=0;j<m.getWidth();j++)
  {
   m[i][j] = i*10+j;
  }
 }

 cerr<<m;

=========================================================================

头文件


#if !defined(__JKS_MATRIX_HPP_)
#define __JKS_MATRIX_HPP_

#if _MSC_VER > 1000
#pragma once
#endif // _MSC_VER > 1000

#include <cassert>
#include <iostream>
namespace jks
{
//////////////////////////////////////////////////////////////////////////
 
template<typename T>
class CMatrix 
{
class CLine
{
 T* _pLine;
 CMatrix* _pMat;

public:
 CLine(T* pLine,CMatrix* pMat):_pLine(pLine),_pMat(pMat) {};
 T& operator[] (long nCol)
 {assert(nCol>=0 && nCol<_pMat->getWidth());return _pLine[nCol];}

};

 long _sLn,_sCol;
 T* _pData;
public:
 CMatrix();
 CMatrix(T * arrAddress,long arrWidth);//构造一维矩阵(一行,arrWidth列)
 CMatrix(T * arrAddress,long arrHeight,long arrWidth);//构造二维矩阵
 CMatrix(long Height,long Width);//构造空矩阵
 CMatrix(const CMatrix<T> &);//复制构造函数
 virtual ~CMatrix(void);//默认析构函数

 //属性
 long getHeight() const {return _sLn; }
 long getWidth() const {return _sCol;}
 T* getData(long& arrHeight,long& arrWidth,T* p=NULL) const;
 T* getData(int& arrHeight,int& arrWidth,T* p=NULL) const
 {long t1 = arrHeight;long t2 = arrWidth;
 arrHeight = _sLn;arrWidth = _sCol;
  return getData(t1,t2,p);}

 operator void* () {return isValid()?this:0;}
 bool isValid() const
 {
  if(this!=NULL && _pData!=NULL) return true;
  else return false;
 }
 bool isInDomain(long nLn,long nCol)
 {
  if (nLn<_sLn && nLn>=0 && nCol<_sCol && nCol>=0)
   return true;
  else
   return false;
 }
 int isVector();//如果是0,那就是一个数;1,列向量;-1行向量

 //运算
 CMatrix operator+(CMatrix<T> &);
 CMatrix operator-(CMatrix<T> &);
 CMatrix operator*(CMatrix<T> &);
 friend CMatrix operator*(double alpha,CMatrix<T> &);//实数与矩阵相乘
 CMatrix operator*(double alpha);//矩阵与实数相乘
 CMatrix operator/(CMatrix<T> &);//实际是实数相除或矩阵和实数相除
 CMatrix operator/(double sub);
 CMatrix operator+=(CMatrix<T> &);
 CMatrix operator-=(CMatrix<T> &);
 CMatrix operator*=(CMatrix<T> &);//矩阵与实数相乘
 CMatrix operator*=(double alpha);//矩阵与实数相乘
 CMatrix & operator = (CMatrix<T> &);//赋值
 CLine operator[](long heightPos);//用于实现用[][]操作矩阵元素
 friend CMatrix  sqrt(CMatrix<T> m);//开方
 friend double abs(CMatrix<T> &);//取绝对值(泛数)
 friend double sum(CMatrix<T> &);//求和
 friend CMatrix multiply(CMatrix<T> &m1,CMatrix<T> & m2);//按元素相乘
 friend T operator+(double dbl,CMatrix<T> &);
 friend T operator+(CMatrix<T> &,double dbl);
 friend T operator-(double dbl,CMatrix<T> &);
 friend T operator-(CMatrix<T> &,double dbl);
 const T* c_ptr() const;
 friend bool operator == (CMatrix<T> &,CMatrix<T> &);

 //输出
 friend std::ostream& operator<<(std::ostream &,jks::CMatrix<T> &);

public:
 //公有属性
 static float m_fPrecision;//控制==的精度
 static double m_dPrecision;

};

//////////////////////////////////////////////////////////////////////////
//函数实现
template<typename T>
float CMatrix<T>::m_fPrecision = 0.0000001;
template<typename T>
double CMatrix<T>::m_dPrecision = 1e-20;

//////////////////////////////////////////////////////////////////////////
//构造与析构
template<typename T>
CMatrix<T>::CMatrix(void)//:_sCol(1),_sLn(1)

 _sCol = 0;
 _sLn = 0;
 _pData=NULL;
}

template<typename T>
CMatrix<T>::CMatrix(T * arrAddress,long arrWidth)
{
 long arrHeight=1;
 _pData=new T[arrWidth*arrHeight];
 memcpy(_pData,arrAddress,arrWidth*arrHeight*sizeof(T));

 _sCol=arrWidth;
 _sLn=arrHeight;
}

template<typename T>
CMatrix<T>::CMatrix(T * arrAddress,long arrHeight,long arrWidth)
{
 _pData=new T[arrWidth*arrHeight];
 memcpy(_pData,arrAddress,arrWidth*arrHeight*sizeof(T));

 _sCol=arrWidth;
 _sLn=arrHeight;
}

template<typename T>
CMatrix<T>::CMatrix(long height,long width)
{
 _sCol=width;
 _sLn=height; 
 _pData=new T[height*width];
}

template<typename T>
CMatrix<T>::CMatrix(const CMatrix<T> & m)//copy constructor
{
 _sLn=m._sLn;
 _sCol=m._sCol;

 _pData=new T[_sLn*_sCol];
 memcpy(_pData,m._pData,_sLn*_sCol*sizeof(T));
}

template<typename T>
CMatrix<T>::~CMatrix()
{
 if (_pData)
 {
  delete []_pData;
 }
 _pData = NULL;

 _sLn = 0;
 _sCol = 0;

}

//////////////////////////////////////////////////////////////////////////
//运算

template<typename T>
CMatrix<T> CMatrix<T>::operator +(CMatrix &m1)
{
 assert(m1._sLn==_sLn && m1._sCol==_sCol);
 long tmpHeight=m1._sLn;
 long tmpWidth=m1._sCol;
 T * t=new T[tmpWidth*tmpHeight];
 for(long i=0;i<tmpHeight;i++){
  for(long j=0;j<tmpWidth;j++){
   *(t+tmpWidth*i+j)=*((T*)m1._pData+tmpWidth*i+j)+*((T*)_pData+tmpWidth*i+j);
  }
 }
 CMatrix<T> m(t,tmpHeight,tmpWidth);
 delete [] t;
 return m;
}

template<typename T>
CMatrix<T> CMatrix<T>::operator -(CMatrix &m1)
{
 assert(m1._sLn==_sLn && m1._sCol==_sCol);
 long tmpHeight=m1._sLn;
 long tmpWidth=m1._sCol;
 T * t=new T[tmpWidth*tmpHeight];
 for(long i=0;i<tmpHeight;i++){
  for(long j=0;j<tmpWidth;j++){
   *(t+tmpWidth*i+j)=*((T*)_pData+tmpWidth*i+j)-*((T*)m1._pData+tmpWidth*i+j);
  }
 }
 CMatrix<T> m(t,tmpHeight,tmpWidth);
 delete [] t;
 return m;
}

template<typename T>
CMatrix<T> CMatrix<T>::operator *(CMatrix &m1)
{
 if(!this->isVector() && m1.isVector()){//左为数,右为矩阵
  CMatrix<T> m;
   m=((T*)_pData)[0]*m1;
  return m;
 }else if(this->isVector() && !m1.isVector()){//左为矩阵,右为数
  CMatrix m;
   m=*this*m1[0][0];
  return m;
 }else if(!this->isVector() && m1.isVector()){//左右都为数
  T * t=new T[1];
  t[0]=((T*)_pData)[0]*m1[0][0];
  CMatrix<T> m(t,1,1);
  delete [] t;
  return m;
 }else if(this->isVector() && m1.isVector() && _sCol==m1._sLn){//左为矩阵,右为矩阵
  double sum;
  T * t=new T[_sLn*m1._sCol];
  for(long i=0;i<_sLn;i++){
   for(long j=0;j<m1._sCol;j++){
    sum=0;
    for(long k=0;k<_sCol;k++){
     sum+=(*((T*)_pData+_sCol*i+k))*(m1[k][j]);
    }
    *(t+m1._sCol*i+j)=sum;
   }
  }
  CMatrix<T> m(t,_sLn,m1._sCol);
  delete [] t;
  return m;
 }else{
  assert(0);//未知运算
  return *this;
 }
}

template<typename T>
CMatrix<T> operator*(double alpha,CMatrix<T> & m1)
{
 CMatrix<T> m=m1;
 for(long i=0;i<m._sLn;i++){
  for(long j=0;j<m._sCol;j++){
   m[i][j]=alpha*m1[i][j];
  }
 }
 return m;
}

template<typename T>
CMatrix<T> CMatrix<T>::operator*(double alpha)
{
 return alpha*(*this);
}

template<typename T>
CMatrix<T> CMatrix<T>::operator+=(CMatrix<T> & m)
{
 return *this+m;
}


template<typename T>
CMatrix<T> CMatrix<T>::operator-=(CMatrix<T> & m)
{
 return *this-m;
}


template<typename T>
CMatrix<T> CMatrix<T>::operator *=(double alpha)
{
 return *this*alpha;
}

template<typename T>
CMatrix<T> CMatrix<T>::operator *=(CMatrix<T> & m1)
{
 return *this*m1;
}


template<typename T>
const T* CMatrix<T>::c_ptr () const
{
 return _pData;
}


template<typename T>
CMatrix<T> CMatrix<T>::operator /(CMatrix<T> &m1)
{
 assert(m1._sCol==1 && m1._sLn==1);
 assert(m1[0][0]!=0);
 return *this/m1[0][0];
}

template<typename T>
CMatrix<T> CMatrix<T>::operator /(double sub)
{
 assert(sub!=0);
 CMatrix<T> m=*this;
 for(long i=0;i<_sLn;i++){
  for(long j=0;j<_sCol;j++){
   m[i][j]=*((T*)_pData+_sCol*i+j)/sub;
  }
 }
 return m;
}

template<typename T>
CMatrix<T> & CMatrix<T>::operator =(CMatrix<T> & m)
{
 if(&m==this) return *this;

 _sLn=m._sLn;
 _sCol=m._sCol;
 if(_pData)
 {
  delete [] _pData;
  _pData = NULL;
 }
 _pData=new T[_sLn*_sCol];
 memcpy(_pData,m._pData,_sLn*_sCol*sizeof(T));

 return *this;
}

template<typename T>
bool operator == (CMatrix<T>& m1,CMatrix<T>& m2)
{
 if (&m1 == &m2)
 {
  return true;
 }

 if (m1.getWidth()!=m2.getWidth() || m1.getHeight()!=m2.getHeight())
 {
  return false;
 }

 T *p1,*p2;
 p1 = m1[0];
 p2 = m2[0];
 long sum = m1.getHeight()*m1.getWidth();
 long i;
 if(typeid(T) == typeid(float))
  for (i=0;i<sum;i++)
  {
   if (fabs(*p1++ - *p2++)<CMatrix<T>::m_fPrecision)
   {
    return false;
   }
  }
 if(typeid(T) == typeid(double))
  for (i=0;i<sum;i++)
  {
   if (fabs(*p1++ - *p2++)<CMatrix<T>::m_dPrecision)
   {
    return false;
   }
  }
 else
  for (i=0;i<sum;i++)
  {
   if (*p1++ != *p2++)
   {
    return false;
   }
  }

 return true;
}

template<typename T>
T operator+(double dbl,CMatrix<T> & m)
{
 assert(m.getHeight()==1 && m.getWidth()==1);
 return dbl+m[0][0];
}

template<typename T>
T operator+(CMatrix<T> & m,double dbl)
{
 return dbl+m;
}

template<typename T>
T operator-(double dbl,CMatrix<T> & m)
{
 assert(m.getHeight()==1 && m.getWidth()==1);
 return dbl-m[0][0];
}

template<typename T>
T operator-(CMatrix<T> & m,double dbl)
{
 return -(dbl-m);
}

template<typename T>
CMatrix<T>::CLine CMatrix<T>::operator [](long heightPos)
{
 assert(isValid());
 assert(heightPos>=0 && heightPos<=_sLn);//报错

 CLine rLine(_pData+heightPos*_sCol,this);
 return rLine;//取回的是行头指针
}

//////////////////////////////////////////////////////////////////////////
//输出

template<typename T>
std::ostream & operator<<(std::ostream & os,jks::CMatrix<T> & m)
{
 os<<"Sum Ln:"<<m._sLn<<" "<<"Sum Col:"<<m._sCol<<std::endl;
 long i,j;

 if(typeid(T)==typeid(unsigned char))
  for (i=0;i<m._sLn;i++)
  {
   for (j=0;j<m._sCol;j++)
   {
    os<<(int)m._pData[i*m._sCol+j]<<"/t";
   }
   os<<std::endl;
  }
 else
  for (i=0;i<m._sLn;i++)
  {
   for (j=0;j<m._sCol;j++)
   {
    os<<m._pData[i*m._sCol+j]<<"/t";
   }

   os<<std::endl;
  }

 return os;
}

//////////////////////////////////////////////////////////////////////////
template<typename T>
int CMatrix<T>::isVector()
{
 //return !(nWidth==1 && nHeight==1);
 if (_sCol==1)
  if (_sLn==1)
   return 0;
  else
   return 1;
 else
  return -1;
}

//////////////////////////////////////////////////////////////////////////
} //
#endif // !defined(__JKS_MATRIX_HPP_)
 

原创粉丝点击