keras load model 报错

来源:互联网 发布:老子去了哪里 知乎 编辑:程序博客网 时间:2024/06/05 18:04

在keras中使用save model 或者 check point 保持模型的时候,例如

check_point = ModelCheckpoint('./check_point/weights.{epoch:02d}-{val_loss:.2f}.hdf5', monitor='val_loss', verbose=1,                              save_best_only=False, save_weights_only=False, period=20)


然后重新加载模型,对数据进行预测的时候,操作如下

model = load_model('check_point/weights.79-0.31.hdf5')model.summary()model.predict_on_batch(depth_list)

但是出现了以下报错

Traceback (most recent call last):  File "/home/jia/Desktop/My_hand_pose/evaluation.py", line 17, in <module>    model = load_model('check_point/weights.79-0.31.hdf5')  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/models.py", line 246, in load_model    model = model_from_config(model_config, custom_objects=custom_objects)  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/models.py", line 314, in model_from_config    return layer_module.deserialize(config, custom_objects=custom_objects)  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize    printable_module_name='layer')  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 140, in deserialize_keras_object    list(custom_objects.items())))  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/engine/topology.py", line 2450, in from_config    process_layer(layer_data)  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/engine/topology.py", line 2419, in process_layer    custom_objects=custom_objects)  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize    printable_module_name='layer')  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 142, in deserialize_keras_object    return cls.from_config(config['config'])  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/engine/topology.py", line 1242, in from_config    return cls(**config)  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/layers/advanced_activations.py", line 38, in __init__    self.alpha = K.cast_to_floatx(alpha)  #modified as https://github.com/fchollet/keras/issues/7107  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/keras/backend/common.py", line 108, in cast_to_floatx    return np.asarray(x, dtype=_FLOATX)  File "/home/jia/.virtualenvs/keras_tf/local/lib/python2.7/site-packages/numpy/core/numeric.py", line 531, in asarray    return array(a, dtype, copy=False, order=order)TypeError: float() argument must be a string or a number

从后面往前推,查找错误,发现是高级激活函数层, advanced_activations在中的LeakyReLu出现了问题,解决办法

Modifying the advanced_activations.py file in keras/layers as follows:
I changed this:

self.alpha = K.cast_to_floatx(alpha)
to:

try:    self.alpha = K.cast_to_floatx(alpha)except TypeError:    self.alpha = K.cast_to_floatx(alpha['value'])


重新运行,发现这个时候能够正确load模型了


参考: https://github.com/fchollet/keras/issues/7107   感谢这位大神!!!