RBM 推荐系统 Java代码(质量堪忧,仅供参考,欢迎讨论)
来源:互联网 发布:json格式怎么打开 mock 编辑:程序博客网 时间:2024/06/04 19:27
这是一段代码段,用Java写的,关于RBM ,也关于推荐系统。效率比较低,代码也有点问题。贴出来,仅仅是为了给大家提供一点思路,也是希望大家多多指教。仅仅贴出最重要的代码块,完整的代码,大家可以给我留言,我抽出时间来再来给大家发一份。抛砖引玉,希望能有同学或工程师来一起交流,大家共同进步,共同提高。
package rbm_1th;import java.awt.List;import java.util.ArrayList;import java.util.HashMap;import java.util.Map;import java.util.Random;import static rbm_1th.utils.*;import static rbm_1th.Get_data.*;public class RBM{ public int N; public int n_visible; public int n_hidden; public double[][] W; public double[] hbias; public double[] vbias; public Random rng; //RBM的构造函数 public RBM(int N,int n_visiable,int n_hidden,double[][] W,double[] hbias,double[] vbias,Random rng) { this.N = N; this.n_visible = n_visiable; this.n_hidden = n_hidden; if (rng == null) { this.rng = new Random(1234); } else{ this.rng = rng; } if (W == null){ this.W =new double[this.n_hidden][this.n_visible]; double a = 1.0/this.n_visible; for (int i = 0; i < this.n_hidden; i++) { for (int j = 0; j< this.n_visible ; j++ ) { this.W[i][j] = uniform(-a,a,rng); } } }else { this.W = W; } if(hbias == null){ this.hbias = new double[this.n_hidden]; for (int i = 0;i<this.n_hidden;i++ ) { this.hbias[i] = 0; } }else{ this.hbias = hbias; } if (vbias == null) { this.vbias = new double[this.n_visible]; for (int j = 0; j < this.n_visible ; j++ ) { this.vbias[j] = 0; } }else{ this.vbias = vbias; } } //CD-k算法 public void contrastive_divergence(Map<Integer,Integer> input,double lr,int k) { double[] ph_mean = new double[n_hidden]; int[] ph_sample = new int[n_hidden]; double[] nv_means = new double[n_visible]; Map<Integer,Integer> nv_samples = new HashMap<>(); double[] nh_means = new double[n_hidden]; int[] nh_samples = new int[n_hidden]; sample_h_given_v(input,ph_mean,ph_sample); for (int step = 0; step < k;step++) { if (step == 0) { gibbs_hvh(ph_sample,nv_means,nv_samples,nh_means,nh_samples); }else { gibbs_hvh(nh_samples,nv_means,nv_samples,nh_means,nh_samples); } } /**这块儿代码可能有点问题,即在推荐系统里面是应该对所有的权重都改变还是只改变有过用户行为的显层和隐层的连接权重,这会儿脑子太乱,不想想了,欢迎大家拍砖指正**/ for (int i=0; i<n_hidden; i++) { for (int j: input.keySet())//此处可能有问题,各位看看 { W[i][j] += lr *(ph_mean[i]*input.get(j) - nh_means[i]*nv_samples.get(j));//此处可能有问题,各位看看,欢迎各位指正 } hbias[i] += lr * (ph_sample[i] - nh_means[i]); } for (int j: input.keySet()) { vbias[j] += lr *(input.get(j) - nv_samples.get(j)); } } //可见层到隐藏层 public double propup(Map<Integer,Integer> v,double[] w,double b) { double pre_sigmoid_activation = 0.0; for (int j :v.keySet() ) { pre_sigmoid_activation += w[j] * v.get(j); } pre_sigmoid_activation += b; return sigmoid(pre_sigmoid_activation); } //隐藏层到可见层 public double propdown(int[] h,int j,double b) { double pre_sigmoid_activation = 0.0; for (int i =0; i < h.length ; i++ ) { pre_sigmoid_activation += W[i][j] * h[i]; } pre_sigmoid_activation += b; return sigmoid(pre_sigmoid_activation); } //吉布斯 public void gibbs_hvh(int[] h0_sample,double[] nv_means,Map<Integer,Integer> nv_samples,double[] nh_means,int[] nh_samples) { sample_v_given_h(h0_sample,nv_means,nv_samples); sample_h_given_v(nv_samples,nh_means,nh_samples); } //sample given hidden get visible public void sample_v_given_h(int[] h0_sample,double[] mean,Map<Integer,Integer> sample) { for (int j:sample.keySet() ) { mean[j] = propdown(h0_sample, j,vbias[j]); int oz = binomial(1,mean[j],rng); sample.put(j, oz); } } //given visible get hidden public void sample_h_given_v(Map<Integer,Integer> v0_sample,double[] mean , int[] sample) { for (int i = 0;i<n_hidden;i++) { mean[i] = propup(v0_sample,W[i],hbias[i]); sample[i] = binomial(1,mean[i],rng); } } //函数重构 public void reconstruct(Map<Integer,Integer> v,double[] reconstructed_v) { double[] h = new double[n_hidden]; double pre_sigmoid_activation; for (int i = 0;i<n_hidden ;i++ ) { h[i] = propup(v,W[i],hbias[i]); } for (int j :v.keySet() ) { pre_sigmoid_activation = 0.0; for (int i =0; i<n_hidden ;i++ ) { pre_sigmoid_activation += W[i][j] * h[i]; } pre_sigmoid_activation += vbias[j]; reconstructed_v[j] = sigmoid(pre_sigmoid_activation); } } //训练模型 public static double[][] train_rbm() { Random rng = new Random(123); double learning_rate = 0.1; int training_epochs = 1; int k = 1; //获得数据 String inpath = "D:/ml-1m/ratings.dat"; Map<String,Map<String,Double>> u_p_s = loadData(inpath); Map<String,Integer> pid2num = get_pid2num(u_p_s); Map<String,Double> u_means = getU_mean(u_p_s); int train_N = u_p_s.size(); int n_visible = pid2num.size(); int n_hidden = 2; RBM rbm = new RBM(train_N,n_visible,n_hidden,null,null,null,rng); ArrayList<String> uid_list = new ArrayList<String>(u_p_s.keySet()); for(int epoch = 0;epoch < training_epochs;epoch++) { for (int i =0 ; i <uid_list.size() ;i++) { String uid = uid_list.get(i); Map<String,Double> p_s = u_p_s.get(uid); double u_m = u_means.get(uid); Map<Integer,Integer> pid_oz = new HashMap<Integer,Integer>(); for (String pid:p_s.keySet()) { Double score = u_p_s.get(uid).get(pid); int x = pid2num.get(pid); if (score > u_m) { pid_oz.put(x,1); } else { pid_oz.put(x,0); } } rbm.contrastive_divergence(pid_oz,learning_rate,k);//注意此处 } } return rbm.W; } public static void main(String[] args) { double[][] w =train_rbm(); } }
0 0
- RBM 推荐系统 Java代码(质量堪忧,仅供参考,欢迎讨论)
- RBM应用于推荐系统
- RBM算法模型应用在推荐系统 Python代码实现
- 【推荐】IT职业规划堪忧
- 推荐一下日常效率工具,欢迎讨论
- RBM(受限玻尔兹曼机)原理及代码
- Java 代码质量专题
- jAVA 代码质量
- 提高 Java 代码质量
- Excel数据提取C++代码(仅供参考)
- Hinton关于RBM的代码注解之(一)rbm.m
- Hinton关于RBM的代码注解之(一)rbm.m
- Hinton关于RBM的代码注解之(一)rbm.m
- Hinton关于RBM的代码注解之(一)rbm.m
- rbm C++代码理解
- 【RBM】代码学习--DeepLearningToolBox
- 提高你的Java代码质量吧:推荐在复杂字符串操作中使用正则表达式
- 提高你的Java代码质量吧:推荐在复杂字符串操作中使用正则表达式 .
- 外部中断
- 历代LINUX的版本区分 以及计算机的组成及其功能
- Android中设置分割线
- Java NIO使用及原理分析 (四)
- 信息系统项目管理简谈
- RBM 推荐系统 Java代码(质量堪忧,仅供参考,欢迎讨论)
- 第三周项目1--顺序表的基本运算
- 「译」JUnit 5 系列:环境搭建
- VS1003详解
- LeetCode---7. Reverse Integer
- C#控制台基础 判断指定目录下文件夹是否存在
- Unity客户端架构-DialogManager
- SPI总线的初步认识
- POJ 1042 Gone Fishing