tf-idf算法的基本实现,java

来源:互联网 发布:淘宝宝贝原价和折扣价 编辑:程序博客网 时间:2024/06/05 12:05

声明

  以下代码只是对tf-idf算法思想的基本实现,因此许多地方需待完善,总结如下:
  1.实现逻辑问题:特殊位置、比如段首或者名词(相对于动词),应该有更大的权重;
  2.分词前应该对文本进行基本处理:去掉标点,合适的方式调用分词接口,使得文本内容变大时能够分两次调用,但结果相同;
  3.速度有待提升:总文本数一星期更新一次就行,关键词所在的文本现测量方式;

实现 

package demo.utils;import com.google.common.util.concurrent.ThreadFactoryBuilder;import org.springframework.beans.factory.annotation.Autowired;import org.springframework.beans.factory.annotation.Value;import org.springframework.http.ResponseEntity;import org.springframework.stereotype.Component;import org.springframework.web.client.RestTemplate;import java.util.*;import java.util.concurrent.*;import java.util.function.Function;import java.util.regex.Matcher;import java.util.regex.Pattern;import java.util.stream.Collectors;/** * @author 杜艮魁 * @date 2017/12/1 */@Componentpublic class LTPUtils {    @Value("${demo.ltp-url}")    private String LTPURL;    @Value("${demo.api-key}")    private String apiKey;    private ExecutorService pool;    private final Pattern SUM_PATTERN= Pattern.compile("\\d+(,\\d{3})*\\s条结果");    @Autowired    public LTPUtils() {        ThreadFactory namedThreadFactory = new ThreadFactoryBuilder().setNameFormat("Thread-pool-%d").build();        ExecutorService ex = new ThreadPoolExecutor(5, 20, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingDeque<Runnable>(1024), namedThreadFactory, new ThreadPoolExecutor.AbortPolicy());        pool = ex;    }    /**     * tf-idf算法的实现     * @param content 需要分词、计算词频、逆文本频率、TF-IDF值的内容     * @return  关键词和相应的tf-idf值,按照value值降序排列     */    public  Map<String,Double> tfIdf(String content){        try {            String[] strArrs = getInfoForLTP(content, "ws", "plain");            Map<String, Double> tf = countTF(strArrs);            //获取tf-idf值            Map<String, Double> result = new HashMap<>();            for (Map.Entry<String,Double> ele:tf.entrySet()) {                result.put(ele.getKey(), ele.getValue() * getIDF(ele.getKey()));            }            //根据value值进行排序            result=result.entrySet().stream().sorted(Map.Entry.comparingByValue(Collections.reverseOrder())).collect(Collectors.toMap(                    //result=result.entrySet().stream().sorted(Map.Entry.comparingByValue(/*去掉这个将升序排列*/)).collect(Collectors.toMap(                    Map.Entry::getKey,                    Map.Entry::getValue,                    (e1,e2)->e1,                    LinkedHashMap::new            ));            return result;        }catch (Exception e){            //todo 细分            throw new RuntimeException(e.getMessage());        }    }    /**     * 调用哈工大分词,并返回结果     *     * @param text    要处理的文本     * @param pattern 匹配模式     * @param format 返回数据格式     * @return     */    public String[] getInfoForLTP(String text, String pattern, String format) throws ClassNotFoundException, ExecutionException, InterruptedException {        String url = LTPURL + "?api_key=" + apiKey + "&text=" + text + "&pattern=" + pattern + "&format=" + format;        RestTemplate restTemplate = new RestTemplate();        Future<ResponseEntity> resp = pool.submit(() -> restTemplate.getForEntity(url, String.class, "分词"));        ResponseEntity<String> respBody=resp.get();        String [] respArrs=respBody.getBody().split(" ");        return respArrs;    }    /**     * 统计词频,频率归一化,即出现次数比总次数     * @param strArrs`     * @return     */    public Map<String,Double> countTF(String [] strArrs){        Map<String,Long> map=Arrays.stream(strArrs).collect(Collectors.groupingBy(Function.identity(),Collectors.counting()));        map=map.entrySet().stream().sorted(Map.Entry.comparingByValue(Collections.reverseOrder())).collect(Collectors.toMap(        //result=result.entrySet().stream().sorted(Map.Entry.comparingByValue(/*去掉这个将升序排列*/)).collect(Collectors.toMap(                Map.Entry::getKey,                Map.Entry::getValue,                (e1,e2)->e1,                LinkedHashMap::new        ));        Map<String,Double> result=new HashMap<>();        map.entrySet().stream().forEach(x->            result.put(x.getKey(),x.getValue()/(double)strArrs.length)        );        return result;    }    /**     * 获取词语str的IDF值     * @param str     * @return     */    public double getIDF(String str){        RestTemplate restTemplate=new RestTemplate();        String respSum="",respArr="";        try {            respSum = restTemplate.getForObject("https://cn.bing.com/search?q=的", String.class, "总");            respArr = restTemplate.getForObject("https://cn.bing.com/search?q=" + str, String.class, "出现某词文档数");        }catch(Exception e){            e.printStackTrace();            return 0;        }        Long sumResp=666L;        Long arrResp=666L;        Matcher m= SUM_PATTERN.matcher(respSum);        if(m.find()){            String patternStr=m.group();            sumResp= Long.parseLong(patternStr.substring(0,patternStr.indexOf(" 条结果")).replace(",",""));        }        m= SUM_PATTERN.matcher(respArr);        if(m.find()){            String patternStr=m.group();            arrResp= Long.parseLong(patternStr.substring(0,patternStr.indexOf(" 条结果")).replace(",",""));        }        if(sumResp!=666L&&arrResp!=666L){//如果都有返回结果            return Math.log(sumResp/arrResp);        }else{//            throw new RuntimeException("返回结果有误");            System.out.println("返回结果有误:"+str);            return 0;        }    }}
原创粉丝点击