Tensorflow + Mnist (两层CNN,两层全连接)
来源:互联网 发布:看门狗2低配置怎么优化 编辑:程序博客网 时间:2024/06/08 15:49
前面几天中断了好几天,装了个linux,搭建了一下深度学习环境。入坑tensorflow,算是目前相当方便的一个平台了。环境的搭建我有单独写了个博客。
我搭建的环境 ubuntu16.04LTS + python3.6+tensorflow1.2
我的硬件环境:
i7-4720HQ @2.60ghz*8 + 950m
直接上手tensorflow的入门教程,mnist手写字符的识别,tensorflow的官方文档写了一个手写字符识别的入门CNN网络,但是没有画出网络结构,相信对于初学者还是有点难以理解的。
我这里画了一个草图
总的来说就是两层卷积(第一层包括一个卷积(32个5×5的kernel)+一个池化,第二层包括一个卷积(64个5×5的kernel)+一个池化)两层全连接,前面3层的激活函数都是采用了relu:max(0,x),最后一层用softmax输出10类目标
#!/usr/bin/env python3# -*- coding: utf-8 -*-"""Created on Sun Jun 25 11:59:53 2017@author: matthew"""import input_dataimport tensorflow as tfmnist = input_data.read_data_sets("MNIST_data/", one_hot=True)sess = tf.InteractiveSession()# build softmaxx = tf.placeholder("float",shape = [None,784])y_ = tf.placeholder("float",shape = [None,10])#initialize weights and biasdef weight_variable(shape): initial = tf.truncated_normal(shape,stddev = 0.1) return tf.Variable(initial)def bias_variable(shape): initial = tf.constant(0.1,shape = shape) return tf.Variable(initial)#conv and poolingdef conv2d(x,w): return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding = 'SAME')def max_pool_2x2(x): return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding = 'SAME')#first layer,a conv layer and a pooling layer#conv layer:[5,5,1,32](the size of kernel)w_conv1 = weight_variable([5,5,1,32])b_conv1 = bias_variable([32])x_image = tf.reshape(x, [-1,28,28,1])#-1 denotes orignal sizeh_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)h_pool1 = max_pool_2x2(h_conv1)#second layer,a conv layer and a pooling layer#kernal size[5,5,32,64]w_conv2 = weight_variable([5,5,32,64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)h_pool2 = max_pool_2x2(h_conv2)#third layer,fclayer#now after two pooling, the size of image is 7*7w_fc1 = weight_variable([7*7*64,1024])b_fc1 = bias_variable([1024])#mat2vech_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)#add drop to fc layerkeep_prob = tf.placeholder("float")h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)w_fc2 = weight_variable([1024, 10])b_fc2 = bias_variable([10])y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)#cost funccross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))sess.run(tf.global_variables_initializer())for i in range(20000): batch = mnist.train.next_batch(50) if i%100 == 0: train_accuracy = accuracy.eval(feed_dict = {x:batch[0],y_:batch[1],keep_prob:1.0}) print ("step %d,training accuracy %g"%(i,train_accuracy)) train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})print ("test accuracy %g"%accuracy.eval(feed_dict = {x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))
这是tensorflow官方提供的mnist数据下载文件
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Functions for downloading and reading MNIST data."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrange # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
值得一提的是,drop_out技巧,用在全连接层,随机的减去全连接层当中的一些连接,加强网络范化能力,防止过拟合。
在我的电脑上跑了快一个小时,最终的范化误差再0.993。
阅读全文
1 0
- Tensorflow + Mnist (两层CNN,两层全连接)
- Tensorflow CNN(两层卷积+全连接+softmax)
- Tensorflow-浅层CNN(MNIST数据集)
- tensorflow CNN for mnist
- TensorFlow入门-MNIST & CNN
- tensorflow & mnist & CNN
- 暑期 tensorflow+CNN+mnist
- TensorFlow MNIST CNN LeNet5模型
- 【TensorFlow】MNIST(使用CNN)
- tensorflow入门 ubuntu mnist cnn例程
- Tensorflow训练CNN网络识别mnist
- Tensorflow中mnist数据使用CNN训练
- tensorflow中CNN对mnist识别
- tensorflow进行MNIST手写数字识别-CNN
- Tensorflow实战-CNN网络Mnist识别
- Tensorflow Cnn mnist 的一些细节
- tensorflow学习之---CNN识别MNIST
- tensorflow之用CNN识别MNIST
- Nginx+Tomcat 动静分离实现负载均衡
- ubifs 提取
- C/C++修饰符static、const、extern
- 欢迎使用CSDN-markdown编辑器
- js判断IE与FIREFOX浏览器的方法
- Tensorflow + Mnist (两层CNN,两层全连接)
- WebService之工作原理 一
- eclipse中对相同变量的高亮显示
- python升级引起的pip执行错误
- 帧内预测模式RDO
- 5-2 Reversing Linked List
- Java语言基础——标识符
- c语言的指针数组与数组指针
- R语言包路径