贝叶斯分类器

来源:互联网 发布:编译单个java文件 编辑:程序博客网 时间:2024/06/05 17:52

贝叶斯


贝叶斯公式:

P(A|C)=P(C|A)P(A)P(C)
事件A在事件C发生的概率为事件C在A发生下的概率乘以事件A发生的概率,最后除上事件C发生的概率

经典场景

射击问题

A,B两人射击,A有50%的概率命中,B有60%概率命中,已知目标被命中,求分别为A、B的概率。

令目标被命中事件为C,则有:
由求贝叶斯公式可得:
P(C)=0.50.6+0.50.6+0.50.4=0.8

P(A|C)=P(C|A)P(A)P(C)=58

同理可得P(B|C)=34

医疗检测

已知条件如下:
1. 人口统计先验有:
得癌症的概率:P(ω1)=0.008
不得癌症概率:P(ω2)=0.992
2. 医疗检测中:
阳性:
P(+|ω1)=0.98
P(+|ω2)=0.02
阴性:
P(|ω2)=0.97
P(|ω1)=0.03
那么当一次检测为阳性时,得癌症的概率有多大?
P(+)=P(ω1)P(+|ω1)+P(ω2)P(+|ω2)=0.1948
P(ω1|+)=P(+|ω1)P(ω1)P(+)=0.28
当第二次检测为阳性时,得癌症的概率为多少?
这里的计算过程不变,但是先验概率P(ω1)改变了,为0.28,所以要重新计算P(+)

用贝叶斯做分类

推导过程

1.开始公式

ωmap=argmaxωiωP(ωi|a1,a2,a3..an)
其中,ai为其中的属性。整个公式的解释是:这条数据的最终类别是ωi在条件{a_的概率最大的那个分类

2.用贝叶斯公式

ωmap=argmaxωiωP(a1,a2...an|ωi)P(ωi)P(a1,a2...an)

3.化简

去掉P(a1,a2...an),因为每个都一样
其中P(ωi)是可以从训练集中统计出来的先验概率

4.引入独立条件

P(a1,a2...an|ωi)=P(a1|ωi)P(a2|ωi)...P(an|ωi)

5.最终可得到公式

ωmap=argmaxωiωP(ωi)jP(aj|ωi)
输入数据就是aj,也就是各个属性的值
在数据集中可以获得的数据有:
1. P(ωi)=
2. P(aj|ωi)=ωiajωi

决策树

ID3

定义:
Entropy(S)=i=1cpilog(pi)
解释:
1. S是最后的标签属性,取值范围为c
2. pi=i

信息增益

在上节中,只是计算了当前集合的总体熵,信息增益=总体熵-(用标签外的属性X来划分之后的熵)
Gain(S,X)=Entropy(S)Entropy(S|X)
Entropy(S|X)=vX|Sv||S|Entropy(Sv)
Entropy(Sv)=i=1cpilog(pi)这个样本公式的Sv代表属性X=v的所有属性局
例子:

id 是否抽烟 头发长度 鞋码 性别(男|女) 1 false 100 mid 女 2 true 100 small 女 3 true 10 big 男 4 false 20 mid 男 5 true 30 mid 女 6 true 70 big 男 7 false 100 small 女 8 false 50 small 女

1. 总体熵:
p()=3/8
p()=5/8
Entropy()=p()log2p()p()log2p()=0.96
2. 计算Entropy(|=true):
p(=|=true)=0.5
p(=|=true)=0.5
Entropy(|=true)=1.0
3. 计算Entropy(|=false)
p(=|=false)=1/4
p(=|=false)=3/4
Entropy(|=false)=14log21434log234=0.81
4. 计算Entropy(|)
p(=true)=0.5
p(=false)=0.5
Entropy(|)=0.51.0+0.50.81=0.905
5. 信息增益
Gain(,)=0.960.905=0.095

利用信息熵

分别计算各个属性的信息增益,去最大的那个属性作为节点label

过拟合

两个分类器A、B,A在训练集中的效果比B好,但是在测试集中比B差,我们说A过拟合。

限制决策树高度

剪枝

将两个叶子节点,合并后,按照少数服从多数得出label
需要增设一个校验集,用于剪枝过程中的误差比较。
当剪枝进行到在校验集上误差由减小到增大的拐点时,停止剪枝

处理连续性数据

采用信息增益衡量按照进行对阈值切分点后的数据集的纯度,采用信息增益比较大的。

贝叶斯分类器实现

package com.liuyanzuo.datamining.classification;import java.util.*;/** * 朴素贝叶斯分类器实现 * Created by tempuser on 2017/1/19. */public class NaiveBayesClassification {    //定义常量    public static final String NOT_DEFINE_ATTR="not build the attributeList";    public static final String SUCCESS="success";    //定义存储类别信息的结构    private Map<String,Map<String,Map<String,Integer>>> statisticsMsg;    //定义属性名称集合    private List<String> attributeList;    //每个label的数量统计    private Map<String,Integer> labelCountMap;    //每个属性可取值的范围    private Map<String,List<String>> attrValue;    //每个label的每个属性的百分比统计    private Map<String,Map<String,Map<String,Double>>> labelAttrPercentMap;    //label在属性的下标    private int labelIndex;    //数据的总数量    private int totalCount;    /**     * 构造分类器     * @param data     * @param labelIndex     */    public String build(List<List<String>> data,int labelIndex){        if(null==attributeList || "".equals(attributeList)){            return NOT_DEFINE_ATTR;        }        this.labelIndex=labelIndex;        //初始化各个属性        statisticsMsg=new HashMap<>();        labelCountMap=new HashMap<>();        attrValue=new HashMap<>();        for(List<String> attributeLabelList : data){            //这行数据的标签            String label=attributeLabelList.get(labelIndex);            //统计这行数据的label            Integer labelPercentValue=labelCountMap.get(label);            if(labelPercentValue==null){                labelPercentValue=0;            }            labelPercentValue++;            labelCountMap.put(label,labelPercentValue);            totalCount++;            Map<String,Map<String,Integer>> labelMap= statisticsMsg.get(label);            if(null == labelMap){                labelMap=new HashMap<>();                statisticsMsg.put(label,labelMap);            }            for(int i=0;i<attributeLabelList.size();i++){                if(i != labelIndex){                    //现在所在下标的属性名称                    String attributeName=attributeList.get(i);                    //现在所在下标的属性值                    String attributeValue=attributeLabelList.get(i);                    //统计属性的取值范围                    List<String> attrValueList=attrValue.get(attributeName);                    if(attrValueList==null){                        attrValueList=new ArrayList<>();                    }                    if(!attrValueList.contains(attributeValue)){                        attrValueList.add(attributeValue);                    }                    attrValue.put(attributeName,attrValueList);                    Map<String,Integer> attributeMap=labelMap.get(attributeName);                    if( null == attributeMap){                        attributeMap=new HashMap<>();                        labelMap.put(attributeList.get(i),attributeMap);                    }                    Integer attributeCountValue=attributeMap.get(attributeValue);                    if(null==attributeCountValue){                        attributeCountValue=0;                    }                    attributeCountValue++;                    attributeMap.put(attributeValue,attributeCountValue);                }            }        }        labelAttrPercentMap=new HashMap<>();        //统计label百分比        Set<String> labelSet=statisticsMsg.keySet();        for(String label:labelSet){            //这个label的总长度            int labelCount=labelCountMap.get(label);            Map<String,Map<String,Integer>> statisticsLabelAttrMap=statisticsMsg.get(label);            //统计每个label下的各个属性的各个取值的数量            Map<String,Map<String,Double>> percentValue=new HashMap<>();            Set<String> attrSet=statisticsLabelAttrMap.keySet();            for(String attribute:attrSet){                Map<String,Integer> attributeValueMap=statisticsLabelAttrMap.get(attribute);                Set<String> attributeValueSet=attributeValueMap.keySet();                Map<String,Double> percentAttributeValueMap=new HashMap<>();                for(String attributeValue:attributeValueSet){                    //最终属性取值的百分比                    percentAttributeValueMap.put(attributeValue,attributeValueMap.get(attributeValue)/(labelCount*1.0));                }                percentValue.put(attribute,percentAttributeValueMap);            }            labelAttrPercentMap.put(label,percentValue);        }        return SUCCESS;    }    /**     * 对传入数据进行分类     * @param needClassify     */    public Map<String,Double> classify(List<String> needClassify){        Map<String,Double> result=new HashMap<>();        if(null == statisticsMsg || statisticsMsg.size()==0){            return result;        }        for(String label:labelCountMap.keySet()){            double prediction=1.0;            for(int i=0;i<attributeList.size();i++){                //当前属性名称                String attrName=attributeList.get(i);                if(i != labelIndex){                    //要做一个laplace平滑                    Integer labelAttrPercentValue=statisticsMsg.get(label).get(attributeList.get(i)).get(needClassify.get(i));                    if(labelAttrPercentValue==null){                        labelAttrPercentValue=0;                    }                    prediction*=(labelAttrPercentValue+1.0)/(attrValue.get(attrName).size()*1.0+labelCountMap.get(label)*1.0);                }            }            result.put(label,prediction);        }        return result;    }    public Map<String, Map<String, Map<String, Integer>>> getStatisticsMsg() {        return statisticsMsg;    }    public void setStatisticsMsg(Map<String, Map<String, Map<String, Integer>>> statisticsMsg) {        this.statisticsMsg = statisticsMsg;    }    public List<String> getAttributeList() {        return attributeList;    }    public void setAttributeList(List<String> attributeList) {        this.attributeList = attributeList;    }    public Map<String, Integer> getLabelCountMap() {        return labelCountMap;    }    public Map<String, Map<String, Map<String, Double>>> getLabelAttrPercentMap() {        return labelAttrPercentMap;    }    public void setLabelAttrPercentMap(Map<String, Map<String, Map<String, Double>>> labelAttrPercentMap) {        this.labelAttrPercentMap = labelAttrPercentMap;    }}
0 0