【keras-DeepLearning_Models】_obtain_input_shape() got an unexpected keyword argument 'include_top'

来源:互联网 发布:如何做美工设计 编辑:程序博客网 时间:2024/06/03 10:21

前言:


最近想跑一些主流的网络感受感受。从github上找到了 deep-learning-models 提供的几个模型,包括:inception-v2, inception-v3, resnet50, vgg16, vgg19 等等。这些代码都是基于 keras 框架,正好我最近有在学 tensorflow 和 keras,所以很想跑跑这些代码。


心动不如行动,万事俱备,只欠把代码跑起来。此时,出现了一些常见的问题,也正好借此机会整理下来。问题如下:


1)_obtain_input_shape() got an unexpected keyword argument 'include_top'

2)Exception: URL fetch failure on https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5: None -- [Errno 110] Connection timed out


本文主要分析和整理了第一个问题的解决方案。


第二个问题就是:在下载模型参数文件的过程中,可能url对应的地址被墙了,导致下载不了。解决办法就是另想办法把 resnet50_weights_tf_dim_ordering_tf_kernels.h5 这个文件下载下来。如果有需要的话,可在本文留言。


Reference:

1)github 源码:https://github.com/fchollet/deep-learning-models

2)模型参数资源:https://github.com/fchollet/deep-learning-models/releases

3)相关博客:http://blog.csdn.net/sinat_26917383/article/details/72982230

4)本文使用的测试数据如下所示:



elephant.jpg



本文使用的 tensorflow 和 keras 的版本:


- tensorflow:

>>> import tensorflow as tf>>> tf.__version__'1.1.0'

- keras:

import keras>>> print keras.__version__2.0.9

本文实践步骤如下,以 "resnet50.py" 为例:


1)下载 github 源码,源码中使用一张名为“elephant.jpg”的图像作为测试。

2)下载测试数据集,上文有提供链接。

3)在源码的目录下执行如下命令:

python resnet50.py

程序报错,如下所示:

Traceback (most recent call last):  File "resnet50.py", line 289, in <module>    model = ResNet50(include_top=True, weights='imagenet')  File "resnet50.py", line 193, in ResNet50    include_top=include_top)TypeError: _obtain_input_shape() got an unexpected keyword argument 'include_top'


导致程序报错的原因分析:


1)keras.__version__ == 2.0.9 中,函数 _obtain_input_shape() 的形式:

def _obtain_input_shape(input_shape,                        default_size,                        min_size,                        data_format,                        require_flatten,                        weights=None):

2)deep-learning-models 案例中,调用 _obtain_input_shape() 函数的方式如下:

# Determine proper input shapeinput_shape = _obtain_input_shape(input_shape,                                  default_size=299,                                  min_size=71,                                  data_format=K.image_data_format(),                                  include_top=include_top)

显然,本文使用的环境中,函数 _obtain_input_shape() 的形参中,没有关键字 include_top 而是 require_flatten。而案例中的代码使用的是关键字 include_top,这显然是不可取的。综上,只要把案例中的 inlcude_top 关键字换成 require_flatten 即可。如下所示:

# Determine proper input shapeinput_shape = _obtain_input_shape(input_shape,                                  default_size=224,                                  min_size=197,                                  data_format=K.image_data_format(),                                  require_flatten=include_top)

再次执行命令,就可以成功运行案例代码,如下图所示:





本文的第二个问题:无法正常下载模型参数


模型参数“resnet50_weights_tf_dim_ordering_tf_kernels.h5”下载地址如下:

链接:http://pan.baidu.com/s/1dE1Lh5J 密码:puke


需修改的代码如下:

# WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'

# load weightsif weights == 'imagenet':    # if include_top:    #     weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels.h5',    #                             WEIGHTS_PATH,    #                             cache_subdir='models',    #                             md5_hash='a7b3fe01876f51b976af0dea6bc144eb')    # else:    #     weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',    #                             WEIGHTS_PATH_NO_TOP,    #                             cache_subdir='models',    #                             md5_hash='a268eb855778b3df3c7506639542a6af')    weights_path = './resnet50_weights_tf_dim_ordering_tf_kernels.h5'    model.load_weights(weights_path)

运行结果如下所示:

Predicted: [[(u'n01871265', u'tusker', 0.65325415), (u'n02504458', u'African_elephant', 0.29492217), (u'n02504013', u'Indian_elephant', 0.048155606), (u'n02422106', u'hartebeest', 0.001847562), (u'n02397096', u'warthog', 0.00034257883)]]

由结果可知:p(tusker) = 0.65, p(African_elephant) = 0.29, p(Indian_elephant) = 0.048 ... ... 其中 tusker 的概率是最高的,所以识别结果为 tusker(有长牙的动物,如:象,野猪等)



阅读全文
0 0