keras中models的Squential类的源码简介

来源:互联网 发布:淘宝小二电话是多少 编辑:程序博客网 时间:2024/06/07 16:00

keras中最重要的就是models的Sequential类了,下面我结合源码以及自己的理解,尽可能的去学习并能够说明白,源代码太多,先贴一个fit函数的实现:

    def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],            validation_split=0., validation_data=None, shuffle=True,            class_weight=None, sample_weight=None, **kwargs):        '''        Args:            x: 表示输入可以是narray, 如果是多个输入,也可以是[narray, narray], 必须有            y: labels,only a narray, 必须有            batch_size: mini batch表示多少次更新一次权重,默认是32            nb_epoch: 需要迭代多少次去训练这个模型,默认是10            verbose: 是不是输出打印log到标准输出,默认是打印            callbacks: 回调函数(暂时不是很理解这个地方怎么用)            validation_split: 测试数据的比例,默认是0            validation_data: 测试数据,tuple(input , lable)默认是空            shuffle:不懂            class_weight:不懂            sample_weight:不懂            **kwargs: 只有一个候选项就是 'show_accuracy'        Returns:        '''        '''Trains the model for a fixed number of epochs.        # Arguments            x: input data, as a Numpy array or list of Numpy arrays                (if the model has multiple inputs).            y: labels, as a Numpy array.            batch_size: integer. Number of samples per gradient update.            nb_epoch: integer, the number of epochs to train the model.            verbose: 0 for no logging to stdout,                1 for progress bar logging, 2 for one log line per epoch.            callbacks: list of `keras.callbacks.Callback` instances.                List of callbacks to apply during training.                See [callbacks](/callbacks).            validation_split: float (0. < x < 1).                Fraction of the data to use as held-out validation data.            validation_data: tuple (X, y) to be used as held-out                validation data. Will override validation_split.            shuffle: boolean or str (for 'batch').                Whether to shuffle the samples at each epoch.                'batch' is a special option for dealing with the                limitations of HDF5 data; it shuffles in batch-sized chunks.            class_weight: dictionary mapping classes to a weight value,                used for scaling the loss function (during training only).            sample_weight: Numpy array of weights for                the training samples, used for scaling the loss function                (during training only). You can either pass a flat (1D)                Numpy array with the same length as the input samples                (1:1 mapping between weights and samples),                or in the case of temporal data,                you can pass a 2D array with shape (samples, sequence_length),                to apply a different weight to every timestep of every sample.                In this case you should make sure to specify                sample_weight_mode="temporal" in compile().        # Returns            A `History` object. Its `History.history` attribute is            a record of training loss values and metrics values            at successive epochs, as well as validation loss values            and validation metrics values (if applicable).        '''        if self.model is None:            raise Exception('The model needs to be compiled before being used.')        if 'show_accuracy' in kwargs:            kwargs.pop('show_accuracy')            warnings.warn('The "show_accuracy" argument is deprecated, '                          'instead you should pass the "accuracy" metric to '                          'the model at compile time:\n'                          '`model.compile(optimizer, loss, '                          'metrics=["accuracy"])`')        if kwargs:            raise Exception('Received unknown keyword arguments: ' +                            str(kwargs))        return self.model.fit(x, y,                              batch_size=batch_size,                              nb_epoch=nb_epoch,                              verbose=verbose,                              callbacks=callbacks,                              validation_split=validation_split,                              validation_data=validation_data,                              shuffle=shuffle,                              class_weight=class_weight,                              sample_weight=sample_weight)
主要是学会怎么使用,因为这段代码放到整个类中去看才有意义,所以,后续继续补充吧, 发现欠了好多债了,后续需要补充的东西太多了,逼我把源码看完的节奏。

0 0