tfidf算法+余弦相似度算法计算文本相似度

来源:互联网 发布:java 中的super this 编辑:程序博客网 时间:2024/05/16 15:05

TF-IDF(term frequency–inverse document frequency)是一种用于信息检索与数据挖掘的常用加权技术。TF意思是词频(Term Frequency),IDF意思是逆向文件频率(Inverse Document Frequency)。

思想:对文本进行分词,然后用tfidf算法得到文本对应的词向量,然后利用余弦算法求相似度
需要的jar :je-analysis-1.5.3.jar ,lucene-core-2.4.1.jar(高于4的版本会有冲突)

/** * 直接匹配2个文本 *  * @author rock * */public class GetText {    private static List<String> fileList = new ArrayList<String>();    private static HashMap<String, HashMap<String, Double>> allTheTf = new HashMap<String, HashMap<String, Double>>();    private static HashMap<String, HashMap<String, Integer>> allTheNormalTF = new HashMap<String, HashMap<String, Integer>>();    private static LinkedHashMap<String, Double[]> vectorMap = new LinkedHashMap<String, Double[]>();    /**     * 分词     *      * @author create by rock     */    public static String[] TextcutWord(String text) throws IOException {        String[] cutWordResult = null;        MMAnalyzer analyzer = new MMAnalyzer();        String tempCutWordResult = analyzer.segment(text, " ");        cutWordResult = tempCutWordResult.split(" ");        return cutWordResult;    }    public static Map<String, HashMap<String, Integer>> NormalTFOfAll(String key1, String key2, String text1,            String text2) throws IOException {        if (allTheNormalTF.get(key1) == null) {            HashMap<String, Integer> dict1 = new HashMap<String, Integer>();            dict1 = normalTF(TextcutWord(text1));            allTheNormalTF.put(key1, dict1);        }        if (allTheNormalTF.get(key2) == null) {            HashMap<String, Integer> dict2 = new HashMap<String, Integer>();            dict2 = normalTF(TextcutWord(text2));            allTheNormalTF.put(key2, dict2);        }        return allTheNormalTF;    }    public static Map<String, HashMap<String, Double>> tfOfAll(String key1, String key2, String text1, String text2)            throws IOException {            allTheTf.clear();            HashMap<String, Double> dict1 = new HashMap<String, Double>();            HashMap<String, Double> dict2 = new HashMap<String, Double>();            dict1 = tf(TextcutWord(text1));            dict2 = tf(TextcutWord(text2));            allTheTf.put(key1, dict1);            allTheTf.put(key2, dict2);            return allTheTf;    }    /**     * 计算词频     *      * @author create by rock     */    public static HashMap<String, Double> tf(String[] cutWordResult) {        HashMap<String, Double> tf = new HashMap<String, Double>();// 正规化        int wordNum = cutWordResult.length;        int wordtf = 0;        for (int i = 0; i < wordNum; i++) {            wordtf = 0;            if (cutWordResult[i] != " ") {                for (int j = 0; j < wordNum; j++) {                    if (i != j) {                        if (cutWordResult[i].equals(cutWordResult[j])) {                            cutWordResult[j] = " ";                            wordtf++;                        }                    }                }                tf.put(cutWordResult[i], (new Double(++wordtf)) / wordNum);                cutWordResult[i] = " ";            }        }        return tf;    }    public static HashMap<String, Integer> normalTF(String[] cutWordResult) {        HashMap<String, Integer> tfNormal = new HashMap<String, Integer>();// 没有正规化        int wordNum = cutWordResult.length;        int wordtf = 0;        for (int i = 0; i < wordNum; i++) {            wordtf = 0;            if (cutWordResult[i] != " ") {                for (int j = 0; j < wordNum; j++) {                    if (i != j) {                        if (cutWordResult[i].equals(cutWordResult[j])) {                            cutWordResult[j] = " ";                            wordtf++;                        }                    }                }                tfNormal.put(cutWordResult[i], ++wordtf);                cutWordResult[i] = " ";            }        }        return tfNormal;    }    public static Map<String, Double> idf(String key1, String key2, String text1, String text2)            throws FileNotFoundException, UnsupportedEncodingException, IOException {        // 公式IDF=log((1+|D|)/|Dt|),其中|D|表示文档总数,|Dt|表示包含关键词t的文档数量。        Map<String, Double> idf = new HashMap<String, Double>();        List<String> located = new ArrayList<String>();        NormalTFOfAll(key1, key2, text1, text2);        float Dt = 1;        float D = allTheNormalTF.size();// 文档总数        List<String> key = fileList;// 存储各个文档名的List        String[] keyarr = new String[2];        keyarr[0] = key1;        keyarr[1] = key2;        for(String item :keyarr) {            if (!fileList.contains(item)) {                 fileList.add(item);            }        }        Map<String, HashMap<String, Integer>> tfInIdf = allTheNormalTF;// 存储各个文档tf的Map        for (int i = 0; i < D; i++) {            HashMap<String, Integer> temp = tfInIdf.get(key.get(i));            for (String word : temp.keySet()) {                Dt = 1;                if (!(located.contains(word))) {                    for (int k = 0; k < D; k++) {                        if (k != i) {                            HashMap<String, Integer> temp2 = tfInIdf.get(key.get(k));                            if (temp2.keySet().contains(word)) {                                located.add(word);                                Dt = Dt + 1;                                continue;                            }                        }                    }                    idf.put(word, (double) Log.log((1 + D) / Dt, 10));                }            }        }        return idf;    }    public static Map<String, HashMap<String, Double>> tfidf(String key1, String key2, String text1, String text2)            throws IOException {        Map<String, Double> idf = idf(key1, key2, text1, text2);        tfOfAll(key1, key2, text1, text2);        for (String key : allTheTf.keySet()) {            Map<String, Double> singelFile = allTheTf.get(key);            int length = idf.size();            Double[] arr = new Double[length];            int index = 0;            for (String word : singelFile.keySet()) {                singelFile.put(word, (idf.get(word)) * singelFile.get(word));            }            for (String word : idf.keySet()) {                  arr[index] = singelFile.get(word) != null ?singelFile.get(word):0d;                index++;            }            vectorMap.put(key, arr);        }        return allTheTf;    }    /* 得到词向量以后,用余弦相似度匹配 */    public static Double sim(String key1, String key2) {        Double[] arr1 = vectorMap.get(key1);        Double[] arr2 = vectorMap.get(key2);        int length = arr1.length;        Double result1 = 0.00; // 向量1的模        Double result2 = 0.00; // 向量2的模        Double sum = 0d;        if (length == 0) {            return 0d;        }        for (int i = 0; i < length; i++) {            result1 += arr1[i] * arr1[i];            result2 += arr2[i] * arr2[i];            sum += arr1[i] * arr2[i];        }        Double result = Math.sqrt(result1 * result2);        System.out.println(key1 + "和" + key2 + "相似度" + sum / result);        return sum / result;    }}

匹配多个文件

/** * 从语料仓库去匹配 * @author rock * */public class ReadFiles {    private static List<String> fileList = new ArrayList<String>();    private static HashMap<String, HashMap<String, Float>> allTheTf = new HashMap<String, HashMap<String, Float>>();    private static HashMap<String, HashMap<String, Integer>> allTheNormalTF = new HashMap<String, HashMap<String, Integer>>();    private static LinkedHashMap<String, Float[]> vectorMap = new LinkedHashMap<String, Float[]>();    /**     * 读取语料仓库     * @author create by rock     */    public static List<String> readDirs(String filepath) throws FileNotFoundException, IOException {        try {            File file = new File(filepath);            if (!file.isDirectory()) {                System.out.println("输入的参数应该为[文件夹名]");                System.out.println("filepath: " + file.getAbsolutePath());            } else if (file.isDirectory()) {                String[] filelist = file.list();                for (int i = 0; i < filelist.length; i++) {                    File readfile = new File(filepath + "\\" + filelist[i]);                    if (!readfile.isDirectory()) {                        fileList.add(readfile.getAbsolutePath());                    } else if (readfile.isDirectory()) {                        readDirs(filepath + "\\" + filelist[i]);                    }                }            }        } catch (FileNotFoundException e) {            System.out.println(e.getMessage());        }        return fileList;    }    /**     * 读取txt文件     * @author create by rock     */    public static String readFiles(String file) throws FileNotFoundException, IOException {        StringBuffer sb = new StringBuffer();        InputStreamReader is = new InputStreamReader(new FileInputStream(file), "utf-8");        BufferedReader br = new BufferedReader(is);        String line = br.readLine();        while (line != null) {            sb.append(line).append("\r\n");            line = br.readLine();        }        br.close();        return sb.toString();    }    /**     * 分词     * @author create by rock     */    public static String[] cutWord(String file) throws IOException {        String[] cutWordResult = null;        String text = ReadFiles.readFiles(file);        MMAnalyzer analyzer = new MMAnalyzer();        String tempCutWordResult = analyzer.segment(text, " ");        cutWordResult = tempCutWordResult.split(" ");        return cutWordResult;    }    /**     * 计算词频     * @author create by rock     */    public static HashMap<String, Float> tf(String[] cutWordResult) {        HashMap<String, Float> tf = new HashMap<String, Float>();//正规化        int wordNum = cutWordResult.length;        int wordtf = 0;        for (int i = 0; i < wordNum; i++) {            wordtf = 0;            for (int j = 0; j < wordNum; j++) {                if (cutWordResult[i] != " " && i != j) {                    if (cutWordResult[i].equals(cutWordResult[j])) {                        cutWordResult[j] = " ";                        wordtf++;                    }                }            }            if (cutWordResult[i] != " ") {                tf.put(cutWordResult[i], (new Float(++wordtf)) / wordNum);                cutWordResult[i] = " ";            }        }        return tf;    }    public static HashMap<String, Integer> normalTF(String[] cutWordResult) {        HashMap<String, Integer> tfNormal = new HashMap<String, Integer>();//没有正规化        int wordNum = cutWordResult.length;        int wordtf = 0;        for (int i = 0; i < wordNum; i++) {            wordtf = 0;            if (cutWordResult[i] != " ") {                for (int j = 0; j < wordNum; j++) {                    if (i != j) {                        if (cutWordResult[i].equals(cutWordResult[j])) {                            cutWordResult[j] = " ";                            wordtf++;                        }                    }                }                tfNormal.put(cutWordResult[i], ++wordtf);                cutWordResult[i] = " ";            }        }        return tfNormal;    }    public static Map<String, HashMap<String, Float>> tfOfAll(String dir) throws IOException {        List<String> fileList = ReadFiles.readDirs(dir);        for (String file : fileList) {            HashMap<String, Float> dict = new HashMap<String, Float>();            dict = ReadFiles.tf(ReadFiles.cutWord(file));            allTheTf.put(file, dict);        }        return allTheTf;    }    /**     * 自定义文档内容     * @author create by rock     */    public static Map<String, HashMap<String, Float>> tfOfAll(String[] files) throws IOException {        for (String file : files) {            HashMap<String, Float> dict = new HashMap<String, Float>();            dict = ReadFiles.tf(ReadFiles.cutWord(file));            allTheTf.put(file, dict);        }        return allTheTf;    }    public static Map<String, HashMap<String, Integer>> NormalTFOfAll(String dir) throws IOException {        List<String> fileList = ReadFiles.readDirs(dir);        for (int i = 0; i < fileList.size(); i++) {            HashMap<String, Integer> dict = new HashMap<String, Integer>();            dict = ReadFiles.normalTF(ReadFiles.cutWord(fileList.get(i)));            allTheNormalTF.put(fileList.get(i), dict);        }        return allTheNormalTF;    }    public static Map<String, Float> idf(String dir) throws FileNotFoundException, UnsupportedEncodingException, IOException {        //公式IDF=log((1+|D|)/|Dt|),其中|D|表示文档总数,|Dt|表示包含关键词t的文档数量。        Map<String, Float> idf = new HashMap<String, Float>();        List<String> located = new ArrayList<String>();        NormalTFOfAll(dir);        float Dt = 1;        float D = allTheNormalTF.size();//文档总数        List<String> key = fileList;//存储各个文档名的List        Map<String, HashMap<String, Integer>> tfInIdf = allTheNormalTF;//存储各个文档tf的Map        for (int i = 0; i < D; i++) {            HashMap<String, Integer> temp = tfInIdf.get(key.get(i));            for (String word : temp.keySet()) {                Dt = 1;                if (!(located.contains(word))) {                    for (int k = 0; k < D; k++) {                        if (k != i) {                            HashMap<String, Integer> temp2 = tfInIdf.get(key.get(k));                            if (temp2.keySet().contains(word)) {                                located.add(word);                                Dt = Dt + 1;                                continue;                            }                        }                    }                    idf.put(word, Log.log((1 + D) / Dt, 10));                }            }        }        return idf;    }    public static Map<String, HashMap<String, Float>> tfidf(String dir) throws IOException {        Map<String, Float> idf = ReadFiles.idf(dir);        Map<String, HashMap<String, Float>> tf = ReadFiles.tfOfAll(dir);        for (String file : tf.keySet()) {            Map<String, Float> singelFile = tf.get(file);            int length = idf.size();            Float[] arr = new Float[length];            int index = 0;            for (String word : singelFile.keySet()) {                singelFile.put(word, (idf.get(word)) * singelFile.get(word));            }            for(String word : idf.keySet()) {                if(singelFile.get(word) != null) {                    arr[index] = singelFile.get(word);                }else {                    arr[index] = 0f;                }                index++;            }            vectorMap.put(file, arr);        }            return tf;    }     public static double sim(String file1,String file2) {        Float [] arr1 = vectorMap.get(file1);        Float [] arr2 = vectorMap.get(file2);         int length = arr1.length;         double result1 = 0.00;  //向量1的模         double result2 = 0.00;  //向量2的模         Float sum = 0f;        for(int i =0;i<length;i++) {            result1 += arr1[i]*arr1[i];            result2 += arr2[i]*arr2[i];            sum+=arr1[i]*arr2[i];        }        double result = Math.sqrt(result1*result2);        System.out.println(sum/result);        return sum/result;    }}
原创粉丝点击