初识TensorFlow

来源:互联网 发布:微云 mac 同步版 编辑:程序博客网 时间:2024/06/08 06:18

唔,我可能会抽个时间把Tensorflow系统的介绍一下,因为数据流图这个概念,张量的概念等我想并不是每个人都十分清楚(至少我不是很清楚,囧)


不过对于大多数初步接触ML的同学来说,大可以把Tensorflow暂时当成是一个函数包,直接调用函数搭建模型(一般都是神经网络对不对^.^),就可以了。


下面是一个基础的MNIST手写数字识别系统 (from mlp in the University of Edinburgh)


import tensorflow as tf
inputs = tf.placeholder(tf.float32, [None, 784], 'inputs')
targets = tf.placeholder(tf.float32, [None, 10], 'targets')
weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([10]))
outputs = tf.matmul(inputs, weights) + biases
per_datapoint_errors = tf.nn.softmax_cross_entropy_with_logits(outputs, targets)
error = tf.reduce_mean(per_datapoint_errors)
per_datapoint_pred_is_correct = tf.equal(tf.argmax(outputs, 1), tf.argmax(targets, 1))
accuracy = tf.reduce_mean(tf.cast(per_datapoint_pred_is_correct, tf.float32))
train_step = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(error)
sess = tf.InteractiveSession()
init_op = tf.global_variables_initializer()
sess.run(init_op)
train_data = data_providers.MNISTDataProvider('train', batch_size=50)\\here should be your data reader
valid_data = data_providers.MNISTDataProvider('valid', batch_size=50)


num_epoch = 5
for e in range(num_epoch):
    running_error = 0.
    for input_batch, target_batch in train_data:
        _, batch_error = sess.run(
            [train_step, error], 
            feed_dict={inputs: input_batch, targets: target_batch})
        running_error += batch_error
    running_error /= train_data.num_batches
    print('End of epoch {0}: running error average = {1:.2f}'.format(e + 1, running_error))


def get_error_and_accuracy(data):
    """Calculate average error and classification accuracy across a dataset.
    
    Args:
        data: Data provider which iterates over input-target batches in dataset.
        
    Returns:
        Tuple with first element scalar value corresponding to average error
        across all batches in dataset and second value corresponding to
        average classification accuracy across all batches in dataset.
    """
    err = 0
    acc = 0
    for input_batch, target_batch in data:
        err += sess.run(error, feed_dict={inputs: input_batch, targets: target_batch})
        acc += sess.run(accuracy, feed_dict={inputs: input_batch, targets: target_batch})
    err /= data.num_batches
    acc /= data.num_batches
    return err, acc




print('Train data: Error={0:.2f} Accuracy={1:.2f}'
      .format(*get_error_and_accuracy(train_data)))
print('Valid data: Error={0:.2f} Accuracy={1:.2f}'
      .format(*get_error_and_accuracy(valid_data)))


原创粉丝点击