RBM 推荐系统 Java代码(质量堪忧,仅供参考,欢迎讨论)

来源:互联网 发布:json格式怎么打开 mock 编辑:程序博客网 时间:2024/06/04 19:27

这是一段代码段,用Java写的,关于RBM ,也关于推荐系统。效率比较低,代码也有点问题。贴出来,仅仅是为了给大家提供一点思路,也是希望大家多多指教。仅仅贴出最重要的代码块,完整的代码,大家可以给我留言,我抽出时间来再来给大家发一份。抛砖引玉,希望能有同学或工程师来一起交流,大家共同进步,共同提高。


package rbm_1th;import java.awt.List;import java.util.ArrayList;import java.util.HashMap;import java.util.Map;import java.util.Random;import static rbm_1th.utils.*;import static rbm_1th.Get_data.*;public class RBM{    public int N;    public int n_visible;    public int n_hidden;    public double[][] W;    public double[] hbias;    public double[] vbias;    public Random rng;    //RBM的构造函数    public RBM(int N,int n_visiable,int n_hidden,double[][] W,double[] hbias,double[] vbias,Random rng)    {        this.N = N;        this.n_visible = n_visiable;        this.n_hidden = n_hidden;                if (rng == null)        {            this.rng = new Random(1234);        }        else{            this.rng = rng;        }        if (W == null){            this.W =new double[this.n_hidden][this.n_visible];            double a = 1.0/this.n_visible;            for (int i = 0; i < this.n_hidden; i++)            {                for (int j = 0; j< this.n_visible ; j++ )                {                    this.W[i][j] = uniform(-a,a,rng);                }            }        }else        {            this.W = W;        }                if(hbias == null){            this.hbias = new double[this.n_hidden];            for (int i = 0;i<this.n_hidden;i++ )            {                this.hbias[i] = 0;            }        }else{            this.hbias = hbias;        }        if (vbias == null)        {            this.vbias = new double[this.n_visible];            for (int j = 0; j < this.n_visible ; j++ )            {                this.vbias[j] = 0;            }        }else{            this.vbias = vbias;        }    }    //CD-k算法    public void contrastive_divergence(Map<Integer,Integer> input,double lr,int k)    {                double[] ph_mean = new double[n_hidden];        int[] ph_sample = new int[n_hidden];        double[] nv_means = new double[n_visible];        Map<Integer,Integer> nv_samples = new HashMap<>();        double[] nh_means = new double[n_hidden];        int[] nh_samples = new int[n_hidden];                sample_h_given_v(input,ph_mean,ph_sample);        for (int step = 0; step < k;step++)        {            if (step == 0)            {                gibbs_hvh(ph_sample,nv_means,nv_samples,nh_means,nh_samples);            }else {                gibbs_hvh(nh_samples,nv_means,nv_samples,nh_means,nh_samples);            }        }        /**这块儿代码可能有点问题,即在推荐系统里面是应该对所有的权重都改变还是只改变有过用户行为的显层和隐层的连接权重,这会儿脑子太乱,不想想了,欢迎大家拍砖指正**/        for (int i=0; i<n_hidden; i++)        {            for (int j: input.keySet())//此处可能有问题,各位看看            {                W[i][j] += lr *(ph_mean[i]*input.get(j) - nh_means[i]*nv_samples.get(j));//此处可能有问题,各位看看,欢迎各位指正                            }            hbias[i] += lr * (ph_sample[i] - nh_means[i]);        }        for (int j: input.keySet())        {            vbias[j] += lr *(input.get(j) - nv_samples.get(j));        }    }    //可见层到隐藏层    public double propup(Map<Integer,Integer> v,double[] w,double b)    {        double pre_sigmoid_activation = 0.0;        for (int j :v.keySet() )        {            pre_sigmoid_activation += w[j] * v.get(j);        }        pre_sigmoid_activation += b;        return sigmoid(pre_sigmoid_activation);    }        //隐藏层到可见层    public double propdown(int[] h,int j,double b)    {        double pre_sigmoid_activation = 0.0;        for (int i =0; i < h.length ; i++ )        {            pre_sigmoid_activation += W[i][j] * h[i];        }        pre_sigmoid_activation += b;                return sigmoid(pre_sigmoid_activation);    }        //吉布斯    public void gibbs_hvh(int[] h0_sample,double[] nv_means,Map<Integer,Integer> nv_samples,double[] nh_means,int[] nh_samples)    {        sample_v_given_h(h0_sample,nv_means,nv_samples);        sample_h_given_v(nv_samples,nh_means,nh_samples);    }        //sample given hidden get visible    public void sample_v_given_h(int[] h0_sample,double[] mean,Map<Integer,Integer> sample)    {        for (int j:sample.keySet() )        {            mean[j] = propdown(h0_sample, j,vbias[j]);            int oz = binomial(1,mean[j],rng);            sample.put(j, oz);        }    }    //given visible get hidden    public void sample_h_given_v(Map<Integer,Integer> v0_sample,double[] mean , int[] sample)    {        for (int i = 0;i<n_hidden;i++)        {            mean[i] = propup(v0_sample,W[i],hbias[i]);            sample[i] = binomial(1,mean[i],rng);        }    }        //函数重构    public void reconstruct(Map<Integer,Integer> v,double[] reconstructed_v)    {        double[] h = new double[n_hidden];        double pre_sigmoid_activation;        for (int i = 0;i<n_hidden ;i++ )        {            h[i] = propup(v,W[i],hbias[i]);        }        for (int j :v.keySet() )        {            pre_sigmoid_activation = 0.0;            for (int i =0; i<n_hidden ;i++ )            {                pre_sigmoid_activation += W[i][j] * h[i];            }            pre_sigmoid_activation += vbias[j];            reconstructed_v[j] = sigmoid(pre_sigmoid_activation);        }    }    //训练模型    public static double[][] train_rbm()    {        Random rng = new Random(123);        double learning_rate = 0.1;        int training_epochs = 1;        int k = 1;        //获得数据        String inpath = "D:/ml-1m/ratings.dat";        Map<String,Map<String,Double>> u_p_s = loadData(inpath);        Map<String,Integer> pid2num = get_pid2num(u_p_s);        Map<String,Double> u_means = getU_mean(u_p_s);        int train_N = u_p_s.size();                int n_visible = pid2num.size();        int n_hidden = 2;        RBM rbm = new RBM(train_N,n_visible,n_hidden,null,null,null,rng);        ArrayList<String> uid_list = new ArrayList<String>(u_p_s.keySet());            for(int epoch = 0;epoch < training_epochs;epoch++)        {            for (int i =0 ; i <uid_list.size() ;i++)            {                String uid = uid_list.get(i);                Map<String,Double> p_s = u_p_s.get(uid);                                double u_m = u_means.get(uid);                Map<Integer,Integer> pid_oz = new HashMap<Integer,Integer>();                for (String pid:p_s.keySet())                {                    Double score = u_p_s.get(uid).get(pid);                    int x = pid2num.get(pid);                    if (score > u_m)                    {                        pid_oz.put(x,1);                    }                    else                    {                        pid_oz.put(x,0);                    }                }                rbm.contrastive_divergence(pid_oz,learning_rate,k);//注意此处            }        }                return rbm.W;                    }            public static void main(String[] args)    {        double[][] w =train_rbm();            }    }


0 0