初识TensorFlow
来源:互联网 发布:微云 mac 同步版 编辑:程序博客网 时间:2024/06/08 06:18
唔,我可能会抽个时间把Tensorflow系统的介绍一下,因为数据流图这个概念,张量的概念等我想并不是每个人都十分清楚(至少我不是很清楚,囧)
不过对于大多数初步接触ML的同学来说,大可以把Tensorflow暂时当成是一个函数包,直接调用函数搭建模型(一般都是神经网络对不对^.^),就可以了。
下面是一个基础的MNIST手写数字识别系统 (from mlp in the University of Edinburgh)
import tensorflow as tf
inputs = tf.placeholder(tf.float32, [None, 784], 'inputs')
targets = tf.placeholder(tf.float32, [None, 10], 'targets')
weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([10]))
outputs = tf.matmul(inputs, weights) + biases
per_datapoint_errors = tf.nn.softmax_cross_entropy_with_logits(outputs, targets)
error = tf.reduce_mean(per_datapoint_errors)
per_datapoint_pred_is_correct = tf.equal(tf.argmax(outputs, 1), tf.argmax(targets, 1))
accuracy = tf.reduce_mean(tf.cast(per_datapoint_pred_is_correct, tf.float32))
train_step = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(error)
sess = tf.InteractiveSession()
init_op = tf.global_variables_initializer()
sess.run(init_op)
train_data = data_providers.MNISTDataProvider('train', batch_size=50)\\here should be your data reader
valid_data = data_providers.MNISTDataProvider('valid', batch_size=50)
num_epoch = 5
for e in range(num_epoch):
running_error = 0.
for input_batch, target_batch in train_data:
_, batch_error = sess.run(
[train_step, error],
feed_dict={inputs: input_batch, targets: target_batch})
running_error += batch_error
running_error /= train_data.num_batches
print('End of epoch {0}: running error average = {1:.2f}'.format(e + 1, running_error))
def get_error_and_accuracy(data):
"""Calculate average error and classification accuracy across a dataset.
Args:
data: Data provider which iterates over input-target batches in dataset.
Returns:
Tuple with first element scalar value corresponding to average error
across all batches in dataset and second value corresponding to
average classification accuracy across all batches in dataset.
"""
err = 0
acc = 0
for input_batch, target_batch in data:
err += sess.run(error, feed_dict={inputs: input_batch, targets: target_batch})
acc += sess.run(accuracy, feed_dict={inputs: input_batch, targets: target_batch})
err /= data.num_batches
acc /= data.num_batches
return err, acc
print('Train data: Error={0:.2f} Accuracy={1:.2f}'
.format(*get_error_and_accuracy(train_data)))
print('Valid data: Error={0:.2f} Accuracy={1:.2f}'
.format(*get_error_and_accuracy(valid_data)))
- tensorflow初识
- 初识TensorFlow
- TensorFlow初识
- 初识TensorFlow
- 初识TensorFlow
- 初识tensorflow
- Tensorflow(一)- 初识tensorflow
- TensorFlow学习系列(一):初识TensorFlow
- [深度学习]-初识 TensorFlow (Python)
- 零基础学TensorFlow(二):初识TensorFlow
- 初识Tensorflow,基本概念及简单示例
- 深度学习的应用以及初识Tensorflow
- 1.TensorFlow初识 TF实现线性回归模型
- Tensorflow学习笔记(一):初识TensorFlow——实现线性回归
- tensorflow
- TensorFlow
- TensorFlow
- tensorflow
- SQL--基础语句2
- JAVA设计模式之单例模式
- Angular最新教程-第三节在谷歌浏览器中调试Angular
- 在linux内核中读写文件
- .apply()用法和call()的区别
- 初识TensorFlow
- JAVAEE之JSP
- Javascript生成全局唯一标识符(GUID,UUID)的方法
- 017 ACM/ICPC Asia Regional Shenyang Online 1001 后缀数组+单调队列
- redhat上安装mysql
- Maven优雅的添加第三方Jar包
- vim粘贴代码 如果有注释,那么粘贴后就惨不忍睹,类似于:
- sudo授权指定用户在指定主机上运行某些命令
- android端 socket长连接 架构