Tensorflow-model模板
来源:互联网 发布:js文字大小变化效果 编辑:程序博客网 时间:2024/05/20 10:22
参考:http://www.cnblogs.com/wang-kai/p/6479960.html
#!/usr/bin/env python3# -*- coding: UTF-8 -*-"""说明数据:mnist模型建立 Model数据的输入 Inputs模型保存与提取 Save_and_load_mode模型可视化 TensorBoard"""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport osimport argparseimport sysclass Model(object): def __init__(self,X,Y,w,b,learning_rate): self.X=X self.Y=Y self.w=w self.b=b self.learning_rate=learning_rate def inference(self,activation='softmax'): if activation=='softmax': pred=tf.nn.softmax(tf.matmul(self.X, self.w) + self.b) else: pred=tf.nn.bias_add(tf.matmul(self.X, self.w),self.b) return pred def loss(self,pred_value,MSE_error=False,one_hot=True): if MSE_error:return tf.reduce_mean(tf.reduce_sum( tf.square(pred_value-self.Y),reduction_indices=[1])) else: if one_hot: return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( labels=self.Y, logits=pred_value)) else: return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.cast(self.Y, tf.int32), logits=pred_value)) def evaluate(self,pred_value,one_hot=True): if one_hot: correct_prediction = tf.equal(tf.argmax(pred_value, 1), tf.argmax(self.Y, 1)) # correct_prediction= tf.nn.in_top_k(pred_value, Y, 1) else: correct_prediction = tf.equal(tf.argmax(pred_value, 1), tf.cast(self.y, tf.int64)) return tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) def train(self,cross_entropy): global_step = tf.Variable(0, trainable=False) return tf.train.GradientDescentOptimizer(self.learning_rate).minimize(cross_entropy, global_step=global_step)class Inputs(object): def __init__(self,file_path,batch_size,one_hot=True): self.file_path=file_path self.batch_size=batch_size self.mnist=input_data.read_data_sets(self.file_path, one_hot=one_hot) def inputs(self): batch_xs, batch_ys = self.mnist.train.next_batch(self.batch_size) return batch_xs, batch_ys def test_inputs(self): return self.mnist.test.images,self.mnist.test.labelsclass Save_and_load_mode(object): def __init__(self,logdir,sess): self.saver = tf.train.Saver() self.logdir=logdir # 保存模型位置 self.sess=sess def save_model(self,step): if not os.path.exists(self.logdir):os.makedirs(self.logdir) self.saver.save(self.sess, os.path.join(self.logdir,'model.ckpt'), global_step=step) def load_model(self): # 验证之前是否已经保存了检查点文件 ckpt = tf.train.get_checkpoint_state(self.logdir) if ckpt and ckpt.model_checkpoint_path: self.saver.restore(self.sess, ckpt.model_checkpoint_path) return True else: return Falseclass TensorBoard(object): def __init__(self): pass def variable_summaries(self,var): """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" with tf.name_scope('summaries'): mean = tf.reduce_mean(var) tf.summary.scalar('mean', mean) with tf.name_scope('stddev'): stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.summary.scalar('stddev', stddev) tf.summary.scalar('max', tf.reduce_max(var)) tf.summary.scalar('min', tf.reduce_min(var)) tf.summary.histogram('histogram', var) def image_summary(self,name,tensor,max_outputs=10): tf.summary.image(name, tensor, max_outputs) def hist_summary(self,name,values): tf.summary.histogram(name, values) def scalar_summary(self,name,tensor): tf.summary.scalar(name, tensor) def merge_all_summary(self): return tf.summary.merge_all() def FileWriter_summary(self,log_dir,graph=None): return tf.summary.FileWriter(log_dir,graph)FLAGS = Nonedef train(): tb_model=TensorBoard() # Input layer with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, 28*28*1],'x') y_ = tf.placeholder(tf.float32, [None,10],'y_') with tf.name_scope('input_reshape'): image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) tb_model.image_summary('input', image_shaped_input, 10) # Output layer with tf.name_scope('line_layer'): with tf.name_scope('weights'): w = tf.Variable(tf.random_normal([28*28*1, 10])) # 二分类 tb_model.variable_summaries(w) with tf.name_scope('biases'): b = tf.Variable(tf.random_normal([10])) tb_model.variable_summaries(b) input_model=Inputs(FLAGS.data_dir,FLAGS.batch_size,one_hot=FLAGS.one_hot) model=Model(x,y_,w,b,FLAGS.learning_rate) with tf.name_scope('Wx_plus_b'): y=model.inference(activation='softmax') tb_model.hist_summary('pred',y) with tf.name_scope('total_loss'): cross_entropy=model.loss(y,MSE_error=False,one_hot=FLAGS.one_hot) tb_model.scalar_summary('cross_entropy', cross_entropy) train_op=model.train(cross_entropy) with tf.name_scope('accuracy'): accuracy=model.evaluate(y,one_hot=FLAGS.one_hot) tb_model.scalar_summary('accuracy', accuracy) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session() as sess: # Merge all the summaries and write them out to /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default) merged = tb_model.merge_all_summary() if not os.path.exists(os.path.join(FLAGS.log_dir + '/train')): os.makedirs(os.path.join(FLAGS.log_dir + '/train')) if not os.path.exists(os.path.join(FLAGS.log_dir + '/test')): os.makedirs(os.path.join(FLAGS.log_dir + '/test')) train_writer = tb_model.FileWriter_summary(os.path.join(FLAGS.log_dir + '/train'),sess.graph) test_writer = tb_model.FileWriter_summary(os.path.join(FLAGS.log_dir + '/test')) save=Save_and_load_mode(FLAGS.log_dir,sess) if not save.load_model():init.run() for step in range(FLAGS.num_steps): batch_xs, batch_ys = input_model.inputs() train_op.run({x: batch_xs, y_: batch_ys}) if step % FLAGS.disp_step == 0: acc=accuracy.eval({x: batch_xs, y_: batch_ys}) print("step", step, 'acc', acc, 'loss', cross_entropy.eval({x: batch_xs, y_: batch_ys})) train_result = merged.eval({x: batch_xs, y_: batch_ys}) train_writer.add_summary(train_result, step) test_x, test_y = input_model.test_inputs() acc = accuracy.eval({x: test_x, y_: test_y}) print("step", step, 'acc', acc) test_result = merged.eval({x: test_x, y_: test_y}) test_writer.add_summary(test_result, step) save.save_model(step) """ # test acc test_x,test_y=input_model.test_inputs() print('test acc', acc) """def main(_): # if tf.gfile.Exists(FLAGS.log_dir): # tf.gfile.DeleteRecursively(FLAGS.log_dir) if not tf.gfile.Exists(FLAGS.log_dir): tf.gfile.MakeDirs(FLAGS.log_dir) train()if __name__=="__main__": # 设置必要参数 parser = argparse.ArgumentParser() parser.add_argument('--num_steps', type=int, default=1000, help = 'Number of steps to run trainer.') parser.add_argument('--disp_step', type=int, default=100, help='Number of steps to display.') parser.add_argument('--learning_rate', type=float, default=0.5, help='Learning rate.') parser.add_argument('--batch_size', type=int, default=128, help='Number of mini training samples.') parser.add_argument('--one_hot', type=bool, default=True, help='One-Hot Encoding.') parser.add_argument('--data_dir', type=str, default='./MNIST_data/', help = 'Directory for storing input data') parser.add_argument('--log_dir', type=str, default='./log_dir', help='Summaries log directory') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)# 启动TensorBoard: tensorboard --logdir=path/to/log-directory# tensorboard --logdir='log_dir'
阅读全文
0 0
- Tensorflow-model模板
- tensorflow pretrained model
- Tensorflow Model Persistence
- Tensorflow-save model
- Tensorflow-model API
- Build basic model in Tensorflow
- CodeSmith Model层模板
- tensorflow学习笔记(十一):seq2seq Model
- ResNet-TensorFlow Model Zoo代码理解
- TensorFlow学习笔记8----TensorFlow Linear Model Tutorial
- 模板模式(template model)解析例子
- 设计模式之模板模式(Template Model)
- laravel框架model类简单模板
- RNN代码解读之char-RNN with TensorFlow(model.py)
- CS 20SI|Lecture 4 Structure your TensorFlow model
- Tensorflow: How to restore a previously saved model (python)
- Tensorflow使用pip安装后没有model目录
- 利用opencv3读取tensorflow model,对图像进行分类
- ndarray数据类型
- Xcode(IOS)百度地图开发起步篇
- java.lang.String 下面的toUpperCase()和toLowerCase()方法
- 四行命令 Ubuntu安装Oracle JDK8 且自动配置环境变量 [翻译]
- 定制Oreo 初步
- Tensorflow-model模板
- iOS textView 高度自适应
- 12.Nginx启动流程之ngx_init_cycle
- AngularJS最佳实践: 请小心使用 ng-repeat 中的 $index
- android锯齿问题,解决方案
- servlet会话技术Cookie&Session
- 设置label上文字显示不同大小、颜色、字体类型
- 单例模式:饿汉式、懒汉式、静态内部类式
- SpringMvc中的拦截器