cnn文本分类 --deeplearning4j为例子

来源:互联网 发布:软件注册码 编辑:程序博客网 时间:2024/06/05 08:00

绪论

deeplearning4j提供了一个英文文本分类的例子,虽说是英文,但是中文分词之后,依葫芦画瓢用于训练,首先是利用用word2vec生成词向量模型,参考上面一篇提到的生成代码。然后是构建cnn用于文本分类。


训练数据格式:

  1. 1小额贷款
  2. 1商家 @ 回应 十里洋场 江景会 订房 经理 电话 五星级 餐饮 五星级 k 包房 体验 8
  3. 1可以 微信同号
  4. 1看头像 加微信
  5. 1看头像 加微信
  6. 1 可以 加研欧 书院 微信公众号 里面 具体 收费 本月 25 书院 开展 试听 公开课 如果 有时间 可以来 试听
  7. 0强烈推荐 4 4 店长 非常 满意 朋友 发型 合适 我的 开心 因为 自然 每年 头发 去过 不少 理发店 接触 不少 发型师 目前为止 觉得 4 ( 帅哥 一枚 ) 称心如意 根据 发质 情况 自己的 想法 给你 设计 适合 发型 不仅 专业 而且 非常 负责 细心 最喜欢 细心 因为 只有 这样 做出 好的 效果 这次 发型 做出来 4 店长 建议 染色 那样 效果 更好 不用 那么 辛苦 太久 贴心 是不是 总之 满意 染色 后的 效果 重要 4 亲自动手 发型 强烈推荐 大家 过来 4 店长 下次 4 店长 支持 djehfjdushdjd
  8. 0非常感谢 老师 小高 5 点评 想起 老师 一起 游览 情景 历历在目 时间 有时 一个 小偷 不知不觉 四月 已经 过完 2017 已经 过去 三分之一 老话 连雨 不知 方知夏 还没 来得及 好好 享受 温熏 春光 夏日 远处 徐徐 而来 夏天 青岛 美的 青岛 避暑 理想 欢迎 老师 有机会 夏天 再来 青岛 我的 手机号 微信号 到时 记得 微信 小高 酒店 青岛 小高 老师 身体健康 阖家欢乐 万事如意
  9. 0 时候 他家 根本 拒绝 兑换 也是 得到 气死人
  10. 0商家回应 尊敬 客人 感谢您 抽出 宝贵 时间 我们 评价 心瑞 国际 月子会所 来源于 台湾 拥有 26 台式 母婴护理 经验 聚集 专业 精湛 台湾 护理 技术 团队 精心 定制 专属 护理服务 秉承 规范 操作 细致入微 精益求精 服务理念 后妈 提供 饮食 护理 康复 早教 心理健康 全方位 贴身 服务模式 孕产 期间 家庭 全方位 专业 照护 舒适 体验 会所 紧邻 国内 最好 医院 协和医院 产后 妈咪 宝宝 都有 坚实 医疗保障 我们 免费提供 健身会所 游泳馆 入住 客人 家属 使用 , 我们 不定期 举办 丰富多彩 活动 更多 孕妈咪 了解 孕期 保健知识 新生儿 喂养 知识 非常 期待 下次 妈咪 见面 心瑞 国际 月子会所 全体员工 服务热线 座机号码 请关注 微信号 微信账号
  11. 0商家回应 亲爱 贵宾 感谢您 襄阳 巴厘岛 休闲 度假酒店 肯定 支持 酒店 休闲会所 主要 提供 休闲 洗浴 游泳 桑拿 干湿 足疗 按摩 项目 如此 体验 没能 满意 我们 表示 深深 歉意 我们 足疗 专业 休闲 手法 所有 技师 素质 专业 魅力 服务 更好 巴厘岛 一切 舒适 休闲 方式 为先 满意 我们 继续 进步 动力 酒店 全体员工 期待 再次光临
  12. 0 直接 不行 必须 每人 一份 主食 给你 东西 不用 上了 不掉 浪费 服务员 b 来个 服务员 c 过来 我们 杯茶 可能 听懂 bill 拿走 一会 送来 一个新 上面 一杯 我们 收费 不要 直接 气呼呼 拿走 bill 回来 时候 打印 去掉 杯茶 bill 直接 桌子 然后 回头 牛逼 这么多 英联邦 国家 地区 别人 一看 中国 客客气气 何况 还是 顾客 妹的
  13. 0维权 怎么 套路

训练代码:

LabeledSentence.java

  1. package com.dianping.cnn.textclassify;
  2. import java.io.BufferedReader;
  3. import java.io.FileInputStream;
  4. import java.io.InputStreamReader;
  5. import java.util.ArrayList;
  6. import java.util.Collections;
  7. import java.util.HashMap;
  8. import java.util.List;
  9. import java.util.Map;
  10. import java.util.Random;
  11. import org.datavec.api.util.RandomUtils;
  12. import org.deeplearning4j.berkeley.Pair;
  13. import org.deeplearning4j.iterator.LabeledSentenceProvider;
  14. import org.nd4j.linalg.collection.CompactHeapStringList;
  15. public class LabeledSentence implements LabeledSentenceProvider {
  16. private int totalCount;
  17. private Map<String, List<String>> filesByLabel;
  18. private List<String> normList;
  19. private List<String> negList;
  20. private final List<String> sentenslist;
  21. private final int[] labelIndexes;
  22. private final Random rng;
  23. private final int[] order;
  24. private final List<String> allLabels;
  25. private int cursor = 0;
  26. public LabeledSentence(String path) {
  27. this(path, new Random());
  28. }
  29. public LabeledSentence(String path, Random rng) {
  30. totalCount = 0;
  31. filesByLabel = new HashMap<String, List<String>>();
  32. normList = new ArrayList<String>();
  33. negList = new ArrayList<>();
  34. BufferedReader buffered = null;
  35. try {
  36. buffered = new BufferedReader(new InputStreamReader(
  37. new FileInputStream(path)));
  38. String line = buffered.readLine();
  39. while (line != null) {
  40. String[] lines = line.split("\t");
  41. String label = lines[0];
  42. String contennt = lines[1];
  43. if ("1".equalsIgnoreCase(label)) {
  44. normList.add(contennt);
  45. } else if("0".equalsIgnoreCase(label)) {
  46. negList.add(contennt);
  47. }
  48. totalCount++;
  49. line = buffered.readLine();
  50. }
  51. buffered.close();
  52. } catch (Exception e) {
  53. e.printStackTrace();
  54. }
  55. System.out.println("totalCount is:"+totalCount);
  56. filesByLabel.put("1", normList);
  57. filesByLabel.put("0", negList);
  58. this.rng = rng;
  59. if (rng == null) {
  60. order = null;
  61. } else {
  62. order = new int[totalCount];
  63. for (int i = 0; i < totalCount; i++) {
  64. order[i] = i;
  65. }
  66. RandomUtils.shuffleInPlace(order, rng);
  67. }
  68. allLabels = new ArrayList<>(filesByLabel.keySet());
  69. Collections.sort(allLabels);
  70. Map<String, Integer> labelsToIdx = new HashMap<>();
  71. for (int i = 0; i < allLabels.size(); i++) {
  72. labelsToIdx.put(allLabels.get(i), i);
  73. }
  74. sentenslist = new CompactHeapStringList();
  75. labelIndexes = new int[totalCount];
  76. int position = 0;
  77. for (Map.Entry<String, List<String>> entry : filesByLabel.entrySet()) {
  78. int labelIdx = labelsToIdx.get(entry.getKey());
  79. for (String f : entry.getValue()) {
  80. sentenslist.add(f);
  81. labelIndexes[position] = labelIdx;
  82. position++;
  83. }
  84. }
  85. }
  86. @Override
  87. public boolean hasNext() {
  88. return cursor < totalCount;
  89. }
  90. @Override
  91. public Pair<String, String> nextSentence() {
  92. int idx;
  93. if (rng == null) {
  94. idx = cursor++;
  95. } else {
  96. idx = order[cursor++];
  97. }
  98. ;
  99. String label = allLabels.get(labelIndexes[idx]);
  100. String sentence;
  101. sentence = sentenslist.get(idx);
  102. return new Pair<>(sentence, label);
  103. }
  104. @Override
  105. public void reset() {
  106. cursor = 0;
  107. if (rng != null) {
  108. RandomUtils.shuffleInPlace(order, rng);
  109. }
  110. }
  111. @Override
  112. public int totalNumSentences() {
  113. return totalCount;
  114. }
  115. @Override
  116. public List<String> allLabels() {
  117. return allLabels;
  118. }
  119. @Override
  120. public int numLabelClasses() {
  121. return allLabels.size();
  122. }
  123. }

CnnSentenceDataSetIterator.java

  1. package com.dianping.cnn.textclassify;
  2. import lombok.AllArgsConstructor;
  3. import lombok.NonNull;
  4. import org.deeplearning4j.berkeley.Pair;
  5. import org.deeplearning4j.iterator.LabeledSentenceProvider;
  6. import org.deeplearning4j.iterator.provider.LabelAwareConverter;
  7. import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
  8. import org.deeplearning4j.text.documentiterator.LabelAwareDocumentIterator;
  9. import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
  10. import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
  11. import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
  12. import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
  13. import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
  14. import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
  15. import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
  16. import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
  17. import org.nd4j.linalg.api.ndarray.INDArray;
  18. import org.nd4j.linalg.dataset.DataSet;
  19. import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
  20. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  21. import org.nd4j.linalg.factory.Nd4j;
  22. import org.nd4j.linalg.indexing.INDArrayIndex;
  23. import org.nd4j.linalg.indexing.NDArrayIndex;
  24. import java.util.*;
  25. public class CnnSentenceDataSetIterator implements DataSetIterator {
  26. public enum UnknownWordHandling {
  27. RemoveWord, UseUnknownVector
  28. }
  29. private static final String UNKNOWN_WORD_SENTINEL = "UNKNOWN_WORD_SENTINEL";
  30. private LabeledSentenceProvider sentenceProvider = null;
  31. private WordVectors wordVectors;
  32. private TokenizerFactory tokenizerFactory;
  33. private UnknownWordHandling unknownWordHandling;
  34. private boolean useNormalizedWordVectors;
  35. private int minibatchSize;
  36. private int maxSentenceLength;
  37. private boolean sentencesAlongHeight;
  38. private DataSetPreProcessor dataSetPreProcessor;
  39. private int wordVectorSize;
  40. private int numClasses;
  41. private Map<String, Integer> labelClassMap;
  42. private INDArray unknown;
  43. private int cursor = 0;
  44. private CnnSentenceDataSetIterator(Builder builder) {
  45. this.sentenceProvider = builder.sentenceProvider;
  46. this.wordVectors = builder.wordVectors;
  47. this.tokenizerFactory = builder.tokenizerFactory;
  48. this.unknownWordHandling = builder.unknownWordHandling;
  49. this.useNormalizedWordVectors = builder.useNormalizedWordVectors;
  50. this.minibatchSize = builder.minibatchSize;
  51. this.maxSentenceLength = builder.maxSentenceLength;
  52. this.sentencesAlongHeight = builder.sentencesAlongHeight;
  53. this.dataSetPreProcessor = builder.dataSetPreProcessor;
  54. this.numClasses = this.sentenceProvider.numLabelClasses();
  55. this.labelClassMap = new HashMap<>();
  56. int count = 0;
  57. //First: sort the labels to ensure the same label assignment order (say train vs. test)
  58. List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels());
  59. Collections.sort(sortedLabels);
  60. for (String s : sortedLabels) {
  61. this.labelClassMap.put(s, count++);
  62. }
  63. if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
  64. if (useNormalizedWordVectors) {
  65. wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
  66. } else {
  67. wordVectors.getWordVectorMatrix(wordVectors.getUNK());
  68. }
  69. }
  70. this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
  71. }
  72. /**
  73. * Generally used post training time to load a single sentence for predictions
  74. */
  75. public INDArray loadSingleSentence(String sentence) {
  76. List<String> tokens = tokenizeSentence(sentence);
  77. int[] featuresShape = new int[] {1, 1, 0, 0};
  78. if (sentencesAlongHeight) {
  79. featuresShape[2] = Math.min(maxSentenceLength, tokens.size());
  80. featuresShape[3] = wordVectorSize;
  81. } else {
  82. featuresShape[2] = wordVectorSize;
  83. featuresShape[3] = Math.min(maxSentenceLength, tokens.size());
  84. }
  85. INDArray features = Nd4j.create(featuresShape);
  86. int length = (sentencesAlongHeight ? featuresShape[2] : featuresShape[3]);
  87. for (int i = 0; i < length; i++) {
  88. INDArray vector = getVector(tokens.get(i));
  89. INDArrayIndex[] indices = new INDArrayIndex[4];
  90. indices[0] = NDArrayIndex.point(0);
  91. indices[1] = NDArrayIndex.point(0);
  92. if (sentencesAlongHeight) {
  93. indices[2] = NDArrayIndex.point(i);
  94. indices[3] = NDArrayIndex.all();
  95. } else {
  96. indices[2] = NDArrayIndex.all();
  97. indices[3] = NDArrayIndex.point(i);
  98. }
  99. features.put(indices, vector);
  100. }
  101. return features;
  102. }
  103. private INDArray getVector(String word) {
  104. INDArray vector;
  105. if (unknownWordHandling == UnknownWordHandling.UseUnknownVector && word == UNKNOWN_WORD_SENTINEL) { //Yes, this *should* be using == for the sentinel String here
  106. vector = unknown;
  107. } else {
  108. if (useNormalizedWordVectors) {
  109. vector = wordVectors.getWordVectorMatrixNormalized(word);
  110. } else {
  111. vector = wordVectors.getWordVectorMatrix(word);
  112. }
  113. }
  114. return vector;
  115. }
  116. private List<String> tokenizeSentence(String sentence) {
  117. Tokenizer t = tokenizerFactory.create(sentence);
  118. List<String> tokens = new ArrayList<>();
  119. while (t.hasMoreTokens()) {
  120. String token = t.nextToken();
  121. if (!wordVectors.hasWord(token)) {
  122. switch (unknownWordHandling) {
  123. case RemoveWord:
  124. continue;
  125. case UseUnknownVector:
  126. token = UNKNOWN_WORD_SENTINEL;
  127. }
  128. }
  129. tokens.add(token);
  130. }
  131. return tokens;
  132. }
  133. public Map<String, Integer> getLabelClassMap() {
  134. return new HashMap<>(labelClassMap);
  135. }
  136. @Override
  137. public List<String> getLabels() {
  138. //We don't want to just return the list from the LabelledSentenceProvider, as we sorted them earlier to do the
  139. // String -> Integer mapping
  140. String[] str = new String[labelClassMap.size()];
  141. for (Map.Entry<String, Integer> e : labelClassMap.entrySet()) {
  142. str[e.getValue()] = e.getKey();
  143. }
  144. return Arrays.asList(str);
  145. }
  146. @Override
  147. public boolean hasNext() {
  148. if (sentenceProvider == null) {
  149. throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
  150. }
  151. return sentenceProvider.hasNext();
  152. }
  153. @Override
  154. public DataSet next() {
  155. return next(minibatchSize);
  156. }
  157. @Override
  158. public DataSet next(int num) {
  159. if (sentenceProvider == null) {
  160. throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
  161. }
  162. List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(num);
  163. int maxLength = -1;
  164. int minLength = Integer.MAX_VALUE; //Track to we know if we can skip mask creation for "all same length" case
  165. for (int i = 0; i < num && sentenceProvider.hasNext(); i++) {
  166. Pair<String, String> p = sentenceProvider.nextSentence();
  167. List<String> tokens = tokenizeSentence(p.getFirst());
  168. maxLength = Math.max(maxLength, tokens.size());
  169. tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
  170. }
  171. if (maxSentenceLength > 0 && maxLength > maxSentenceLength) {
  172. maxLength = maxSentenceLength;
  173. }
  174. int currMinibatchSize = tokenizedSentences.size();
  175. INDArray labels = Nd4j.create(currMinibatchSize, numClasses);
  176. for (int i = 0; i < tokenizedSentences.size(); i++) {
  177. String labelStr = tokenizedSentences.get(i).getSecond();
  178. if (!labelClassMap.containsKey(labelStr)) {
  179. throw new IllegalStateException("Got label \"" + labelStr
  180. + "\" that is not present in list of LabeledSentenceProvider labels");
  181. }
  182. int labelIdx = labelClassMap.get(labelStr);
  183. labels.putScalar(i, labelIdx, 1.0);
  184. }
  185. int[] featuresShape = new int[4];
  186. featuresShape[0] = currMinibatchSize;
  187. featuresShape[1] = 1;
  188. if (sentencesAlongHeight) {
  189. featuresShape[2] = maxLength;
  190. featuresShape[3] = wordVectorSize;
  191. } else {
  192. featuresShape[2] = wordVectorSize;
  193. featuresShape[3] = maxLength;
  194. }
  195. INDArray features = Nd4j.create(featuresShape);
  196. for (int i = 0; i < currMinibatchSize; i++) {
  197. List<String> currSentence = tokenizedSentences.get(i).getFirst();
  198. for (int j = 0; j < currSentence.size() && j < maxSentenceLength; j++) {
  199. INDArray vector = getVector(currSentence.get(j));
  200. INDArrayIndex[] indices = new INDArrayIndex[4];
  201. //TODO REUSE
  202. indices[0] = NDArrayIndex.point(i);
  203. indices[1] = NDArrayIndex.point(0);
  204. if (sentencesAlongHeight) {
  205. indices[2] = NDArrayIndex.point(j);
  206. indices[3] = NDArrayIndex.all();
  207. } else {
  208. indices[2] = NDArrayIndex.all();
  209. indices[3] = NDArrayIndex.point(j);
  210. }
  211. features.put(indices, vector);
  212. }
  213. }
  214. INDArray featuresMask = null;
  215. if (minLength != maxLength) {
  216. featuresMask = Nd4j.create(currMinibatchSize, maxLength);
  217. for (int i = 0; i < currMinibatchSize; i++) {
  218. int sentenceLength = tokenizedSentences.get(i).getFirst().size();
  219. if (sentenceLength >= maxLength) {
  220. featuresMask.getRow(i).assign(1.0);
  221. } else {
  222. featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.interval(0, sentenceLength)).assign(1.0);
  223. }
  224. }
  225. }
  226. DataSet ds = new DataSet(features, labels, featuresMask, null);
  227. if (dataSetPreProcessor != null) {
  228. dataSetPreProcessor.preProcess(ds);
  229. }
  230. cursor += ds.numExamples();
  231. return ds;
  232. }
  233. @Override
  234. public int totalExamples() {
  235. return sentenceProvider.totalNumSentences();
  236. }
  237. @Override
  238. public int inputColumns() {
  239. return wordVectorSize;
  240. }
  241. @Override
  242. public int totalOutcomes() {
  243. return numClasses;
  244. }
  245. @Override
  246. public boolean resetSupported() {
  247. return true;
  248. }
  249. @Override
  250. public boolean asyncSupported() {
  251. return true;
  252. }
  253. @Override
  254. public void reset() {
  255. cursor = 0;
  256. sentenceProvider.reset();
  257. }
  258. @Override
  259. public int batch() {
  260. return minibatchSize;
  261. }
  262. @Override
  263. public int cursor() {
  264. return cursor;
  265. }
  266. @Override
  267. public int numExamples() {
  268. return totalExamples();
  269. }
  270. @Override
  271. public void setPreProcessor(DataSetPreProcessor preProcessor) {
  272. this.dataSetPreProcessor = preProcessor;
  273. }
  274. @Override
  275. public DataSetPreProcessor getPreProcessor() {
  276. return dataSetPreProcessor;
  277. }
  278. @Override
  279. public void remove() {
  280. throw new UnsupportedOperationException("Not supported");
  281. }
  282. public static class Builder {
  283. private LabeledSentenceProvider sentenceProvider = null;
  284. private WordVectors wordVectors;
  285. private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
  286. private UnknownWordHandling unknownWordHandling = UnknownWordHandling.RemoveWord;
  287. private boolean useNormalizedWordVectors = true;
  288. private int maxSentenceLength = -1;
  289. private int minibatchSize = 32;
  290. private boolean sentencesAlongHeight = true;
  291. private DataSetPreProcessor dataSetPreProcessor;
  292. /**
  293. * Specify how the (labelled) sentences / documents should be provided
  294. */
  295. public Builder sentenceProvider(LabeledSentenceProvider labeledSentenceProvider) {
  296. this.sentenceProvider = labeledSentenceProvider;
  297. return this;
  298. }
  299. /**
  300. * Specify how the (labelled) sentences / documents should be provided
  301. */
  302. public Builder sentenceProvider(LabelAwareIterator iterator, @NonNull List<String> labels) {
  303. LabelAwareConverter converter = new LabelAwareConverter(iterator, labels);
  304. return sentenceProvider(converter);
  305. }
  306. /**
  307. * Specify how the (labelled) sentences / documents should be provided
  308. */
  309. public Builder sentenceProvider(LabelAwareDocumentIterator iterator, @NonNull List<String> labels) {
  310. DocumentIteratorConverter converter = new DocumentIteratorConverter(iterator);
  311. return sentenceProvider(converter, labels);
  312. }
  313. /**
  314. * Specify how the (labelled) sentences / documents should be provided
  315. */
  316. public Builder sentenceProvider(LabelAwareSentenceIterator iterator, @NonNull List<String> labels) {
  317. SentenceIteratorConverter converter = new SentenceIteratorConverter(iterator);
  318. return sentenceProvider(converter, labels);
  319. }
  320. /**
  321. * Provide the WordVectors instance that should be used for training
  322. */
  323. public Builder wordVectors(WordVectors wordVectors) {
  324. this.wordVectors = wordVectors;
  325. return this;
  326. }
  327. /**
  328. * The {@link TokenizerFactory} that should be used. Defaults to {@link DefaultTokenizerFactory}
  329. */
  330. public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
  331. this.tokenizerFactory = tokenizerFactory;
  332. return this;
  333. }
  334. /**
  335. * Specify how unknown words (those that don't have a word vector in the provided WordVectors instance) should be
  336. * handled. Default: remove/ignore unknown words.
  337. */
  338. public Builder unknownWordHandling(UnknownWordHandling unknownWordHandling) {
  339. this.unknownWordHandling = unknownWordHandling;
  340. return this;
  341. }
  342. /**
  343. * Minibatch size to use for the DataSetIterator
  344. */
  345. public Builder minibatchSize(int minibatchSize) {
  346. this.minibatchSize = minibatchSize;
  347. return this;
  348. }
  349. /**
  350. * Whether normalized word vectors should be used. Default: true
  351. */
  352. public Builder useNormalizedWordVectors(boolean useNormalizedWordVectors) {
  353. this.useNormalizedWordVectors = useNormalizedWordVectors;
  354. return this;
  355. }
  356. /**
  357. * Maximum sentence/document length. If sentences exceed this, they will be truncated to this length by
  358. * taking the first 'maxSentenceLength' known words.
  359. */
  360. public Builder maxSentenceLength(int maxSentenceLength) {
  361. this.maxSentenceLength = maxSentenceLength;
  362. return this;
  363. }
  364. /**
  365. * If true (default): output features data with shape [minibatchSize, 1, maxSentenceLength, wordVectorSize]<br>
  366. * If false: output features with shape [minibatchSize, 1, wordVectorSize, maxSentenceLength]
  367. */
  368. public Builder sentencesAlongHeight(boolean sentencesAlongHeight) {
  369. this.sentencesAlongHeight = sentencesAlongHeight;
  370. return this;
  371. }
  372. /**
  373. * Optional DataSetPreProcessor
  374. */
  375. public Builder dataSetPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
  376. this.dataSetPreProcessor = dataSetPreProcessor;
  377. return this;
  378. }
  379. public CnnSentenceDataSetIterator build() {
  380. if (wordVectors == null) {
  381. throw new IllegalStateException(
  382. "Cannot build CnnSentenceDataSetIterator without a WordVectors instance");
  383. }
  384. return new CnnSentenceDataSetIterator(this);
  385. }
  386. }
  387. }

TrainAdxCnnModel.java

  1. package com.dianping.cnn.textclassify;
  2. import java.io.File;
  3. import java.io.FileNotFoundException;
  4. import java.io.UnsupportedEncodingException;
  5. import java.util.List;
  6. import java.util.Random;
  7. import org.deeplearning4j.eval.Evaluation;
  8. import org.deeplearning4j.iterator.LabeledSentenceProvider;
  9. import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
  10. import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
  11. import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
  12. import org.deeplearning4j.nn.conf.ConvolutionMode;
  13. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  14. import org.deeplearning4j.nn.conf.Updater;
  15. import org.deeplearning4j.nn.conf.graph.MergeVertex;
  16. import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
  17. import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
  18. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  19. import org.deeplearning4j.nn.conf.layers.PoolingType;
  20. import org.deeplearning4j.nn.graph.ComputationGraph;
  21. import org.deeplearning4j.nn.weights.WeightInit;
  22. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  23. import org.nd4j.linalg.activations.Activation;
  24. import org.nd4j.linalg.api.ndarray.INDArray;
  25. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  26. import org.nd4j.linalg.lossfunctions.LossFunctions;
  27. public class TrainAdxCnnModel {
  28. public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException {
  29. String WORD_VECTORS_PATH = "adx/word2vec.model";
  30. // 基础配置
  31. int batchSize = 10;
  32. int vectorSize = 100; // 词典向量的维度,这边是100
  33. int nEpochs =3; // 迭代代数
  34. int truncateReviewsToLength = 256; // 词长大于256则抛弃
  35. int cnnLayerFeatureMaps = 100; // 卷积神经网络特征图标 / channels / CNN每层layer的深度
  36. PoolingType globalPoolingType = PoolingType.MAX;
  37. Random rng = new Random(100); // 随机抽样
  38. // 设置网络配置->我们有多个卷积层,每个带宽3,4,5的滤波器
  39. ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
  40. .weightInit(WeightInit.RELU)
  41. .activation(Activation.LEAKYRELU)
  42. .updater(Updater.ADAM)
  43. .convolutionMode(ConvolutionMode.Same) //This is important so we can 'stack' the results later
  44. .regularization(true).l2(0.0001)
  45. .learningRate(0.01)
  46. .graphBuilder()
  47. .addInputs("input")
  48. .addLayer("cnn3", new ConvolutionLayer.Builder()
  49. .kernelSize(3,vectorSize)
  50. .stride(1,vectorSize)
  51. .nIn(1)
  52. .nOut(cnnLayerFeatureMaps)
  53. .build(), "input")
  54. .addLayer("cnn4", new ConvolutionLayer.Builder()
  55. .kernelSize(4,vectorSize)
  56. .stride(1,vectorSize)
  57. .nIn(1)
  58. .nOut(cnnLayerFeatureMaps)
  59. .build(), "input")
  60. .addLayer("cnn5", new ConvolutionLayer.Builder()
  61. .kernelSize(5,vectorSize)
  62. .stride(1,vectorSize)
  63. .nIn(1)
  64. .nOut(cnnLayerFeatureMaps)
  65. .build(), "input")
  66. .addVertex("merge", new MergeVertex(), "cnn3", "cnn4", "cnn5") //Perform depth concatenation
  67. .addLayer("globalPool", new GlobalPoolingLayer.Builder()
  68. .poolingType(globalPoolingType)
  69. .build(), "merge")
  70. .addLayer("out", new OutputLayer.Builder()
  71. .lossFunction(LossFunctions.LossFunction.MCXENT)
  72. .activation(Activation.SOFTMAX)
  73. .nIn(3*cnnLayerFeatureMaps)
  74. .nOut(2) //2 classes: positive or negative
  75. .build(), "globalPool")
  76. .setOutputs("out")
  77. .build();
  78. ComputationGraph net = new ComputationGraph(config);
  79. net.init();
  80. net.setListeners(new ScoreIterationListener(1));
  81. // 加载向量字典并获取训练集合测试集的DataSetIterators
  82. System.out
  83. .println("Loading word vectors and creating DataSetIterators");
  84. WordVectors wordVectors = WordVectorSerializer.fromPair(WordVectorSerializer.loadTxt(new File(WORD_VECTORS_PATH)));
  85. DataSetIterator trainIter = getDataSetIterator(true, wordVectors,batchSize, truncateReviewsToLength, rng);
  86. DataSetIterator testIter = getDataSetIterator(false, wordVectors,batchSize, truncateReviewsToLength, rng);
  87. System.out.println("Starting training");
  88. for (int i = 0; i < nEpochs; i++) {
  89. net.fit(trainIter);
  90. trainIter.reset();
  91. // 进行网络演化(进化)获得网络判定参数
  92. Evaluation evaluation = net.evaluate(testIter);
  93. testIter.reset();
  94. System.out.println(evaluation.stats());
  95. }
  96. // 训练之后:加载一个句子并输出预测
  97. String contentsFirstPas = "我的 手机 是 手机号码";
  98. INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)testIter).loadSingleSentence(contentsFirstPas);
  99. INDArray predictionsFirstNegative = net.outputSingle(featuresFirstNegative);
  100. List<String> labels = testIter.getLabels();
  101. System.out.println("\n\nPredictions for first negative review:");
  102. for( int i=0; i<labels.size(); i++ ){
  103. System.out.println("P(" + labels.get(i) + ") = " + predictionsFirstNegative.getDouble(i));
  104. }
  105. }
  106. private static DataSetIterator getDataSetIterator(boolean isTraining,
  107. WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
  108. Random rng) {
  109. String path = isTraining ? "adx/rnnsenec.txt" : "adx/rnnsenectest.txt";
  110. LabeledSentenceProvider sentenceProvider = new LabeledSentence(path,
  111. rng);
  112. return new CnnSentenceDataSetIterator.Builder()
  113. .sentenceProvider(sentenceProvider).wordVectors(wordVectors)
  114. .minibatchSize(minibatchSize)
  115. .maxSentenceLength(maxSentenceLength)
  116. .useNormalizedWordVectors(false).build();
  117. }
  118. }


模型过程结果:

  1. Loading word vectors and creating DataSetIterators
  2. totalCount is:60
  3. totalCount is:17
  4. Starting training
  5. Examples labeled as 0 classified by model as 0: 9 times
  6. Examples labeled as 1 classified by model as 0: 2 times
  7. Examples labeled as 1 classified by model as 1: 6 times
  8. ==========================Scores========================================
  9. Accuracy: 0.8824
  10. Precision: 0.9091
  11. Recall: 0.875
  12. F1 Score: 0.8917
  13. ========================================================================
  14. Examples labeled as 0 classified by model as 0: 9 times
  15. Examples labeled as 1 classified by model as 0: 1 times
  16. Examples labeled as 1 classified by model as 1: 7 times
  17. ==========================Scores========================================
  18. Accuracy: 0.9412
  19. Precision: 0.95
  20. Recall: 0.9375
  21. F1 Score: 0.9437
  22. ========================================================================
  23. Examples labeled as 0 classified by model as 0: 9 times
  24. Examples labeled as 1 classified by model as 1: 8 times
  25. ==========================Scores========================================
  26. Accuracy: 1
  27. Precision: 1
  28. Recall: 1
  29. F1 Score: 1
  30. ========================================================================
  31. Predictions for first negative review:
  32. P(0) = 0.4453294575214386
  33. P(1) = 0.554670512676239

有问题联系我微信: xuxu_ge