tensorflow tutorials(三):用tensorflow建立逻辑回归模型

来源:互联网 发布:java base64 字符串 编辑:程序博客网 时间:2024/06/05 14:25


声明:版权所有,转载请联系作者并注明出处




import tensorflow as tfimport numpy as npfrom tensorflow.examples.tutorials.mnist import input_datadef init_weights(shape):    return tf.Variable(tf.random_normal(shape, stddev=0.01))def model(X, W):    return tf.matmul(X, W) # notice we use the same model as linear regression, this is because there is a baked in cost function which performs softmax and cross entropymnist = input_data.read_data_sets("/tmp/data", one_hot=True)train_X, train_Y, test_X, test_Y = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labelsX = tf.placeholder("float", [None, 784]) # create symbolic variablesY = tf.placeholder("float", [None, 10])W = init_weights([784, 10]) # like in linear regression, we need a shared variable weight matrix for logistic regressionpy_x = model(X, W)# defined the cost function, compute mean cross entropy (softmax is applied internally)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y)) # construct optimizertrain_op = tf.train.GradientDescentOptimizer(0.05).minimize(cost) # Launch the graph in a sessionwith tf.Session() as sess:    # you need to initialize all variables    tf.initialize_all_variables().run()    for i in range(100):        for start, end in zip(range(0, len(train_X), 128), range(128, len(train_X)+1, 128)):            sess.run(train_op, feed_dict={X: train_X[start:end], Y: train_Y[start:end]})        print(i, np.mean(np.argmax(test_Y, axis=1) ==                         sess.run(tf.argmax(py_x, 1), feed_dict={X: test_X})))
Extracting /tmp/data/train-images-idx3-ubyte.gzExtracting /tmp/data/train-labels-idx1-ubyte.gzExtracting /tmp/data/t10k-images-idx3-ubyte.gzExtracting /tmp/data/t10k-labels-idx1-ubyte.gz(0, 0.8841)(1, 0.89680000000000004)(2, 0.90310000000000001)(3, 0.90739999999999998)(4, 0.90939999999999999)(5, 0.91090000000000004)(6, 0.91210000000000002)(7, 0.91310000000000002)(8, 0.91490000000000005)(9, 0.91569999999999996)(10, 0.91590000000000005)(11, 0.91700000000000004)(12, 0.91720000000000002)(13, 0.91739999999999999)(14, 0.91769999999999996)(15, 0.91800000000000004)(16, 0.91849999999999998)(17, 0.91910000000000003)(18, 0.91959999999999997)(19, 0.91990000000000005)(20, 0.91979999999999995)(21, 0.91990000000000005)(22, 0.92030000000000001)(23, 0.92030000000000001)(24, 0.9204)(25, 0.92110000000000003)(26, 0.92090000000000005)(27, 0.92120000000000002)(28, 0.92130000000000001)(29, 0.92159999999999997)(30, 0.92179999999999995)(31, 0.92200000000000004)(32, 0.92179999999999995)(33, 0.92159999999999997)(34, 0.92149999999999999)(35, 0.92159999999999997)(36, 0.92149999999999999)(37, 0.9214)(38, 0.92200000000000004)(39, 0.92200000000000004)(40, 0.92220000000000002)(41, 0.92200000000000004)(42, 0.92190000000000005)(43, 0.92200000000000004)(44, 0.92190000000000005)(45, 0.92179999999999995)(46, 0.92200000000000004)(47, 0.92200000000000004)(48, 0.92220000000000002)(49, 0.92220000000000002)(50, 0.92200000000000004)(51, 0.92220000000000002)(52, 0.92230000000000001)(53, 0.92220000000000002)(54, 0.92220000000000002)(55, 0.9224)(56, 0.92249999999999999)(57, 0.92279999999999995)(58, 0.92290000000000005)(59, 0.92300000000000004)(60, 0.92310000000000003)(61, 0.9234)(62, 0.9234)(63, 0.9234)(64, 0.92369999999999997)(65, 0.92359999999999998)(66, 0.92369999999999997)(67, 0.92369999999999997)(68, 0.92379999999999995)(69, 0.92359999999999998)(70, 0.92359999999999998)(71, 0.9234)(72, 0.92349999999999999)(73, 0.9234)(74, 0.92369999999999997)(75, 0.92369999999999997)(76, 0.92369999999999997)(77, 0.92359999999999998)(78, 0.92369999999999997)(79, 0.92369999999999997)(80, 0.92369999999999997)(81, 0.92359999999999998)(82, 0.92390000000000005)(83, 0.92379999999999995)(84, 0.92369999999999997)(85, 0.92369999999999997)(86, 0.92379999999999995)(87, 0.92379999999999995)(88, 0.92390000000000005)(89, 0.92390000000000005)(90, 0.92390000000000005)(91, 0.92369999999999997)(92, 0.92359999999999998)(93, 0.92359999999999998)(94, 0.92369999999999997)(95, 0.92369999999999997)(96, 0.92379999999999995)(97, 0.92379999999999995)(98, 0.92379999999999995)(99, 0.92379999999999995)

1 0
原创粉丝点击