MNIST的AlexNet实现

来源:互联网 发布:淘宝直播公司 编辑:程序博客网 时间:2024/06/11 17:04

一些关键函数的介绍

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None,
name=None)

参数说明:

● data_format:表示输入的格式,有两种分别为:“NHWC”和“NCHW”,默认为“NHWC”

● input:输入是一个4维格式的(图像)数据,数据的 shape 由 data_format 决定:当 data_format 为“NHWC”输入数据的shape表示为[batch, in_height, in_width, in_channels],分别表示训练时一个batch的图片数量、图片高度、 图片宽度、 图像通道数。当 data_format 为“NCHW”输入数据的shape表示为[batch, in_channels, in_height, in_width]

● filter:卷积核是一个4维格式的数据:shape表示为:[height,width,in_channels, out_channels],分别表示卷积核的高、宽、深度(与输入的in_channels应相同)、输出 feature map的个数(即卷积核的个数)。

● strides:表示步长:一个长度为4的一维列表,每个元素跟data_format互相对应,表示在data_format每一维上的移动步长。当输入的默认格式为:“NHWC”,则 strides = [batch , in_height , in_width, in_channels]。其中 batch 和 in_channels 要求一定为1,即只能在一个样本的一个通道上的特征图上进行移动,in_height , in_width表示卷积核在特征图的高度和宽度上移动的布长,即 strideheight 和 stridewidth 。

● padding:表示填充方式:“SAME”表示采用填充的方式,简单地理解为以0填充边缘,但还有一个要求,左边(上边)补0的个数和右边(下边)补0的个数一样或少一个,“VALID”表示采用不填充的方式,多余地进行丢弃。具体公式:

“SAME”: output_spatial_shape[i]=(input_spatial_shape[i] / strides[i])

“VALID”: output_spatial_shape[i]=((input_spatial_shape[i]-(spatial_filter_shape[i]-1)/strides[i])

结果返回一个Tensor,这个输出,就是我们常说的feature map,shape仍然是[batch, height, width, channels]这种形式。

tf.nn.max_pool( value, ksize,strides,padding,data_format=’NHWC’,name=None)

参数说明:

● value:表示池化的输入:一个4维格式的数据,数据的 shape 由 data_format 决定,默认情况下shape 为[batch, height, width, channels]

● 其他参数与 tf.nn.cov2d 类型

● ksize:表示池化窗口的大小:一个长度为4的一维列表,一般为[1, height, width, 1],因不想在batch和channels上做池化,则将其值设为1。

from __future__ import print_functionfrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("mnistData", one_hot=True)import tensorflow as tf# 定义网络超参数learning_rate = 0.001   #学习率training_iters = 200000batch_size = 64   # 每一次64张图片进行训练display_step = 20  # 每个20个记录打印一次# 定义网络参数n_input = 784   # 输入的维度(img shape:28*28) n_classes = 10 # 标签的维度(0-9 digits)dropout = 0.75 # Dropout 的概率,输出的可能性# 占位符输入x = tf.placeholder(tf.float32, [None, n_input])y = tf.placeholder(tf.float32, [None, n_classes])keep_prob = tf.placeholder(tf.float32)# 卷积操作def conv2d(name, l_input, w, b):    x = tf.nn.conv2d(l_input, w, strides=[1, 1, 1, 1], padding='SAME')    x = tf.nn.bias_add(x,b)    return tf.nn.relu(x,name=name)# 最大下采样操作def max_pool(name, l_input, k):    return tf.nn.max_pool(l_input, ksize=[1, k, k, 1], strides=[1, k, k, 1],                           padding='SAME', name=name)# 归一化操作def norm(name, l_input, lsize=4):    return tf.nn.lrn(l_input, lsize, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name=name)# 存储所有的网络参数weights = {    # 卷积核filter大小11*11 输入层为1个feature maps,输出层有64 feature maps    'wc1': tf.Variable(tf.random_normal([11, 11, 1, 64])),    # 卷积核filter大小5*5 输入层为192个feature maps,输出层有384 feature maps    'wc2': tf.Variable(tf.random_normal([5, 5, 64, 192])),    'wc3': tf.Variable(tf.random_normal([3, 3, 192, 384])),    'wc4': tf.Variable(tf.random_normal([3, 3, 384, 384])),    'wc5': tf.Variable(tf.random_normal([3, 3, 384, 256])),    'wd1': tf.Variable(tf.random_normal([4*4*256, 4096])),    'wd2': tf.Variable(tf.random_normal([4096, 4096])),    'out': tf.Variable(tf.random_normal([4096, 10]))}# 初始化偏置项biases = {    'bc1': tf.Variable(tf.random_normal([64])),    'bc2': tf.Variable(tf.random_normal([192])),    'bc3': tf.Variable(tf.random_normal([384])),    'bc4': tf.Variable(tf.random_normal([384])),    'bc5': tf.Variable(tf.random_normal([256])),    'bd1': tf.Variable(tf.random_normal([4096])),    'bd2': tf.Variable(tf.random_normal([4096])),    'out': tf.Variable(tf.random_normal([n_classes]))}# 定义整个网络def alex_net(_X, _weights, _biases, _dropout):    # 向量转为矩阵 把图片转化为28*28*1的tensor    _X = tf.reshape(_X, shape=[-1, 28, 28, 1])    # 第一层卷积    # 卷积    conv1 = conv2d('conv1', _X, _weights['wc1'], _biases['bc1'])    # 下采样    pool1 = max_pool('pool1', conv1, k=2)    # 归一化    norm1 = norm('norm1', pool1, lsize=4)    # 第二层卷积    # 卷积    conv2 = conv2d('conv2', norm1, _weights['wc2'], _biases['bc2'])    # 下采样    pool2 = max_pool('pool2', conv2, k=2)    # 归一化    norm2 = norm('norm2', pool2, lsize=4)    # 第三层卷积    # 卷积    conv3 = conv2d('conv3', norm2, _weights['wc3'], _biases['bc3'])    # 归一化    norm3 = norm('norm3', conv3, lsize=4)    # 第四层卷积    # 卷积    conv4 = conv2d('conv4', norm3, _weights['wc4'], _biases['bc4'])    # 归一化    norm4 = norm('norm4', conv4, lsize=4)    # 第五层卷积    # 卷积    conv5 = conv2d('conv5', norm4, _weights['wc5'], _biases['bc5'])    # 下采样    pool5 = max_pool('pool5', conv5, k=2)    # 归一化    norm5 = norm('norm5', pool5, lsize=4)    # 全连接层1,先把特征图转为向量    dense1 = tf.reshape(norm5, [-1, _weights['wd1'].get_shape().as_list()[0]])    dense1 = tf.nn.relu(tf.matmul(dense1, _weights['wd1']) + _biases['bd1'], name='fc1')    dense1 = tf.nn.dropout(dense1, _dropout)    # 全连接层2    dense2 = tf.reshape(dense1, [-1, _weights['wd2'].get_shape().as_list()[0]])    dense2 = tf.nn.relu(tf.matmul(dense1, _weights['wd2']) + _biases['bd2'], name='fc2') # Relu activation    dense2 = tf.nn.dropout(dense2, _dropout)    # 网络输出层    out = tf.matmul(dense2, _weights['out']) + _biases['out']  # X^T*W+b    return out# 构建模型pred = alex_net(x, weights, biases, keep_prob)# 定义损失函数和学习步骤cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y))optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)# 测试网络#tf.arg_max(pred,1)是按行取最大值的下标correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))#先将correct_pred中数据格式转换为float32类型#求correct_pred中的平均值,因为correct_pred中除了0就是1,因此求平均值即为1的所占比例,即正确率accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))# 初始化所有的共享变量init = tf.global_variables_initializer()# 开启一个训练with tf.Session() as sess:    sess.run(init)    step = 1    # Keep training until reach max iterations    while step * batch_size < training_iters:        # 每一次从mnist的训练集中取出batch_size的图片数据        # batch_xs为图片数据 ; batch_ys 为标签值        batch_xs, batch_ys = mnist.train.next_batch(batch_size)         # 获取批数据        sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys, keep_prob: dropout})        if step % display_step == 0:            # 计算精度            acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 1.})            # 计算损失值            loss = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 1.})            print ("Iter " + str(step*batch_size) + ", Minibatch Loss = " + "{:.6f}".format(loss) + ", Training Accuracy = " + "{:.5f}".format(acc))        step += 1    print ("Optimization Finished!")    # 计算测试精度    print ("Testing Accuracy:", sess.run(accuracy, feed_dict={x: mnist.test.images[:256], y: mnist.test.labels[:256], keep_prob: 1.}))
Extracting mnistData\train-images-idx3-ubyte.gzExtracting mnistData\train-labels-idx1-ubyte.gzExtracting mnistData\t10k-images-idx3-ubyte.gzExtracting mnistData\t10k-labels-idx1-ubyte.gzIter 1280, Minibatch Loss = 251490.468750, Training Accuracy = 0.56250Iter 2560, Minibatch Loss = 152387.500000, Training Accuracy = 0.71875Iter 3840, Minibatch Loss = 67490.921875, Training Accuracy = 0.84375Iter 5120, Minibatch Loss = 91920.078125, Training Accuracy = 0.79688Iter 6400, Minibatch Loss = 90993.453125, Training Accuracy = 0.82812Iter 7680, Minibatch Loss = 68676.859375, Training Accuracy = 0.85938Iter 8960, Minibatch Loss = 72976.929688, Training Accuracy = 0.84375Iter 10240, Minibatch Loss = 29980.505859, Training Accuracy = 0.89062Iter 11520, Minibatch Loss = 20164.855469, Training Accuracy = 0.90625Iter 12800, Minibatch Loss = 47458.316406, Training Accuracy = 0.84375Iter 14080, Minibatch Loss = 38644.464844, Training Accuracy = 0.89062Iter 15360, Minibatch Loss = 54797.839844, Training Accuracy = 0.84375Iter 16640, Minibatch Loss = 26811.896484, Training Accuracy = 0.89062Iter 17920, Minibatch Loss = 63796.734375, Training Accuracy = 0.89062Iter 19200, Minibatch Loss = 63318.792969, Training Accuracy = 0.87500Iter 20480, Minibatch Loss = 56750.226562, Training Accuracy = 0.87500Iter 21760, Minibatch Loss = 22058.449219, Training Accuracy = 0.92188Iter 23040, Minibatch Loss = 30142.957031, Training Accuracy = 0.90625Iter 24320, Minibatch Loss = 9075.640625, Training Accuracy = 0.96875Iter 25600, Minibatch Loss = 31460.503906, Training Accuracy = 0.93750Iter 26880, Minibatch Loss = 19208.484375, Training Accuracy = 0.89062Iter 28160, Minibatch Loss = 0.000000, Training Accuracy = 1.00000Iter 29440, Minibatch Loss = 0.000000, Training Accuracy = 1.00000Iter 30720, Minibatch Loss = 9593.075195, Training Accuracy = 0.95312Iter 32000, Minibatch Loss = 11140.182617, Training Accuracy = 0.95312Iter 33280, Minibatch Loss = 47006.328125, Training Accuracy = 0.92188Iter 34560, Minibatch Loss = 10920.359375, Training Accuracy = 0.96875Iter 35840, Minibatch Loss = 33868.253906, Training Accuracy = 0.89062Iter 37120, Minibatch Loss = 20953.609375, Training Accuracy = 0.92188Iter 38400, Minibatch Loss = 14402.460938, Training Accuracy = 0.89062Iter 39680, Minibatch Loss = 2783.530273, Training Accuracy = 0.96875Iter 40960, Minibatch Loss = 6583.197266, Training Accuracy = 0.95312Iter 42240, Minibatch Loss = 9657.086914, Training Accuracy = 0.95312Iter 43520, Minibatch Loss = 8889.012695, Training Accuracy = 0.93750Iter 44800, Minibatch Loss = 10301.249023, Training Accuracy = 0.93750Iter 46080, Minibatch Loss = 11356.807617, Training Accuracy = 0.93750Iter 47360, Minibatch Loss = 1037.505859, Training Accuracy = 0.98438Iter 48640, Minibatch Loss = 37.261719, Training Accuracy = 0.98438Iter 49920, Minibatch Loss = 5537.283203, Training Accuracy = 0.98438Iter 51200, Minibatch Loss = 9993.500977, Training Accuracy = 0.93750Iter 52480, Minibatch Loss = 7684.349609, Training Accuracy = 0.93750Iter 53760, Minibatch Loss = 21212.550781, Training Accuracy = 0.90625Iter 55040, Minibatch Loss = 20383.501953, Training Accuracy = 0.96875Iter 56320, Minibatch Loss = 12894.260742, Training Accuracy = 0.96875Iter 57600, Minibatch Loss = 15404.772461, Training Accuracy = 0.95312Iter 58880, Minibatch Loss = 18539.234375, Training Accuracy = 0.90625Iter 60160, Minibatch Loss = 177.647461, Training Accuracy = 0.98438Iter 61440, Minibatch Loss = 16868.160156, Training Accuracy = 0.96875Iter 62720, Minibatch Loss = 19093.220703, Training Accuracy = 0.93750Iter 64000, Minibatch Loss = 5445.425781, Training Accuracy = 0.96875Iter 65280, Minibatch Loss = 10812.613281, Training Accuracy = 0.95312Iter 66560, Minibatch Loss = 7888.953125, Training Accuracy = 0.93750Iter 67840, Minibatch Loss = 14476.430664, Training Accuracy = 0.96875Iter 69120, Minibatch Loss = 8190.536621, Training Accuracy = 0.96875Iter 70400, Minibatch Loss = 4033.034668, Training Accuracy = 0.96875Iter 71680, Minibatch Loss = 13678.387695, Training Accuracy = 0.92188Iter 72960, Minibatch Loss = 12257.235352, Training Accuracy = 0.92188Iter 74240, Minibatch Loss = 12096.169922, Training Accuracy = 0.95312Iter 75520, Minibatch Loss = 191.411621, Training Accuracy = 0.98438Iter 76800, Minibatch Loss = 10987.725586, Training Accuracy = 0.92188Iter 78080, Minibatch Loss = 13003.524414, Training Accuracy = 0.93750Iter 79360, Minibatch Loss = 14407.958984, Training Accuracy = 0.90625Iter 80640, Minibatch Loss = 6809.606934, Training Accuracy = 0.98438Iter 81920, Minibatch Loss = 0.000000, Training Accuracy = 1.00000Iter 83200, Minibatch Loss = 561.085693, Training Accuracy = 0.98438Iter 84480, Minibatch Loss = 4651.179688, Training Accuracy = 0.95312Iter 85760, Minibatch Loss = 9919.875977, Training Accuracy = 0.96875Iter 87040, Minibatch Loss = 3059.538086, Training Accuracy = 0.96875Iter 88320, Minibatch Loss = 8951.718750, Training Accuracy = 0.92188Iter 89600, Minibatch Loss = 24988.746094, Training Accuracy = 0.89062Iter 90880, Minibatch Loss = 17393.894531, Training Accuracy = 0.92188Iter 92160, Minibatch Loss = 3295.959717, Training Accuracy = 0.95312Iter 93440, Minibatch Loss = 9333.998047, Training Accuracy = 0.95312Iter 94720, Minibatch Loss = 11038.242188, Training Accuracy = 0.92188Iter 96000, Minibatch Loss = 19766.988281, Training Accuracy = 0.95312Iter 97280, Minibatch Loss = 8863.857422, Training Accuracy = 0.95312Iter 98560, Minibatch Loss = 17762.171875, Training Accuracy = 0.90625Iter 99840, Minibatch Loss = 27.215332, Training Accuracy = 0.98438Iter 101120, Minibatch Loss = 1896.718262, Training Accuracy = 0.98438Iter 102400, Minibatch Loss = 5295.691895, Training Accuracy = 0.95312Iter 103680, Minibatch Loss = 3671.312988, Training Accuracy = 0.95312Iter 104960, Minibatch Loss = 8974.707031, Training Accuracy = 0.92188Iter 106240, Minibatch Loss = 2568.282715, Training Accuracy = 0.98438Iter 107520, Minibatch Loss = 1171.708496, Training Accuracy = 0.96875Iter 108800, Minibatch Loss = 1360.234863, Training Accuracy = 0.96875Iter 110080, Minibatch Loss = 11885.685547, Training Accuracy = 0.87500Iter 111360, Minibatch Loss = 3097.158691, Training Accuracy = 0.96875Iter 112640, Minibatch Loss = 7685.139648, Training Accuracy = 0.90625Iter 113920, Minibatch Loss = 697.487305, Training Accuracy = 0.98438Iter 115200, Minibatch Loss = 15347.265625, Training Accuracy = 0.90625Iter 116480, Minibatch Loss = 10992.441406, Training Accuracy = 0.93750Iter 117760, Minibatch Loss = 2353.596191, Training Accuracy = 0.98438Iter 119040, Minibatch Loss = 5958.698730, Training Accuracy = 0.93750Iter 120320, Minibatch Loss = 1862.846191, Training Accuracy = 0.96875Iter 121600, Minibatch Loss = 2203.640869, Training Accuracy = 0.98438Iter 122880, Minibatch Loss = 15024.573242, Training Accuracy = 0.89062Iter 124160, Minibatch Loss = 1854.105713, Training Accuracy = 0.98438Iter 125440, Minibatch Loss = 4552.501953, Training Accuracy = 0.96875Iter 126720, Minibatch Loss = 4455.558594, Training Accuracy = 0.93750Iter 128000, Minibatch Loss = 4613.866211, Training Accuracy = 0.95312Iter 129280, Minibatch Loss = 9531.547852, Training Accuracy = 0.96875Iter 130560, Minibatch Loss = 7502.578125, Training Accuracy = 0.96875Iter 131840, Minibatch Loss = 5357.618164, Training Accuracy = 0.96875Iter 133120, Minibatch Loss = 0.000000, Training Accuracy = 1.00000Iter 134400, Minibatch Loss = 3056.698730, Training Accuracy = 0.96875Iter 135680, Minibatch Loss = 2282.407715, Training Accuracy = 0.95312Iter 136960, Minibatch Loss = 2287.059326, Training Accuracy = 0.96875Iter 138240, Minibatch Loss = 6769.006836, Training Accuracy = 0.96875Iter 139520, Minibatch Loss = 1657.745850, Training Accuracy = 0.96875Iter 140800, Minibatch Loss = 1030.750977, Training Accuracy = 0.95312Iter 142080, Minibatch Loss = 1803.483398, Training Accuracy = 0.96875Iter 143360, Minibatch Loss = 2835.777588, Training Accuracy = 0.96875Iter 144640, Minibatch Loss = 1630.902100, Training Accuracy = 0.95312Iter 145920, Minibatch Loss = 4224.594238, Training Accuracy = 0.95312Iter 147200, Minibatch Loss = 0.000000, Training Accuracy = 1.00000Iter 148480, Minibatch Loss = 10338.882812, Training Accuracy = 0.95312Iter 149760, Minibatch Loss = 4764.475586, Training Accuracy = 0.96875Iter 151040, Minibatch Loss = 9170.373047, Training Accuracy = 0.95312Iter 152320, Minibatch Loss = 1501.253418, Training Accuracy = 0.96875Iter 153600, Minibatch Loss = 1343.297119, Training Accuracy = 0.98438Iter 154880, Minibatch Loss = 17587.242188, Training Accuracy = 0.93750Iter 156160, Minibatch Loss = 2955.808594, Training Accuracy = 0.95312Iter 157440, Minibatch Loss = 5036.468750, Training Accuracy = 0.92188Iter 158720, Minibatch Loss = 2451.070068, Training Accuracy = 0.96875Iter 160000, Minibatch Loss = 3700.370117, Training Accuracy = 0.95312Iter 161280, Minibatch Loss = 3440.578613, Training Accuracy = 0.93750Iter 162560, Minibatch Loss = 651.897705, Training Accuracy = 0.98438Iter 163840, Minibatch Loss = 0.000000, Training Accuracy = 1.00000Iter 165120, Minibatch Loss = 1975.876343, Training Accuracy = 0.96875Iter 166400, Minibatch Loss = 10295.979492, Training Accuracy = 0.92188Iter 167680, Minibatch Loss = 1935.198242, Training Accuracy = 0.98438Iter 168960, Minibatch Loss = 1371.462891, Training Accuracy = 0.96875Iter 170240, Minibatch Loss = 9050.325195, Training Accuracy = 0.95312Iter 171520, Minibatch Loss = 3336.230957, Training Accuracy = 0.96875Iter 172800, Minibatch Loss = 2149.940186, Training Accuracy = 0.98438Iter 174080, Minibatch Loss = 1938.491211, Training Accuracy = 0.95312Iter 175360, Minibatch Loss = 5143.197754, Training Accuracy = 0.95312Iter 176640, Minibatch Loss = 547.373779, Training Accuracy = 0.96875Iter 177920, Minibatch Loss = 4239.710938, Training Accuracy = 0.95312Iter 179200, Minibatch Loss = 1702.791992, Training Accuracy = 0.96875Iter 180480, Minibatch Loss = 3474.269531, Training Accuracy = 0.95312Iter 181760, Minibatch Loss = 1654.626221, Training Accuracy = 0.96875Iter 183040, Minibatch Loss = 0.000000, Training Accuracy = 1.00000Iter 184320, Minibatch Loss = 327.817871, Training Accuracy = 0.98438Iter 185600, Minibatch Loss = 2751.619629, Training Accuracy = 0.96875Iter 186880, Minibatch Loss = 8166.226074, Training Accuracy = 0.92188Iter 188160, Minibatch Loss = 970.713623, Training Accuracy = 0.96875Iter 189440, Minibatch Loss = 7472.984863, Training Accuracy = 0.93750Iter 190720, Minibatch Loss = 3025.575439, Training Accuracy = 0.95312Iter 192000, Minibatch Loss = 948.524170, Training Accuracy = 0.98438Iter 193280, Minibatch Loss = 848.313965, Training Accuracy = 0.96875Iter 194560, Minibatch Loss = 2407.324219, Training Accuracy = 0.95312Iter 195840, Minibatch Loss = 5774.828125, Training Accuracy = 0.95312Iter 197120, Minibatch Loss = 419.255859, Training Accuracy = 0.98438Iter 198400, Minibatch Loss = 457.075439, Training Accuracy = 0.96875Iter 199680, Minibatch Loss = 1585.383911, Training Accuracy = 0.95312Optimization Finished!Testing Accuracy: 0.96875
原创粉丝点击