深度学习-文档分类
来源:互联网 发布:大数据分析快3 编辑:程序博客网 时间:2024/04/27 09:53
本文主要是用ParagraphVectors方法做文档分类,训练数据有一些带类别的文档,预测没有类别的文档属于哪个类别。这里简单说下ParagraphVectors模型,每篇文档映射在一个唯一的向量上,由矩阵中的一列表示,每个word则类似的被映射到向量上,这个向量由另一个矩阵的列表示。使用连接方式获得新word的预测,可以说ParagraphVectors是在word2vec基础上加了一组paragraph输入列向量一起训练构成的模型public class ParagraphVectorsClassifierExample { ParagraphVectors paragraphVectors;//声明ParagraphVectors类
LabelAwareIterator iterator;//声明要实现的迭代器接口,用来识别句子或文档及标签,这里假定所有的文档已变成字符串或词表的形式 TokenizerFactory tokenizerFactory;//声明字符串分割器 private static final Logger log = LoggerFactory.getLogger(ParagraphVectorsClassifierExample.class); public static void main(String[] args) throws Exception { ParagraphVectorsClassifierExample app = new ParagraphVectorsClassifierExample();//又是这种写法,构建实现类 app.makeParagraphVectors();//调用构建模型方法 app.checkUnlabeledData();//检查标签数据 /* Your output should be like this: Document 'health' falls into the following categories: health: 0.29721372296220205 science: 0.011684473733853906 finance: -0.14755302887323793 Document 'finance' falls into the following categories: health: -0.17290237675941766 science: -0.09579267574606627 finance: 0.4460859189453788 so,now we know categories for yet unseen documents */ } void makeParagraphVectors() throws Exception { ClassPathResource resource = new ClassPathResource("paravec/labeled");//弄一个带标签的文档路径 // build a iterator for our dataset iterator = new FileLabelAwareIterator.Builder()//实现LabelAwareIterator接口,添加数据源,构成迭代器 .addSourceFolder(resource.getFile()) .build(); tokenizerFactory = new DefaultTokenizerFactory();//构建逗号分割器 tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); // ParagraphVectors training configuration paragraphVectors = new ParagraphVectors.Builder()//ParagraphVectors继承Word2Vec,Word2Vec继承SequenceVectors,配置ParagraphVectors的学习率,最小学习率,批大小,步数,迭代器,同时构建词和文档,词分割器.learningRate(0.025) .minLearningRate(0.001) .batchSize(1000) .epochs(20) .iterate(iterator) .trainWordVectors(true) .tokenizerFactory(tokenizerFactory) .build(); // Start model training paragraphVectors.fit();//模型定型 } void checkUnlabeledData() throws FileNotFoundException { /* At this point we assume that we have model built and we can check//这里假定模型已经构建好,现在预测无标签的文档属于哪个类,我们装载无标签文档并对其进行检测 which categories our unlabeled document falls into. So we'll start loading our unlabeled documents and checking them */ ClassPathResource unClassifiedResource = new ClassPathResource("paravec/unlabeled");//构建无标签文档读取器 FileLabelAwareIterator unClassifiedIterator = new FileLabelAwareIterator.Builder() .addSourceFolder(unClassifiedResource.getFile()) .build(); /* Now we'll iterate over unlabeled data, and check which label it could be assigned to//预测未标记文档,很多情况一个文档可能对应多个类别,只不过每个类别值有高有低 Please note: for many domains it's normal to have 1 document fall into few labels at once, with different "weight" for each. */ MeansBuilder meansBuilder = new MeansBuilder(//构建了求质心的类, (InMemoryLookupTable<VocabWord>)paragraphVectors.getLookupTable(),//通过获取WordVectors实现类WordVectorsImpl中的getLookupTable方法获取查询table及tokenizerFactory构造MeansBuilder类
tokenizerFactory); LabelSeeker seeker = new LabelSeeker(iterator.getLabelsSource().getLabels(),//同理通过获取WordVectors实现类WordVectorsImpl中的getLookupTable方法获取查询table及标签列表构造LabelSeeker类 (InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable()); while (unClassifiedIterator.hasNextDocument()) {//遍历未分类文档 LabelledDocument document = unClassifiedIterator.nextDocument(); INDArray documentAsCentroid = meansBuilder.documentAsVector(document);//把文档转成向量 List<Pair<String, Double>> scores = seeker.getScores(documentAsCentroid);//获取文档的类别得分 /* please note, document.getLabel() is used just to show which document we're looking at now, as a substitute for printing out the whole document name. So, labels on these two documents are used like titles, just to visualize our classification done properly//注意getLabel是获取当前文档的标签 */ log.info("Document '" + document.getLabel() + "' falls into the following categories: "); for (Pair<String, Double> score: scores) {//遍历标签得分 log.info(" " + score.getFirst() + ": " + score.getSecond());//打印元素的第一个第二个元素 } } }
public class MeansBuilder {//平均值类 private VocabCache<VocabWord> vocabCache;//词汇表 private InMemoryLookupTable<VocabWord> lookupTable;//查询table private TokenizerFactory tokenizerFactory;//分词器 public MeansBuilder(@NonNull InMemoryLookupTable<VocabWord> lookupTable,//构造方法,根据传入的参数赋值当前对象的词汇表,查询table,分词器 @NonNull TokenizerFactory tokenizerFactory) { this.lookupTable = lookupTable; this.vocabCache = lookupTable.getVocab(); this.tokenizerFactory = tokenizerFactory; } /** * This method returns centroid (mean vector) for document.//返回文档的质心,也就是向量的平均值 * * @param document * @return */ public INDArray documentAsVector(@NonNull LabelledDocument document) {//传入有标记的文档 List<String> documentAsTokens = tokenizerFactory.create(document.getContent()).getTokens();//切割文档,获取词列表 AtomicInteger cnt = new AtomicInteger(0);//声明一个原子整数0,保证线程安全 for (String word: documentAsTokens) {//统计独立词计数 if (vocabCache.containsWord(word)) cnt.incrementAndGet(); } INDArray allWords = Nd4j.create(cnt.get(), lookupTable.layerSize());//根据词计数构建词矩阵,行是词计数,列是每个词对应的向量长度,默认100 cnt.set(0);//词计数清零 for (String word: documentAsTokens) {//给词矩阵赋值, if (vocabCache.containsWord(word)) allWords.putRow(cnt.getAndIncrement(), lookupTable.vector(word));//根据词表索引,取出对应词权重向量的行,放入allWords矩阵 } INDArray mean = allWords.mean(0);//通过mean(0)把矩阵合成一行,0代表维度,也是就求质心并返回 return mean; }}public class LabelSeeker {//寻找标签类 private List<String> labelsUsed;//声明标签列表 private InMemoryLookupTable<VocabWord> lookupTable;//声明查询table public LabelSeeker(@NonNull List<String> labelsUsed, @NonNull InMemoryLookupTable<VocabWord> lookupTable) {//构造器 if (labelsUsed.isEmpty()) throw new IllegalStateException("You can't have 0 labels used for ParagraphVectors"); this.lookupTable = lookupTable; this.labelsUsed = labelsUsed; } /** * This method accepts vector, that represents any document,//这方法接收表示文档的向量,返回文档的距离,之前训练的类别 * and returns distances between this document, and previously trained categories * @return */ public List<Pair<String, Double>> getScores(@NonNull INDArray vector) {//获取得分的方法 List<Pair<String, Double>> result = new ArrayList<>();//声明列表,每个元素都是元素 for (String label: labelsUsed) {遍历标签列表 INDArray vecLabel = lookupTable.vector(label);//同理根据词表索引,取出对应词权重向量的行 if (vecLabel == null) throw new IllegalStateException("Label '"+ label+"' has no known vector!"); double sim = Transforms.cosineSim(vector, vecLabel);//把词权重向量和传入的文档做相似度 result.add(new Pair<String, Double>(label, sim));//返回和每个标签的相似度 } return result; }}
}
0 0
- 深度学习-文档分类
- 深度学习-线性分类
- 深度学习分类网络
- 深度学习分类网络
- 深度学习文档1.0
- 深度学习有关文档
- 深度学习-分类卷积网络
- 深度学习---之回归,分类
- 深度学习在图像分类中的应用
- 深度学习在文本分类中的应用
- 深度学习笔记(一):logistic分类
- [深度学习基础] 2. 线性分类器
- 深度学习系列之图像分类
- 关于深度学习中的分类器
- 深度学习入门笔记--图像线性分类
- 深度学习基础(一):logistic分类
- 基于深度学习的场景分类算法
- 深度学习框架Caffe图片分类教程
- 函数重载、覆盖与隐藏
- MySQL分区表详解
- 关于Bitmap使用的笔记汇总
- RecyclerView 自适应高度 正确做法
- HTML复合表格
- 深度学习-文档分类
- Symbols 错误符号分析
- 【GDOI2017模拟12.9】最近公共祖先
- 第十三周项目1
- Marriage Match IV HDU3416 spfa+isap
- Android深入浅出之Binder机制
- 数数 noip 单调队列
- PHP实现各种经典算法
- java基础学习(3)