计算信息增益(Information Gain),考虑交叉feature

来源:互联网 发布:在淘宝中rol ppc是什么 编辑:程序博客网 时间:2024/06/07 00:20
 import java.io.BufferedReader;import java.io.FileReader;import java.util.ArrayList;import java.util.Collections;import java.util.Comparator;import java.util.HashMap;import java.util.List;import java.util.Map;import java.util.Map.Entry; /** *  * @author qibaoyuan *  */public class InformationGain {/** * calculate the info(entrophy from a list of classes) *  * @param classes *            字符类型的分类信息 * @return info entropy */static Double calculateEntrophy(List<String> classes) {Double info = 0.0;try {// 总的个数int size = classes.size();// map to store the count of each unique classMap<String, Integer> counter = new HashMap<String, Integer>();// iter all the classfor (String key : classes) {// already exists,incrementalif (counter.containsKey(key.trim()))counter.put(key.trim(), counter.get(key.trim()) + 1);else// set 1counter.put(key.trim(), 1);}// iter the mapfor (Entry<String, Integer> entry : counter.entrySet()) {Double ratio = Double.parseDouble(Integer.toString((entry.getValue()))) / size;info -= ratio * (Math.log(ratio) / Math.log(2));}} catch (Exception e) {e.printStackTrace();}return info;}/** *  * @param records *            输入记录 example:{[我 n 1 0 0 0 0 0 YES],[是 n 0 0 0 0 0 0 NO]} * @return */static Map<Integer, Double> calculateIG(List<String[]> records,Boolean isSingleFeature) {Map<Integer, Double> index4select = new HashMap<Integer, Double>();try {// 1.计算总的infoList<String> labels = new ArrayList<String>();int feature_size = 0;for (String[] arr : records) {String label = arr[arr.length - 1];labels.add(label);feature_size = arr.length - 1;}Map<Integer, List<Object>> features = PermutationTest.genPerLess(feature_size, 3);Double total = calculateEntrophy(labels);System.out.print("label的熵信息:");System.out.println(total);// 2.计算每个feature的entrophy// int i=0;for (Entry<Integer, List<Object>> entry1 : features.entrySet()) {Double info_i = 0.0;Map<String, List<String>> featureMap = new HashMap<String, List<String>>();// divide the records according to the featurefor (String[] arr : records) {// get the featureString feature = "";if (entry1.getValue().size() > 1 && isSingleFeature)continue;for (Object obj : entry1.getValue()) {if (obj instanceof Integer)feature += arr[(Integer) obj];}// check whether if it's countedif (featureMap.containsKey(feature)) {List<String> featureList = featureMap.get(feature);featureList.add(arr[arr.length - 1]);featureMap.put(feature, featureList);} else {List<String> featureList = new ArrayList<String>();featureList.add(arr[arr.length - 1]);featureMap.put(feature, featureList);}}// calculate entrophy of each value of the featurefor (Entry<String, List<String>> entry : featureMap.entrySet()) {Double score = calculateEntrophy(entry.getValue());info_i += (Double.parseDouble(Integer.toString(entry.getValue().size())) / records.size()) * score;}System.out.print("feature " + entry1.getKey() + " ig:");System.out.println(total - info_i);// ig=f-totalindex4select.put(entry1.getKey(), total - info_i);}// ///sort by the valueArrayList<Integer> keys = new ArrayList<Integer>(index4select.keySet());// 得到key集合final Map<Integer, Double> scoreMap_temp = index4select;Collections.sort(keys, new Comparator<Object>() {public int compare(Object o1, Object o2) {if (Double.parseDouble(scoreMap_temp.get(o1).toString()) < Double.parseDouble(scoreMap_temp.get(o2).toString()))return 1;if (Double.parseDouble(scoreMap_temp.get(o1).toString()) == Double.parseDouble(scoreMap_temp.get(o2).toString()))return 0;elsereturn -1;}});int y = 0;for (Integer key : keys) {System.out.println(key + "" + features.get(key) + "= "+ scoreMap_temp.get(key));}// //////////////////////} catch (Exception e) {e.printStackTrace();}return index4select;}/** * 从文件读入输入,计算每个feature的ig,最後一列是手工標註的label *  * @param file *            存放手工标注语料的路径 */static void calculateIG(String file) {try {FileReader reader = new FileReader(file);BufferedReader br = new BufferedReader(reader);String line = null;List<String[]> lists = new ArrayList<String[]>();while ((line = br.readLine()) != null) {if (line.trim().length() == 0)continue;lists.add(line.split("\t"));}System.out.print(calculateIG(lists,false));} catch (Exception e) {e.printStackTrace();}}/** * @param args */public static void main(String[] args) {calculateIG("/home/qibaoyuan/qibaoyuan/lexo/cv/all.txt");}}