lstm文本分类--deeplearning4j为例

来源:互联网 发布:旭日阳刚成名网络视频 编辑:程序博客网 时间:2024/06/05 21:39

绪论

今下午看到deeplearning4j提供了好几个文本分类的例子,都是利用word2vec与lstm相结合的例子,今天下午在其上面的代码改了下,用自己的数据的格式,跑了下,记录下,我的基本数据格式如下。

训练集格式:

  1. 1 可以 兰容网 问下 咨询师 里面 不仅 可以 咨询 任何 整形 问题 并且 全国各地 可以 帮你 查询 推荐 最好 医院 专家 公立医院 三甲医院 专科 整形医院 可以 查询 之前 就是 上面
  2. 1小额贷款
  3. 1商家 @ 回应 十里洋场 江景会 订房 经理 电话 五星级 餐饮 五星级 k 包房 体验 8
  4. 1可以 微信同号
  5. 1看头像 加微信
  6. 1看头像 加微信
  7. 1 可以 加研欧 书院 微信公众号 里面 具体 收费 本月 25 书院 开展 试听 公开课 如果 有时间 可以来 试听
  8. 0强烈推荐 4 4 店长 非常 满意 朋友 发型 合适 我的 开心 因为 自然 每年 头发 去过 不少 理发店 接触 不少 发型师 目前为止 觉得 4 ( 帅哥 一枚 ) 称心如意 根据 发质 情况 自己的 想法 给你 设计 适合 发型 不仅 专业 而且 非常 负责 细心 最喜欢 细心 因为 只有 这样 做出 好的 效果 这次 发型 做出来 4 店长 建议 染色 那样 效果 更好 不用 那么 辛苦 太久 贴心 是不是 总之 满意 染色 后的 效果 重要 4 亲自动手 发型 强烈推荐 大家 过来 4 店长 下次 4 店长 支持 djehfjdushdjd
  9. 0非常感谢 老师 小高 5 点评 想起 老师 一起 游览 情景 历历在目 时间 有时 一个 小偷 不知不觉 四月 已经 过完 2017 已经 过去 三分之一 老话 连雨 不知 方知夏 还没 来得及 好好 享受 温熏 春光 夏日 远处 徐徐 而来 夏天 青岛 美的 青岛 避暑 理想 欢迎 老师 有机会 夏天 再来 青岛 我的 手机号 微信号 到时 记得 微信 小高 酒店 青岛 小高 老师 身体健康 阖家欢乐 万事如意
  10. 0 时候 他家 根本 拒绝 兑换 也是 得到 气死人
  11. 0商家回应 尊敬 客人 感谢您 抽出 宝贵 时间 我们 评价 心瑞 国际 月子会所 来源于 台湾 拥有 26 台式 母婴护理 经验 聚集 专业 精湛 台湾 护理 技术 团队 精心 定制 专属 护理服务 秉承 规范 操作 细致入微 精益求精 服务理念 后妈 提供 饮食 护理 康复 早教 心理健康 全方位 贴身 服务模式 孕产 期间 家庭 全方位 专业 照护 舒适 体验 会所 紧邻 国内 最好 医院 协和医院 产后 妈咪 宝宝 都有 坚实 医疗保障 我们 免费提供 健身会所 游泳馆 入住 客人 家属 使用 , 我们 不定期 举办 丰富多彩 活动 更多 孕妈咪 了解 孕期 保健知识 新生儿 喂养 知识 非常 期待 下次 妈咪 见面 心瑞 国际 月子会所 全体员工 服务热线 座机号码 请关注 微信号 微信账号
  12. 0商家回应 亲爱 贵宾 感谢您 襄阳 巴厘岛 休闲 度假酒店 肯定 支持 酒店 休闲会所 主要 提供 休闲 洗浴 游泳 桑拿 干湿 足疗 按摩 项目 如此 体验 没能 满意 我们 表示 深深 歉意 我们 足疗 专业 休闲 手法 所有 技师 素质 专业 魅力 服务 更好 巴厘岛 一切 舒适 休闲 方式 为先 满意 我们 继续 进步 动力 酒店 全体员工 期待 再次光临
  13. 0 直接 不行 必须 每人 一份 主食 给你 东西 不用 上了 不掉 浪费 服务员 b 来个 服务员 c 过来 我们 杯茶 可能 听懂 bill 拿走 一会 送来 一个新 上面 一杯 我们 收费 不要 直接 气呼呼 拿走 bill 回来 时候 打印 去掉 杯茶 bill 直接 桌子 然后 回头 牛逼 这么多 英联邦 国家 地区 别人 一看 中国 客客气气 何况 还是 顾客 妹的
  14. 0维权 怎么 套路

测试格式同上。



maven依赖:

  1. <properties>
  2. <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
  3. <nd4j.version>0.8.0</nd4j.version>
  4. <dl4j.version>0.8.0</dl4j.version>
  5. <datavec.version>0.8.0</datavec.version>
  6. <maven.compiler.target>1.8</maven.compiler.target>
  7. <maven.compiler.source>1.8</maven.compiler.source>
  8. </properties>
  9. <dependencies>
  10. <dependency>
  11. <groupId>org.nd4j</groupId>
  12. <artifactId>nd4j-native</artifactId>
  13. <version>${nd4j.version}</version>
  14. </dependency>
  15. <dependency>
  16. <groupId>org.deeplearning4j</groupId>
  17. <artifactId>deeplearning4j-core</artifactId>
  18. <version>${dl4j.version}</version>
  19. </dependency>
  20. <dependency>
  21. <groupId>org.deeplearning4j</groupId>
  22. <artifactId>deeplearning4j-nlp</artifactId>
  23. <version>${dl4j.version}</version>
  24. </dependency>
  25. <dependency>
  26. <groupId>org.datavec</groupId>
  27. <artifactId>datavec-api</artifactId>
  28. <version>${datavec.version}</version>
  29. </dependency>
  30. <dependency>
  31. <groupId>com.meituan</groupId>
  32. <artifactId>nlp-utils</artifactId>
  33. <version>0.0.1-SNAPSHOT</version>
  34. </dependency>



先把文本转化成word2vec、lstm模型训练的格式,训练并存储word2vec模型:

  1. package com.dianping.recurrent.adx;
  2. import java.io.BufferedReader;
  3. import java.io.BufferedWriter;
  4. import java.io.FileInputStream;
  5. import java.io.FileNotFoundException;
  6. import java.io.FileOutputStream;
  7. import java.io.IOException;
  8. import java.io.InputStreamReader;
  9. import java.io.OutputStreamWriter;
  10. import java.util.stream.Collectors;
  11. import org.apache.commons.lang.StringUtils;
  12. import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
  13. import org.deeplearning4j.models.word2vec.Word2Vec;
  14. import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
  15. import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
  16. import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
  17. import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
  18. import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
  19. import org.slf4j.Logger;
  20. import org.slf4j.LoggerFactory;
  21. import com.dianping.recurrent.util.PathUtils;
  22. import com.meituan.nlp.util.TextUtil;
  23. import com.meituan.nlp.util.WordUtil;
  24. public class PrepareWordVector {
  25. private static Logger log = LoggerFactory.getLogger(PrepareWordVector.class);
  26. private static String datapath=PathUtils.INPUT_ADX;
  27. public static void transtomodel(String input, String outputword22vec,String outputrnn) throws IOException {
  28. BufferedReader reader = null;
  29. BufferedWriter writerword2vec = null;
  30. BufferedWriter writerrnn = null;
  31. reader = new BufferedReader(new InputStreamReader(new FileInputStream(
  32. input)));
  33. writerword2vec = new BufferedWriter(new OutputStreamWriter(
  34. new FileOutputStream(outputword22vec)));
  35. writerrnn = new BufferedWriter(new OutputStreamWriter(
  36. new FileOutputStream(outputrnn)));
  37. String line = reader.readLine();
  38. while (line != null) {
  39. String label = line.split("\t")[0];
  40. String content = line.split("\t")[1];
  41. if (StringUtils.isNotBlank(content)) {
  42. String result = WordUtil
  43. .getAdSegmentNotURL(
  44. WordUtil.replaceAllADXSynonyms(TextUtil.fan2Jian(WordUtil
  45. .converToDigitStr(WordUtil
  46. .replaceAdxAll(content
  47. .toLowerCase())))))
  48. .stream().collect(Collectors.joining(" "));
  49. writerrnn.write(label + "\t" + result + "\n");
  50. writerword2vec.write(result + "\n");
  51. }
  52. line = reader.readLine();
  53. }
  54. reader.close();
  55. writerrnn.close();
  56. writerword2vec.close();
  57. }
  58. public static void trainword2vec(String inputpath, String outputpath)
  59. throws IOException {
  60. SentenceIterator iter = new BasicLineIterator(inputpath);
  61. TokenizerFactory t = new DefaultTokenizerFactory();
  62. t.setTokenPreProcessor(new CommonPreprocessor());
  63. log.info("build word2vec will start");
  64. Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(5)
  65. .layerSize(100).seed(42).windowSize(20).iterate(iter)
  66. .tokenizerFactory(t).build();
  67. log.info("Fitting Word2Vec model....");
  68. vec.fit();
  69. log.info("Writing word vectors to text file....");
  70. // Write word vectors to file
  71. WordVectorSerializer.writeWordVectors(vec, outputpath);
  72. }
  73. public static void main(String[] args) throws IOException {
  74. //transtomodel(datapath,"adx/wordvecsence.txt","adx/rnnsenec.txt");
  75. trainword2vec("adx/wordvecsence.txt","adx/word2vec.model");
  76. }
  77. }

ADXIterator迭代:

  1. package com.dianping.recurrent.adx;
  2. import static org.nd4j.linalg.indexing.NDArrayIndex.all;
  3. import static org.nd4j.linalg.indexing.NDArrayIndex.point;
  4. import java.io.BufferedReader;
  5. import java.io.File;
  6. import java.io.FileNotFoundException;
  7. import java.io.FileReader;
  8. import java.io.IOException;
  9. import java.util.ArrayList;
  10. import java.util.List;
  11. import java.util.NoSuchElementException;
  12. import org.apache.commons.io.FileUtils;
  13. import org.apache.commons.lang3.tuple.Pair;
  14. import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
  15. import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
  16. import org.nd4j.linalg.api.ndarray.INDArray;
  17. import org.nd4j.linalg.dataset.DataSet;
  18. import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
  19. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  20. import org.nd4j.linalg.factory.Nd4j;
  21. import org.nd4j.linalg.indexing.INDArrayIndex;
  22. public class ADXIterator implements DataSetIterator {
  23. private final WordVectors wordVectors;
  24. private final int batchSize;
  25. private final int vectorSize;
  26. private final int truncateLength;
  27. private int maxLength;
  28. private final String dataDirectory;
  29. private final List<Pair<String, List<String>>> categoryData = new ArrayList<>();
  30. private int cursor = 0;
  31. private int totalNews = 0;
  32. private final TokenizerFactory tokenizerFactory;
  33. private int newsPosition = 0;
  34. private final List<String> labels;
  35. private int currCategory = 0;
  36. private ADXIterator(String dataDirectory, WordVectors wordVectors,
  37. int batchSize, int truncateLength, boolean train,
  38. TokenizerFactory tokenizerFactory) {
  39. this.dataDirectory = dataDirectory;
  40. this.batchSize = batchSize;
  41. this.vectorSize = wordVectors.getWordVector(wordVectors.vocab()
  42. .wordAtIndex(0)).length;
  43. this.wordVectors = wordVectors;
  44. this.truncateLength = truncateLength;
  45. this.tokenizerFactory = tokenizerFactory;
  46. this.populateData(train);
  47. this.labels = new ArrayList<>();
  48. for (int i = 0; i < 2; i++) {
  49. this.labels.add(String.valueOf(i));
  50. }
  51. }
  52. public static Builder Builder() {
  53. return new Builder();
  54. }
  55. @Override
  56. public DataSet next(int num) {
  57. if (cursor >= this.totalNews)
  58. throw new NoSuchElementException();
  59. try {
  60. return nextDataSet(num);
  61. } catch (IOException e) {
  62. throw new RuntimeException(e);
  63. }
  64. }
  65. private DataSet nextDataSet(int num) throws IOException {
  66. // Loads news into news list from categoryData List along with category
  67. // of each news
  68. List<String> news = new ArrayList<>(num);
  69. int[] category = new int[num];
  70. // private final List<Pair<String, List<String>>> categoryData
  71. for (int i = 0; i < num && cursor < totalExamples(); i++) {
  72. if (currCategory < categoryData.size()) {
  73. news.add(this.categoryData.get(currCategory).getValue()
  74. .get(newsPosition));
  75. category[i] = Integer.parseInt(this.categoryData.get(
  76. currCategory).getKey());
  77. currCategory++;
  78. cursor++;
  79. } else {
  80. currCategory = 0;
  81. newsPosition++;
  82. i--;
  83. }
  84. }
  85. // Second: tokenize news and filter out unknown words
  86. List<List<String>> allTokens = new ArrayList<>(news.size());
  87. maxLength = 0;
  88. for (String s : news) {
  89. List<String> tokens = tokenizerFactory.create(s).getTokens();
  90. List<String> tokensFiltered = new ArrayList<>();
  91. for (String t : tokens) {
  92. if (wordVectors.hasWord(t))
  93. tokensFiltered.add(t);
  94. }
  95. allTokens.add(tokensFiltered);
  96. maxLength = Math.max(maxLength, tokensFiltered.size());
  97. }
  98. // If longest news exceeds 'truncateLength': only take the first
  99. // 'truncateLength' words
  100. // System.out.println("maxLength : " + maxLength);
  101. if (maxLength > truncateLength)
  102. maxLength = truncateLength;
  103. // Create data for training
  104. // Here: we have news.size() examples of varying lengths
  105. INDArray features = Nd4j.create(news.size(), vectorSize, maxLength);
  106. INDArray labels = Nd4j.create(news.size(), this.categoryData.size(),
  107. maxLength); // Three labels: Crime, Politics, Bollywood
  108. // Because we are dealing with news of different lengths and only one
  109. // output at the final time step: use padding arrays
  110. // Mask arrays contain 1 if data is present at that time step for that
  111. // example, or 0 if data is just padding
  112. INDArray featuresMask = Nd4j.zeros(news.size(), maxLength);
  113. INDArray labelsMask = Nd4j.zeros(news.size(), maxLength);
  114. int[] temp = new int[2];
  115. for (int i = 0; i < news.size(); i++) {
  116. List<String> tokens = allTokens.get(i);
  117. temp[0] = i;
  118. // Get word vectors for each word in news, and put them in the
  119. // training data
  120. for (int j = 0; j < tokens.size() && j < maxLength; j++) {
  121. String token = tokens.get(j);
  122. INDArray vector = wordVectors.getWordVectorMatrix(token);
  123. features.put(new INDArrayIndex[] { point(i), all(), point(j) },
  124. vector);
  125. temp[1] = j;
  126. featuresMask.putScalar(temp, 1.0);
  127. }
  128. int idx = category[i];
  129. int lastIdx = Math.min(tokens.size(), maxLength);
  130. labels.putScalar(new int[] { i, idx, lastIdx - 1 }, 1.0);
  131. labelsMask.putScalar(new int[] { i, lastIdx - 1 }, 1.0);
  132. }
  133. DataSet ds = new DataSet(features, labels, featuresMask, labelsMask);
  134. return ds;
  135. }
  136. public INDArray loadFeaturesFromFile(File file, int maxLength)
  137. throws IOException {
  138. String news = FileUtils.readFileToString(file);
  139. return loadFeaturesFromString(news, maxLength);
  140. }
  141. public INDArray loadFeaturesFromString(String reviewContents, int maxLength) {
  142. List<String> tokens = tokenizerFactory.create(reviewContents)
  143. .getTokens();
  144. List<String> tokensFiltered = new ArrayList<>();
  145. for (String t : tokens) {
  146. if (wordVectors.hasWord(t))
  147. tokensFiltered.add(t);
  148. }
  149. int outputLength = Math.max(maxLength, tokensFiltered.size());
  150. INDArray features = Nd4j.create(1, vectorSize, outputLength);
  151. for (int j = 0; j < tokens.size() && j < maxLength; j++) {
  152. String token = tokens.get(j);
  153. INDArray vector = wordVectors.getWordVectorMatrix(token);
  154. features.put(new INDArrayIndex[] { point(0), all(), point(j) },
  155. vector);
  156. }
  157. return features;
  158. }
  159. /*
  160. * This function loads news headlines from files stored in resources into
  161. * categoryData List.
  162. */
  163. private void populateData(boolean train) {
  164. String name = train ? "rnnsenec.txt"
  165. : "rnnsenectest.txt";
  166. String curFileName=this.dataDirectory+name;
  167. BufferedReader currBR = null;
  168. File currFile = new File(curFileName);
  169. try {
  170. currBR = new BufferedReader((new FileReader(currFile)));
  171. String tempCurrLine = "";
  172. List<String> tempListnorme = new ArrayList<>();
  173. List<String> tempListneg = new ArrayList<>();
  174. while ((tempCurrLine = currBR.readLine()) != null) {
  175. String[] lines = tempCurrLine.split("\t");
  176. String label = lines[0];
  177. if ("1".equalsIgnoreCase(label)) {
  178. tempListnorme.add(lines[1]);
  179. } else if("0".equalsIgnoreCase(label)) {
  180. tempListneg.add(lines[1]);
  181. }
  182. this.totalNews++;
  183. }
  184. currBR.close();
  185. Pair<String, List<String>> tempPairnore = Pair.of("1",
  186. tempListnorme);
  187. this.categoryData.add(tempPairnore);
  188. Pair<String, List<String>> tempPair = Pair.of("0", tempListneg);
  189. this.categoryData.add(tempPair);
  190. } catch (Exception e) {
  191. e.printStackTrace();
  192. }
  193. }
  194. @Override
  195. public int totalExamples() {
  196. return this.totalNews;
  197. }
  198. @Override
  199. public int inputColumns() {
  200. return vectorSize;
  201. }
  202. @Override
  203. public int totalOutcomes() {
  204. return this.categoryData.size();
  205. }
  206. @Override
  207. public void reset() {
  208. cursor = 0;
  209. newsPosition = 0;
  210. currCategory = 0;
  211. }
  212. public boolean resetSupported() {
  213. return true;
  214. }
  215. @Override
  216. public boolean asyncSupported() {
  217. return true;
  218. }
  219. @Override
  220. public int batch() {
  221. return batchSize;
  222. }
  223. @Override
  224. public int cursor() {
  225. return cursor;
  226. }
  227. @Override
  228. public int numExamples() {
  229. return totalExamples();
  230. }
  231. @Override
  232. public void setPreProcessor(DataSetPreProcessor preProcessor) {
  233. throw new UnsupportedOperationException();
  234. }
  235. @Override
  236. public List<String> getLabels() {
  237. return this.labels;
  238. }
  239. @Override
  240. public boolean hasNext() {
  241. return cursor < numExamples();
  242. }
  243. @Override
  244. public DataSet next() {
  245. return next(batchSize);
  246. }
  247. @Override
  248. public void remove() {
  249. }
  250. @Override
  251. public DataSetPreProcessor getPreProcessor() {
  252. throw new UnsupportedOperationException("Not implemented");
  253. }
  254. public int getMaxLength() {
  255. return this.maxLength;
  256. }
  257. public static class Builder {
  258. private String dataDirectory;
  259. private WordVectors wordVectors;
  260. private int batchSize;
  261. private int truncateLength;
  262. TokenizerFactory tokenizerFactory;
  263. private boolean train;
  264. Builder() {
  265. }
  266. public ADXIterator.Builder dataDirectory(String dataDirectory) {
  267. this.dataDirectory = dataDirectory;
  268. return this;
  269. }
  270. public ADXIterator.Builder wordVectors(WordVectors wordVectors) {
  271. this.wordVectors = wordVectors;
  272. return this;
  273. }
  274. public ADXIterator.Builder batchSize(int batchSize) {
  275. this.batchSize = batchSize;
  276. return this;
  277. }
  278. public ADXIterator.Builder truncateLength(int truncateLength) {
  279. this.truncateLength = truncateLength;
  280. return this;
  281. }
  282. public ADXIterator.Builder train(boolean train) {
  283. this.train = train;
  284. return this;
  285. }
  286. public ADXIterator.Builder tokenizerFactory(
  287. TokenizerFactory tokenizerFactory) {
  288. this.tokenizerFactory = tokenizerFactory;
  289. return this;
  290. }
  291. public ADXIterator build() {
  292. return new ADXIterator(dataDirectory, wordVectors, batchSize,
  293. truncateLength, train, tokenizerFactory);
  294. }
  295. public String toString() {
  296. return "org.deeplearning4j.examples.recurrent.ProcessNews.NewsIterator.Builder(dataDirectory="
  297. + this.dataDirectory
  298. + ", wordVectors="
  299. + this.wordVectors
  300. + ", batchSize="
  301. + this.batchSize
  302. + ", truncateLength="
  303. + this.truncateLength + ", train=" + this.train + ")";
  304. }
  305. }
  306. }


模型训练代码:

  1. package com.dianping.recurrent.adx;
  2. import java.io.File;
  3. import java.io.IOException;
  4. import org.deeplearning4j.eval.Evaluation;
  5. import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
  6. import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
  7. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  8. import org.deeplearning4j.nn.conf.GradientNormalization;
  9. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  10. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  11. import org.deeplearning4j.nn.conf.Updater;
  12. import org.deeplearning4j.nn.conf.layers.GravesLSTM;
  13. import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
  14. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  15. import org.deeplearning4j.nn.weights.WeightInit;
  16. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  17. import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
  18. import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
  19. import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
  20. import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
  21. import org.deeplearning4j.util.ModelSerializer;
  22. import org.nd4j.linalg.activations.Activation;
  23. import org.nd4j.linalg.api.ndarray.INDArray;
  24. import org.nd4j.linalg.dataset.api.DataSet;
  25. import org.nd4j.linalg.lossfunctions.LossFunctions;
  26. public class TrainAdxRnnModel {
  27. public static String userDirectory = "";
  28. public static String DATA_PATH = "";
  29. public static String WORD_VECTORS_PATH = "";
  30. public static WordVectors wordVectors;
  31. private static Tokenizer tokenizerFactory;
  32. public static void main(String[] args) throws IOException {
  33. DATA_PATH = "adx/";
  34. WORD_VECTORS_PATH = "adx/word2vec.model";
  35. int batchSize = 6; // Number of examples in each minibatch
  36. int nEpochs = 10; // 训练次数
  37. int truncateReviewsToLength = 300; // 文本最大长度
  38. wordVectors = WordVectorSerializer.fromPair(WordVectorSerializer.loadTxt(new File(WORD_VECTORS_PATH)));
  39. TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
  40. tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
  41. ADXIterator iTrain = new ADXIterator.Builder().dataDirectory(DATA_PATH)
  42. .wordVectors(wordVectors).batchSize(batchSize)
  43. .truncateLength(truncateReviewsToLength)
  44. .tokenizerFactory(tokenizerFactory).train(true).build();
  45. ADXIterator iTest = new ADXIterator.Builder().dataDirectory(DATA_PATH)
  46. .wordVectors(wordVectors).batchSize(batchSize)
  47. .truncateLength(truncateReviewsToLength)
  48. .tokenizerFactory(tokenizerFactory).train(false).build();
  49. int inputNeurons = wordVectors.getWordVector(wordVectors.vocab()
  50. .wordAtIndex(0)).length; // 100 in our case
  51. int outputs = iTrain.getLabels().size();
  52. tokenizerFactory = new DefaultTokenizerFactory();
  53. tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
  54. System.out.println("inputNeurons is :" + inputNeurons);
  55. // Set up network configuration
  56. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  57. .optimizationAlgo(
  58. OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  59. .iterations(1)
  60. .updater(Updater.RMSPROP)
  61. .regularization(true)
  62. .l2(1e-5)
  63. .weightInit(WeightInit.XAVIER)
  64. .gradientNormalization(
  65. GradientNormalization.ClipElementWiseAbsoluteValue)
  66. .gradientNormalizationThreshold(1.0)
  67. .learningRate(0.0018)
  68. .list()
  69. .layer(0,
  70. new GravesLSTM.Builder().nIn(inputNeurons).nOut(200)
  71. .activation(Activation.SOFTSIGN).build())
  72. .layer(1,
  73. new RnnOutputLayer.Builder()
  74. .activation(Activation.SOFTMAX)
  75. .lossFunction(LossFunctions.LossFunction.MCXENT)
  76. .nIn(200).nOut(outputs).build())
  77. .pretrain(false).backprop(true).build();
  78. MultiLayerNetwork net = new MultiLayerNetwork(conf);
  79. net.init();
  80. //设置没两百步观察数据情况
  81. net.setListeners(new ScoreIterationListener(200));
  82. System.out.println("Starting training");
  83. for (int i = 0; i < nEpochs; i++) {
  84. net.fit(iTrain);
  85. iTrain.reset();
  86. System.out
  87. .println("Epoch " + i + " complete. Starting evaluation:");
  88. // Run evaluation. This is on 25k reviews, so can take some time
  89. Evaluation evaluation = new Evaluation();
  90. while (iTest.hasNext()) {
  91. DataSet t = iTest.next();
  92. INDArray features = t.getFeatureMatrix();
  93. INDArray lables = t.getLabels();
  94. // System.out.println("labels : " + lables);
  95. INDArray inMask = t.getFeaturesMaskArray();
  96. INDArray outMask = t.getLabelsMaskArray();
  97. INDArray predicted = net.output(features, false);
  98. // System.out.println("predicted : " + predicted);
  99. evaluation.evalTimeSeries(lables, predicted, outMask);
  100. }
  101. iTest.reset();
  102. System.out.println(evaluation.stats());
  103. }
  104. ModelSerializer.writeModel(net, "adx/" + "NewsModel.net", true);
  105. System.out.println("----- Example complete -----");
  106. }
  107. }

模型过程以及结果:

  1. Starting training
  2. Epoch 0 complete. Starting evaluation:
  3. Examples labeled as 0 classified by model as 0: 8 times
  4. Examples labeled as 0 classified by model as 1: 1 times
  5. Examples labeled as 1 classified by model as 1: 9 times
  6. ==========================Scores========================================
  7. Accuracy: 0.9444
  8. Precision: 0.95
  9. Recall: 0.9444
  10. F1 Score: 0.9472
  11. ========================================================================
  12. Epoch 1 complete. Starting evaluation:
  13. Examples labeled as 0 classified by model as 0: 9 times
  14. Examples labeled as 1 classified by model as 1: 9 times
  15. ==========================Scores========================================
  16. Accuracy: 1
  17. Precision: 1
  18. Recall: 1
  19. F1 Score: 1
  20. ========================================================================
  21. Epoch 2 complete. Starting evaluation:
  22. Examples labeled as 0 classified by model as 0: 9 times
  23. Examples labeled as 1 classified by model as 1: 9 times
  24. ==========================Scores========================================
  25. Accuracy: 1
  26. Precision: 1
  27. Recall: 1
  28. F1 Score: 1
  29. ========================================================================
  30. Epoch 3 complete. Starting evaluation:
  31. Examples labeled as 0 classified by model as 0: 9 times
  32. Examples labeled as 1 classified by model as 1: 9 times
  33. ==========================Scores========================================
  34. Accuracy: 1
  35. Precision: 1
  36. Recall: 1
  37. F1 Score: 1
  38. ========================================================================


上面花了一晚上的实际写成的代码,不过马上要睡觉有个困惑想问下,为什么在我这里每种样本类别必须相等程序才不会出bug,挑时间调试下什么情况,望高手赐教。


有问题联系我微信: xuxu_ge


阅读全文
0 0
原创粉丝点击