mnist_cnn

来源:互联网 发布:淘宝拍卖房产靠谱吗 编辑:程序博客网 时间:2024/06/15 13:18
from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfimport numpy as npdata_dir="mnist"mnist = input_data.read_data_sets(data_dir, one_hot=True)trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labelstrX = trX.reshape(-1, 28, 28, 1)teX = teX.reshape(-1, 28, 28, 1)X = tf.placeholder("float", [None, 28, 28, 1])Y = tf.placeholder("float", [None, 10])def init_weights(shape):    return tf.Variable(tf.random_normal(shape, stddev=0.01))w = init_weights([3, 3, 1, 32])w2 = init_weights([3, 3, 32, 64])w3 = init_weights([3, 3, 64, 128])w4 = init_weights([128 * 4 * 4, 625])w_o = init_weights([625, 10])def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):    l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1,1], padding='SAME'))    l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1,2,2,1], padding='SAME')    l1 = tf.nn.dropout(l1, p_keep_conv)    l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1,1], padding='SAME'))    l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1,2,2,1], padding='SAME')    l2 = tf.nn.dropout(l2, p_keep_conv)    l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1,1], padding='SAME'))    l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1,2,2,1], padding='SAME')    l3 = tf.reshape(l3,[-1, w4.get_shape().as_list()[0]])    l3 = tf.nn.dropout(l3, p_keep_conv)    l4 = tf.nn.relu(tf.matmul(l3, w4))    l4 = tf.nn.dropout(l4, p_keep_hidden)    pyx = tf.matmul(l4, w_o)    return pyxp_keep_conv = tf.placeholder("float")p_keep_hidden = tf.placeholder("float")py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)predict_op = tf. argmax(py_x, 1)batch_size = 128test_size = 256with tf.Session() as sess:    tf.global_variables_initializer().run()    for i in range(100):        training_batch = zip(range(0, len(trX), batch_size),range(batch_size, len(trX)+1, batch_size))        for start, end in training_batch:             sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end], p_keep_conv: 0.8, p_keep_hidden: 0.5})        test_indices = np.arange(len(teX))        np.random.shuffle(test_indices)        test_indices = test_indices[0:test_size]        print(i, np.mean(np.argmax(teY[test_indices], axis=1) == sess.run(predict_op, feed_dict={X: teX[test_indices], p_keep_conv: 1.0, p_keep_hidden: 1.0})))

(0, 0.94921875)
(1, 0.98046875)
(2, 0.98046875)
(3, 0.97265625)
(4, 0.9921875)
(5, 0.984375)
(6, 0.97265625)
(7, 0.984375)
(8, 0.9921875)
(9, 0.984375)
(10, 1.0)
(11, 0.99609375)
(12, 0.984375)
(13, 0.984375)
(14, 0.99609375)
(15, 0.9921875)
(16, 0.9921875)
(17, 0.9921875)
(18, 0.98828125)
(19, 0.98046875)
(20, 0.98828125)
(21, 0.99609375)
(22, 0.9921875)
(23, 0.99609375)
(24, 0.9921875)
(25, 0.99609375)
(26, 1.0)
(27, 0.9921875)
(28, 0.9921875)
(29, 0.9921875)
(30, 0.9921875)
(31, 0.9921875)
(32, 0.984375)
(33, 0.9921875)
(34, 0.9765625)
(35, 0.99609375)
(36, 0.9921875)
(37, 0.98828125)
(38, 0.9921875)
(39, 0.99609375)
(40, 0.98828125)
(41, 0.9921875)
(42, 0.9921875)
(43, 1.0)
(44, 0.99609375)
(45, 0.9921875)
(46, 0.9921875)
(47, 0.98828125)
(48, 0.98828125)
(49, 1.0)
(50, 0.9921875)
(51, 0.9921875)
(52, 0.99609375)
(53, 0.99609375)
(54, 0.9921875)
(55, 0.9921875)
(56, 0.98828125)
(57, 0.99609375)
(58, 0.98828125)
(59, 0.98828125)
(60, 0.9921875)
(61, 0.99609375)
(62, 1.0)
(63, 0.98046875)
(64, 0.98828125)
(65, 0.9921875)
(66, 0.9921875)
(67, 0.99609375)
(68, 0.98828125)
(69, 0.99609375)
(70, 0.99609375)
(71, 1.0)
(72, 0.9921875)
(73, 1.0)
(74, 0.98828125)
(75, 0.9921875)
(76, 0.99609375)
(77, 1.0)
(78, 1.0)
(79, 0.9921875)
(80, 0.9921875)
(81, 0.98828125)
(82, 0.99609375)
(83, 0.99609375)
(84, 0.9921875)
(85, 0.9765625)
(86, 0.9921875)
(87, 1.0)
(88, 0.98828125)
(89, 0.98828125)
(90, 0.9921875)
(91, 0.9921875)
(92, 0.98828125)
(93, 0.9921875)
(94, 0.984375)
(95, 0.9921875)
(96, 0.98828125)
(97, 0.99609375)
(98, 0.9921875)
(99, 0.98828125)