Keras限制GPU显存使用

来源:互联网 发布:淘宝旺铺店招 编辑:程序博客网 时间:2024/05/28 17:08

深度学习门槛越来越低,尤其是Keras这样的高层次API加入以后,简单几行代码就能构建网络并得到不错的效果。

最近工作需要,开始使用Keras写3DConv做3D数据分类。

数据是肺部CT,目标是对检测网络得到的结果使用3D Conv网络做2分类。2D的检测网络获得肺部结节的ROI,但是由于二维上结节目标与肺部正常组织结构(如血管,气管等)特征类似,所以检测结果中包含了大量“假阳”,所以在检测网络的结果基础上使用3D网络,实际测试对“假阳”有很好抑制效果。

由于检测网络使用caffe框架,而3D网络又用以TensorFlow作为后端的Keras编写,结果TensorFlow跟caffe不兼容,同一个Python进程中不能同时包含这两个框架的环境。

于是将两部分代码分别放到不同的显卡上。TensorFlow如果不加限制会占满服务器上的所有显存,需要加以限制,查看了GitHub上的源码和解决方案如下:

首先是对keras配置文件进行修改,如果是在Linux环境下,也就是修改:

~/.keras/keras.json

这个文件。

修改内容如下:

{    "epsilon": 1e-07,    "floatx": "float32",    "image_data_format": "channels_last",    "backend": "tensorflow",    "gpu_options": {        "allow_growth": false,        "per_process_gpu_memory_fraction": 0.5,        "visible_device_list": "1"    }}
其中
per_process_gpu_memory_fraction
表示使用显存的比例;

visible_device_list
表示使用第几块显卡,0开始;

然后就需要修改后端代码:

在python库路径中找到keras位置,Linux下一般也就是:/usr/local/lib/python2.7/dist-packages/keras-2.****

这样一个路径,修改TensorFlow后端代码:

def get_session():    """Returns the TF session to be used by the backend.    If a default TensorFlow session is available, we will return it.    Else, we will return the global Keras session.    If no global Keras session exists at this point:    we will create a new global session.    Note that you can manually set the global session    via `K.set_session(sess)`.    # Returns        A TensorFlow session.    """    global _SESSION    if tf.get_default_session() is not None:        session = tf.get_default_session()    else:        if _SESSION is None:            _keras_base_dir = os.path.expanduser('~')            if not os.access(_keras_base_dir, os.W_OK):                 _keras_base_dir = '/tmp'            _keras_dir = os.path.join(_keras_base_dir, '.keras')            _config_path = os.path.expanduser(os.path.join(_keras_dir,                                                            'keras.json'))            if os.path.exists(_config_path):                try:                    _config = json.load(open(_config_path))                except ValueError:                    _config = {}            _options = _config.get('gpu_options', None)            _allow_growth = _options.get('allow_growth', False)            _mem_frac = _options.get('per_process_gpu_memory_fraction', 1.0)            _visible_device_list = _options.get('visible_device_list', None)            _gpu_options = tf.GPUOptions(allow_growth=_allow_growth,                                         per_process_gpu_memory_fraction=_mem_frac,                                         visible_device_list=_visible_device_list)            if not os.environ.get('OMP_NUM_THREADS'):                config = tf.ConfigProto(allow_soft_placement=True,                                        gpu_options=_gpu_options)            else:                num_thread = int(os.environ.get('OMP_NUM_THREADS'))                config = tf.ConfigProto(intra_op_parallelism_threads=num_thread,                                        allow_soft_placement=True,                                        gpu_options=_gpu_options)            _SESSION = tf.Session(config=config)        session = _SESSION    if not _MANUAL_VAR_INIT:        with session.graph.as_default():            _initialize_variables()    return session

需要额外import 一个json库用来解析keras.json配置文件。

其实也就是对TensorFlow的session属性进行一下初始化。