slurm提交Tensorflow任务实现

来源:互联网 发布:重庆企业网站seo 编辑:程序博客网 时间:2024/05/22 03:28

主要目的

目前tensorflow单机多卡模式可以参考tutorial很容易使用,但是如果想在集群多节点搭建分布式tensorflow训练任务部署,官方没有一个很好的示例代码,只能通过很naive的方法,指定ps node/worker node,在不同的节点分别执行对应的程序来实现多机协同训练模型的效果.这种方式对于集群环境,存在大量节点的情况就显得非常的不方便.本文是基于slurm集群资源管理工具,实现分布式tensorflow训练任务的分发.

实现

#定义function用与读取slurm提交一个任务后,分配的集群计算资源.#传递两个参数:ps_number代表需要的parameter server节点个数,默认剩余其它节点均作为worker节点.#作为ps的node也可以作为worker,但是为了避免端口的冲突,我们不这么做.#port_number传递本次任务多节点通信的端口,如果ps所在的node同时还启动了worker进程,那么不同的worker进程需要指定不同的端口,为方便,默认使用的节点个数num_nodes>1,worker与ps不分配在相同节点.def tf_config_from_slurm(ps_number, port_number=2222):    """    Creates configuration for a distributed tensorflow session     from environment variables  provided by the Slurm cluster    management system.    @param: ps_number number of parameter servers to run    @param: port_number port number to be used for communication    @return: a tuple containing cluster with fields cluster_spec,             task_name and task_id     """    nodelist = os.environ["SLURM_JOB_NODELIST"]    print(nodelist)    print("jacob")    nodename = os.environ["SLURMD_NODENAME"]    nodelist = _expand_nodelist(nodelist)    num_nodes = int(os.getenv("SLURM_JOB_NUM_NODES"))    if len(nodelist) != num_nodes:        raise ValueError("Number of slurm nodes {} not equal to {}".format(len(nodelist), num_nodes))    if nodename not in nodelist:        raise ValueError("Nodename({}) not in nodelist({}). This should not happen! ".format(nodename,nodelist))  if ps_number > num_nodes :        raise ValueError("Number of ps node is largger than nodes be given by slurm!")    ps_nodes = [node for i, node in enumerate(nodelist) if i < ps_number]    worker_nodes = [node for i, node in enumerate(nodelist) if i >= ps_number]    if nodename in ps_nodes:        my_job_name = "ps"        my_task_index = ps_nodes.index(nodename)    else:        my_job_name = "worker"        my_task_index = worker_nodes.index(nodename)    worker_sockets = [":".join([node, str(port_number)]) for node in worker_nodes]    ps_sockets = [":".join([node, str(port_number)]) for node in ps_nodes]    cluster = {"worker": worker_sockets, "ps" : ps_sockets}    return cluster, my_job_name, my_task_indexdef _pad_zeros(iterable, length):    return (str(t).rjust(length, '0') for t in iterable)def _expand_ids(ids):    ids = ids.split(',')    result = []    for id in ids:        if '-' in id:            begin, end = [int(token) for token in id.split('-')]            result.extend(_pad_zeros(range(begin, end+1), len(token)))        else:            result.append(id)    return resultdef _expand_nodelist(nodelist):    prefix, ids = re.findall("(.*)\[(.*)\]", nodelist)[0]    ids = _expand_ids(ids)    result = [prefix + str(id) for id in ids]    return resultdef _worker_task_id(nodelist, nodename):    return nodelist.index(nodename)

tensorflow构建网络模型

# 获取slurm分配的集群计算资源,以及当前执行节点的job name,配置clusterspec并启动server.# 另外需要注意的是ps节点因为要保持接收worker的消息,完成参数的同步更新,所以其服务需要一直join,不能直接退出.cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=3)cluster_spec = tf.train.ClusterSpec(cluster)server = tf.train.Server(server_or_cluster_def=cluster_spec,                         job_name=my_job_name,                         task_index=my_task_index)if my_job_name == 'ps':    server.join()    sys.exit(0)
后续完善后继续更新