Deeplearning4j 实战(4):Deep AutoEncoder进行Mnist压缩的Spark实现

来源:互联网 发布:手机自动录像软件 编辑:程序博客网 时间:2024/05/20 21:57

图像压缩,在图像的检索、图像传输等领域都有着广泛的应用。事实上,图像的压缩,我觉得也可以算是一种图像特征的提取方法。如果从这个角度来看的话,那么在理论上利用这些压缩后的数据去做图像的分类,图像的检索也是可以的。图像压缩的算法有很多种,这里面只说基于神经网络结构进行的图像压缩。但即使把范围限定在神经网络这个领域,其实还是有很多网络结构进行选择。比如:

1.传统的DNN,也就是加深全连接结构网络的隐层的数量,以还原原始图像为输出,以均方误差作为整个网络的优化方向。

2.DBN,基于RBM的网络栈,构成的深度置信网络,每一层RBM对数据进行压缩,以KL散度为损失函数,最后以MSE进行优化

3.VAE,变分自编码器,也是非常流行的一种网络结构。后续也会写一些自己测试的效果。

这里主要讲第二种,也就是基于深度置信网络对图像进行压缩。这种模型是一种多层RBM的结构,可以参考的论文就是G.Hinton教授的paper:《Reducing the Dimensionality of Data with Neural Network》。这里简单说下RBM原理。RBM,中文叫做受限玻尔兹曼机。所谓的受限,指的是同一层的节点之间不存在边将其相连。RBM自身分成Visible和Hidden两层。它利用输入数据本身,首先进行数据的压缩或扩展,然后再以压缩或扩展的数据为输入,以重构原始输入为目标进行反向权重的更新。因此是一种无监督的结构。如果我没记错,这种结构本身也是Hinton提出来的。将RBM进行多层的堆叠,就形成深度置信网络,用于编码或压缩的时候,被成为Deep Autoencoder。

下面就具体来说说基于开源库Deeplearning4j的Deep Autoencoder的实现,以及在Spark上进行训练的过程和结果。

1.创建Maven工程,加入Deeplearning4j的相关jar包依赖,具体如下

<properties>    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>    <nd4j.version>0.7.1</nd4j.version>  <dl4j.version>0.7.1</dl4j.version>  <datavec.version>0.7.1</datavec.version>  <scala.binary.version>2.10</scala.binary.version>  </properties>   <dependencies>   <dependency>     <groupId>org.nd4j</groupId>     <artifactId>nd4j-native</artifactId>      <version>${nd4j.version}</version>   </dependency>   <dependency>    <groupId>org.deeplearning4j</groupId>   <artifactId>dl4j-spark_2.11</artifactId>    <version>${dl4j.version}</version></dependency>   <dependency>            <groupId>org.datavec</groupId>            <artifactId>datavec-spark_${scala.binary.version}</artifactId>            <version>${datavec.version}</version>     </dependency><dependency>        <groupId>org.deeplearning4j</groupId>        <artifactId>deeplearning4j-core</artifactId>        <version>${dl4j.version}</version>     </dependency>     <dependency>    <groupId>org.nd4j</groupId>    <artifactId>nd4j-kryo_${scala.binary.version}</artifactId>    <version>${nd4j.version}</version></dependency></dependencies>


2.启动Spark任务,传入必要的参数,从HDFS上读取Mnist数据集(事先已经将数据以DataSet的形式保存在HDFS上,至于如何将Mnist数据集以DataSet的形式存储在HDFS上,之前的博客有说明,这里就直接使用了)

        if( args.length != 6 ){            System.err.println("Input Format:<inputPath> <numEpoch> <modelSavePah> <lr> <numIter> <numBatch>");            return;        }        SparkConf conf = new SparkConf()                        .set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator")                        .setAppName("Deep AutoEncoder (Java)");        JavaSparkContext jsc = new JavaSparkContext(conf);        final String inputPath = args[0];        final int numRows = 28;        final int numColumns = 28;        int seed = 123;        int batchSize = Integer.parseInt(args[5]);        int iterations = Integer.parseInt(args[4]);        final double lr = Double.parseDouble(args[3]);        //        JavaRDD<DataSet> javaRDDMnist = jsc.objectFile(inputPath);        JavaRDD<DataSet> javaRDDTrain = javaRDDMnist.map(new Function<DataSet, DataSet>() {            @Override            public DataSet call(DataSet next) throws Exception {                return new DataSet(next.getFeatureMatrix(),next.getFeatureMatrix());            }        });

由于事先我们已经将Mnist数据集以DataSet的形式序列化保存在HDFS上,因此我们一开始就直接反序列化读取这些数据并保存在RDD中就可以了。接下来,我们构建训练数据集,由于Deep Autoencoder中,是以重构输入图片为目的的,所以feature和label其实都是原始图片。此外,程序一开始的时候,就已经将学习率、迭代次数等等传进来了。

3.设计Deep Autoencoder的网络结构,具体代码如下:

       MultiLayerConfiguration netconf = new NeuralNetConfiguration.Builder()                .seed(seed)                .iterations(iterations)                .learningRate(lr)                .learningRateScoreBasedDecayRate(0.5)                .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)                .updater(Updater.ADAM).adamMeanDecay(0.9).adamVarDecay(0.999)                .list()                .layer(0, new RBM.Builder()                              .nIn(numRows * numColumns)                              .nOut(1000)                              .lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE)                              .visibleUnit(VisibleUnit.IDENTITY)                              .hiddenUnit(HiddenUnit.IDENTITY)                              .activation("relu")                              .build())                .layer(1, new RBM.Builder()                              .nIn(1000)                              .nOut(500)                              .lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE)                              .visibleUnit(VisibleUnit.IDENTITY)                              .hiddenUnit(HiddenUnit.IDENTITY)                              .activation("relu")                              .build())                .layer(2, new RBM.Builder()                              .nIn(500)                              .nOut(250)                              .lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE)                              .visibleUnit(VisibleUnit.IDENTITY)                              .hiddenUnit(HiddenUnit.IDENTITY)                              .activation("relu")                              .build())                //.layer(3, new RBM.Builder().nIn(250).nOut(100).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())                //.layer(4, new RBM.Builder().nIn(100).nOut(30).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build()) //encoding stops                //.layer(5, new RBM.Builder().nIn(30).nOut(100).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build()) //decoding starts                //.layer(6, new RBM.Builder().nIn(100).nOut(250).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())                .layer(3, new RBM.Builder()                              .nIn(250)                              .nOut(500)                              .lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE)                              .visibleUnit(VisibleUnit.IDENTITY)                              .hiddenUnit(HiddenUnit.IDENTITY)                              .activation("relu")                              .build())                .layer(4, new RBM.Builder()                              .nIn(500)                              .nOut(1000)                              .visibleUnit(VisibleUnit.IDENTITY)                              .hiddenUnit(HiddenUnit.IDENTITY)                              .activation("relu")                              .lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build())                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation("relu").nIn(1000).nOut(numRows*numColumns).build())                .pretrain(true).backprop(true)                .build();

这里需要说明下几点。第一,和Hinton老先生论文里的结构不太一样的是,我并没有把图像压缩到30维这么小。但是这肯定是可以进行尝试的。第二,Visible和Hidden的转换函数用的是Identity,而不是和论文中的Gussian和Binary。第三,学习率是可变的。在Spark集群上训练,初始的学习率可以设置得大一些,比如0.1,然后,在代码中有个机制,就是当损失函数不再下降或者下降不再明白的时候,减半学习率,也就是减小步长,试图使模型收敛得更好。第四,更新机制用的是ADAM。当然,以上这些基本都是超参数的范畴,大家可以有自己的理解和调优过程。

4.训练网络并在训练过程中进行效果的查看

      ParameterAveragingTrainingMaster trainMaster = new ParameterAveragingTrainingMaster.Builder(batchSize)                                                            .workerPrefetchNumBatches(0)                                                            .saveUpdater(true)                                                            .averagingFrequency(5)                                                            .batchSizePerWorker(batchSize)                                                            .build();        MultiLayerNetwork net = new MultiLayerNetwork(netconf);        //net.setListeners(new ScoreIterationListener(1));        net.init();        SparkDl4jMultiLayer sparkNetwork = new SparkDl4jMultiLayer(jsc, net, trainMaster);        sparkNetwork.setListeners(Collections.<IterationListener>singletonList(new ScoreIterationListener(1)));        int numEpoch = Integer.parseInt(args[1]);        for( int i = 0; i < numEpoch; ++i ){            sparkNetwork.fit(javaRDDTrain);            System.out.println("----- Epoch " + i + " complete -----");            MultiLayerNetwork trainnet = sparkNetwork.getNetwork();            System.out.println("Epoch " + i + " Score: " + sparkNetwork.getScore());            List<DataSet> listDS = javaRDDTrain.takeSample(false, 50);            for( DataSet ds : listDS ){                INDArray testFeature = ds.getFeatureMatrix();                INDArray testRes = trainnet.output(testFeature);                System.out.println("Euclidean Distance: " + testRes.distance2(testFeature));            }            DataSet first = listDS.get(0);            INDArray testFeature = first.getFeatureMatrix();            double[] doubleFeature = testFeature.data().asDouble();            INDArray testRes = trainnet.output(testFeature);            double[] doubleRes = testRes.data().asDouble();            for( int j = 0; j < doubleFeature.length && j < doubleRes.length; ++j ){                double f = doubleFeature[j];                double t = doubleRes[j];                System.out.print(f + ":" + t + "  ");            }            System.out.println();                    }
这里的逻辑其实都比较的明白。首先,申请一个参数服务对象,这个主要是用来负责对各个节点上计算的梯度进行聚合和更新,也是一种机器学习在集群上实现优化的策略。下面则是对数据集进行多轮训练,并且在每一轮训练完以后,我们随机抽样一些数据,计算他们预测的值和原始值的欧式距离。然后抽取其中一张图片,输出每个像素点,原始的值和预测的值。以此,在训练过程中,直观地评估训练的效果。当然,每一轮训练后,损失函数的得分也要打印出来看下,如果一直保持震荡下降,那么就是可以的。


5.Spark集训训练的过程和结果展示

Spark训练过程中,stage的web ui:



从图中可以看出,aggregate是做参数更新时候进行的聚合操作,这个action在基于Spark的大规模机器学习算法中也是很常用的。至于有takeSample的action,主要是之前所说的,在训练的过程中会抽取一部分数据来看效果。下面的图就是直观的比较
训练过程中,数据的直观比对


这张图是刚开始训练的时候,欧式距离会比较大,当经过100~200轮的训练后,欧式距离平均在1.0左右。也就是说,每个像素点原始值和预测值的差值在0.035左右,应该说比较接近了。最后来看下可视化界面展现的图以及他们的距离计算

原始图片和重构图片对比以及他们之间的欧式距离



第一张图左边的原始图,右边是用训练好的Deep Autoencoder预测的或者说重构的图:图有点小,不过仔细看,发现基本还是很像的,若干像素点上明暗不太一样。不过总体还算不错。下面的图,是两者欧式距离的计算,差值在1.4左右。

最后做一些回顾:
用堆叠RBM构成DBN做图像压缩,在理论上比单纯增加全连阶层的效果应该会好些,毕竟每一层RBM本身可以利用自身可以重构输入数据的特点进行更为有效的压缩。从实际的效果来看,应该也是还算看得过去。其实图像压缩本身如果足够高效,那么对图像检索的帮助也是很大。所以Hinton老先生的一篇论文就是利用Deep AutoEncoder对图像进行压缩后再进行检索,论文中把这个效果和用欧式距离还有PCA提取的图片特征进行了比较,论文中的结果是用Deep AutoEncoder的进行压缩后在做检索的效果最佳。不过,这里还是得说明,在论文中RBM的Hidden的转换函数是binary,因为作者希望压缩出来的结果是0,1二进制的。这样,检索图片的时候,计算Hamming距离就可以了。而且这样即使以后图片的数量急剧增加,检索的时间不会显著增加,因为计算Hamming距离可以说计算机是非常快的,底层电路做异或运算就可以了。但是,我自己觉得,虽然压缩成二进制是个好方法,检索时间也很短。但是二进制的表现力是否有所欠缺呢?毕竟非0即1,和用浮点数表示的差别,表现力上面应该是差蛮多的。所以,具体是否可以在图像检索系统依赖这样的方式,还有待进一步实验。另外就是,上面在构建多层RBM的时候,其实有很多超参数可以调整,包括可以增加RBM的层数,来做进一步的压缩等等,就等有时间再慢慢研究了。还有,Spark提交的命令这里没有写,不过在只之前的文章里有提到,需要的同学可以参考。至于模型的保存,都有相应的接口可以调用,这里就不赘述了。。。

0 0
原创粉丝点击