lda(linear discriminant analysis)线性判别分析算法代码

来源:互联网 发布:阿里云logo 编辑:程序博客网 时间:2024/06/08 19:24

做文本聚类分析,采用了pca等降维效果都不好,于是决定采用有监督的学习算法lda,网络找代码,找到一个看不懂如何降维,于是自己改写,代码如下:

package lda;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

import Jama.Matrix;

public class LDA
{
    private double[][] groupRataTengah;
    private double[][] kovarianGlobal;
    private double[] probabilitas;
    private ArrayList<Integer> groupList = new ArrayList<Integer>();
    static int hasil;
    static double f1, f2, f3;
    private HashMap _map = new HashMap();
    private RealVector[] _top2vec = new RealVector[2];
    public LDA()
    {

    }
    /**
     *
     * @param d 聚类结果数组
     * @param g 聚类的类别标识,和前面的d关系一致
     * @param p
     */
    public LDA(double[][] d, int[] g, boolean p)
    {
 // memeriksa apakah data dan kelompok array mempunyai ukuran yang sama
 if (d.length != g.length)
     return;

 double[][] data = new double[d.length][d[0].length];// panjang data(i)
           // dan fitur(j)
 for (int i = 0; i < d.length; i++)
 {
     for (int j = 0; j < d[i].length; j++)
     {
  data[i][j] = d[i][j];
     }
 }
 int[] group = new int[g.length];
 for (int j = 0; j < g.length; j++)
 {
     group[j] = g[j];
 }

 double[] rataTengah;
 double[][][] kovarian;

 // memisahkan berdasarkan grup atau kelas
 for (int i = 0; i < group.length; i++)
 {
     if (!groupList.contains(group[i]))
     {
  groupList.add(group[i]);
     }
 }

 // membagi data ke dalam subset
 ArrayList<double[]>[] subset = new ArrayList[groupList.size()];
 for (int i = 0; i < subset.length; i++)
 {
     subset[i] = new ArrayList<double[]>();
     for (int j = 0; j < data.length; j++)
     {
  if (group[j] == groupList.get(i))
  {
      subset[i].add(data[j]);
  }
     }
 }

 // menghitung mean tiap fitur tiap kelas
 groupRataTengah = new double[subset.length][data[0].length];
 for (int i = 0; i < groupRataTengah.length; i++)
 {
     for (int j = 0; j < groupRataTengah[i].length; j++)
     {
  groupRataTengah[i][j] = getGroupMean(j, subset[i]);
     }
 }

 // menghitung global mean atau mean tiap fitur pada semua kelas
 rataTengah = new double[data[0].length];
 for (int i = 0; i < data[0].length; i++)
 {
     rataTengah[i] = getGlobalMean(i, data);
 }

 
 double[][] tempMatrix = new double[subset.length][data[0].length];
 for (int i = 0; i < subset.length; i++)
 {
     for (int j = 0; j < data[0].length; j++)
     {
  tempMatrix[i][j] = groupRataTengah[i][j] - rataTengah[j];
     }
 }

 double[][] SB = new double[data[0].length][data[0].length];
 for (int k = 0; k < subset.length; k++)
 {
     int t = subset[k].size();
     for (int i = 0; i < SB.length; i++)
     {
  for (int j = 0; j < SB[i].length; j++)
  {      
      SB[i][j] = SB[i][j] + (t * (tempMatrix[k][i] * tempMatrix[k][j]))/data.length;
  }
     }
 }
 
 
 
 
 double[][] SW = new double[data[0].length][data[0].length];
 
 for (int k = 0; k < groupList.size(); k++)
 {
     ArrayList _class = subset[k];  
     for (int l = 0; l < _class.size(); l++)
     {
  double [] _el = (double[])_class.get(l);
  for (int i = 0; i < SB.length; i++)
  {
      for (int j = 0; j < SB[i].length; j++)
      {

   SW[i][j] = SW[i][j] + (groupRataTengah[k][i]-_el[i])*(groupRataTengah[k][j]-_el[j]);
    
      }
  }
     }
 }
 RealMatrix rsw = MatrixUtils.createRealMatrix(SW);
 RealMatrix rswInverse = new LUDecomposition(rsw).getSolver().getInverse();
 RealMatrix rsb = MatrixUtils.createRealMatrix(SB);
 RealMatrix r = rswInverse.multiply(rsb);
 
 
 EigenDecomposition  en = new EigenDecomposition(r);
 double [] eg =en.getRealEigenvalues();
 for(int i=0;i<eg.length;i++)
 {
    RealVector rv = en.getEigenvector(i);
    System.out.println(eg[i] + ":");
    System.out.println(rv.toString());
    _map.put(eg[i], rv);
 }
    }
   
    /**
     * 得到前两个特征向量
     * @return
     */
    public RealVector[] getTop2Vector()
    {
 RealVector [] arrVec = new RealVector[2];
 Iterator iter = this._map.entrySet().iterator();
 ArrayList _list = new ArrayList();
 while (iter.hasNext()) {
 Map.Entry entry = (Map.Entry) iter.next();
          Object key = entry.getKey();
   //Object val = entry.getValue();
          _list.add(key);
        }
 Collections.sort(_list);
 int j=0;
 for(int i=_list.size()-1;i>-1;i--)
 {
     System.out.println(_list.get(i));
     arrVec[j] = (RealVector)this._map.get(_list.get(i));
     j++;
     if(j==2)
     {
  break;
     }
 }
 this._top2vec = arrVec;
 
 return arrVec;
    }
   
    /**
     * 输入点向量,得到对应的二维平面坐标
     * @param in
     * @return
     */
    public double[] getxydot(double[]in)
    {
 
 RealVector inv = MatrixUtils.createRealVector(in);
 double[] out = new double[2];
 for(int i=0;i<2;i++)
 {
     out[i] = this._top2vec[i].dotProduct(inv);
 }
 
 return out;
    }
   
   
   
   
    private double getGroupMean(int column, ArrayList<double[]> data)
    {
 double[] d = new double[data.size()];
 for (int i = 0; i < data.size(); i++)
 {
     d[i] = data.get(i)[column];
 }

 return getMean(d);
    }

    private double getGlobalMean(int column, double data[][])
    {
 double[] d = new double[data.length];
 for (int i = 0; i < data.length; i++)
 {
     d[i] = data[i][column];
 }

 return getMean(d);
    }

    // menghitung nilai fungsi discriminant untuk kelas yang berbeda
    public double[] getDiscriminantFunctionValues(double[] values)
    {
 double[] function = new double[groupList.size()];
 for (int i = 0; i < groupList.size(); i++)
 {
     double[] tmp = matrixMultiplication(groupRataTengah[i], kovarianGlobal);
     function[i] = (matrixMultiplication(tmp, values))// fi=miu i*invers
            // kovarian*data
            // testing-1/2 miu
            // i*invers
            // kovarian*miu i
            // trans+ln(pi)
      - (.5d * matrixMultiplication(tmp, groupRataTengah[i]))
      + Math.log(probabilitas[i]);
 }

 return function;
    }

    // memprediksi masuk kelas mana
    public int predict(double[] values)
    {
 int group = -1;
 double max = Double.NEGATIVE_INFINITY;
 double[] discr = this.getDiscriminantFunctionValues(values);
 for (int i = 0; i < discr.length; i++)
 {
     if (discr[i] > max)
     {
  max = discr[i];
  group = groupList.get(i);
     }
 }

 return group;
    }

    // mengalikan dua matriks
    private double[] matrixMultiplication(double[] matrixA, double[][] matrixB)
    {

 double c[] = new double[matrixA.length];
 for (int i = 0; i < matrixA.length; i++)
 {
     c[i] = 0;
     for (int j = 0; j < matrixB[i].length; j++)
     {
  c[i] += matrixA[i] * matrixB[i][j];
     }
 }

 return c;
    }

    private double matrixMultiplication(double[] matrixA, double[] matrixB)
    {

 double c = 0d;
 for (int i = 0; i < matrixA.length; i++)
 {
     c += matrixA[i] * matrixB[i];
 }

 return c;
    }

    public static double getMean(final double[] values)
    {
 if (values == null || values.length == 0)
     return Double.NaN;

 double mean = 0.0d;

 for (int index = 0; index < values.length; index++)
     mean += values[index];

 return mean / (double) values.length;
    }

    public static void test(extraksi_fitur e, double a, double b, double c, double d)
    {
 extraksi_fitur ef = e;
 int[] group = { 1, 1, 1, 1, 2, 2, 2 };// 1=lemon,2=manis,3=nipis
 double[][] data = new double[7][2];// 15=jumlah data,4=fitur R,G,B,D
 int count = 0;

 data[0][0] = 2.95;
 data[0][1] = 6.63;
 data[0][0] = 2.53;
 data[0][1] = 7.79;
 data[0][0] = 3.57;
 data[0][1] = 5.65;
 data[0][0] = 3.16;
 data[0][1] = 5.47;
 data[0][0] = 2.58;
 data[0][1] = 4.46;
 data[0][0] = 2.16;
 data[0][1] = 6.22;
 data[0][0] = 3.27;
 data[0][1] = 3.52;

 LDA test = new LDA(data, group, true);
 double[] testData = { a, b, c, d };

 // test
 double[] values = test.getDiscriminantFunctionValues(testData);
 for (int i = 0; i < values.length; i++)
 {
     System.out.println("Discriminant function " + (i + 1) + ": " + values[i]);
 }

 System.out.println("Predicted group: " + test.predict(testData));
 hasil = test.predict(testData);
 f1 = values[0];
 f2 = values[1];
 f3 = values[2];
    }

    public static void main(String[] args)
    {
 double[][] data = new double[7][3];// 15=jumlah data,4=fitur R,G,B,D
 data[0][0] = 2.95;
 data[0][1] = 6.63;
 data[0][2] = 2.34;
 data[1][0] = 2.53;
 data[1][1] = 7.79;
 data[1][2] = 2.56;
 data[2][0] = 3.57;
 data[2][1] = 5.65;
 data[2][2] = 2.76;
 data[3][0] = 3.16;
 data[3][1] = 5.47;
 data[3][2] = 2.36;
 data[4][0] = 2.58;
 data[4][1] = 4.46;
 data[4][2] = 5.2;
 data[5][0] = 2.16;
 data[5][1] = 6.22;
 data[5][2] = 5.4;
 data[6][0] = 3.27;
 data[6][1] = 3.52;
 data[6][2] = 6;
 int[] group = { 1, 1, 1, 1, 2, 2, 2 };// 1=lemon,2=manis,3=nipis
 LDA lda = new LDA(data, group, true);
 lda.getTop2Vector();
 double[] tt = {123.0,23.0,4};
 double[] out = lda.getxydot(tt);
 for(int i=0;i<out.length;i++)
 {
     System.out.println(out[i]);
 }
    }
}

 

试用效果良好,实验数据为200维矩阵。
0 0
原创粉丝点击