Tensorflow学习笔记4:分布式Tensorflow
来源:互联网 发布:淘宝透明拉链袋 编辑:程序博客网 时间:2024/05/16 17:21
Tensorflow学习笔记4:分布式Tensorflow
简介
Tensorflow API提供了Cluster、Server以及Supervisor来支持模型的分布式训练。
关于Tensorflow的分布式训练介绍可以参考Distributed Tensorflow。简单的概括说明如下:
- Tensorflow分布式Cluster由多个Task组成,每个Task对应一个tf.train.Server实例,作为Cluster的一个单独节点;
- 多个相同作用的Task可以被划分为一个job,例如ps job作为参数服务器只保存Tensorflow model的参数,而worker job则作为计算节点只执行计算密集型的Graph计算。
- Cluster中的Task会相对进行通信,以便进行状态同步、参数更新等操作。
Tensorflow分布式集群的所有节点执行的代码是相同的。分布式任务代码具有固定的模式:
# 第1步:命令行参数解析,获取集群的信息ps_hosts和worker_hosts,以及当前节点的角色信息job_name和task_index# 第2步:创建当前task结点的Servercluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)# 第3步:如果当前节点是ps,则调用server.join()无休止等待;如果是worker,则执行第4步。if FLAGS.job_name == "ps": server.join()# 第4步:则构建要训练的模型# build tensorflow graph model# 第5步:创建tf.train.Supervisor来管理模型的训练过程# Create a "supervisor", which oversees the training process.sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), logdir="/tmp/train_logs")# The supervisor takes care of session initialization and restoring from a checkpoint.sess = sv.prepare_or_wait_for_session(server.target)# Loop until the supervisor shuts downwhile not sv.should_stop() # train model
Tensorflow分布式训练代码框架
根据上面说到的Tensorflow分布式训练代码固定模式,如果要编写一个分布式的Tensorlfow代码,其框架如下所示。
import tensorflow as tf# Flags for defining the tf.train.ClusterSpectf.app.flags.DEFINE_string("ps_hosts", "", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("worker_hosts", "", "Comma-separated list of hostname:port pairs")# Flags for defining the tf.train.Servertf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")FLAGS = tf.app.flags.FLAGSdef main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts(",") # Create a cluster from the parameter server and worker hosts. cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) # Create and start a server for the local task. server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": # Assigns ops to the local worker by default. with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): # Build model... loss = ... global_step = tf.Variable(0) train_op = tf.train.AdagradOptimizer(0.01).minimize( loss, global_step=global_step) saver = tf.train.Saver() summary_op = tf.merge_all_summaries() init_op = tf.initialize_all_variables() # Create a "supervisor", which oversees the training process. sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), logdir="/tmp/train_logs", init_op=init_op, summary_op=summary_op, saver=saver, global_step=global_step, save_model_secs=600) # The supervisor takes care of session initialization and restoring from # a checkpoint. sess = sv.prepare_or_wait_for_session(server.target) # Start queue runners for the input pipelines (if any). sv.start_queue_runners(sess) # Loop until the supervisor shuts down (or 1000000 steps have completed). step = 0 while not sv.should_stop() and step < 1000000: # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. _, step = sess.run([train_op, global_step])if __name__ == "__main__": tf.app.run()
对于所有Tensorflow分布式代码,可变的只有两点:
- 构建tensorflow graph模型代码;
- 每一步执行训练的代码
分布式MNIST任务
我们通过修改tensorflow/tensorflow提供的mnist_softmax.py来构造分布式的MNIST样例来进行验证。修改后的代码请参考mnist_dist.py。
我们同样通过tensorlfow的Docker image来启动一个容器来进行验证。
$ docker run -d -v /path/to/your/code:/tensorflow/mnist --name tensorflow tensorflow/tensorflow
启动tensorflow之后,启动4个Terminal,然后通过下面命令进入tensorflow容器,切换到/tensorflow/mnist目录下
$ docker exec -ti tensorflow /bin/bash$ cd /tensorflow/mnist
然后在四个Terminal中分别执行下面一个命令来启动Tensorflow cluster的一个task节点,
# Start ps 0python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=ps --task_index=0# Start ps 1python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=ps --task_index=1# Start worker 0python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=worker --task_index=0# Start worker 1python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=worker --task_index=1
具体效果自己验证哈。
0 0
- Tensorflow学习笔记4:分布式Tensorflow
- tensorflow学习笔记(十九):分布式Tensorflow
- tensorflow学习笔记(十九):分布式Tensorflow
- TensorFlow学习笔记4
- tensorflow学习笔记(二十):分布式注意事项
- 学习笔记TF061:分布式TensorFlow,分布式原理、最佳实践
- tensorflow 学习笔记(4)-basic_example
- tensorflow26《TensorFlow实战Google深度学习框架》笔记-10-03 分布式TensorFlow code
- 分布式TensorFlow
- 分布式tensorflow
- 分布式tensorflow
- 分布式 tensorflow
- TensorFlow学习笔记-1
- TensorFlow学习笔记
- TensorFlow 深度学习笔记
- TensorFlow学习笔记1
- tensorflow-Alexnet学习笔记
- TensorFlow学习笔记
- linux串口编程 3种方式
- linux使用脚本杀死指定名称的进程
- JNI学习笔记 C++传递结构体、String、数组对象给JavaC++传递结构体、String、数组对象给Java
- 与Java的初次相遇
- 传智168期JavaEE hibernate 姜涛 day34~day35(2017年2月27日16:50:17)
- Tensorflow学习笔记4:分布式Tensorflow
- python解析csv文件 提取数据
- 网络编程学习笔记
- 堕落之源------我的第一篇博客
- IOS Category 与 Extension区别
- 禁止ScrollView的childview自动滑动到底部
- css样式的加载顺序及覆盖顺序
- c++强制类型转换
- 内存空间 逻辑地址空间 相对地址 绝对地址