Dl4j-fit(DataSetIterator iterator)源码阅读(四)dropout
来源:互联网 发布:数据互通的手游 编辑:程序博客网 时间:2024/04/28 10:32
preOut这一部分就是网络模型前向传播的重点。
public INDArray preOutput(boolean training) { applyDropOutIfNecessary(training); INDArray b = getParam(DefaultParamInitializer.BIAS_KEY); INDArray W = getParam(DefaultParamInitializer.WEIGHT_KEY); //Input validation: if (input.rank() != 2 || input.columns() != W.rows()) { if (input.rank() != 2) { throw new DL4JInvalidInputException("Input that is not a matrix; expected matrix (rank 2), got rank " + input.rank() + " array with shape " + Arrays.toString(input.shape())); } throw new DL4JInvalidInputException("Input size (" + input.columns() + " columns; shape = " + Arrays.toString(input.shape()) + ") is invalid: does not match layer input size (layer # inputs = " + W.size(0) + ")"); } if (conf.isUseDropConnect() && training && conf.getLayer().getDropOut() > 0) { W = Dropout.applyDropConnect(this, DefaultParamInitializer.WEIGHT_KEY); } INDArray ret = input.mmul(W).addiRowVector(b); if (maskArray != null) { applyMask(ret); } return ret;}
首先使用applyDropOutIfNecessary(training);
函数判断当前是否使用dropout。
protected void applyDropOutIfNecessary(boolean training) { if (conf.getLayer().getDropOut() > 0 && !conf.isUseDropConnect() && training && !dropoutApplied) { input = input.dup(); Dropout.applyDropout(input, conf.getLayer().getDropOut()); dropoutApplied = true; }}
使用dropout的条件如下:
- 当前层设置 dropout > 0
- 当前配置没有使用dropConnect(), 这一配置在卷积神经网络常见。
- 当前是训练过程,也就是training的值为true。 在预测的时候dropout不会被应用
- dropout在之前没有被调用。
如果以上条件都满足,则先对当前的输入使用dup()
函数进行复制(注:dup取自单词duplicate,复制的意思),然后传入下一个函数。
/** 5. Apply dropout to the given input 6. and return the drop out mask used 7. @param input the input to do drop out on 8. @param dropout the drop out probability */public static void applyDropout(INDArray input, double dropout) { if (Nd4j.getRandom().getStatePointer() != null) { Nd4j.getExecutioner().exec(new DropOutInverted(input, dropout)); } else { Nd4j.getExecutioner().exec(new LegacyDropOutInverted(input, dropout)); }}
dropout的实现方式很多,根据这个源码阅读方式发现,dl4j的dropout实现方式是根据截断当前层的输入来实现drpout。
/** 9. This method returns pointer to RNG state structure. 10. Please note: DefaultRandom implementation returns NULL here, making it impossible to use with RandomOps 11. - @return */@Overridepublic Pointer getStatePointer() { return statePointer;}
这个getStatePointer()的目的从代码的注释情况上来还不是很清楚。接下来查看两种实现方式
- DropOutInverted
/** * Inverted DropOut implementation as Op * * @author raver119@gmail.com */public class DropOutInverted extends BaseRandomOp { private double p; public DropOutInverted() { } public DropOutInverted(@NonNull INDArray x, double p) { this(x, x, p, x.lengthLong()); } public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p) { this(x, z, p, x.lengthLong()); } public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p, long n) { this.p = p; init(x, null, z, n); } @Override public int opNum() { return 2; } @Override public String name() { return "dropout_inverted"; } @Override public void init(INDArray x, INDArray y, INDArray z, long n) { super.init(x, y, z, n); this.extraArgs = new Object[] {p}; }}
- LegacyDropOutInverted
/** * Inverted DropOut implementation as Op * * PLEASE NOTE: This is legacy DropOutInverted implementation, please consider using op with the same name from randomOps * @author raver119@gmail.com */public class LegacyDropOutInverted extends BaseTransformOp { private double p; public LegacyDropOutInverted() { } public LegacyDropOutInverted(INDArray x, double p) { super(x); this.p = p; init(x, null, x, x.length()); } public LegacyDropOutInverted(INDArray x, INDArray z, double p) { super(x, z); this.p = p; init(x, null, z, x.length()); } public LegacyDropOutInverted(INDArray x, INDArray z, double p, long n) { super(x, z, n); this.p = p; init(x, null, z, n); } @Override public int opNum() { return 44; } @Override public String name() { return "legacy_dropout_inverted"; } @Override public IComplexNumber op(IComplexNumber origin, double other) { return null; } @Override public IComplexNumber op(IComplexNumber origin, float other) { return null; } @Override public IComplexNumber op(IComplexNumber origin, IComplexNumber other) { return null; } @Override public float op(float origin, float other) { return 0; } @Override public double op(double origin, double other) { return 0; } @Override public double op(double origin) { return 0; } @Override public float op(float origin) { return 0; } @Override public IComplexNumber op(IComplexNumber origin) { return null; } @Override public Op opForDimension(int index, int dimension) { INDArray xAlongDimension = x.vectorAlongDimension(index, dimension); if (y() != null) return new LegacyDropOutInverted(xAlongDimension, z.vectorAlongDimension(index, dimension), p, xAlongDimension.length()); else return new LegacyDropOutInverted(xAlongDimension, z.vectorAlongDimension(index, dimension), p, xAlongDimension.length()); } @Override public Op opForDimension(int index, int... dimension) { INDArray xAlongDimension = x.tensorAlongDimension(index, dimension); if (y() != null) return new LegacyDropOutInverted(xAlongDimension, z.tensorAlongDimension(index, dimension), p, xAlongDimension.length()); else return new LegacyDropOutInverted(xAlongDimension, z.tensorAlongDimension(index, dimension), p, xAlongDimension.length()); } @Override public void init(INDArray x, INDArray y, INDArray z, long n) { super.init(x, y, z, n); this.extraArgs = new Object[] {p, (double) n}; }}
这个dropout有些难以理解,这里用单步的调试信息来查看计算流程来尝试理解:
当前程序运行的dropout的类型为DropOutInverted
。此时调用的函数如下:
public DropOutInverted(@NonNull INDArray x, double p) { this(x, x, p, x.lengthLong());}
当前输入的x的值为:
[-10.0,-9.99,-9.98,-9.97,-9.96,-9.95,-9.94,-9.93,-9.92,-9.91,-9.9,-9.89,-9.88,-9.87,-9.86,-9.85,-9.84,-9.83,-9.82,-9.81]
它的shape为[20, 1],也就是一个 20 x 1的列向量。其中调用的x.lengthLong()
的值也为20。当前的p值也有改变,p值变为当前层的dropout值,即当前p = 0.5
。之后调用this运行到另外一个构造函数中:
public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p, long n) { this.p = p; init(x, null, z, n);}
在调用到当前构造函数的时候,调用init函数,此时的 z和x是相同的值。
@Overridepublic void init(INDArray x, INDArray y, INDArray z, long n) { super.init(x, y, z, n); this.extraArgs = new Object[] {p};}
执行到当前步,各项参数如下:
x = [-10.00, -9.99, -9.98, -9.97, -9.96, -9.95, -9.94, -9.93, -9.92, -9.91, -9.90, -9.89, -9.88, -9.87, -9.86, -9.85, -9.84, -9.83, -9.82, -9.81]y = nullz = [-10.00, -9.99, -9.98, -9.97, -9.96, -9.95, -9.94, -9.93, -9.92, -9.91, -9.90, -9.89, -9.88, -9.87, -9.86, -9.85, -9.84, -9.83, -9.82, -9.81]n = 20p = 0.5
之后就会跳转到父类的init()方法:
@Overridepublic void init(INDArray x, INDArray y, INDArray z, long n) { this.x = x; this.y = y; this.z = z; this.n = n;}
父类方法只是对成员变量进行简单赋值。
在以上变量初始化完成之后,继续执行Nd4j.getExecutioner().exec(new DropOutInverted(input, dropout));
方法。
/** * This method executes specified RandomOp using default RNG available via Nd4j.getRandom() * * @param op */@Overridepublic INDArray exec(RandomOp op) { return exec(op, Nd4j.getRandom());}
根据注释,两个dropout类是特殊的RandomOp。之后继续调用下一个exec()
方法。
/** * This method executes specific * RandomOp against specified RNG * * @param op * @param rng */@Overridepublic INDArray exec(RandomOp op, Random rng) { if (rng.getStateBuffer() == null) throw new IllegalStateException( "You should use one of NativeRandom classes for NativeOperations execution"); long st = profilingHookIn(op); validateDataType(Nd4j.dataType(), op); if (op.x() != null && op.y() != null && op.z() != null) { // triple arg call if (Nd4j.dataType() == DataBuffer.Type.FLOAT) { loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr (FloatPointer) op.x().data().addressPointer(), (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.y().data().addressPointer(), (IntPointer) op.y().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.extraArgsDataBuff().addressPointer()); } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) { loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr (DoublePointer) op.x().data().addressPointer(), (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(), (DoublePointer) op.y().data().addressPointer(), (IntPointer) op.y().shapeInfoDataBuffer().addressPointer(), (DoublePointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), (DoublePointer) op.extraArgsDataBuff().addressPointer()); } } else if (op.x() != null && op.z() != null) { //double arg call if (Nd4j.dataType() == DataBuffer.Type.FLOAT) { loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr (FloatPointer) op.x().data().addressPointer(), (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.extraArgsDataBuff().addressPointer()); } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) { loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr (DoublePointer) op.x().data().addressPointer(), (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(), (DoublePointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), (DoublePointer) op.extraArgsDataBuff().addressPointer()); } } else { // single arg call if (Nd4j.dataType() == DataBuffer.Type.FLOAT) { loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr (FloatPointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.extraArgsDataBuff().addressPointer()); } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) { loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr (DoublePointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), (DoublePointer) op.extraArgsDataBuff().addressPointer()); } } profilingHookOut(op, st); return op.z();}
这个函数首先使用validateDataType(Nd4j.dataType(), op);
用于检验当前数据类型的合法性。然后根据传入的op的三个成员变量x, y, z来判断进入哪一分支。在上面的debug信息我们可以看到,我们的x和z是两个非空变量,因此进入第二个分支,并且我们当前的Nd4j.dataType()
为DataBuffer.Type.FLOAT
。为此在当前环境下会执行以下语句:
loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr (FloatPointer) op.x().data().addressPointer(), (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.extraArgsDataBuff().addressPointer());
然后这部分的具体实现应该是JNI调用的底层
public native void execRandomFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, FloatPointer x, IntPointer xShapeBuffer, FloatPointer z, IntPointer zShapeBuffer, FloatPointer extraArguments);
经过如上方法的的运行之后,返回z值,这时候通过debug信息看到的z值为:
[-20.00, -19.98, -19.96, 0.00, 0.00, -19.90, -19.88, 0.00, -19.84, -19.82, 0.00, 0.00, -19.76, -19.74, -19.72, -19.70, -19.68, -19.66, -19.64, 0.00]
因为在前面输入的时候z其实和x是等同的。在执行以上方法之后,相当于对x做了一个变幻。使得x变为如上的数值。到这里就使得dl4j的applyDropOutIfNecessary(training)
方法部分完成,继续回到preOutput()
方法体内继续执行。(这里猜测实现的方式是部分位置随机置0,然后再所有的数据除以dropout的值)
- Dl4j-fit(DataSetIterator iterator)源码阅读(四)dropout
- Dl4j-fit(DataSetIterator iterator)源码阅读(一)
- Dl4j-fit(DataSetIterator iterator)源码阅读(二)
- Dl4j-fit(DataSetIterator iterator)源码阅读(三)
- Dl4j-fit(DataSetIterator iterator)源码阅读(五)正向传播
- Dl4j-fit(DataSetIterator iterator)源码阅读(六) 反向传播部分
- Dl4j-fit(DataSetIterator iterator)源码阅读(七) 损失函数得分计算
- Dl4j-fit(DataSetIterator iterator)源码阅读(八) 根据参数更新梯度
- Dl4j-fit(DataSetIterator iterator)源码阅读(九) 利用梯度更新参数
- dl4j源码阅读心得及问题(Spark部分)
- STL源码阅读-iterator
- WINVNC源码阅读(四)
- SDWebImage 源码阅读(四)
- Java8 Iterator接口源码阅读
- Argo源码阅读(四):Servlet过滤器
- Gaea源码阅读(四):服务端通讯
- spring源码阅读(四)之BeanFactory
- Tomcat源码阅读(四)Server
- 第十二周项目一:验证算法(1)
- 【bzoj 3262】陌上花开(CDQ分治)
- CSS-定位属性
- IIS 静态页面网站搭建
- FreeRtos osMessagePut osMessageGet 函数
- Dl4j-fit(DataSetIterator iterator)源码阅读(四)dropout
- 使用axis2解析wsdl生成Webservice客户端代码
- Spring MVC学习笔记:helloworld的实现+@RequestMapping 的使用(上集)
- 欢迎使用CSDN-markdown编辑器
- centos7 python3.X django mysql 安装部署
- 手把手教你配置一个强大的Vim
- android实现换头像功能
- ftl提示 eclipse freemarker ide 插件安装
- 【BZOJ】4300 绝世好题 DP