tensorflow系列(3)分布式tensorflow
来源:互联网 发布:淘宝新品怎么做爆款 编辑:程序博客网 时间:2024/05/21 14:50
多机如何分布式运行tensorflow模型?
(原文发表在我的博客,欢迎访问
0x00.前言
对于比较复杂的模型,在本机或者单服务器上跑起来需要很长时间。在很多科研单位或公司,可能没有插满gpu的服务器,这时候怎么办呢,有没有可能多台服务器一起跑一个模型呢?
这里就要用到分布式的tensorflow(distributed tensorflow)。
下面介绍在集群上部署tensorflow的方法。
0x01.基本概念
在分布式tensorflow中,服务器被分为两类,一类叫做参数服务器(parameter server,简称ps),另一类叫做计算服务器(worker)。顾名思义,ps会存储参数,分发参数;而worker运行模型,与ps就参数进行交互。
1.训练方式
tensorflow中常用的并行化训练方式有同步模式和异步模式两种方式。
在同步模式中,worker同时读取参数,但是训练完成后不会单独对参数进行更新,而是等待所有worker运行完,统一更新参数。
而在异步训练中,不同worker会对参数独立的更新。
0x02.tensorflow官方示例
tensorflow的官方代码在https://github.com/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py,下面我给示例代码打了一些注释,有条件的朋友可以尝试跑一下
1.变量设置
首先设置tf.app.flags定义标记,在命令行执行时,可指定相应参数的值。
import tensorflow as tfflags = tf.app.flagsFLAGS = flags.FLAGS
是否开启同步并行。
flags.DEFINE_boolean("sync_replicas", True, "Use the sync_replicas (synchronized replicas) mode, " "wherein the parameter updates from workers are aggregated " "before applied to avoid stale gradients")
在多少个batch后更新模型的参数(在同步更新中)。
flags.DEFINE_integer("replicas_to_aggregate", None, "Number of replicas to aggregate before parameter update" "is applied (For sync_replicas mode only; default: " "num_workers)")
ps服务器、worker服务器地址的设置信息。
flags.DEFINE_string("ps_hosts","10.10.19.7:2222", "Comma-separated list of hostname:port pairs")flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222", "Comma-separated list of hostname:port pairs")
job_name、task_index的定义,通常是通过命令行指定,不需要手动填写。
flags.DEFINE_string("job_name", None,"job name: worker or ps")flags.DEFINE_integer("task_index", None, "Worker task index, should be >= 0. task_index=0 is " "the master worker task the performs the variable " "initialization ")
判断是否填写job_name、task_index。
if FLAGS.job_name is None or FLAGS.job_name == "": raise ValueError("Must specify an explicit `job_name`")if FLAGS.task_index is None or FLAGS.task_index =="": raise ValueError("Must specify an explicit `task_index`")print("job name = %s" % FLAGS.job_name)print("task index = %d" % FLAGS.task_index)
从变量中解析ps、worker服务器。
ps_spec = FLAGS.ps_hosts.split(",")worker_spec = FLAGS.worker_hosts.split(",")num_workers = len(worker_spec)
2.分布式配置
创建tf中的cluster对象以及server:
cluster = tf.train.ClusterSpec({"ps": ps_spec,"worker": worker_spec})server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)# 判断是否为主节点is_chief = (FLAGS.task_index == 0)
计算资源配置,这里仅使用cpu。如果是ps服务器,则只需要等待worker服务器工作即可。
if FLAGS.job_name == "ps": server.join()cpu = 0worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
资源配置
with tf.device(tf.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)):
3.训练准备
全局步数记录
global_step = tf.Variable(0, name="global_step", trainable=False)
同步模式需要对优化器进行扩展,所以假如有优化器opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
,则有:
if FLAGS.sync_replicas: # n batch后更新模型参数 if FLAGS.replicas_to_aggregate is None: replicas_to_aggregate = num_workers else: replicas_to_aggregate = FLAGS.replicas_to_aggregate # 创建新的优化器 opt = tf.train.SyncReplicasOptimizer(opt,replicas_to_aggregate=replicas_to_aggregate, total_num_replicas=num_workers,name="mnist_sync_replicas")
优化器:
train_step = opt.minimize(cross_entropy, global_step=global_step)
初始化:
if FLAGS.sync_replicas: local_init_op = opt.local_step_init_op if is_chief: local_init_op = opt.chief_init_op ready_for_local_init_op = opt.ready_for_local_init_op # 队列执行器 chief_queue_runner = opt.get_chief_queue_runner() # 全局参数初始化器 sync_init_op = opt.get_init_tokens_op()# 本地参数初始化init_op = tf.global_variables_initializer()# 临时训练目录train_dir = tempfile.mkdtemp()
分布式训练监督器创建:
if FLAGS.sync_replicas: sv = tf.train.Supervisor(is_chief=is_chief,logdir=train_dir, init_op=init_op,local_init_op=local_init_op, ready_for_local_init_op=ready_for_local_init_op, recovery_wait_secs=1,global_step=global_step)else: sv = tf.train.Supervisor(is_chief=is_chief,logdir=train_dir, init_op=init_op,recovery_wait_secs=1, global_step=global_step)
设置sess的参数:
sess_config = tf.ConfigProto(allow_soft_placement=True,log_device_placement=False, device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
准备运行
if is_chief: print("Worker %d: Initializing session..." % FLAGS.task_index)else: print("Worker %d: Waiting for session to be initialized..." % FLAGS.task_index)# 等待/准备sessionsess = sv.prepare_or_wait_for_session(server.target, config=sess_config)print("Worker %d: Session initialization complete." % FLAGS.task_index)if FLAGS.sync_replicas and is_chief: # 全局参数初始化器 sess.run(sync_init_op) # 队列化执行器 sv.start_queue_runners(sess, [chief_queue_runner])
4.开始训练:
time_begin = time.time()print("Training begins @ %f" % time_begin)local_step = 0while True: batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size) train_feed = {x: batch_xs, y_: batch_ys} _, step = sess.run([train_step, global_step], feed_dict=train_feed) local_step += 1 now = time.time() print("%f: Worker %d: training step %d done (global step: %d)" % (now, FLAGS.task_index, local_step, step)) if step >= FLAGS.train_steps: breaktime_end = time.time()print("Training ends @ %f" % time_end)training_time = time_end - time_beginprint("Training elapsed time: %f s" % training_time)# 测试集val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}val_xent = sess.run(cross_entropy, feed_dict=val_feed)print("After %d training step(s), validation cross entropy = %g" % (FLAGS.train_steps, val_xent))
0x03.服务器上实际操作
更改tensorflow官方示例中的ps、worker服务器的ip,之后文件代码如下:
#coding:utf-8# 只是用了cpuflags.DEFINE_integer("num_gpus", 0, "Total number of gpus for each machine." "If you don't use GPU, please set it to '0'")# ps服务器、worker服务器地址设置flags.DEFINE_string("ps_hosts","10.10.19.7:2222", "Comma-separated list of hostname:port pairs")flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222", "Comma-separated list of hostname:port pairs")
在三台服务器上依次运行:
python distribute_test.py --job_name=ps --task_index=0 --sync_replicas=True
python distribute_test.py --job_name=worker --task_index=0 --sync_replicas=True
python distribute_test.py --job_name=worker --task_index=1 --sync_replicas=True
在ps服务器上可以看到输出信息:
Extracting /tmp/mnist-data/train-images-idx3-ubyte.gzExtracting /tmp/mnist-data/train-labels-idx1-ubyte.gzExtracting /tmp/mnist-data/t10k-images-idx3-ubyte.gzExtracting /tmp/mnist-data/t10k-labels-idx1-ubyte.gzjob name = pstask index = 0I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> localhost:2222}I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> 10.10.19.8:2222, 1 -> 10.10.19.9:2222}I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222
另外两台worker上:
Extracting /tmp/mnist-data/train-images-idx3-ubyte.gzExtracting /tmp/mnist-data/train-labels-idx1-ubyte.gzExtracting /tmp/mnist-data/t10k-images-idx3-ubyte.gzExtracting /tmp/mnist-data/t10k-labels-idx1-ubyte.gzjob name = workertask index = 0I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> 10.10.19.7:2222}I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> localhost:2222, 1 -> 10.10.19.9:2222}I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222Worker 0: Initializing session...I tensorflow/core/distributed_runtime/master_session.cc:1012] Start master session df31e159ecf5dc77 with config: device_filters: "/job:ps"device_filters: "/job:worker/task:0"allow_soft_placement: trueWorker 0: Session initialization complete.Training begins @ 1500442384.0914481500442384.150300: Worker 0: training step 1 done (global step: 0)1500442384.163003: Worker 0: training step 2 done (global step: 0)1500442384.172685: Worker 0: training step 3 done (global step: 1)1500442384.182413: Worker 0: training step 4 done (global step: 1)......1500442387.524158: Worker 0: training step 269 done (global step: 198)1500442387.539484: Worker 0: training step 270 done (global step: 199)1500442387.555133: Worker 0: training step 271 done (global step: 200)Training ends @ 1500442387.555215Training elapsed time: 3.463767 sAfter 200 training step(s), validation cross entropy = 781.478
worker1如下:
Extracting /tmp/mnist-data/train-images-idx3-ubyte.gzExtracting /tmp/mnist-data/train-labels-idx1-ubyte.gzExtracting /tmp/mnist-data/t10k-images-idx3-ubyte.gzExtracting /tmp/mnist-data/t10k-labels-idx1-ubyte.gzjob name = workertask index = 1I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> 10.10.19.7:2222}I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> 10.10.19.8:2222, 1 -> localhost:2222}I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222Worker 1: Waiting for session to be initialized...I tensorflow/core/distributed_runtime/master_session.cc:1012] Start master session 540b3d300aac1583 with config: device_filters: "/job:ps"device_filters: "/job:worker/task:1"allow_soft_placement: trueWorker 1: Session initialization complete.Training begins @ 1500442385.4905771500442385.520064: Worker 1: training step 1 done (global step: 68)1500442385.534573: Worker 1: training step 2 done (global step: 69)1500442385.549380: Worker 1: training step 3 done (global step: 70)......1500442387.524271: Worker 1: training step 131 done (global step: 198)1500442387.539464: Worker 1: training step 132 done (global step: 199)1500442387.555071: Worker 1: training step 133 done (global step: 200)Training ends @ 1500442387.555124Training elapsed time: 2.064547 sAfter 200 training step(s), validation cross entropy = 781.478
0x04.打造自己的代码
这里我在lstm上尝试使用分布式,对上面的代码进行了大量的复用。
1.通用代码
首先设置tensorflow的基本标志:
flags = tf.app.flagsflags.DEFINE_boolean("sync_replicas", True, "Use the sync_replicas (synchronized replicas) mode, " "wherein the parameter updates from workers are aggregated " "before applied to avoid stale gradients")flags.DEFINE_integer("replicas_to_aggregate", None, "Number of replicas to aggregate before parameter update" "is applied (For sync_replicas mode only; default: " "num_workers)")flags.DEFINE_string("ps_hosts","10.10.19.7:2222", "Comma-separated list of hostname:port pairs")flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222", "Comma-separated list of hostname:port pairs")flags.DEFINE_string("job_name", None,"job name: worker or ps")flags.DEFINE_integer("task_index", 0, "Worker task index, should be >= 0. task_index=0 is " "the master worker task the performs the variable " "initialization ")flags.DEFINE_integer("train_steps", 500, "Number of (global) training steps to perform")FLAGS = flags.FLAGSif FLAGS.job_name is None or FLAGS.job_name == "": raise ValueError("Must specify an explicit `job_name`")if FLAGS.task_index is None or FLAGS.task_index =="": raise ValueError("Must specify an explicit `task_index`")print("job name = %s" % FLAGS.job_name)print("task index = %d" % FLAGS.task_index)
之后是配置:
ps_spec = FLAGS.ps_hosts.split(",")worker_spec = FLAGS.worker_hosts.split(",")num_workers = len(worker_spec)cluster = tf.train.ClusterSpec({ "ps": ps_spec, "worker": worker_spec})server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)if FLAGS.job_name == "ps": server.join()is_chief = (FLAGS.task_index == 0)cpu = 0worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
配置计算资源:
with tf.device( tf.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)): global_step = tf.Variable(0, name="global_step", trainable=False) # tf.placeholder... # weight\bias # 这里是计算 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) if FLAGS.sync_replicas: if FLAGS.replicas_to_aggregate is None: replicas_to_aggregate = num_workers else: replicas_to_aggregate = FLAGS.replicas_to_aggregate optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=replicas_to_aggregate, total_num_replicas=num_workers, name="mnist_sync_replicas") train_step = optimizer.minimize(cost, global_step=global_step) correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) if FLAGS.sync_replicas: local_init_op = optimizer.local_step_init_op if is_chief: local_init_op = optimizer.chief_init_op ready_for_local_init_op = optimizer.ready_for_local_init_op chief_queue_runner = optimizer.get_chief_queue_runner() sync_init_op = optimizer.get_init_tokens_op() init_op = tf.global_variables_initializer() train_dir = tempfile.mkdtemp() if FLAGS.sync_replicas: sv = tf.train.Supervisor( is_chief=is_chief, logdir=train_dir, init_op=init_op, local_init_op=local_init_op, ready_for_local_init_op=ready_for_local_init_op, recovery_wait_secs=1, global_step=global_step) else: sv = tf.train.Supervisor( is_chief=is_chief, logdir=train_dir, init_op=init_op, recovery_wait_secs=1, global_step=global_step) sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False, device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index]) if is_chief: print("Worker %d: Initializing session..." % FLAGS.task_index) else: print("Worker %d: Waiting for session to be initialized..." % FLAGS.task_index) sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) print("Worker %d: Session initialization complete." % FLAGS.task_index) if FLAGS.sync_replicas and is_chief: sess.run(sync_init_op) sv.start_queue_runners(sess, [chief_queue_runner])
开始计算,这里while true
、[train_step, global_step]
的组织形式很重要。
time_begin = time.time() print("Training begins @ %f" % time_begin) local_step = 0 while True: batch_x, batch_y = mnist.train.next_batch(batch_size) batch_x = batch_x.reshape((batch_size, n_steps, n_input)) _, step = sess.run([train_step, global_step], feed_dict={x: batch_x, y: batch_y}) local_step += 1 now = time.time() print("%f: Worker %d: training step %d done (global step: %d)" % (now, FLAGS.task_index, local_step, step)) if step >= FLAGS.train_steps: break time_end = time.time() print("Training ends @ %f" % time_end) training_time = time_end - time_begin print("Training elapsed time: %f s" % training_time)
2.运行调试
同样的在服务器上运行,异步模式要比同步模式快一倍,而准确率相差不大。
0x05.参考
- tensorflow系列(3)分布式tensorflow
- 分布式TensorFlow
- 分布式tensorflow
- 分布式tensorflow
- 分布式 tensorflow
- 【Tensorflow 目录】tensorflow系列
- Tensorflow 分布式部署简介
- Tensorflow 分布式部署简介
- Tensorflow分布式并行策略
- Tensorflow 分布式部署简介
- 分布式的TensorFlow
- 分布式TensorFlow 踩坑记
- tensorflow分布式训练
- 初学Tensorflow分布式
- 分布式Tensorflow入门Demo
- Tensorflow分布式并行策略
- TensorFlow 分布式集群
- tensorflow的分布式
- Linux 系统的 VMware 虚拟机压缩硬盘
- tensorflow系列(2)自编码器AE
- Java 自定义异常 异常抛出
- JAVA win7下cmd编译Java源代码
- 【二分图匹配入门专题1】L
- tensorflow系列(3)分布式tensorflow
- LINUX的XEN和KVM到底区别在什么地方?
- 斐波那契数列在php中的简单实现
- Hibernate查询缓存
- matlab快速入门2——数据载入与保存
- python django日志器的使用及配置
- beego使用orm插入大量数据,回滚报错:buffer busy
- ArrayList概念及手写代码
- ThreadLocal原理