Tensorflow: recurrent neural network char-level 1
来源:互联网 发布:中国it培训 编辑:程序博客网 时间:2024/05/19 13:44
import numpy as npimport tensorflow as tfimport matplotlib.pyplot as pltfrom tensorflow.models.rnn import rnn, rnn_cellfrom tensorflow.models.rnn import seq2seqimport collections# @karpathydata = open('ThreeMusketeers.txt').read()chars = list(set(data))data_size , vocab_size = len(data), len(chars)print 'data has %d characters, %d unique.' %(data_size, vocab_size)char_to_ix = {ch:i for i, ch in enumerate(chars)}ix_to_char = {i:ch for i, ch in enumerate(chars)}counter = collections.Counter(data)counter = sorted(counter.items(), key=lambda x:-x[1])for i in xrange(5): print counter[i]corpus = [char_to_ix[c] for c in data]batch_size = 1seq_length = 1hidden_size = 128num_layers = 2max_grad_norm = 5.0an_lstm = rnn_cell.BasicLSTMCell(hidden_size)multi_lstm = rnn_cell.MultiRNNCell([an_lstm] * num_layers)x = tf.placeholder(tf.int32, [batch_size, seq_length])y = tf.placeholder(tf.int32, [batch_size, seq_length])init_state = multi_lstm.zero_state(batch_size, tf.float32)with tf.variable_scope('rnn'): softmax_w = tf.get_variable('softmax_w', [hidden_size, vocab_size]) softmax_b = tf.get_variable('softmax_b', [vocab_size]) with tf.device('/cpu:0'): embedding = tf.get_variable('embedding', [vocab_size, hidden_size]) inputs = tf.nn.embedding_lookup(embedding, x) inputs = tf.split(1, seq_length, inputs) inputs = [tf.squeeze(input_, [1]) for input_ in inputs]#def loop(prev):# prev = tf.nn.xw_plus_b(prev, softmax_w, softmax_b)# prev_symbol = tf.stop_gradient(tf.arg_max(prev, 1))# return tf.nn.embedding_lookup(embedding, prev_symbol)outputs, last_state = seq2seq.rnn_decoder(inputs, init_state, multi_lstm, loop_function=None, scope='rnn')# outputs is a list of 2D-Tensor with shape [batch_size , hidden_size]# the len(outputs)) is seq_length# first, hiddenlayer outputs belong to same sequence should be concatenated together out_conca = tf.concat(1, outputs) # [batch_size, hidden_size*seq_length]# second, to get the softmax prob and add the fc layer, the out_conca's second dim should# be reshaped to the size: hidden_size# [batch_size*seq_length, hidden_size]output = tf.reshape(out_conca, [-1, hidden_size])# [batch_size*seq_length, vocab_size]score = tf.nn.xw_plus_b(output, softmax_w, softmax_b)# [batch_size*seq_length, vocab_size]probs = tf.nn.softmax(score)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)epoch = 20batch_size = 100snapshot = 5save_step = 1saver = tf.train.Saver()ckpt = tf.train.get_checkpoint_state('net_snapshot/')saver.restore(sess, ckpt.model_checkpoint_path)print ckpt.model_checkpoint_pathdef weighted_pick(weights): t = np.cumsum(weights) s = np.sum(weights) return np.searchsorted(t, np.random.rand(1)*s).tolist()[0]prime = 'The'state = sess.run(multi_lstm.zero_state(1, tf.float32))for c in prime[-1]: ix = np.zeros((1,1)) ix[0,0] = char_to_ix[c] state = sess.run(last_state, feed_dict={x:ix, init_state: state})def char_filter(pred): cache = ['!', '.', '?', '\"', ' ', ',', '%'] if pred >='a' and pred <= 'z': return pred if pred >='A' and pred <= 'Z': return pred if pred in cache: return pred return ''ret = primechar = prime[-1]num = 1000for n in xrange(num): ix = np.zeros((1,1)) ix[0,0] = char_to_ix[char] probsval, state = sess.run([probs, last_state], feed_dict={x:ix, init_state:state}) if np.random.rand(1) > 0.5: sample = np.argmax(probsval[0]) else: sample = weighted_pick(probsval[0]) pred = ix_to_char[sample] ret += char_filter(pred) char = ret[-1]print ret
sampled chars:
The and the caral there turrid seing, acting, and I scapter the cardy and this tow in the carmant of the cardantg wit. The sesper the carmisted the camplessing unce you will the cardow shanded the entereled the haved is thuld a conturute of you wnocch forlack the dal would the cardersty a saiking the contter the man the this lackul not to the carster which he wat madame you was abver D he womple lead me the conleshs of the cardinous of the closed there obe in his had past, acking that time the cardinal same trrang the cant of the cat of the sir thim the deave, and the clome the ca in aftented of throse one and all be, and proming him are to they undersince, she dy hering thought o man the conter to musked the cardinal be at the porsess of a seet the for the cardinan and as the omemen exter the cardinal and to the wall take proming this, and the cands, and then? Tetter. Tno tay will see to see thougle. The chaned them then for a sra
reference:
https://github.com/sjchoi86/tensorflow-101/blob/master/notebooks/char_rnn_train_tutorial.ipynb
https://github.com/sherjilozair/char-rnn-tensorflow
1 0
- Tensorflow: recurrent neural network char-level 1
- Tensorflow: recurrent neural network char-level 0
- Tensorflow: recurrent neural network (mnist basic)
- Recurrent Neural Network (RNN)
- lecture10,Recurrent Neural Network
- CS231N-10-Recurrent Neural Network
- tensorflow 的 Recurrent Neural Networks
- 概述 循环神经网络(RNN-Recurrent Neural Network)(1)
- Recurrent Neural Network系列1--RNN(循环神经网络)概述
- TensorFlow Neural Network Lab
- Tensorflow API: Neural network
- tensorflow编程: Neural Network
- Recurrent Neural Network 学习之路
- 回归神经网络RNN(Recurrent Neural network)
- 机器学习: Python with Recurrent Neural Network
- 【论文笔记】Recurrent Neural Network Regularization
- 详解循环神经网络(Recurrent Neural Network)
- 详解循环神经网络(Recurrent Neural Network)
- Eclipse快捷键 10个最有用的快捷键
- NOI 2002 营业额统计
- 布局属性layout_weight解析
- 杭电专题四1006
- Web --- 缓存
- Tensorflow: recurrent neural network char-level 1
- 高速抓包内容分析过滤项目进度
- 苹果app升级测试
- ecplise 使用link方式安装 SVN
- String、StringBuffer与StringBuilder之间区别
- socket异步编程--libevent的使用
- javaweb中的jstl标签
- Windows10+Ubuntu双系统安装[多图]
- android 自定义控件 自定义属性详细介绍