cnn、rnn相结合进行文本分类

来源:互联网 发布:淘宝团队建设 编辑:程序博客网 时间:2024/05/16 17:48

主要参考代码思路:

  1. https://github.com/jiegzhan/multi-class-text-classification-cnn-rnn

cnn和rnn结合一起进行文本分类主要思路如下:

  1. data--->batch iter-->cnn input-->embedding--->卷积--->池化--->rnn输入--->lstm cell--softmax


在前面的博客已经提到如何把文本数据转化了一个batch  iter的形式,下面贴上关于cnn-rnn文本分类的一些代码:



基本配置:

  1. class TCNNRNNConfig(object):
  2. # 模型参数
  3. embedding_dim = 64 # 词向量维度
  4. seq_length = 300 # 序列长度
  5. num_classes = 2 # 类别数
  6. num_filters = 256 # 卷积核数目
  7. kernel_size = 5 # 卷积核尺寸
  8. vocab_size = 130000 # 词汇表达小
  9. max_pool_size=4 #最大的pool层
  10. hidden_dim = 128 # 全连接层神经元
  11. dropout_keep_prob = 0.8 # dropout保留比例
  12. learning_rate = 1e-3 # 学习率
  13. hidden_unit=256 #lstm神经元的个数
  14. batch_size = 128 # 每批训练大小
  15. num_epochs = 20 # 总迭代轮次
  16. print_per_batch = 100 # 每多少轮输出一次结果
  17. multi_kernel_size = '3,4,5'
  18. l2_reg_lambda = 0.0

模型代码:

  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorflow as tf
  4. import numpy as np
  5. class TextCnnRnn(object):
  6. def __init__(self,config):
  7. self.config=config
  8. self.input_x=tf.placeholder(tf.int32,[None, self.config.seq_length],name="input_x")
  9. self.input_y=tf.placeholder(tf.float32,[None, self.config.num_classes],name="inpyt_y")
  10. self.keep_prob=tf.placeholder(tf.float32,None,name='keep_prob')
  11. self.pad = tf.placeholder(tf.float32, [None, 1, self.config.embedding_dim, 1], name='pad')
  12. self.l2_loss = tf.constant(0.0)
  13. self.real_len = tf.placeholder(tf.int32, [None], name='real_len')
  14. self.filter_sizes = list(map(int, self.config.multi_kernel_size.split(",")))
  15. self.cnnrnn()
  16. def input_embedding(self):
  17. """词嵌套"""
  18. with tf.device('/cpu:0'):
  19. embedding =tf.get_variable("embedding",[self.config.vocab_size,self.config.embedding_dim])
  20. _input = tf.nn.embedding_lookup(embedding, self.input_x)
  21. _input_expanded = tf.expand_dims(_input, -1)
  22. return _input_expanded
  23. def cnnrnn(self):
  24. emb=self.input_embedding()
  25. pooled_concat = []
  26. reduced = np.int32(np.ceil((self.config.seq_length) * 1.0 / self.config.max_pool_size))
  27. for i, filter_size in enumerate(self.filter_sizes):
  28. with tf.name_scope('conv-maxpool-%s' % filter_size):
  29. # Zero paddings so that the convolution output have dimension batch x sequence_length x emb_size x channel
  30. num_prio = (filter_size - 1) // 2
  31. num_post = (filter_size - 1) - num_prio
  32. pad_prio = tf.concat([self.pad] * num_prio, 1)
  33. pad_post = tf.concat([self.pad] * num_post, 1)
  34. emb_pad = tf.concat([pad_prio, emb, pad_post], 1)
  35. filter_shape = [filter_size, self.config.embedding_dim, 1, self.config.num_filters]
  36. W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name='W')
  37. b = tf.Variable(tf.constant(0.1, shape=[self.config.num_filters]), name='b')
  38. conv = tf.nn.conv2d(emb_pad, W, strides=[1, 1, 1, 1], padding='VALID', name='conv')
  39. h = tf.nn.relu(tf.nn.bias_add(conv, b), name='relu')
  40. # Maxpooling over the outputs
  41. pooled = tf.nn.max_pool(h, ksize=[1, self.config.max_pool_size, 1, 1], strides=[1, self.config.max_pool_size, 1, 1], padding='SAME',
  42. name='pool')
  43. pooled = tf.reshape(pooled, [-1, reduced, self.config.num_filters])
  44. pooled_concat.append(pooled)
  45. pooled_concat = tf.concat(pooled_concat, 2)
  46. pooled_concat = tf.nn.dropout(pooled_concat, self.keep_prob)
  47. # lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self.config.hidden_unit)
  48. # lstm_cell = tf.nn.rnn_cell.GRUCell(num_units=self.config.hidden_unit)
  49. lstm_cell = tf.contrib.rnn.GRUCell(num_units=self.config.hidden_unit)
  50. # lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.dropout_keep_prob)
  51. lstm_cell = tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=self.keep_prob)
  52. self._initial_state = lstm_cell.zero_state(self.config.batch_size, tf.float32)
  53. # inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(1, reduced, pooled_concat)]
  54. inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(pooled_concat, num_or_size_splits=int(reduced), axis=1)]
  55. # outputs, state = tf.nn.rnn(lstm_cell, inputs, initial_state=self._initial_state, sequence_length=self.real_len)
  56. #outputs, state = tf.contrib.rnn.static_rnn(lstm_cell, inputs, initial_state=self._initial_state,
  57. # sequence_length=self.real_len)
  58. outputs, state=tf.nn.static_rnn( lstm_cell, inputs,self._initial_state,sequence_length=self.real_len)
  59. # Collect the appropriate last words into variable output (dimension = batch x embedding_size)
  60. output = outputs[0]
  61. with tf.variable_scope('Output'):
  62. tf.get_variable_scope().reuse_variables()
  63. one = tf.ones([1, self.config.hidden_unit], tf.float32)
  64. for i in range(1, len(outputs)):
  65. ind = self.real_len < (i + 1)
  66. ind = tf.to_float(ind)
  67. ind = tf.expand_dims(ind, -1)
  68. mat = tf.matmul(ind, one)
  69. output = tf.add(tf.multiply(output, mat), tf.multiply(outputs[i], 1.0 - mat))
  70. with tf.name_scope('score'):
  71. self.W = tf.Variable(tf.truncated_normal([self.config.hidden_unit, self.config.num_classes], stddev=0.1), name='W')
  72. b = tf.Variable(tf.constant(0.1, shape=[self.config.num_classes]), name='b')
  73. self.l2_loss += tf.nn.l2_loss(W)
  74. self.l2_loss += tf.nn.l2_loss(b)
  75. self.scores = tf.nn.xw_plus_b(output, self.W, b, name='scores')
  76. self.pred_y = tf.nn.softmax(self.scores, name="pred_y")
  77. tf.add_to_collection('pred_network', self.pred_y)
  78. self.predictions = tf.argmax(self.scores, 1, name='predictions')
  79. with tf.name_scope('loss'):
  80. losses = tf.nn.softmax_cross_entropy_with_logits(labels=self.input_y,
  81. logits=self.scores) # only named arguments accepted
  82. self.loss = tf.reduce_mean(losses) + self.config.l2_reg_lambda * self.l2_loss
  83. with tf.name_scope("optimize"):
  84. # 优化器
  85. optimizer = tf.train.AdamOptimizer(
  86. learning_rate=self.config.learning_rate)
  87. self.optim = optimizer.minimize(self.loss)
  88. with tf.name_scope('accuracy'):
  89. correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
  90. self.acc = tf.reduce_mean(tf.cast(correct_predictions, "float"), name='accuracy')
  91. with tf.name_scope('num_correct'):
  92. correct = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
  93. self.num_correct = tf.reduce_sum(tf.cast(correct, 'float'))


run代码:

  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. from cnn_rnn_model import TextCnnRnn
  4. from configuration import TCNNRNNConfig
  5. from data_utils_cut import preocess_file,batch_iter
  6. import time
  7. import tensorflow as tf
  8. import os
  9. import numpy as np
  10. from datetime import timedelta
  11. trainpath="/Users/shuubiasahi/Desktop/tensorflow/adx/"
  12. def run_epoch(cnnrnnmodel=True):
  13. # 载入数据
  14. print('Loading data...')
  15. start_time = time.time()
  16. x_train, y_train, words = preocess_file(data_path=trainpath+"cnn.txt")
  17. if cnnrnnmodel:
  18. print('Using CNNRNN model...')
  19. config = TCNNRNNConfig()
  20. config.vocab_size = len(words)
  21. print("vocab_size is:", config.vocab_size)
  22. model = TextCnnRnn(config)
  23. tensorboard_dir = '/Users/shuubiasahi/Desktop/tensorflow/boardlog'
  24. end_time = time.time()
  25. time_dif = end_time - start_time
  26. time_dif = timedelta(seconds=int(round(time_dif)))
  27. print('Time usage:', time_dif)
  28. print('Constructing TensorFlow Graph...')
  29. session = tf.Session()
  30. session.run(tf.global_variables_initializer())
  31. saver = tf.train.Saver()
  32. # 配置 tensorboard
  33. tf.summary.scalar("loss", model.loss)
  34. tf.summary.scalar("accuracy", model.acc)
  35. if not os.path.exists(tensorboard_dir):
  36. os.makedirs(tensorboard_dir)
  37. merged_summary = tf.summary.merge_all()
  38. writer = tf.summary.FileWriter(tensorboard_dir)
  39. writer.add_graph(session.graph)
  40. # 生成批次数据
  41. print('Generating batch...')
  42. batch_train = batch_iter(list(zip(x_train, y_train)),
  43. config.batch_size, config.num_epochs)
  44. def feed_data(batch):
  45. """准备需要喂入模型的数据"""
  46. x_batch, y_batch = zip(*batch)
  47. feed_dict = {
  48. model.input_x: x_batch,
  49. model.input_y: y_batch,
  50. model.real_len:real_len(x_batch)
  51. }
  52. return feed_dict, len(x_batch)
  53. def real_len(batches):
  54. return [np.ceil(np.argmin(batch + [0]) * 1.0 / config.max_pool_size) for batch in batches]
  55. def evaluate(x_, y_):
  56. """
  57. 模型评估
  58. 一次运行所有的数据会OOM,所以需要分批和汇总
  59. """
  60. batch_eval = batch_iter(list(zip(x_, y_)), 128, 1)
  61. total_loss = 0.0
  62. total_acc = 0.0
  63. cnt = 0
  64. for batch in batch_eval:
  65. feed_dict, cur_batch_len = feed_data(batch)
  66. feed_dict[model.keep_prob] = 1.0
  67. loss, acc = session.run([model.loss, model.acc],
  68. feed_dict=feed_dict)
  69. total_loss += loss * cur_batch_len
  70. total_acc += acc * cur_batch_len
  71. cnt += cur_batch_len
  72. return total_loss / cnt, total_acc / cnt
  73. # 训练与验证
  74. print('Training and evaluating...')
  75. start_time = time.time()
  76. print_per_batch = config.print_per_batch
  77. for i, batch in enumerate(batch_train):
  78. feed_dict, lenbatch = feed_data(batch)
  79. feed_dict[model.keep_prob] = config.dropout_keep_prob
  80. feed_dict[model.pad]=np.zeros([lenbatch, 1, config.embedding_dim, 1])
  81. if i % 5 == 0: # 每5次将训练结果写入tensorboard scalar
  82. s = session.run(merged_summary, feed_dict=feed_dict)
  83. writer.add_summary(s, i)
  84. if i % print_per_batch == print_per_batch - 1: # 每200次输出在训练集和验证集上的性能
  85. loss_train, acc_train = session.run([model.loss, model.acc],
  86. feed_dict=feed_dict)
  87. #loss, acc = evaluate(x_val, y_val) 验证机暂时不需要
  88. # 时间
  89. end_time = time.time()
  90. time_dif = end_time - start_time
  91. time_dif = timedelta(seconds=int(round(time_dif)))
  92. msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},'\
  93. + ' Time: {3}'
  94. print(msg.format(i + 1, loss_train, acc_train, time_dif))
  95. # if i%10==0 and i>0:
  96. # graph=tf.graph_util.convert_variables_to_constants(session,session.graph_def,["keep_prob","input_x","score/pred_y"])
  97. # tf.train.write_graph(graph,".","/Users/shuubiasahi/Desktop/tensorflow/modelsavegraph/graph.db",as_text=False)
  98. if i%500==0 and i>0:
  99. graph = tf.graph_util.convert_variables_to_constants(session, session.graph_def,
  100. ["keep_prob","real_len","pad", "input_x", "score/pred_y"])
  101. if cnnrnnmodel:
  102. tf.train.write_graph(graph, ".", trainpath+"graphcnnrnn.model",
  103. as_text=False)
  104. print("模型在第{0}步已经保存".format(i))
  105. session.run(model.optim, feed_dict=feed_dict) # 运行优化
  106. # 最后在测试集上进行评估
  107. session.close()
  108. if __name__ == '__main__':
  109. run_epoch()



简单的结果分析:

  1. Using CNNRNN model...
  2. vocab_size is: 160238
  3. Time usage: 0:00:35
  4. Constructing TensorFlow Graph...
  5. 2017-10-30 23:22:18.426329: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
  6. 2017-10-30 23:22:18.426342: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
  7. 2017-10-30 23:22:18.426346: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
  8. 2017-10-30 23:22:18.426351: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
  9. Generating batch...
  10. Training and evaluating...
  11. Iter: 100, Train Loss: 0.66, Train Acc: 71.09%, Time: 0:02:47
  12. Iter: 200, Train Loss: 0.65, Train Acc: 61.72%, Time: 0:05:38


迭代几百步相比单纯的用cnn、bi-lstm实际效果是很差了,可能文本本身的特征已经够明显,再用这种反而效果会变差吧,cnn这种相当于一个超级n-gram,bi-lstm正反两面捕捉文本上下文的信息进行信息输出,之前在GitHub上看到别人做文本分类,cnn、bilstm这种量效果是最佳。。。由于电脑原因并没有迭代很多步,哪天用gpu试试吧
原创粉丝点击