深度学习-文档分类

来源:互联网 发布:大数据分析快3 编辑:程序博客网 时间:2024/04/27 09:53
本文主要是用ParagraphVectors方法做文档分类,训练数据有一些带类别的文档,预测没有类别的文档属于哪个类别。这里简单说下ParagraphVectors模型,每篇文档映射在一个唯一的向量上,由矩阵中的一列表示,每个word则类似的被映射到向量上,这个向量由另一个矩阵的列表示。使用连接方式获得新word的预测,可以说ParagraphVectors是在word2vec基础上加了一组paragraph输入列向量一起训练构成的模型
public class ParagraphVectorsClassifierExample { ParagraphVectors paragraphVectors;//声明ParagraphVectors类
LabelAwareIterator iterator;//声明要实现的迭代器接口,用来识别句子或文档及标签,这里假定所有的文档已变成字符串或词表的形式 TokenizerFactory tokenizerFactory;//声明字符串分割器 private static final Logger log = LoggerFactory.getLogger(ParagraphVectorsClassifierExample.class); public static void main(String[] args) throws Exception { ParagraphVectorsClassifierExample app = new ParagraphVectorsClassifierExample();//又是这种写法,构建实现类 app.makeParagraphVectors();//调用构建模型方法 app.checkUnlabeledData();//检查标签数据 /* Your output should be like this: Document 'health' falls into the following categories: health: 0.29721372296220205 science: 0.011684473733853906 finance: -0.14755302887323793 Document 'finance' falls into the following categories: health: -0.17290237675941766 science: -0.09579267574606627 finance: 0.4460859189453788 so,now we know categories for yet unseen documents */ } void makeParagraphVectors() throws Exception { ClassPathResource resource = new ClassPathResource("paravec/labeled");//弄一个带标签的文档路径 // build a iterator for our dataset iterator = new FileLabelAwareIterator.Builder()//实现LabelAwareIterator接口,添加数据源,构成迭代器 .addSourceFolder(resource.getFile()) .build(); tokenizerFactory = new DefaultTokenizerFactory();//构建逗号分割器 tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); // ParagraphVectors training configuration paragraphVectors = new ParagraphVectors.Builder()//ParagraphVectors继承Word2Vec,Word2Vec继承SequenceVectors,
配置ParagraphVectors的学习率,最小学习率,批大小,步数,迭代器,同时构建词和文档,词分割器
.learningRate(0.025) .minLearningRate(0.001) .batchSize(1000) .epochs(20) .iterate(iterator) .trainWordVectors(true) .tokenizerFactory(tokenizerFactory) .build(); // Start model training paragraphVectors.fit();//模型定型 } void checkUnlabeledData() throws FileNotFoundException { /* At this point we assume that we have model built and we can check//这里假定模型已经构建好,现在预测无标签的文档属于哪个类,我们装载无标签文档并对其进行检测 which categories our unlabeled document falls into. So we'll start loading our unlabeled documents and checking them */ ClassPathResource unClassifiedResource = new ClassPathResource("paravec/unlabeled");//构建无标签文档读取器 FileLabelAwareIterator unClassifiedIterator = new FileLabelAwareIterator.Builder() .addSourceFolder(unClassifiedResource.getFile()) .build(); /* Now we'll iterate over unlabeled data, and check which label it could be assigned to//预测未标记文档,很多情况一个文档可能对应多个类别,只不过每个类别值有高有低 Please note: for many domains it's normal to have 1 document fall into few labels at once, with different "weight" for each. */ MeansBuilder meansBuilder = new MeansBuilder(//构建了求质心的类, (InMemoryLookupTable<VocabWord>)paragraphVectors.getLookupTable(),//通过获取WordVectors实现类WordVectorsImpl中的getLookupTable方法获取查询table及tokenizerFactory构造MeansBuilder类
tokenizerFactory); LabelSeeker seeker = new LabelSeeker(iterator.getLabelsSource().getLabels(),//同理通过获取WordVectors实现类WordVectorsImpl中的getLookupTable方法获取查询table及标签列表构造LabelSeeker类 (InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable()); while (unClassifiedIterator.hasNextDocument()) {//遍历未分类文档 LabelledDocument document = unClassifiedIterator.nextDocument(); INDArray documentAsCentroid = meansBuilder.documentAsVector(document);//把文档转成向量 List<Pair<String, Double>> scores = seeker.getScores(documentAsCentroid);//获取文档的类别得分 /* please note, document.getLabel() is used just to show which document we're looking at now, as a substitute for printing out the whole document name. So, labels on these two documents are used like titles, just to visualize our classification done properly//注意getLabel是获取当前文档的标签 */ log.info("Document '" + document.getLabel() + "' falls into the following categories: "); for (Pair<String, Double> score: scores) {//遍历标签得分 log.info(" " + score.getFirst() + ": " + score.getSecond());//打印元素的第一个第二个元素 } } }
public class MeansBuilder {//平均值类    private VocabCache<VocabWord> vocabCache;//词汇表    private InMemoryLookupTable<VocabWord> lookupTable;//查询table    private TokenizerFactory tokenizerFactory;//分词器    public MeansBuilder(@NonNull InMemoryLookupTable<VocabWord> lookupTable,//构造方法,根据传入的参数赋值当前对象的词汇表,查询table,分词器        @NonNull TokenizerFactory tokenizerFactory) {        this.lookupTable = lookupTable;        this.vocabCache = lookupTable.getVocab();        this.tokenizerFactory = tokenizerFactory;    }    /**     * This method returns centroid (mean vector) for document.//返回文档的质心,也就是向量的平均值     *     * @param document     * @return     */    public INDArray documentAsVector(@NonNull LabelledDocument document) {//传入有标记的文档        List<String> documentAsTokens = tokenizerFactory.create(document.getContent()).getTokens();//切割文档,获取词列表        AtomicInteger cnt = new AtomicInteger(0);//声明一个原子整数0,保证线程安全        for (String word: documentAsTokens) {//统计独立词计数            if (vocabCache.containsWord(word)) cnt.incrementAndGet();        }        INDArray allWords = Nd4j.create(cnt.get(), lookupTable.layerSize());//根据词计数构建词矩阵,行是词计数,列是每个词对应的向量长度,默认100        cnt.set(0);//词计数清零        for (String word: documentAsTokens) {//给词矩阵赋值,            if (vocabCache.containsWord(word))                allWords.putRow(cnt.getAndIncrement(), lookupTable.vector(word));//根据词表索引,取出对应词权重向量的行,放入allWords矩阵        }        INDArray mean = allWords.mean(0);//通过mean(0)把矩阵合成一行,0代表维度,也是就求质心并返回        return mean;    }}
public class LabelSeeker {//寻找标签类    private List<String> labelsUsed;//声明标签列表    private InMemoryLookupTable<VocabWord> lookupTable;//声明查询table    public LabelSeeker(@NonNull List<String> labelsUsed, @NonNull InMemoryLookupTable<VocabWord> lookupTable) {//构造器        if (labelsUsed.isEmpty()) throw new IllegalStateException("You can't have 0 labels used for ParagraphVectors");        this.lookupTable = lookupTable;        this.labelsUsed = labelsUsed;    }    /**     * This method accepts vector, that represents any document,//这方法接收表示文档的向量,返回文档的距离,之前训练的类别     * and returns distances between this document, and previously trained categories     * @return     */    public List<Pair<String, Double>> getScores(@NonNull INDArray vector) {//获取得分的方法        List<Pair<String, Double>> result = new ArrayList<>();//声明列表,每个元素都是元素        for (String label: labelsUsed) {遍历标签列表            INDArray vecLabel = lookupTable.vector(label);//同理根据词表索引,取出对应词权重向量的行            if (vecLabel == null) throw new IllegalStateException("Label '"+ label+"' has no known vector!");            double sim = Transforms.cosineSim(vector, vecLabel);//把词权重向量和传入的文档做相似度            result.add(new Pair<String, Double>(label, sim));//返回和每个标签的相似度        }        return result;    }}

}
0 0
原创粉丝点击