Java实现MultivariateGaussian

来源:互联网 发布:mac文件放在桌面 编辑:程序博客网 时间:2024/06/04 19:18

用java写了MultivariateGaussian 类,计算高斯密度分布。

可以用于高斯混合分布中,


代码如下,

import java.io.Serializable;


import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.DenseMatrix;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.util.MLUtils$;
import org.netlib.util.intW;


import com.github.fommil.netlib.LAPACK;


import scala.Array;
import scala.Tuple2;
import scala.reflect.ClassManifestFactory;


/**
 * 
 */


public class JavaMultivariateGaussian implements Serializable
{
private Vector mu;
private Matrix sigma;
private Matrix rootSigmaInv;
private double u ;
public JavaMultivariateGaussian( Vector mu, Matrix sigma )
{
super( );
this.mu = mu;
this.sigma = sigma;

Tuple2<Matrix, Double> value = calculateCovarianceConstants( );
rootSigmaInv = value._1;
u= value._2;
}

private Tuple2<Matrix, Double> calculateCovarianceConstants()
{
Tuple2<Vector, Matrix> value = eigSym( sigma );
Vector d = value._1;
Matrix u = value._2;

double tol = MLUtils$.MODULE$.EPSILON( )*d.apply( d.argmax( ) ) * d.size( );

double logPseudoDetSigma = caleSum( d, tol );

Matrix pingS = Matrices.diag( calePinvS( d, tol ) );


return new Tuple2<Matrix, Double>(pingS.multiply( (DenseMatrix)u.transpose( ) ), -0.5*(mu.size( ) * Math.log( 2.0*Math.PI ) + logPseudoDetSigma));
}

public double pdf(Vector x)
{
return Math.exp( logpdf( x ) );
}

private double logpdf(Vector x)
{
Vector delta = x.copy( );
BLAS.axpy( -1.0, mu, delta );

Vector v = rootSigmaInv.multiply( delta );

return u + BLAS.dot( v, v )*(-0.5);
}

private static Vector calePinvS(Vector v, double tol)
{
double[] ds = new double[v.size( )];
for (int i=0; i<v.size( ); i++)
{
double value = v.apply( i );
if (value > tol)
{
ds[i] = Math.sqrt( 1.0/value );
}
else
{
ds[i] = 0.0;
}
}

return Vectors.dense( ds );
}
private static double caleSum(Vector v, double tol)
{
double retValue = 0.0;
for (int i=0; i<v.size( ); i++)
{
double d = v.apply( i );
if (d > tol)
{
retValue = retValue + Math.log( d );
}
}

return retValue;
}


private static Tuple2<Vector, Matrix> eigSym(Matrix m)
{
int N = m.numRows( );
int C = m.numCols( );
Matrix A = Matrices.dense( N, C,  lowerTriangular( m.toArray( ), N, C ));
double[] ms = A.toArray( );

Vector evs = Vectors.zeros( N );

int lwork = Math.max( 1, 3*N-1);

double[] work = (double[])Array.ofDim( lwork, ClassManifestFactory.classType( double.class ) );

intW info = new intW(0);
LAPACK lapack = LAPACK.getInstance( );

lapack.dsyev( "V", "L", N, ms, Math.max( 1, N ), evs.toArray( ), work, lwork, info );
A = Matrices.dense( N, C, ms );
return new Tuple2<Vector, Matrix>(evs, A);
}

private static double[] lowerTriangular(double[] martice, int row, int col)
{
int len = martice.length;
double[] retValue = new double[len];
for (int i=0; i<col; i++)
{
for (int j=0; j<row; j++)
{
int pos = i*col + j;
if (j>=i)
{
retValue[pos] = martice[pos];
}
else
{
retValue[pos] = 0.0;
}
}

}

return retValue;
}

}