tensorflow系列(3)分布式tensorflow

来源:互联网 发布:淘宝新品怎么做爆款 编辑:程序博客网 时间:2024/05/21 14:50

多机如何分布式运行tensorflow模型?

(原文发表在我的博客,欢迎访问

0x00.前言

对于比较复杂的模型,在本机或者单服务器上跑起来需要很长时间。在很多科研单位或公司,可能没有插满gpu的服务器,这时候怎么办呢,有没有可能多台服务器一起跑一个模型呢?

这里就要用到分布式的tensorflow(distributed tensorflow)。

下面介绍在集群上部署tensorflow的方法。

0x01.基本概念

在分布式tensorflow中,服务器被分为两类,一类叫做参数服务器(parameter server,简称ps),另一类叫做计算服务器(worker)。顾名思义,ps会存储参数,分发参数;而worker运行模型,与ps就参数进行交互。

1.训练方式

tensorflow中常用的并行化训练方式有同步模式和异步模式两种方式。

在同步模式中,worker同时读取参数,但是训练完成后不会单独对参数进行更新,而是等待所有worker运行完,统一更新参数。

而在异步训练中,不同worker会对参数独立的更新。

0x02.tensorflow官方示例

tensorflow的官方代码在https://github.com/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py,下面我给示例代码打了一些注释,有条件的朋友可以尝试跑一下

1.变量设置

首先设置tf.app.flags定义标记,在命令行执行时,可指定相应参数的值。

import tensorflow as tfflags = tf.app.flagsFLAGS = flags.FLAGS

是否开启同步并行。

flags.DEFINE_boolean("sync_replicas", True,                     "Use the sync_replicas (synchronized replicas) mode, "                     "wherein the parameter updates from workers are aggregated "                     "before applied to avoid stale gradients")

在多少个batch后更新模型的参数(在同步更新中)。

flags.DEFINE_integer("replicas_to_aggregate", None,                     "Number of replicas to aggregate before parameter update"                     "is applied (For sync_replicas mode only; default: "                     "num_workers)")

ps服务器、worker服务器地址的设置信息。

flags.DEFINE_string("ps_hosts","10.10.19.7:2222",                    "Comma-separated list of hostname:port pairs")flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222",                    "Comma-separated list of hostname:port pairs")

job_name、task_index的定义,通常是通过命令行指定,不需要手动填写。

flags.DEFINE_string("job_name", None,"job name: worker or ps")flags.DEFINE_integer("task_index", None,                     "Worker task index, should be >= 0. task_index=0 is "                     "the master worker task the performs the variable "                     "initialization ")

判断是否填写job_name、task_index。

if FLAGS.job_name is None or FLAGS.job_name == "":    raise ValueError("Must specify an explicit `job_name`")if FLAGS.task_index is None or FLAGS.task_index =="":    raise ValueError("Must specify an explicit `task_index`")print("job name = %s" % FLAGS.job_name)print("task index = %d" % FLAGS.task_index)

从变量中解析ps、worker服务器。

ps_spec = FLAGS.ps_hosts.split(",")worker_spec = FLAGS.worker_hosts.split(",")num_workers = len(worker_spec)

2.分布式配置

创建tf中的cluster对象以及server:

cluster = tf.train.ClusterSpec({"ps": ps_spec,"worker": worker_spec})server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)# 判断是否为主节点is_chief = (FLAGS.task_index == 0)

计算资源配置,这里仅使用cpu。如果是ps服务器,则只需要等待worker服务器工作即可。

if FLAGS.job_name == "ps":    server.join()cpu = 0worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)

资源配置

with tf.device(tf.train.replica_device_setter(          worker_device=worker_device,          ps_device="/job:ps/cpu:0",          cluster=cluster)):

3.训练准备

全局步数记录

global_step = tf.Variable(0, name="global_step", trainable=False)

同步模式需要对优化器进行扩展,所以假如有优化器opt = tf.train.AdamOptimizer(FLAGS.learning_rate),则有:

if FLAGS.sync_replicas:    # n batch后更新模型参数    if FLAGS.replicas_to_aggregate is None:        replicas_to_aggregate = num_workers    else:        replicas_to_aggregate = FLAGS.replicas_to_aggregate    # 创建新的优化器    opt = tf.train.SyncReplicasOptimizer(opt,replicas_to_aggregate=replicas_to_aggregate,        total_num_replicas=num_workers,name="mnist_sync_replicas")

优化器:

train_step = opt.minimize(cross_entropy, global_step=global_step)

初始化:

if FLAGS.sync_replicas:    local_init_op = opt.local_step_init_op    if is_chief:        local_init_op = opt.chief_init_op    ready_for_local_init_op = opt.ready_for_local_init_op    # 队列执行器    chief_queue_runner = opt.get_chief_queue_runner()    # 全局参数初始化器    sync_init_op = opt.get_init_tokens_op()# 本地参数初始化init_op = tf.global_variables_initializer()# 临时训练目录train_dir = tempfile.mkdtemp()

分布式训练监督器创建:

if FLAGS.sync_replicas:    sv = tf.train.Supervisor(is_chief=is_chief,logdir=train_dir,          init_op=init_op,local_init_op=local_init_op,          ready_for_local_init_op=ready_for_local_init_op,          recovery_wait_secs=1,global_step=global_step)else:    sv = tf.train.Supervisor(is_chief=is_chief,logdir=train_dir,          init_op=init_op,recovery_wait_secs=1,          global_step=global_step)

设置sess的参数:

sess_config = tf.ConfigProto(allow_soft_placement=True,log_device_placement=False,        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])

准备运行

if is_chief:    print("Worker %d: Initializing session..." % FLAGS.task_index)else:    print("Worker %d: Waiting for session to be initialized..." % FLAGS.task_index)# 等待/准备sessionsess = sv.prepare_or_wait_for_session(server.target, config=sess_config)print("Worker %d: Session initialization complete." % FLAGS.task_index)if FLAGS.sync_replicas and is_chief:      # 全局参数初始化器      sess.run(sync_init_op)      # 队列化执行器      sv.start_queue_runners(sess, [chief_queue_runner])

4.开始训练:

time_begin = time.time()print("Training begins @ %f" % time_begin)local_step = 0while True:    batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)    train_feed = {x: batch_xs, y_: batch_ys}    _, step = sess.run([train_step, global_step], feed_dict=train_feed)    local_step += 1    now = time.time()    print("%f: Worker %d: training step %d done (global step: %d)" %        (now, FLAGS.task_index, local_step, step))    if step >= FLAGS.train_steps:        breaktime_end = time.time()print("Training ends @ %f" % time_end)training_time = time_end - time_beginprint("Training elapsed time: %f s" % training_time)# 测试集val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}val_xent = sess.run(cross_entropy, feed_dict=val_feed)print("After %d training step(s), validation cross entropy = %g" %          (FLAGS.train_steps, val_xent))

0x03.服务器上实际操作

更改tensorflow官方示例中的ps、worker服务器的ip,之后文件代码如下:

#coding:utf-8# 只是用了cpuflags.DEFINE_integer("num_gpus", 0,                     "Total number of gpus for each machine."                     "If you don't use GPU, please set it to '0'")# ps服务器、worker服务器地址设置flags.DEFINE_string("ps_hosts","10.10.19.7:2222",                    "Comma-separated list of hostname:port pairs")flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222",                    "Comma-separated list of hostname:port pairs")

在三台服务器上依次运行:

python distribute_test.py --job_name=ps --task_index=0 --sync_replicas=True

python distribute_test.py --job_name=worker --task_index=0 --sync_replicas=True

python distribute_test.py --job_name=worker --task_index=1 --sync_replicas=True

在ps服务器上可以看到输出信息:

Extracting /tmp/mnist-data/train-images-idx3-ubyte.gzExtracting /tmp/mnist-data/train-labels-idx1-ubyte.gzExtracting /tmp/mnist-data/t10k-images-idx3-ubyte.gzExtracting /tmp/mnist-data/t10k-labels-idx1-ubyte.gzjob name = pstask index = 0I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> localhost:2222}I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> 10.10.19.8:2222, 1 -> 10.10.19.9:2222}I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222

另外两台worker上:

Extracting /tmp/mnist-data/train-images-idx3-ubyte.gzExtracting /tmp/mnist-data/train-labels-idx1-ubyte.gzExtracting /tmp/mnist-data/t10k-images-idx3-ubyte.gzExtracting /tmp/mnist-data/t10k-labels-idx1-ubyte.gzjob name = workertask index = 0I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> 10.10.19.7:2222}I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> localhost:2222, 1 -> 10.10.19.9:2222}I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222Worker 0: Initializing session...I tensorflow/core/distributed_runtime/master_session.cc:1012] Start master session df31e159ecf5dc77 with config: device_filters: "/job:ps"device_filters: "/job:worker/task:0"allow_soft_placement: trueWorker 0: Session initialization complete.Training begins @ 1500442384.0914481500442384.150300: Worker 0: training step 1 done (global step: 0)1500442384.163003: Worker 0: training step 2 done (global step: 0)1500442384.172685: Worker 0: training step 3 done (global step: 1)1500442384.182413: Worker 0: training step 4 done (global step: 1)......1500442387.524158: Worker 0: training step 269 done (global step: 198)1500442387.539484: Worker 0: training step 270 done (global step: 199)1500442387.555133: Worker 0: training step 271 done (global step: 200)Training ends @ 1500442387.555215Training elapsed time: 3.463767 sAfter 200 training step(s), validation cross entropy = 781.478

worker1如下:

Extracting /tmp/mnist-data/train-images-idx3-ubyte.gzExtracting /tmp/mnist-data/train-labels-idx1-ubyte.gzExtracting /tmp/mnist-data/t10k-images-idx3-ubyte.gzExtracting /tmp/mnist-data/t10k-labels-idx1-ubyte.gzjob name = workertask index = 1I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> 10.10.19.7:2222}I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> 10.10.19.8:2222, 1 -> localhost:2222}I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222Worker 1: Waiting for session to be initialized...I tensorflow/core/distributed_runtime/master_session.cc:1012] Start master session 540b3d300aac1583 with config: device_filters: "/job:ps"device_filters: "/job:worker/task:1"allow_soft_placement: trueWorker 1: Session initialization complete.Training begins @ 1500442385.4905771500442385.520064: Worker 1: training step 1 done (global step: 68)1500442385.534573: Worker 1: training step 2 done (global step: 69)1500442385.549380: Worker 1: training step 3 done (global step: 70)......1500442387.524271: Worker 1: training step 131 done (global step: 198)1500442387.539464: Worker 1: training step 132 done (global step: 199)1500442387.555071: Worker 1: training step 133 done (global step: 200)Training ends @ 1500442387.555124Training elapsed time: 2.064547 sAfter 200 training step(s), validation cross entropy = 781.478

0x04.打造自己的代码

这里我在lstm上尝试使用分布式,对上面的代码进行了大量的复用。

1.通用代码

首先设置tensorflow的基本标志:

flags = tf.app.flagsflags.DEFINE_boolean("sync_replicas", True,                     "Use the sync_replicas (synchronized replicas) mode, "                     "wherein the parameter updates from workers are aggregated "                     "before applied to avoid stale gradients")flags.DEFINE_integer("replicas_to_aggregate", None,                     "Number of replicas to aggregate before parameter update"                     "is applied (For sync_replicas mode only; default: "                     "num_workers)")flags.DEFINE_string("ps_hosts","10.10.19.7:2222",                    "Comma-separated list of hostname:port pairs")flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222",                    "Comma-separated list of hostname:port pairs")flags.DEFINE_string("job_name", None,"job name: worker or ps")flags.DEFINE_integer("task_index", 0,                     "Worker task index, should be >= 0. task_index=0 is "                     "the master worker task the performs the variable "                     "initialization ")flags.DEFINE_integer("train_steps", 500,                     "Number of (global) training steps to perform")FLAGS = flags.FLAGSif FLAGS.job_name is None or FLAGS.job_name == "":    raise ValueError("Must specify an explicit `job_name`")if FLAGS.task_index is None or FLAGS.task_index =="":    raise ValueError("Must specify an explicit `task_index`")print("job name = %s" % FLAGS.job_name)print("task index = %d" % FLAGS.task_index)

之后是配置:

ps_spec = FLAGS.ps_hosts.split(",")worker_spec = FLAGS.worker_hosts.split(",")num_workers = len(worker_spec)cluster = tf.train.ClusterSpec({    "ps": ps_spec,    "worker": worker_spec})server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)if FLAGS.job_name == "ps":    server.join()is_chief = (FLAGS.task_index == 0)cpu = 0worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)

配置计算资源:

with tf.device(      tf.train.replica_device_setter(          worker_device=worker_device,          ps_device="/job:ps/cpu:0",          cluster=cluster)):    global_step = tf.Variable(0, name="global_step", trainable=False)    # tf.placeholder...    # weight\bias    # 这里是计算    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)    if FLAGS.sync_replicas:      if FLAGS.replicas_to_aggregate is None:        replicas_to_aggregate = num_workers      else:        replicas_to_aggregate = FLAGS.replicas_to_aggregate      optimizer = tf.train.SyncReplicasOptimizer(          optimizer,          replicas_to_aggregate=replicas_to_aggregate,          total_num_replicas=num_workers,          name="mnist_sync_replicas")    train_step = optimizer.minimize(cost, global_step=global_step)    correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))    if FLAGS.sync_replicas:        local_init_op = optimizer.local_step_init_op        if is_chief:            local_init_op = optimizer.chief_init_op        ready_for_local_init_op = optimizer.ready_for_local_init_op        chief_queue_runner = optimizer.get_chief_queue_runner()        sync_init_op = optimizer.get_init_tokens_op()    init_op = tf.global_variables_initializer()    train_dir = tempfile.mkdtemp()    if FLAGS.sync_replicas:        sv = tf.train.Supervisor(          is_chief=is_chief,          logdir=train_dir,          init_op=init_op,          local_init_op=local_init_op,          ready_for_local_init_op=ready_for_local_init_op,          recovery_wait_secs=1,          global_step=global_step)    else:        sv = tf.train.Supervisor(          is_chief=is_chief,          logdir=train_dir,          init_op=init_op,          recovery_wait_secs=1,          global_step=global_step)    sess_config = tf.ConfigProto(        allow_soft_placement=True,        log_device_placement=False,        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])    if is_chief:        print("Worker %d: Initializing session..." % FLAGS.task_index)    else:        print("Worker %d: Waiting for session to be initialized..." %            FLAGS.task_index)    sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)    print("Worker %d: Session initialization complete." % FLAGS.task_index)    if FLAGS.sync_replicas and is_chief:        sess.run(sync_init_op)        sv.start_queue_runners(sess, [chief_queue_runner])

开始计算,这里while true[train_step, global_step]的组织形式很重要。

    time_begin = time.time()    print("Training begins @ %f" % time_begin)    local_step = 0    while True:        batch_x, batch_y = mnist.train.next_batch(batch_size)        batch_x = batch_x.reshape((batch_size, n_steps, n_input))        _, step = sess.run([train_step, global_step], feed_dict={x: batch_x, y: batch_y})        local_step += 1        now = time.time()        print("%f: Worker %d: training step %d done (global step: %d)" %            (now, FLAGS.task_index, local_step, step))        if step >= FLAGS.train_steps:            break    time_end = time.time()    print("Training ends @ %f" % time_end)    training_time = time_end - time_begin    print("Training elapsed time: %f s" % training_time)

2.运行调试

同样的在服务器上运行,异步模式要比同步模式快一倍,而准确率相差不大。

0x05.参考

原创粉丝点击