Keras限制GPU显存使用
来源:互联网 发布:淘宝旺铺店招 编辑:程序博客网 时间:2024/05/28 17:08
深度学习门槛越来越低,尤其是Keras这样的高层次API加入以后,简单几行代码就能构建网络并得到不错的效果。
最近工作需要,开始使用Keras写3DConv做3D数据分类。
数据是肺部CT,目标是对检测网络得到的结果使用3D Conv网络做2分类。2D的检测网络获得肺部结节的ROI,但是由于二维上结节目标与肺部正常组织结构(如血管,气管等)特征类似,所以检测结果中包含了大量“假阳”,所以在检测网络的结果基础上使用3D网络,实际测试对“假阳”有很好抑制效果。
由于检测网络使用caffe框架,而3D网络又用以TensorFlow作为后端的Keras编写,结果TensorFlow跟caffe不兼容,同一个Python进程中不能同时包含这两个框架的环境。
于是将两部分代码分别放到不同的显卡上。TensorFlow如果不加限制会占满服务器上的所有显存,需要加以限制,查看了GitHub上的源码和解决方案如下:
首先是对keras配置文件进行修改,如果是在Linux环境下,也就是修改:
~/.keras/keras.json
这个文件。
修改内容如下:
{ "epsilon": 1e-07, "floatx": "float32", "image_data_format": "channels_last", "backend": "tensorflow", "gpu_options": { "allow_growth": false, "per_process_gpu_memory_fraction": 0.5, "visible_device_list": "1" }}其中
per_process_gpu_memory_fraction表示使用显存的比例;
visible_device_list表示使用第几块显卡,0开始;
然后就需要修改后端代码:
在python库路径中找到keras位置,Linux下一般也就是:/usr/local/lib/python2.7/dist-packages/keras-2.****
这样一个路径,修改TensorFlow后端代码:
def get_session(): """Returns the TF session to be used by the backend. If a default TensorFlow session is available, we will return it. Else, we will return the global Keras session. If no global Keras session exists at this point: we will create a new global session. Note that you can manually set the global session via `K.set_session(sess)`. # Returns A TensorFlow session. """ global _SESSION if tf.get_default_session() is not None: session = tf.get_default_session() else: if _SESSION is None: _keras_base_dir = os.path.expanduser('~') if not os.access(_keras_base_dir, os.W_OK): _keras_base_dir = '/tmp' _keras_dir = os.path.join(_keras_base_dir, '.keras') _config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json')) if os.path.exists(_config_path): try: _config = json.load(open(_config_path)) except ValueError: _config = {} _options = _config.get('gpu_options', None) _allow_growth = _options.get('allow_growth', False) _mem_frac = _options.get('per_process_gpu_memory_fraction', 1.0) _visible_device_list = _options.get('visible_device_list', None) _gpu_options = tf.GPUOptions(allow_growth=_allow_growth, per_process_gpu_memory_fraction=_mem_frac, visible_device_list=_visible_device_list) if not os.environ.get('OMP_NUM_THREADS'): config = tf.ConfigProto(allow_soft_placement=True, gpu_options=_gpu_options) else: num_thread = int(os.environ.get('OMP_NUM_THREADS')) config = tf.ConfigProto(intra_op_parallelism_threads=num_thread, allow_soft_placement=True, gpu_options=_gpu_options) _SESSION = tf.Session(config=config) session = _SESSION if not _MANUAL_VAR_INIT: with session.graph.as_default(): _initialize_variables() return session
需要额外import 一个json库用来解析keras.json配置文件。
其实也就是对TensorFlow的session属性进行一下初始化。阅读全文
0 0
- Keras限制GPU显存使用
- Keras指定使用GPU
- Keras指定使用GPU
- keras使用GPU
- keras系列︱keras是如何指定显卡且限制显存用量
- Keras指定使用GPU运算
- tensorflow中使用指定的GPU及GPU显存
- keras指定运行时显卡及限制GPU用量
- tensorflow使用GPU训练时的显存占用问题
- tensorflow使用GPU训练时的显存占用问题
- tensorflow使用GPU训练时的显存占用问题
- tensorflow GPU显存控制
- Keras设定GPU使用内存大小(Tensorflow backend)
- Keras以及Tensorflow强制使用CPU,GPU
- Keras设定GPU使用内存大小(Tensorflow backend)
- (转)tensorflow中使用指定的GPU及GPU显存
- GPU显卡,显存位宽
- TensorFlow使用GPU训练网络时多块显卡的显存使用问题
- 微信小程序服务类目及资质要求
- RelativeLayout 下的控件叠加
- 面试题27:二叉搜索树与双向链表
- 友情链接
- python __slots__
- Keras限制GPU显存使用
- 主题四 指针和数组(上)----23.C语言中的字符串
- C# WinForm控件美化扩展系列之实现点击收缩的SplitContainer控件
- Java中的单例模式
- Android APP 加固思路
- [LeetCode] Brick Wall
- JSP数据交换
- 遇到bug json解析问题
- memcached实现分布式缓存