Dl4j-fit(DataSetIterator iterator)源码阅读(四)dropout

来源:互联网 发布:数据互通的手游 编辑:程序博客网 时间:2024/04/28 10:32

preOut这一部分就是网络模型前向传播的重点。

public INDArray preOutput(boolean training) {    applyDropOutIfNecessary(training);    INDArray b = getParam(DefaultParamInitializer.BIAS_KEY);    INDArray W = getParam(DefaultParamInitializer.WEIGHT_KEY);    //Input validation:    if (input.rank() != 2 || input.columns() != W.rows()) {        if (input.rank() != 2) {            throw new DL4JInvalidInputException("Input that is not a matrix; expected matrix (rank 2), got rank "                            + input.rank() + " array with shape " + Arrays.toString(input.shape()));        }        throw new DL4JInvalidInputException("Input size (" + input.columns() + " columns; shape = "                        + Arrays.toString(input.shape())                        + ") is invalid: does not match layer input size (layer # inputs = " + W.size(0) + ")");    }    if (conf.isUseDropConnect() && training && conf.getLayer().getDropOut() > 0) {        W = Dropout.applyDropConnect(this, DefaultParamInitializer.WEIGHT_KEY);    }    INDArray ret = input.mmul(W).addiRowVector(b);    if (maskArray != null) {        applyMask(ret);    }    return ret;}

首先使用applyDropOutIfNecessary(training);函数判断当前是否使用dropout。

protected void applyDropOutIfNecessary(boolean training) {    if (conf.getLayer().getDropOut() > 0 && !conf.isUseDropConnect() && training && !dropoutApplied) {        input = input.dup();        Dropout.applyDropout(input, conf.getLayer().getDropOut());        dropoutApplied = true;    }}

使用dropout的条件如下:

  1. 当前层设置 dropout > 0
  2. 当前配置没有使用dropConnect(), 这一配置在卷积神经网络常见。
  3. 当前是训练过程,也就是training的值为true。 在预测的时候dropout不会被应用
  4. dropout在之前没有被调用。

如果以上条件都满足,则先对当前的输入使用dup()函数进行复制(注:dup取自单词duplicate,复制的意思),然后传入下一个函数。

/** 5. Apply dropout to the given input 6. and return the drop out mask used 7. @param input the input to do drop out on 8. @param dropout the drop out probability */public static void applyDropout(INDArray input, double dropout) {    if (Nd4j.getRandom().getStatePointer() != null) {        Nd4j.getExecutioner().exec(new DropOutInverted(input, dropout));    } else {        Nd4j.getExecutioner().exec(new LegacyDropOutInverted(input, dropout));    }}

dropout的实现方式很多,根据这个源码阅读方式发现,dl4j的dropout实现方式是根据截断当前层的输入来实现drpout。

/** 9. This method returns pointer to RNG state structure. 10. Please note: DefaultRandom implementation returns NULL here, making it impossible to use with RandomOps 11.  - @return */@Overridepublic Pointer getStatePointer() {    return statePointer;}

这个getStatePointer()的目的从代码的注释情况上来还不是很清楚。接下来查看两种实现方式

  1. DropOutInverted
/** * Inverted DropOut implementation as Op * * @author raver119@gmail.com */public class DropOutInverted extends BaseRandomOp {    private double p;    public DropOutInverted() {    }    public DropOutInverted(@NonNull INDArray x, double p) {        this(x, x, p, x.lengthLong());    }    public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p) {        this(x, z, p, x.lengthLong());    }    public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p, long n) {        this.p = p;        init(x, null, z, n);    }    @Override    public int opNum() {        return 2;    }    @Override    public String name() {        return "dropout_inverted";    }    @Override    public void init(INDArray x, INDArray y, INDArray z, long n) {        super.init(x, y, z, n);        this.extraArgs = new Object[] {p};    }}
  1. LegacyDropOutInverted
/** * Inverted DropOut implementation as Op * * PLEASE NOTE: This is legacy DropOutInverted implementation, please consider using op with the same name from randomOps * @author raver119@gmail.com */public class LegacyDropOutInverted extends BaseTransformOp {    private double p;    public LegacyDropOutInverted() {    }    public LegacyDropOutInverted(INDArray x, double p) {        super(x);        this.p = p;        init(x, null, x, x.length());    }    public LegacyDropOutInverted(INDArray x, INDArray z, double p) {        super(x, z);        this.p = p;        init(x, null, z, x.length());    }    public LegacyDropOutInverted(INDArray x, INDArray z, double p, long n) {        super(x, z, n);        this.p = p;        init(x, null, z, n);    }    @Override    public int opNum() {        return 44;    }    @Override    public String name() {        return "legacy_dropout_inverted";    }    @Override    public IComplexNumber op(IComplexNumber origin, double other) {        return null;    }    @Override    public IComplexNumber op(IComplexNumber origin, float other) {        return null;    }    @Override    public IComplexNumber op(IComplexNumber origin, IComplexNumber other) {        return null;    }    @Override    public float op(float origin, float other) {        return 0;    }    @Override    public double op(double origin, double other) {        return 0;    }    @Override    public double op(double origin) {        return 0;    }    @Override    public float op(float origin) {        return 0;    }    @Override    public IComplexNumber op(IComplexNumber origin) {        return null;    }    @Override    public Op opForDimension(int index, int dimension) {        INDArray xAlongDimension = x.vectorAlongDimension(index, dimension);        if (y() != null)            return new LegacyDropOutInverted(xAlongDimension, z.vectorAlongDimension(index, dimension), p,                            xAlongDimension.length());        else            return new LegacyDropOutInverted(xAlongDimension, z.vectorAlongDimension(index, dimension), p,                            xAlongDimension.length());    }    @Override    public Op opForDimension(int index, int... dimension) {        INDArray xAlongDimension = x.tensorAlongDimension(index, dimension);        if (y() != null)            return new LegacyDropOutInverted(xAlongDimension, z.tensorAlongDimension(index, dimension), p,                            xAlongDimension.length());        else            return new LegacyDropOutInverted(xAlongDimension, z.tensorAlongDimension(index, dimension), p,                            xAlongDimension.length());    }    @Override    public void init(INDArray x, INDArray y, INDArray z, long n) {        super.init(x, y, z, n);        this.extraArgs = new Object[] {p, (double) n};    }}

这个dropout有些难以理解,这里用单步的调试信息来查看计算流程来尝试理解:
当前程序运行的dropout的类型为DropOutInverted。此时调用的函数如下:

public DropOutInverted(@NonNull INDArray x, double p) {    this(x, x, p, x.lengthLong());}

当前输入的x的值为:
[-10.0,-9.99,-9.98,-9.97,-9.96,-9.95,-9.94,-9.93,-9.92,-9.91,-9.9,-9.89,-9.88,-9.87,-9.86,-9.85,-9.84,-9.83,-9.82,-9.81]
它的shape为[20, 1],也就是一个 20 x 1的列向量。其中调用的x.lengthLong()的值也为20。当前的p值也有改变,p值变为当前层的dropout值,即当前p = 0.5。之后调用this运行到另外一个构造函数中:

public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p, long n) {    this.p = p;    init(x, null, z, n);}

在调用到当前构造函数的时候,调用init函数,此时的 z和x是相同的值。

@Overridepublic void init(INDArray x, INDArray y, INDArray z, long n) {    super.init(x, y, z, n);    this.extraArgs = new Object[] {p};}

执行到当前步,各项参数如下:

x = [-10.00, -9.99, -9.98, -9.97, -9.96, -9.95, -9.94, -9.93, -9.92, -9.91, -9.90, -9.89, -9.88, -9.87, -9.86, -9.85, -9.84, -9.83, -9.82, -9.81]y = nullz = [-10.00, -9.99, -9.98, -9.97, -9.96, -9.95, -9.94, -9.93, -9.92, -9.91, -9.90, -9.89, -9.88, -9.87, -9.86, -9.85, -9.84, -9.83, -9.82, -9.81]n = 20p = 0.5

之后就会跳转到父类的init()方法:

@Overridepublic void init(INDArray x, INDArray y, INDArray z, long n) {    this.x = x;    this.y = y;    this.z = z;    this.n = n;}

父类方法只是对成员变量进行简单赋值。
在以上变量初始化完成之后,继续执行Nd4j.getExecutioner().exec(new DropOutInverted(input, dropout));方法。

/** * This method executes specified RandomOp using default RNG available via Nd4j.getRandom() * * @param op */@Overridepublic INDArray exec(RandomOp op) {    return exec(op, Nd4j.getRandom());}

根据注释,两个dropout类是特殊的RandomOp。之后继续调用下一个exec()方法。

/** * This method executes specific * RandomOp against specified RNG * * @param op * @param rng */@Overridepublic INDArray exec(RandomOp op, Random rng) {    if (rng.getStateBuffer() == null)        throw new IllegalStateException(                        "You should use one of NativeRandom classes for NativeOperations execution");    long st = profilingHookIn(op);    validateDataType(Nd4j.dataType(), op);    if (op.x() != null && op.y() != null && op.z() != null) {        // triple arg call        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {            loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr                            (FloatPointer) op.x().data().addressPointer(),                            (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),                            (FloatPointer) op.y().data().addressPointer(),                            (IntPointer) op.y().shapeInfoDataBuffer().addressPointer(),                            (FloatPointer) op.z().data().addressPointer(),                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),                            (FloatPointer) op.extraArgsDataBuff().addressPointer());        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {            loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr                            (DoublePointer) op.x().data().addressPointer(),                            (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),                            (DoublePointer) op.y().data().addressPointer(),                            (IntPointer) op.y().shapeInfoDataBuffer().addressPointer(),                            (DoublePointer) op.z().data().addressPointer(),                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),                            (DoublePointer) op.extraArgsDataBuff().addressPointer());        }    } else if (op.x() != null && op.z() != null) {        //double arg call        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {            loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr                            (FloatPointer) op.x().data().addressPointer(),                            (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),                            (FloatPointer) op.z().data().addressPointer(),                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),                            (FloatPointer) op.extraArgsDataBuff().addressPointer());        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {            loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr                            (DoublePointer) op.x().data().addressPointer(),                            (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),                            (DoublePointer) op.z().data().addressPointer(),                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),                            (DoublePointer) op.extraArgsDataBuff().addressPointer());        }    } else {        // single arg call        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {            loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr                            (FloatPointer) op.z().data().addressPointer(),                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),                            (FloatPointer) op.extraArgsDataBuff().addressPointer());        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {            loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr                            (DoublePointer) op.z().data().addressPointer(),                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),                            (DoublePointer) op.extraArgsDataBuff().addressPointer());        }    }    profilingHookOut(op, st);    return op.z();}

这个函数首先使用validateDataType(Nd4j.dataType(), op);用于检验当前数据类型的合法性。然后根据传入的op的三个成员变量x, y, z来判断进入哪一分支。在上面的debug信息我们可以看到,我们的x和z是两个非空变量,因此进入第二个分支,并且我们当前的Nd4j.dataType()DataBuffer.Type.FLOAT。为此在当前环境下会执行以下语句:

loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr                                (FloatPointer) op.x().data().addressPointer(),                                (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),                                (FloatPointer) op.z().data().addressPointer(),                                (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),                                (FloatPointer) op.extraArgsDataBuff().addressPointer());

然后这部分的具体实现应该是JNI调用的底层

public native void execRandomFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, FloatPointer x, IntPointer xShapeBuffer, FloatPointer z, IntPointer zShapeBuffer, FloatPointer extraArguments);

经过如上方法的的运行之后,返回z值,这时候通过debug信息看到的z值为:

[-20.00, -19.98, -19.96, 0.00, 0.00, -19.90, -19.88, 0.00, -19.84, -19.82, 0.00, 0.00, -19.76, -19.74, -19.72, -19.70, -19.68, -19.66, -19.64, 0.00]

因为在前面输入的时候z其实和x是等同的。在执行以上方法之后,相当于对x做了一个变幻。使得x变为如上的数值。到这里就使得dl4j的applyDropOutIfNecessary(training)方法部分完成,继续回到preOutput()方法体内继续执行。(这里猜测实现的方式是部分位置随机置0,然后再所有的数据除以dropout的值)

阅读全文
0 0
原创粉丝点击