从多项分布采样的Java实现

来源:互联网 发布:新中新dkq a16d软件 编辑:程序博客网 时间:2024/06/02 04:26
思路:
将每个概率值对应到[0,1]区间内的各个子区间(概率值大小体现在子区间的长度上),每次采样时,按照均匀分布随机生成一个[0,1]区间内的值,其落到哪个区间,则该区间概率值对应的元素即为被采样的元素;


算法:
1、先对概率值从大到小排列(不是必要过程,是便于加速的技巧,这样每次查找时优先检测随机数是否落在大概率的区间内,减少比较次数);
2、生成一个[0,1)区间内的随机数x (注意,Rand().nextDouble()得到的是[0,1)区间内的数,而wikipedia给出的算法中要求生成的是(0,1)区间的数);
3、将x与概率值列表中的各值pi逐个比较,并累加已比较过的前i-1个概率值的累加和sum:
若x落在[sum, sum+pi)区间内,则pi对应的元素被采样并返回 (注意区间的开闭应该参考步骤2中的情况);
否则,将pi累加入sum,继续将x与p(i+1)比较;

Tips:
若程序退出时仍未采到合法样本,则可能给定的概率分布不满足∑pi=1的条件(且x刚好落在[1-sum, 1)区间内);

应用场景:
机器学习(如强化学习)中,利用softmax函数定义policy,根据多项分布选择对应的action(使得agent有较大概率选到当前模型下的最佳action,又有一定的几率去探索其他action);
softmax policy的另一种替代方式是epsilon-learning中用epsilon来控制探索和利用的几率的方式,即以epsilon的概率进行探索(随机选一个action),以1-epsilon的概率进行利用(选当前模型下最佳action);

算法代码:

/**
        * sample from amultinomialdistribution
        * https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
        *@parampdist a list of <item,probablity>
        *@returnthe selected item, i.e. result belongs to pdist.getFirstValues
        *@authorqxliuOct16, 2017 11:47:33 AM
        */
       publicstaticintsampleFromMultinomialDistribution(List<TwoTuple<Integer,Double>>pdist){
              List<TwoTuple<Integer,Double>>pidxlist=newArrayList<>(pdist);//avoid changing pdist
              intitemNum=pidxlist.size();
              Collections.sort(pidxlist,newReRanker().setIsDesc(true));
              Randomrand=newRandom();
              doublerd=rand.nextDouble();//a random double in [0,1)
              doublesum=0;
              intsampledIdx=-1;
              for(intk=0;k<itemNum;k++){
                     doublepk=pidxlist.get(k).getSecond();
                     if(rd>=sum&&rd<sum+pk){
                           sampledIdx=pidxlist.get(k).getFirst();//====
                           break;
                     }
                     sum+=pk;
              }
              if(sampledIdx<0&&sum!=1){
                     thrownewIllegalArgumentException("error distribution! sampledIdx="+sampledIdx+", distribution="+pidxlist);
              }
              returnsampledIdx;
       }
测试代码:
/**
 *@authorqxliu2017Oct10, 2017 4:34:12 PM
 *
 */
publicstaticvoidmain(String[]args){             
              intsampleNum=10;//打算采样的次数(决定最终采得的样本数)
              List<TwoTuple<Integer,Double>>pdist=newArrayList<>();//定义多项式分布
              pdist.add(newTwoTuple<Integer,Double>(1,0.5));//每个元素为:样本的标记(比如id),样本被选中的概率;
              pdist.add(newTwoTuple<Integer,Double>(2,0.3));//应该要求概率分布之和为1,但算法里并未检查概率值总和是否为1;
              pdist.add(newTwoTuple<Integer,Double>(3,0.2));//若概率值之和不为1,则有可能报错(当随机数出现在[sum, 1)区间内时),也可能不报错;
//            pdist.add(new TwoTuple<Integer, Double>(9, 0.5));
//            pdist.add(new TwoTuple<Integer, Double>(7, 0.5));
              for(inti=0;i<sampleNum;i++){
                     System.out.println(sampleFromMultinomialDistribution(pdist));
              }
       }

输出:(10次采样的结果,也可能是满足该分布的其他情况;由于采样次数少,有时结果也可能看起来不满足原分布)
1
3
1
2
1
2
2
2
1
2
原创粉丝点击