keras —— 常见问题汇总
来源:互联网 发布:华硕黄静事件 知乎 编辑:程序博客网 时间:2024/05/21 10:14
如何引用keras?
如果keras对您的研究有帮助,请在出版物中引用。BibTeX例子如下:
@misc{chollet2015keras, title={Keras}, author={Chollet, Fran\c{c}ois and others}, year={2015}, publisher={GitHub}, howpublished={\url{https://github.com/fchollet/keras}},}
如何在GPU上运行keras?
如果运行在TensorFlow后端上,代码会自动运行在检测到的GPU上。
如果运行在Theano后端上,可使用方法有:
1、使用Theano标志
THEANO_FLAGS=device=gpu,floatX=float32 python my_keras_script.py
'gpu'根据你的设备更改识别(例如gpu0, gpu1等等)
2、设置.theanorc
3、在代码最前面手动设置theano.config.device,theano.config.floatX
import theanotheano.config.device = 'gpu'theano.config.floatX = 'float32'
不推荐用pickle或者cPickle来保存Keras模型。
可以使用model.save(filepath)将Keras模型保存到HDF5文件,包含:
-模型结构,可重建模型
-模型权重
-训练设置(损失、优化器)
-优化器状态,可恢复训练
然后使用keras.models.load_model(filepath)重新实例化模型。load_model会使用保存的训练设置重新编译模型。
例如:
from keras.models import load_modelmodel.save('my_model.h5') # creates a HDF5 file 'my_model.h5'del model # deletes the existing model# returns a compiled model# identical to the previous onemodel = load_model('my_model.h5')
如果你只需要保存模型的结构,不需要权重和训练设置,则可以:
# save as JSONjson_string = model.to_json()# save as YAMLyaml_string = model.to_yaml()
生成的JSON/YAML文件具有可读性,如需要可以手工修改。
也可以从这个数据中重新构建新模型:
# model reconstruction from JSON:from keras.models import model_from_jsonmodel = model_from_json(json_string)# model reconstruction from YAMLfrom keras.models import model_from_yamlmodel = model_from_yaml(yaml_string)
如果只需要保存模型的权重,可以在HDF5文件中使用以下代码。注意你需要已安装HDF5和h5py。
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5', by_name=True)
举例如下:
"""Assume original model looks like this: model = Sequential() model.add(Dense(2, input_dim=3, name="dense_1")) model.add(Dense(3, name="dense_2")) ... model.save_weights(fname)"""# new modelmodel = Sequential()model.add(Dense(2, input_dim=3, name="dense_1")) # will be loadedmodel.add(Dense(10, name="new_dense")) # will not be loaded# load weights from first model; will only affect the first layer, dense_1.model.load_weights(fname, by_name=True)
如何获得中间层的输出?
一个简单的方法是构建一个新模型输出你感兴趣的层。
from keras.models import Modelmodel = ... # create the original modellayer_name = 'my_layer'intermediate_layer_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output)intermediate_output = intermediate_layer_model.predict(data)
或者构建一个Keras函数返回给定输入的特定层的输出。
from keras import backend as K# with a Sequential modelget_3rd_layer_output = K.function([model.layers[0].input], [model.layers[3].output])layer_output = get_3rd_layer_output([X])[0]
注意,如果模型在训练和测试阶段表现不同(例如使用Dropout,BatchNormalization等),你需要将训练阶段标志传入函数。
get_3rd_layer_output = K.function([model.layers[0].input, K.learning_phase()], [model.layers[3].output])# output in test mode = 0layer_output = get_3rd_layer_output([X, 0])[0]# output in train mode = 1layer_output = get_3rd_layer_output([X, 1])[0]
可以用model.train_on_batch(X, y)和model.test_on_batch(X, y)进行批次训练。
或者可以写一个生成器生成训练数据的批次然后使用方法model.fit_generator(data_generator, steps_per_epoch, epochs)
应用可参考https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py
如果验证损失不再下降如果打断训练?
可以使用EarlyStopping回调
from keras.callbacks import EarlyStoppingearly_stopping = EarlyStopping(monitor='val_loss', patience=2)model.fit(X, y, validation_split=0.2, callbacks=[early_stopping])
验证分割如何计算?
如果你设置model.fit申明validation_split为0.1,那么验证集使用最后10%的数据(如果在提取验证数据前没有将数据打乱)。同一验证集用于所有阶段(在同一fit调用)。
训练时打乱数据吗?
是的,如果model.fit中的shuffle设为True(默认值),训练数据在每个阶段会随机打乱。
如何在每个结算记录训练/验证损失/准确度?
model.fit方法返回一个History回调,有一个history属性包含了连续损失及其他度量的列表。
hist = model.fit(X, y, validation_split=0.2)print(hist.history)
冻结意味着把它排除在训练外,例如权重不会更新。这在细调模型中有用,或者对文本输入使用固定嵌入。
可以传递trainable申明(boolean)到层构建器设定层为non-trainable。
frozen_layer = Dense(32, trainable=False)
额外的,可以在实例化后设置一个层的trainable属性为True或者False。需要调用compile()才使其生效。
x = Input(shape=(32,))layer = Dense(32)layer.trainable = Falsey = layer(x)frozen_model = Model(x, y)# in the model below, the weights of `layer` will not be updated during trainingfrozen_model.compile(optimizer='rmsprop', loss='mse')layer.trainable = Truetrainable_model = Model(x, y)# with this model the weights of the layer will be updated during training# (which will also affect the above model since it uses the same layer instance)trainable_model.compile(optimizer='rmsprop', loss='mse')frozen_model.fit(data, labels) # this does NOT update the weights of `layer`trainable_model.fit(data, labels) # this updates the weights of `layer`
如何使用状态RNNs?
X # this is our input data, of shape (32, 21, 16)# we will feed it to our model in sequences of length 10model = Sequential()model.add(LSTM(32, input_shape=(10, 16), batch_size=32, stateful=True))model.add(Dense(16, activation='softmax'))model.compile(optimizer='rmsprop', loss='categorical_crossentropy')# we train the network to predict the 11th timestep given the first 10:model.train_on_batch(X[:, :10, :], np.reshape(X[:, 10, :], (32, 16)))# the state of the network has changed. We can feed the follow-up sequences:model.train_on_batch(X[:, 10:20, :], np.reshape(X[:, 20, :], (32, 16)))# let's reset the states of the LSTM layer:model.reset_states()# another way to do it in this case:model.layers[0].reset_states()
如何在Keras中使用预训练模型?
我们提供以下图像分类模型的代码和预训练权重:
- Xception
- VGG16
- VGG19
- ResNet50
- Inception v3
from keras.applications.xception import Xceptionfrom keras.applications.vgg16 import VGG16from keras.applications.vgg19 import VGG19from keras.applications.resnet50 import ResNet50from keras.applications.inception_v3 import InceptionV3model = VGG16(weights='imagenet', include_top=True)
import h5pywith h5py.File('input/file.hdf5', 'r') as f: X_data = f['X_data'] model.predict(X_data)
也可以使用keras.utils.io_utils中的HDF5Matrix类。
Keras设置文件储存在哪里?
默认的文件夹在$HOME/.keras/
如果因为没有权限创建上述文件夹,则可能在/tmp/.keras/
- keras —— 常见问题汇总
- keras中文文档笔记3——常见问题与解答
- Keras 常见问题
- 【Elasticsearch】常见问题汇总——持续更新
- keras学习笔记3——Merge、GPU调用、快速开始及常见问题
- Keras——Tensorflow
- lstm——keras
- keras 损失函数汇总
- keras神经网络常见问题-mse, nmse
- 【Android入门】——模拟器的创建及常见问题汇总
- 二叉树问题汇总(2)—常见问题
- 常见问题汇总
- 常见问题汇总
- 常见问题汇总
- 常见问题汇总
- keras学习笔记2——Keras模块概述
- keras学习随笔03——常用keras layers
- keras中文文档笔记9——关于keras层
- jsp
- 我的实习面经(Android开发,已拿阿里,华为,CVTE Offer)
- crond和crontab
- JPA学习(三):java持久化查询语言JPQL--介绍、基础语法
- GNU-ld链接脚本浅析
- keras —— 常见问题汇总
- 字节对齐的意义
- 关于crond和crontab
- jQuery Ajax 实例 ($.ajax、$.post、$.get)
- Unity3D引擎-摄像机控制CameraControl
- GDOI2017再次酱油记
- IPC进程间通信主题之信号量
- struts2+jdbc+oracle查询(模糊查询+页面查询)
- Python学习笔记——类属性和实例属性