NMF(非负矩阵分解)的SGD(随机梯度下降)实现

来源:互联网 发布:锤子移动的效果js 编辑:程序博客网 时间:2024/06/05 16:44

NMF把一个矩阵分解为两个矩阵的乘积,可以用来解决很多问题,例如:用户聚类、item聚类、预测(补全)用户对item的评分、个性化推荐等问题。NMF的过程可以转化为最小化损失函数(即误差函数)的过程,其实整个问题也就是一个最优化的问题。详细实现过程如下:(其中,输入矩阵很多时候会比较稀疏,即很多元素都是缺失项,故数据存储采用的是libsvm的格式,这个类在此忽略)


[java] view plaincopyprint?
  1. package NMF_danji; 
  2.  
  3. import java.io.File; 
  4. import java.util.ArrayList; 
  5.  
  6. /**
  7. * @author 玉心sober: http://weibo.com/karensober
  8. * @date 2013-05-19
  9. *
  10. * */ 
  11. public class NMF { 
  12.     private Dataset dataset = null
  13.     private int M = -1;// 行数 
  14.     private int V = -1;// 列数 
  15.     private int K = -1;// 隐含主题数 
  16.     double[][] P; 
  17.     double[][] Q; 
  18.  
  19.     public NMF(String datafileName,int topics) { 
  20.         File datafile = new File(datafileName); 
  21.         if (datafile.exists()) { 
  22.             if ((this.dataset =new Dataset(datafile)) == null) { 
  23.                 System.out.println(datafileName + " is null"); 
  24.             } 
  25.             this.M = this.dataset.size(); 
  26.             this.V = this.dataset.getFeatureNum(); 
  27.             this.K = topics; 
  28.         } else
  29.             System.out.println(datafileName + " doesn't exist"); 
  30.         } 
  31.     } 
  32.  
  33.     public void initPQ() { 
  34.         P = new double[this.M][this.K]; 
  35.         Q = new double[this.K][this.V]; 
  36.  
  37.         for (int k =0; k < K; k++) { 
  38.             for (int i =0; i < M; i++) { 
  39.                 P[i][k] = Math.random(); 
  40.             } 
  41.             for (int j =0; j < V; j++) { 
  42.                 Q[k][j] = Math.random(); 
  43.             } 
  44.         } 
  45.     } 
  46.  
  47.     // 随机梯度下降,更新参数 
  48.     public void updatePQ(double alpha,double beta) { 
  49.         for (int i =0; i < M; i++) { 
  50.             ArrayList<Feature> Ri = this.dataset.getDataAt(i).getAllFeature(); 
  51.             for (Feature Rij : Ri) { 
  52.                 // eij=Rij.weight-PQ for updating P and Q 
  53.                 double PQ =0
  54.                 for (int k =0; k < K; k++) { 
  55.                     PQ += P[i][k] * Q[k][Rij.dim]; 
  56.                 } 
  57.                 double eij = Rij.weight - PQ; 
  58.  
  59.                 // update Pik and Qkj 
  60.                 for (int k =0; k < K; k++) { 
  61.                     double oldPik = P[i][k]; 
  62.                     P[i][k] += alpha 
  63.                             * (2 * eij * Q[k][Rij.dim] - beta * P[i][k]); 
  64.                     Q[k][Rij.dim] += alpha 
  65.                             * (2 * eij * oldPik - beta * Q[k][Rij.dim]); 
  66.                 } 
  67.             } 
  68.         } 
  69.     } 
  70.  
  71.     // 每步迭代后计算SSE 
  72.     public double getSSE(double beta) { 
  73.         double sse = 0
  74.         for (int i =0; i < M; i++) { 
  75.             ArrayList<Feature> Ri = this.dataset.getDataAt(i).getAllFeature(); 
  76.             for (Feature Rij : Ri) { 
  77.                 double PQ =0
  78.                 for (int k =0; k < K; k++) { 
  79.                     PQ += P[i][k] * Q[k][Rij.dim]; 
  80.                 } 
  81.                 sse += Math.pow((Rij.weight - PQ), 2); 
  82.             } 
  83.         } 
  84.  
  85.         for (int i =0; i < M; i++) { 
  86.             for (int k =0; k < K; k++) { 
  87.                 sse += ((beta / 2) * (Math.pow(P[i][k],2))); 
  88.             } 
  89.         } 
  90.  
  91.         for (int i =0; i < V; i++) { 
  92.             for (int k =0; k < K; k++) { 
  93.                 sse += ((beta / 2) * (Math.pow(Q[k][i],2))); 
  94.             } 
  95.         } 
  96.  
  97.         return sse; 
  98.     } 
  99.  
  100.     // 采用随机梯度下降方法迭代求解参数,即求解最终分解后的矩阵 
  101.     public boolean doNMF(int iters,double alpha, double beta) { 
  102.         for (int step =0; step < iters; step++) { 
  103.             updatePQ(alpha, beta); 
  104.             double sse = getSSE(beta); 
  105.             if (step % 100 == 0
  106.                 System.out.println("step " + step +" SSE = " + sse); 
  107.         } 
  108.         return true
  109.     } 
  110.  
  111.     public void printMatrix() { 
  112.         System.out.println("===========原始矩阵=============="); 
  113.         for (int i =0; i < this.dataset.size(); i++) { 
  114.             for (Feature feature : this.dataset.getDataAt(i).getAllFeature()) { 
  115.                 System.out.print(feature.dim + ":" + feature.weight + " "); 
  116.             } 
  117.             System.out.println(); 
  118.         } 
  119.     } 
  120.  
  121.     public void printFacMatrxi() { 
  122.         System.out.println("===========分解矩阵=============="); 
  123.         for (int i =0; i < P.length; i++) { 
  124.             for (int j =0; j < Q[0].length; j++) { 
  125.                 double cell =0
  126.                 for (int k =0; k < K; k++) { 
  127.                     cell += P[i][k] * Q[k][j]; 
  128.                 } 
  129.                 System.out.print(baoliu(cell, 3) + " "); 
  130.             } 
  131.             System.out.println(); 
  132.         } 
  133.     } 
  134.  
  135.     // 为double类型变量保留有效数字 
  136.     public staticdouble baoliu(double d,int n) { 
  137.         double p = Math.pow(10, n); 
  138.         return Math.round(d * p) / p; 
  139.     } 
  140.  
  141.     public staticvoid main(String[] args) { 
  142.         double alpha = 0.002
  143.         double beta = 0.02
  144.  
  145.         NMF nmf = new NMF("D:\\myEclipse\\graphModel\\data\\nmfinput.txt",10); 
  146.         nmf.initPQ(); 
  147.         nmf.doNMF(3000, alpha, beta); 
  148.  
  149.         // 输出原始矩阵 
  150.         nmf.printMatrix(); 
  151.  
  152.         // 输出分解后矩阵 
  153.         nmf.printFacMatrxi(); 
  154.     } 
结果:
...

step 2900 SSE = 0.5878774074369989
===========原始矩阵==============
0:9.0 1:2.0 2:1.0 3:1.0 4:1.0
0:8.0 1:3.0 2:2.0 3:1.0
0:3.0 3:1.0 4:2.0 5:8.0
1:1.0 3:2.0 4:4.0 5:7.0
0:2.0 1:1.0 2:1.0 4:1.0 5:3.0
===========分解矩阵==============
8.959 2.007 1.007 0.996 1.007 6.293
7.981 2.972 1.989 1.005 2.046 7.076
3.01 1.601 1.773 1.003 2.005 7.968
4.821 1.009 2.209 1.984 3.968 6.988
2.0 0.991 0.984 0.51 1.0 2.994

0 0
原创粉丝点击