TF-slim 调用slim提供的网络模型训练自己的数据
来源:互联网 发布:开淘宝能赚钱吗 编辑:程序博客网 时间:2024/06/11 02:37
参考:
1、https://github.com/tensorflow/models/blob/master/research/slim/nets/
2、http://blog.csdn.net/wc781708249/article/details/78414314
3、http://blog.csdn.net/wc781708249/article/details/78414028
说明:
使用slim提供的alexnet与TF-slim快速搭建cnn 相结合,实现调用alexnet运行mnist数据
其他模型也可以通过该方式进行调用
1、下载alexnet.py
2、调用alexnet
需要注意的地方,mnist数据集的shape 28x28,而 alexnet要求的数据shape是224x224,为此使用tf.image.resize_image_with_crop_or_pad()将28x28转成224x224
不需修改数据的通道即,c=1、3或4 都适用
或者修改alexnet模型以适用于28x28的数据
x = tf.placeholder(tf.float32, [None, 28*28*1],'x')image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])image_shaped_input=tf.image.resize_image_with_crop_or_pad(image_shaped_input,224,224) #转成224x224
完整代码:
#!/usr/bin/env python3# -*- coding: UTF-8 -*-"""调用slim 提供的网络来运行自己的数据这里调用alexnet 网络使用的数据集 mnist参考:1、https://github.com/tensorflow/models/blob/master/research/slim/nets2、http://blog.csdn.net/wc781708249/article/details/78414028"""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport tensorflow as tfimport tensorflow.contrib.slim as slim# slim = tf.contrib.slimfrom tensorflow.examples.tutorials.mnist import input_dataimport argparseimport sysimport alexnet # 导入alexnetclass Conv_model(object): # def __init__(self, X, Y, weights, biases, learning_rate, keep): def __init__(self, Y, learning_rate): # super(Conv_model, self).__init__(X,Y,w,b,learning_rate) # 返回父类的对象 # 或者 model.Model.__init__(self,X,Y,w,b,learning_rate) # self.X = X self.Y = Y # self.weights = weights # self.biases = biases self.learning_rate = learning_rate # self.keep = keep ''' def conv2d(self, x, W, b, strides=1): # Conv2D wrapper, with bias and relu activation x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME') x = tf.nn.bias_add(x, b) # strides中间两个为1 表示x,y方向都不间隔取样 return tf.nn.relu(x) def maxpool2d(self, x, k=2): # MaxPool2D wrapper return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME') # strides中间两个为2 表示x,y方向都间隔1个取样 def inference(self, name='conv', activation='softmax'): # 重写inference函数 with tf.name_scope(name): conv1 = self.conv2d(self.X, self.weights['wc1'], self.biases['bc1']) conv1 = self.maxpool2d(conv1, k=2) # shape [N,1,1,32] conv1 = tf.nn.lrn(conv1, depth_radius=5, bias=2.0, alpha=1e-3, beta=0.75) conv1 = tf.nn.dropout(conv1, self.keep) fc1 = tf.reshape(conv1, [-1, self.weights['wd1'].get_shape().as_list()[0]]) fc1 = tf.add(tf.matmul(fc1, self.weights['wd1']), self.biases['bd1']) fc1 = tf.nn.relu(fc1) fc1 = tf.nn.dropout(fc1, self.keep) y = tf.add(tf.matmul(fc1, self.weights['out']), self.biases['out']) if activation == 'softmax': y = tf.nn.softmax(y) return y ''' def loss(self, pred_value, MSE_error=False, one_hot=True): if MSE_error: return tf.reduce_mean(tf.reduce_sum( tf.square(pred_value - self.Y), reduction_indices=[1])) else: if one_hot: return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( labels=self.Y, logits=pred_value)) else: return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.cast(self.Y, tf.int32), logits=pred_value)) def evaluate(self, pred_value, one_hot=True): if one_hot: correct_prediction = tf.equal(tf.argmax(pred_value, 1), tf.argmax(self.Y, 1)) # correct_prediction = tf.nn.in_top_k(pred_value, Y, 1) else: correct_prediction = tf.equal(tf.argmax(pred_value, 1), tf.cast(self.Y, tf.int64)) return tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) def train(self, cross_entropy): global_step = tf.Variable(0, trainable=False) return tf.train.GradientDescentOptimizer(self.learning_rate).minimize(cross_entropy, global_step=global_step)class Inputs(object): def __init__(self,file_path,batch_size,one_hot=True): self.file_path=file_path self.batch_size=batch_size self.mnist=input_data.read_data_sets(self.file_path, one_hot=one_hot) def inputs(self): batch_xs, batch_ys = self.mnist.train.next_batch(self.batch_size) return batch_xs, batch_ys def test_inputs(self): return self.mnist.test.images[:200],self.mnist.test.labels[:200]FLAGS=Nonedef train(): input_model = Inputs(FLAGS.data_dir, FLAGS.batch_size, one_hot=FLAGS.one_hot) with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, 28*28*1],'x') y_ = tf.placeholder(tf.float32, [None,10],'y_') keep=tf.placeholder(tf.float32) is_training= tf.placeholder(tf.bool, name='MODE') image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) # shape [n,28,28,1] # alexnet要求的数据shape是 224x224 image_shaped_input=tf.image.resize_image_with_crop_or_pad(image_shaped_input,224,224) # shape[n,224,224,1] # with slim.arg_scope(cifarnet_arg_scope()): # with slim.arg_scope(inception_resnet_v2_arg_scope()): # y, _ = cifarnet(images=image_shaped_input,num_classes=10,is_training=is_training,dropout_keep_prob=keep) # 上面的修改成 with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): y, _ = alexnet.alexnet_v2(inputs=image_shaped_input,num_classes=10,is_training=is_training,dropout_keep_prob=keep) model=Conv_model(y_,FLAGS.learning_rate) cross_entropy = model.loss(y, MSE_error=False, one_hot=FLAGS.one_hot) train_op = model.train(cross_entropy) accuracy = model.evaluate(y, one_hot=FLAGS.one_hot) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session() as sess: sess.run(init) for step in range(FLAGS.num_steps): batch_xs, batch_ys = input_model.inputs() train_op.run({x: batch_xs, y_: batch_ys,keep:0.7,is_training:True}) if step % FLAGS.disp_step == 0: acc=accuracy.eval({x: batch_xs, y_: batch_ys,keep:1.,is_training:False}) print("step", step, 'acc', acc, 'loss', cross_entropy.eval({x: batch_xs, y_: batch_ys,keep:1.,is_training:False})) # test acc test_x, test_y = input_model.test_inputs() acc = accuracy.eval({x: test_x, y_: test_y,keep:1.,is_training:False}) print('test acc', acc)def main(_): # if tf.gfile.Exists(FLAGS.log_dir): # tf.gfile.DeleteRecursively(FLAGS.log_dir) # if not tf.gfile.Exists(FLAGS.log_dir): # tf.gfile.MakeDirs(FLAGS.log_dir) train()if __name__=="__main__": # 设置必要参数 parser = argparse.ArgumentParser() parser.add_argument('--num_steps', type=int, default=1000, help = 'Number of steps to run trainer.') parser.add_argument('--disp_step', type=int, default=100, help='Number of steps to display.') parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate.') parser.add_argument('--batch_size', type=int, default=128, help='Number of mini training samples.') parser.add_argument('--one_hot', type=bool, default=True, help='One-Hot Encoding.') parser.add_argument('--data_dir', type=str, default='./MNIST_data', help = 'Directory for storing input data') parser.add_argument('--log_dir', type=str, default='./log_dir', help='Summaries log directory') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
阅读全文
0 0
- TF-slim 调用slim提供的网络模型训练自己的数据
- 使用tf-slim的inception_resnet_v2预训练模型进行图像分类
- 使用tf-slim的ResNet V1 152和ResNet V2 152预训练模型进行图像分类
- tensorflow slim【TF-Slim】
- TF-slim
- 使用TF-Slim:在TensorFlow中定义复杂模型的高层库
- 用tf.slim微调vgg模型时遇到的小坑
- 使用 TF-Slim 设计复杂网络
- SLIM模型
- 【Tensorflow slim】 slim.arg_scope的用法
- slim的httppost数据的解析
- slim的httppost数据的解析
- TF-slim学习
- tf.contrib.slim
- tf.contrib.slim
- TF-Slim简介
- Tensorflow之TF-Slim
- tf.slim使用方法
- 【LEFT JOIN 实战记录】统计查询-按主办处室区县查询纳入分析研判库
- KMP算法模板(字符串匹配问题)
- 黑马商城项目_制作导航条的圆点
- xml工具箱phpcms
- jquery内容过滤选择器:内容过滤选择器它是根据元素内部文本内容进行选中。
- TF-slim 调用slim提供的网络模型训练自己的数据
- 阶段总结——软件工程视频(二)
- Java compiler level does not match the version of the installed Java project facet.
- Java8 常用FunctionInterface使用方法
- ROS 学习系列 -- Roomba, Xtion Pro机器人制作地图在Android手机无法实时观测地图
- 51nod 1779 逆序对统计【状压DP】
- hadoop踩坑记2--伪分布式部署
- Spectrum数字化仪为东京大学最高的室内磁场中心提供核心部件
- MACOS下解决git push error: Permission to XXX.git denied to XXX