贝叶斯文本分类

来源:互联网 发布:网络战争游戏排行榜 编辑:程序博客网 时间:2024/04/28 03:17

昨天实现了一个基于贝叶斯定理的的文本分类,贝叶斯定理假设特征属性(在文本中就是词汇)对待分类项的影响都是独立的,道理比较简单,在中文分类系统中,分类的准确性与分词系统的好坏有很大的关系,这段代码也是试验不同分词系统才顺手写的一个。 
    试验数据用的sogou实验室的文本分类样本,一共分为9个类别,每个类别文件夹下大约有2000篇文章。由于文本数据量确实较大,所以得想办法让每次训练的结果都能保存起来,以便于下次直接使用,我这里使用序列化的方式保存在硬盘。 
  训练代码如下: 
Java代码  收藏代码
  1. /** 
  2.  * 训练器 
  3.  *  
  4.  * @author duyf 
  5.  *  
  6.  */  
  7. class Train implements Serializable {  
  8.   
  9.     /** 
  10.      *  
  11.      */  
  12.     private static final long serialVersionUID = 1L;  
  13.   
  14.     public final static String SERIALIZABLE_PATH = "D:\\workspace\\Test\\SogouC.mini\\Sample\\Train.ser";  
  15.     // 训练集的位置  
  16.     private String trainPath = "D:\\workspace\\Test\\SogouC.mini\\Sample";  
  17.   
  18.     // 类别序号对应的实际名称  
  19.     private Map<String, String> classMap = new HashMap<String, String>();  
  20.   
  21.     // 类别对应的txt文本数  
  22.     private Map<String, Integer> classP = new ConcurrentHashMap<String, Integer>();  
  23.   
  24.     // 所有文本数  
  25.     private AtomicInteger actCount = new AtomicInteger(0);  
  26.   
  27.       
  28.   
  29.     // 每个类别对应的词典和频数  
  30.     private Map<String, Map<String, Double>> classWordMap = new ConcurrentHashMap<String, Map<String, Double>>();  
  31.   
  32.     // 分词器  
  33.     private transient Participle participle;  
  34.   
  35.     private static Train trainInstance = new Train();  
  36.   
  37.     public static Train getInstance() {  
  38.         trainInstance = new Train();  
  39.   
  40.         // 读取序列化在硬盘的本类对象  
  41.         FileInputStream fis;  
  42.         try {  
  43.             File f = new File(SERIALIZABLE_PATH);  
  44.             if (f.length() != 0) {  
  45.                 fis = new FileInputStream(SERIALIZABLE_PATH);  
  46.                 ObjectInputStream oos = new ObjectInputStream(fis);  
  47.                 trainInstance = (Train) oos.readObject();  
  48.                 trainInstance.participle = new IkParticiple();  
  49.             } else {  
  50.                 trainInstance = new Train();  
  51.             }  
  52.         } catch (Exception e) {  
  53.             e.printStackTrace();  
  54.         }  
  55.   
  56.         return trainInstance;  
  57.     }  
  58.   
  59.     private Train() {  
  60.         this.participle = new IkParticiple();  
  61.     }  
  62.   
  63.     public String readtxt(String path) {  
  64.         BufferedReader br = null;  
  65.         StringBuilder str = null;  
  66.         try {  
  67.             br = new BufferedReader(new FileReader(path));  
  68.   
  69.             str = new StringBuilder();  
  70.   
  71.             String r = br.readLine();  
  72.   
  73.             while (r != null) {  
  74.                 str.append(r);  
  75.                 r = br.readLine();  
  76.   
  77.             }  
  78.   
  79.             return str.toString();  
  80.         } catch (IOException ex) {  
  81.             ex.printStackTrace();  
  82.         } finally {  
  83.             if (br != null) {  
  84.                 try {  
  85.                     br.close();  
  86.                 } catch (IOException e) {  
  87.                     e.printStackTrace();  
  88.                 }  
  89.             }  
  90.             str = null;  
  91.             br = null;  
  92.         }  
  93.   
  94.         return "";  
  95.     }  
  96.   
  97.     /** 
  98.      * 训练数据 
  99.      */  
  100.     public void realTrain() {  
  101.         // 初始化  
  102.         classMap = new HashMap<String, String>();  
  103.         classP = new HashMap<String, Integer>();  
  104.         actCount.set(0);  
  105.         classWordMap = new HashMap<String, Map<String, Double>>();  
  106.   
  107.         // classMap.put("C000007", "汽车");  
  108.         classMap.put("C000008""财经");  
  109.         classMap.put("C000010""IT");  
  110.         classMap.put("C000013""健康");  
  111.         classMap.put("C000014""体育");  
  112.         classMap.put("C000016""旅游");  
  113.         classMap.put("C000020""教育");  
  114.         classMap.put("C000022""招聘");  
  115.         classMap.put("C000023""文化");  
  116.         classMap.put("C000024""军事");  
  117.   
  118.         // 计算各个类别的样本数  
  119.         Set<String> keySet = classMap.keySet();  
  120.   
  121.         // 所有词汇的集合,是为了计算每个单词在多少篇文章中出现,用于后面计算df  
  122.         final Set<String> allWords = new HashSet<String>();  
  123.   
  124.         // 存放每个类别的文件词汇内容  
  125.         final Map<String, List<String[]>> classContentMap = new ConcurrentHashMap<String, List<String[]>>();  
  126.   
  127.         for (String classKey : keySet) {  
  128.   
  129.             Participle participle = new IkParticiple();  
  130.             Map<String, Double> wordMap = new HashMap<String, Double>();  
  131.             File f = new File(trainPath + File.separator + classKey);  
  132.             File[] files = f.listFiles(new FileFilter() {  
  133.   
  134.                 @Override  
  135.                 public boolean accept(File pathname) {  
  136.                     if (pathname.getName().endsWith(".txt")) {  
  137.                         return true;  
  138.                     }  
  139.                     return false;  
  140.                 }  
  141.   
  142.             });  
  143.   
  144.             // 存储每个类别的文件词汇向量  
  145.             List<String[]> fileContent = new ArrayList<String[]>();  
  146.             if (files != null) {  
  147.                 for (File txt : files) {  
  148.                     String content = readtxt(txt.getAbsolutePath());  
  149.                     // 分词  
  150.                     String[] word_arr = participle.participle(content, false);  
  151.                     fileContent.add(word_arr);  
  152.                     // 统计每个词出现的个数  
  153.                     for (String word : word_arr) {  
  154.                         if (wordMap.containsKey(word)) {  
  155.                             Double wordCount = wordMap.get(word);  
  156.                             wordMap.put(word, wordCount + 1);  
  157.                         } else {  
  158.                             wordMap.put(word, 1.0);  
  159.                         }  
  160.                           
  161.                     }  
  162.                 }  
  163.             }  
  164.   
  165.             // 每个类别对应的词典和频数  
  166.             classWordMap.put(classKey, wordMap);  
  167.   
  168.             // 每个类别的文章数目  
  169.             classP.put(classKey, files.length);  
  170.             actCount.addAndGet(files.length);  
  171.             classContentMap.put(classKey, fileContent);  
  172.   
  173.         }  
  174.   
  175.           
  176.   
  177.           
  178.   
  179.         // 把训练好的训练器对象序列化到本地 (空间换时间)  
  180.         FileOutputStream fos;  
  181.         try {  
  182.             fos = new FileOutputStream(SERIALIZABLE_PATH);  
  183.             ObjectOutputStream oos = new ObjectOutputStream(fos);  
  184.             oos.writeObject(this);  
  185.         } catch (Exception e) {  
  186.             e.printStackTrace();  
  187.         }  
  188.   
  189.     }  
  190.   
  191.     /** 
  192.      * 分类 
  193.      *  
  194.      * @param text 
  195.      * @return 返回各个类别的概率大小 
  196.      */  
  197.     public Map<String, Double> classify(String text) {  
  198.         // 分词,并且去重  
  199.         String[] text_words = participle.participle(text, false);  
  200.   
  201.         Map<String, Double> frequencyOfType = new HashMap<String, Double>();  
  202.         Set<String> keySet = classMap.keySet();  
  203.         for (String classKey : keySet) {  
  204.             double typeOfThis = 1.0;  
  205.             Map<String, Double> wordMap = classWordMap.get(classKey);  
  206.             for (String word : text_words) {  
  207.                 Double wordCount = wordMap.get(word);  
  208.                 int articleCount = classP.get(classKey);  
  209.   
  210.                 /* 
  211.                  * Double wordidf = idfMap.get(word); if(wordidf==null){ 
  212.                  * wordidf=0.001; }else{ wordidf = Math.log(actCount / wordidf); } 
  213.                  */  
  214.   
  215.                 // 假如这个词在类别下的所有文章中木有,那么给定个极小的值 不影响计算  
  216.                 double term_frequency = (wordCount == null) ? ((double1 / (articleCount + 1))  
  217.                         : (wordCount / articleCount);  
  218.   
  219.                 // 文本在类别的概率 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。  
  220.                 // 当double无限小的时候会归为0,为了避免 *10  
  221.   
  222.                 typeOfThis = typeOfThis * term_frequency * 10;  
  223.                 typeOfThis = ((typeOfThis == 0.0) ? Double.MIN_VALUE  
  224.                         : typeOfThis);  
  225.                 // System.out.println(typeOfThis+" : "+term_frequency+" :  
  226.                 // "+actCount);  
  227.             }  
  228.   
  229.             typeOfThis = ((typeOfThis == 1.0) ? 0.0 : typeOfThis);  
  230.   
  231.             // 此类别文章出现的概率  
  232.             double classOfAll = classP.get(classKey) / actCount.doubleValue();  
  233.   
  234.             // 根据贝叶斯公式 $(A|B)=S(B|A)*S(A)/S(B),由于$(B)是常数,在这里不做计算,不影响分类结果  
  235.             frequencyOfType.put(classKey, typeOfThis * classOfAll);  
  236.         }  
  237.   
  238.         return frequencyOfType;  
  239.     }  
  240.   
  241.     public void pringAll() {  
  242.         Set<Entry<String, Map<String, Double>>> classWordEntry = classWordMap  
  243.                 .entrySet();  
  244.         for (Entry<String, Map<String, Double>> ent : classWordEntry) {  
  245.             System.out.println("类别: " + ent.getKey());  
  246.             Map<String, Double> wordMap = ent.getValue();  
  247.             Set<Entry<String, Double>> wordMapSet = wordMap.entrySet();  
  248.             for (Entry<String, Double> wordEnt : wordMapSet) {  
  249.                 System.out.println(wordEnt.getKey() + ":" + wordEnt.getValue());  
  250.             }  
  251.         }  
  252.     }  
  253.   
  254.     public Map<String, String> getClassMap() {  
  255.         return classMap;  
  256.     }  
  257.   
  258.     public void setClassMap(Map<String, String> classMap) {  
  259.         this.classMap = classMap;  
  260.     }  
  261.   
  262. }  

在试验过程中,发觉某篇文章的分类不太准,某篇IT文章分到招聘类别下了,在仔细对比了训练数据后,发觉这是由于招聘类别每篇文章下面都带有“搜狗”的标志,而待分类的这篇IT文章里面充斥这搜狗这类词汇,结果招聘类下的概率比较大。由此想到,在除了做常规的贝叶斯计算时,需要把不同文本中出现次数多的词汇权重降低甚至删除(好比关键词搜索中的tf-idf),通俗点讲就是,在所有训练文本中某词汇(如的,地,得)出现的次数越多,这个词越不重要,比如IT文章中“软件”和“应用”这两个词汇,“应用”应该是很多文章类别下都有的,反而不太重要,但是“软件”这个词汇大多只出现在IT文章里,出现在大量文章的概率并不大。 我这里原本打算计算每个词的idf,然后给定一个阀值来判断是否需要纳入计算,但是由于词汇太多,计算量较大(等待结果时间较长),所以暂时注释掉了。 

原创粉丝点击