models生成与加载
来源:互联网 发布:网络舞曲dj串烧视频 编辑:程序博客网 时间:2024/05/21 05:39
TensorFlow练习2: 对评论进行分类
本帖是前一贴的补充:
- 使用大数据,了解怎么处理数据不能一次全部加载到内存的情况。如果你内存充足,当我没说
- 训练好的模型的保存和使用
- 使用的模型没变,还是简单的feedforward神经网络(update:添加CNN模型)
- 如果你要运行本帖代码,推荐使用GPU版本或强大的VPS,我使用小笔记本差点等吐血
- 后续有关于中文的练习《TensorFlow练习13: 制作一个简单的聊天机器人》《TensorFlow练习7: 基于RNN生成古诗词》《TensorFlow练习18: 根据姓名判断性别》
在正文开始之前,我画了一个机器学习模型的基本开发流程图:
使用的数据集
使用的数据集:http://help.sentiment140.com/for-students/ (情绪分析)
数据集包含1百60万条推特,包含消极、中性和积极tweet。不知道有没有现成的微博数据集。
数据格式:移除表情符号的CSV文件,字段如下:
- 0 – the polarity of the tweet (0 = negative, 2 = neutral, 4 = positive)
- 1 – the id of the tweet (2087)
- 2 – the date of the tweet (Sat May 16 23:58:44 UTC 2009)
- 3 – the query (lyx). If there is no query, then this value is NO_QUERY.
- 4 – the user that tweeted (robotickilldozr)
- 5 – the text of the tweet (Lyx is cool)
training.1600000.processed.noemoticon.csv(238M)
testdata.manual.2009.06.14.csv(74K)
数据预处理
上面代码把原始数据转为training.csv、和tesing.csv,里面只包含label和tweet。lexcion.pickle文件保存了词汇表。
如果数据文件太大,不能一次加载到内存,可以把数据导入数据库
Dask可处理大csv文件
开始漫长的训练
上面程序占用内存600M,峰值1G。
运行:
训练模型保存为model.ckpt。
使用训练好的模型
上面使用简单的feedfroward模型,下面使用CNN模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# https://github.com/Lab41/sunny-side-up
importos
importrandom
importtensorflowastf
importpickle
importnumpyasnp
fromnltk.tokenizeimportword_tokenize
fromnltk.stemimportWordNetLemmatizer
f=open('lexcion.pickle','rb')
lex=pickle.load(f)
f.close()
defget_random_line(file,point):
file.seek(point)
file.readline()
returnfile.readline()
# 从文件中随机选择n条记录
defget_n_random_line(file_name,n=150):
lines=[]
file=open(file_name,encoding='latin-1')
total_bytes=os.stat(file_name).st_size
foriinrange(n):
random_point=random.randint(0,total_bytes)
lines.append(get_random_line(file,random_point))
file.close()
returnlines
defget_test_dataset(test_file):
withopen(test_file,encoding='latin-1')asf:
test_x=[]
test_y=[]
lemmatizer=WordNetLemmatizer()
forlineinf:
label=line.split(':%:%:%:')[0]
tweet=line.split(':%:%:%:')[1]
words=word_tokenize(tweet.lower())
words=[lemmatizer.lemmatize(word)forwordinwords]
features=np.zeros(len(lex))
forwordinwords:
ifwordinlex:
features[lex.index(word)]=1
test_x.append(list(features))
test_y.append(eval(label))
returntest_x,test_y
test_x,test_y=get_test_dataset('tesing.csv')
##############################################################################
input_size=len(lex)
num_classes=3
X=tf.placeholder(tf.int32,[None,input_size])
Y=tf.placeholder(tf.float32,[None,num_classes])
dropout_keep_prob=tf.placeholder(tf.float32)
batch_size=90
defneural_network():
# embedding layer
withtf.device('/cpu:0'),tf.name_scope("embedding"):
embedding_size=128
W=tf.Variable(tf.random_uniform([input_size,embedding_size],-1.0,1.0))
embedded_chars=tf.nn.embedding_lookup(W,X)
embedded_chars_expanded=tf.expand_dims(embedded_chars,-1)
# convolution + maxpool layer
num_filters=128
filter_sizes=[3,4,5]
pooled_outputs=[]
fori,filter_sizeinenumerate(filter_sizes):
withtf.name_scope("conv-maxpool-%s"%filter_size):
filter_shape=[filter_size,embedding_size,1,num_filters]
W=tf.Variable(tf.truncated_normal(filter_shape,stddev=0.1))
b=tf.Variable(tf.constant(0.1,shape=[num_filters]))
conv=tf.nn.conv2d(embedded_chars_expanded,W,strides=[1,1,1,1],padding="VALID")
h=tf.nn.relu(tf.nn.bias_add(conv,b))
pooled=tf.nn.max_pool(h,ksize=[1,input_size-filter_size+1,1,1],strides=[1,1,1,1],padding='VALID')
pooled_outputs.append(pooled)
num_filters_total=num_filters*len(filter_sizes)
h_pool=tf.concat(3,pooled_outputs)
h_pool_flat=tf.reshape(h_pool,[-1,num_filters_total])
# dropout
withtf.name_scope("dropout"):
h_drop=tf.nn.dropout(h_pool_flat,dropout_keep_prob)
# output
withtf.name_scope("output"):
W=tf.get_variable("W",shape=[num_filters_total,num_classes],initializer=tf.contrib.layers.xavier_initializer())
b=tf.Variable(tf.constant(0.1,shape=[num_classes]))
output=tf.nn.xw_plus_b(h_drop,W,b)
returnoutput
deftrain_neural_network():
output=neural_network()
optimizer=tf.train.AdamOptimizer(1e-3)
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output,Y))
grads_and_vars=optimizer.compute_gradients(loss)
train_op=optimizer.apply_gradients(grads_and_vars)
saver=tf.train.Saver(tf.global_variables())
withtf.Session()assess:
sess.run(tf.global_variables_initializer())
lemmatizer=WordNetLemmatizer()
i=0
whileTrue:
batch_x=[]
batch_y=[]
#if model.ckpt文件已存在:
# saver.restore(session, 'model.ckpt') 恢复保存的session
try:
lines=get_n_random_line('training.csv',batch_size)
forlineinlines:
label=line.split(':%:%:%:')[0]
tweet=line.split(':%:%:%:')[1]
words=word_tokenize(tweet.lower())
words=[lemmatizer.lemmatize(word)forwordinwords]
features=np.zeros(len(lex))
forwordinwords:
ifwordinlex:
features[lex.index(word)]=1 # 一个句子中某个词可能出现两次,可以用+=1,其实区别不大
batch_x.append(list(features))
batch_y.append(eval(label))
_,loss_=sess.run([train_op,loss],feed_dict={X:batch_x,Y:batch_y,dropout_keep_prob:0.5})
print(loss_)
exceptExceptionase:
print(e)
ifi%10==0:
predictions=tf.argmax(output,1)
correct_predictions=tf.equal(predictions,tf.argmax(Y,1))
accuracy=tf.reduce_mean(tf.cast(correct_predictions,"float"))
accur=sess.run(accuracy,feed_dict={X:test_x[0:50],Y:test_y[0:50],dropout_keep_prob:1.0})
print('准确率:',accur)
i+=1
train_neural_network()
使用了CNN模型之后,准确率有了显著提升。
阅读全文
0 0
- models生成与加载
- 3D models 加载
- Django本地加载models
- Models--自动生成实体层代码
- 自动生成Sqlalchemy的models文件
- Django中反向生成models.py
- Rails中scaffold与models的区别
- Nodejs连接mysql与models对应2
- Nodejs连接mysql与models对应
- Structuring Your TensorFlow Models-翻译与学习
- 如何搭建MVC3与配置models层
- Asp.net core 通过Models 生成数据库的方法
- Django根据现有数据库,自动生成models模型文件
- 利用sqlacodegen生成models.py数据库模型文件
- wp图片加载形式/生成属性 content与resource
- mybatis的延迟加载与代码生成工具(MBG)
- mybatis的延迟加载与代码生成工具
- 07 Anykey图像优化及文字头像生成与加载
- c++多线程重点难点(六)CriticalSection
- redis中 SETBIT命令 和 BITCOUNT命令
- 消息队列 使用场景
- 水经微图位置标注功能说明
- 京东基于Spark的风控系统架构实践和技术细节
- models生成与加载
- Cannot set the value of read-only property 'outputFile' for ApkVariantOutputImpl_Decorated
- 使用公式C=(5/9)(F-32)打印下列华氏温度与摄氏温度对照表。
- 关于spring注解的配置文件说明context:annotation-config和context:component-scan
- 数据库索引 类型
- error:Found shared references to a collection:
- 作业2:打印出所有的"水仙花数",所谓"水仙花数"是指一个三位数,其各位数字立方和等于该数本身。例如:153 是一个"水仙花数",因为153=1的三次方+5的三次方+3的三次方。
- 【错误】安装Vim出现错误kde-config-telepathy-accounts > 15.04
- Python判断中文字符串是否相等