tensorflow中对于模型的参数都必须声明为变量

来源:互联网 发布:圆通热敏打印软件 编辑:程序博客网 时间:2024/05/22 06:16

1、tensorflow中所有的定义都只是声明,只有在session中run的时候,才会被执行。

谨记:对于模型中所有的参数都必须要使用variable来定义。可以使用tf.truncated_normal()来定义随机初始话,但是必须将随机初始化的值赋给variable。不然,每次需要访问参数的时候,都会驱动tf.truncated_normal()。

正确的写法:

import numpy as npimport tensorflow as tf sess = tf.Session()params = tf.Variable(tf.truncated_normal([4, 5]))indices = tf.constant([2, 0])output = tf.gather(params, indices)sess.run(tf.global_variables_initializer())print (sess.run(params))print (sess.run(output))sess.close()


错误的写法:

import numpy as npimport tensorflow as tf sess = tf.Session()params = tf.truncated_normal([4, 5])indices = tf.constant([2, 0])output = tf.gather(params, indices)  print (sess.run(params))print (sess.run(output))sess.close()

说明:param其实也只是生成随机数的操作,这个操作被驱动了2次,一次是sess.run(params),,一次是sess.run(output)。


阅读全文
0 0
原创粉丝点击