Dl4j-fit(DataSetIterator iterator)源码阅读(一)

来源:互联网 发布:ubuntu 进入grub 编辑:程序博客网 时间:2024/05/12 07:12

fit(DataSetIterator iterator)源码阅读

1 网络模型

//Create the networkint numInput = 1;int numOutputs = 1;int nHidden = 2;MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()        .seed(seed)        .iterations(iterations)        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)        .learningRate(learningRate)        .weightInit(WeightInit.XAVIER)        .updater(Updater.SGD)     //To configure: .updater(new Nesterovs(0.9))        .list()        .layer(0, new DenseLayer.Builder().nIn(numInput).nOut(nHidden)                .activation(Activation.RELU).dropOut(0.5)                .build())        .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)                .activation(Activation.IDENTITY)                .nIn(numInput).nOut(numOutputs).build())        .pretrain(false).backprop(true).build());

调用net.fit(iterator);对源码进行单步阅读。

2 fit(DataSetIterator iterator)

@Overridepublic void fit(DataSetIterator iterator) {    DataSetIterator iter;    // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate    if (iterator.asyncSupported()) {        iter = new AsyncDataSetIterator(iterator, 2);    } else {        iter = iterator;    }    if (trainingListeners.size() > 0) {        for (TrainingListener tl : trainingListeners) {            tl.onEpochStart(this);        }    }    if (layerWiseConfigurations.isPretrain()) {        pretrain(iter);        if (iter.resetSupported()) {            iter.reset();        }    }    if (layerWiseConfigurations.isBackprop()) {        update(TaskUtils.buildTask(iter));        if (!iter.hasNext() && iter.resetSupported()) {            iter.reset();        }        while (iter.hasNext()) {            DataSet next = iter.next();            if (next.getFeatureMatrix() == null || next.getLabels() == null)                break;            boolean hasMaskArrays = next.hasMaskArrays();            if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {                doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(),                                next.getLabelsMaskArray());            } else {                if (hasMaskArrays)                    setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());                setInput(next.getFeatureMatrix());                setLabels(next.getLabels());                if (solver == null) {                    solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();                }                solver.optimize();            }            if (hasMaskArrays)                clearLayerMaskArrays();            Nd4j.getMemoryManager().invokeGcOccasionally();        }    } else if (layerWiseConfigurations.isPretrain()) {        log.warn("Warning: finetune is not applied.");    }    if (trainingListeners.size() > 0) {        for (TrainingListener tl : trainingListeners) {            tl.onEpochEnd(this);        }    }}

2.1 iterator.asyncSupported()

if (iterator.asyncSupported()) {    iter = new AsyncDataSetIterator(iterator, 2);} else {    iter = iterator;}

这里主要判断所给的迭代器是否支持异步,如果支持异步则生成异步迭代器。一般自己实现iterator的时候,对于asyncSupported的实现都是return false;

2.2 trainingListeners.size() > 0

if (trainingListeners.size() > 0) {    for (TrainingListener tl : trainingListeners) {        tl.onEpochStart(this);    }}

这个trainingListeners字段在API文档和对应源码中没有找到对应的解释,从字面意思上是训练监听器。通常使用情况下,不涉及到这个字段

2.3 layerWiseConfigurations.isBackprop()

接下来判断神经网络是否使用Backprop,这个在神经网络的通常情况下,默认值为true

if (layerWiseConfigurations.isBackprop()) {    update(TaskUtils.buildTask(iter));    //如果iter没有下一个元素,且iter支持reset操作    if (!iter.hasNext() && iter.resetSupported()) {        //则调用一个reset,重置迭代器。        iter.reset();    }    //当迭代器拥有元素的时候    while (iter.hasNext()) {        //调用next获取下一个批次需要训练的数据        DataSet next = iter.next();        //如果next中的特征矩阵或者标签矩阵为空的时候,则结束训练过程        if (next.getFeatureMatrix() == null || next.getLabels() == null)            break;        //判断当选训练集合是否拥有掩码(掩码通常在RNN中使用,因为RNN可能会处理非等长序列,需要使用掩码-即填0操作,使得非等长序列等长)        boolean hasMaskArrays = next.hasMaskArrays();        //这里用于判断网络架构的反向传播类型。(TruncatedBPTT这个是RNN常用的方法,截断式反向传播,BPTT- backprop through time, 主要用于解决梯度消失的问题)        if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {            doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(),                            next.getLabelsMaskArray());        } else {            //判断掩码            if (hasMaskArrays)                setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());            //设置特征矩阵            setInput(next.getFeatureMatrix());            //设置标签            setLabels(next.getLabels());            //初始化Solver            //Sovle的类标注是Generic purpose solver。简单理解为            if (solver == null) {                //根据网络架构构造Sovler                solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();            }            solver.optimize();        }        if (hasMaskArrays)            clearLayerMaskArrays();        Nd4j.getMemoryManager().invokeGcOccasionally();    }} else if (layerWiseConfigurations.isPretrain()) {    log.warn("Warning: finetune is not applied.");}

2.3.1 update(TaskUtils.buildTask(iter));

接下来执行

update(TaskUtils.buildTask(iter));

语句。根据后面源码的阅读,这个task的建立是根据当前的网络模型对训练任务目标的确立。

  1. 首先根据传入的iter进行Task的建立。所调用的函数为

    public static Task buildTask(DataSetIterator dataSetIterator) {    return new Task();}

    Task类
    这里使用lombok的两个注解@Data@NoArgsConstructor对这个类进行标注
    这时候获取的类的样式如下Task(networkType=null, architectureType=null, numFeatures=0, numLabels=0, numSamples=0)

  2. 执行update函数

     private void update(Task task) {    if (!initDone) {        //因为`initDone`初始为false,到此时,`initDone`字段改变,标识网络模型的构造完成。        initDone = true;        Heartbeat heartbeat = Heartbeat.getInstance();        //根据网络模型架构填充task类        task = ModelSerializer.taskByModel(this);        Environment env = EnvironmentUtils.buildEnvironment();        heartbeat.reportEvent(Event.STANDALONE, env, task);    }}

    这里用于展开ModelSerializer.taskByModel(this);函数,这个函数主要是根据所传入的model的架构类型对Task进行字段的填充。

    public static Task taskByModel(Model model) {    Task task = new Task();    try {        //先对网络架构设置一个默认值。如当前网络的架构是DenseLayer不满足下列任意一个网络模型,此时就拥有一个默认的网络架构类型。        task.setArchitectureType(Task.ArchitectureType.RECURRENT);        //如果传入的model是一个自定义的计算图模型        if (model instanceof ComputationGraph) {            //设置网络结构类型            task.setNetworkType(Task.NetworkType.ComputationalGraph);            ComputationGraph network = (ComputationGraph) model;            try {                //如果网络层数大于0                if (network.getLayers() != null && network.getLayers().length > 0) {                    //遍历网络层                    for (Layer layer : network.getLayers()) {                        //如果是RBM(受限玻尔兹曼机)                        if (layer instanceof RBM                                        || layer instanceof org.deeplearning4j.nn.layers.feedforward.rbm.RBM) {                            task.setArchitectureType(Task.ArchitectureType.RBM);                            break;                        }                        if (layer.type().equals(Layer.Type.CONVOLUTIONAL)) {                            //如果是卷积                            task.setArchitectureType(Task.ArchitectureType.CONVOLUTION);                            break;                        } else if (layer.type().equals(Layer.Type.RECURRENT)                                        || layer.type().equals(Layer.Type.RECURSIVE)) {                            //如果是循环神经网络                            task.setArchitectureType(Task.ArchitectureType.RECURRENT);                            break;                        }                    }                } else                    task.setArchitectureType(Task.ArchitectureType.UNKNOWN);            } catch (Exception e) {                // do nothing here            }        } else if (model instanceof MultiLayerNetwork) {            //如果是多层网络            task.setNetworkType(Task.NetworkType.MultilayerNetwork);            MultiLayerNetwork network = (MultiLayerNetwork) model;            try {                if (network.getLayers() != null && network.getLayers().length > 0) {                    for (Layer layer : network.getLayers()) {                        if (layer instanceof RBM                                        || layer instanceof org.deeplearning4j.nn.layers.feedforward.rbm.RBM) {                            task.setArchitectureType(Task.ArchitectureType.RBM);                            break;                        }                        if (layer.type().equals(Layer.Type.CONVOLUTIONAL)) {                            task.setArchitectureType(Task.ArchitectureType.CONVOLUTION);                            break;                        } else if (layer.type().equals(Layer.Type.RECURRENT)                                        || layer.type().equals(Layer.Type.RECURSIVE)) {                            task.setArchitectureType(Task.ArchitectureType.RECURRENT);                            break;                        }                    }                } else                    task.setArchitectureType(Task.ArchitectureType.UNKNOWN);            } catch (Exception e) {                // do nothing here            }        }        return task;    } catch (Exception e) {        task.setArchitectureType(Task.ArchitectureType.UNKNOWN);        task.setNetworkType(Task.NetworkType.DenseNetwork);        return task;    }}

    注: initDone字段是MultiLayerNetwork的一个字段。且初始值为false。

    @Setterprotected boolean initDone = false;

2.3.2 Solver

/** 3. Generic purpose solver 4. @author Adam Gibson */public class Solver {    private NeuralNetConfiguration conf;    private Collection<IterationListener> listeners;    private Model model;    private ConvexOptimizer optimizer;    private StepFunction stepFunction;    public void optimize() {        if (optimizer == null)            optimizer = getOptimizer();        optimizer.optimize();    }    public ConvexOptimizer getOptimizer() {        if (optimizer != null)            return optimizer;        switch (conf.getOptimizationAlgo()) {            case LBFGS:                optimizer = new LBFGS(conf, stepFunction, listeners, model);                break;            case LINE_GRADIENT_DESCENT:                optimizer = new LineGradientDescent(conf, stepFunction, listeners, model);                break;            case CONJUGATE_GRADIENT:                optimizer = new ConjugateGradient(conf, stepFunction, listeners, model);                break;            case STOCHASTIC_GRADIENT_DESCENT:                optimizer = new StochasticGradientDescent(conf, stepFunction, listeners, model);                break;            default:                throw new IllegalStateException("No optimizer found");        }        return optimizer;    }    public void setListeners(Collection<IterationListener> listeners) {        this.listeners = listeners;        if (optimizer != null)            optimizer.setListeners(listeners);    }    public static class Builder {        private NeuralNetConfiguration conf;        private Model model;        private List<IterationListener> listeners = new ArrayList<>();        public Builder configure(NeuralNetConfiguration conf) {            this.conf = conf;            return this;        }        public Builder listener(IterationListener... listeners) {            this.listeners.addAll(Arrays.asList(listeners));            return this;        }        public Builder listeners(Collection<IterationListener> listeners) {            this.listeners.addAll(listeners);            return this;        }        public Builder model(Model model) {            this.model = model;            return this;        }        public Solver build() {            Solver solver = new Solver();            solver.conf = conf;            solver.stepFunction = StepFunctions.createStepFunction(conf.getStepFunction());            solver.model = model;            solver.listeners = listeners;            return solver;        }    }}

以上是对Solver这个类的源码,接下来查看源码执行部分

solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
  1. 首先调用configure()listeners()model()等方法获取MultiLayerNetwork类的配置,然后再调用build()方法根据各种配置实例化对象
  2. 除上述之外,主要观察stepFunction这个属性的配置。这里单步因为第一次调用的时候conf.getStepFunction()为null, 所以stepFunction也为null。
  3. 之后就要执行solver.optimize()方法。
阅读全文
0 0
原创粉丝点击