深度学习-pipeline

来源:互联网 发布:优酷mac下的视频在哪里 编辑:程序博客网 时间:2024/06/05 04:31

这里的pipeline就是多了下载数据,解压数据,根据标签把数据分到不同目录的工作,其实也没什么新鲜的,贴代码

/** *  This code example is featured in this youtube video *  https://www.youtube.com/watch?v=ECA6y6ahH5E * ** This differs slightly from the Video Example, * The Video example had the data already downloaded * This example includes code that downloads the data * * * The Data Directory mnist_png will have two child directories training and testing//手写图片目录有两个目录,一个训练,一个测试 * The training and testing directories will have directories 0-9 with//这俩目录有0-9的28*28图片 * 28 * 28 PNG images of handwritten images * * * *  The data is downloaded from *  wget http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz *  followed by tar xzvf mnist_png.tar.gz * * * *  This examples builds on the MnistImagePipelineExample *  by adding a Neural Net */public class MnistImagePipelineExampleAddNeuralNet {    private static Logger log = LoggerFactory.getLogger(MnistImagePipelineExampleAddNeuralNet.class);    /** Data URL for downloading */从哪下的数据    public static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";    /** Location to save and extract the training/testing data *///数据位置    public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_Mnist/");    public static void main(String[] args) throws Exception {        // image information        // 28 * 28 grayscale        // grayscale implies single channel        int height = 28;//28*28的灰度图,1个过滤器,随机数生成器,128个图片为一批,10个类别,步数为1        int width = 28;        int channels = 1;        int rngseed = 123;        Random randNumGen = new Random(rngseed);        int batchSize = 128;        int outputNum = 10;        int numEpochs = 1;         /*        This class downloadData() downloads the data        stores the data in java's tmpdir        15MB download compressed        It will take 158MB of space when uncompressed        The data can be downloaded manually here        http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz         *///下面的方法下载数据到tmpdir目录,原始数据15MB解压后158MB,也可以手动下        downloadData();//调用downloadData方法下载文件目标路径        // Define the File Paths//定义训练和测试的路径        File trainData = new File(DATA_PATH + "/mnist_png/training");        File testData = new File(DATA_PATH + "/mnist_png/testing");        // Define the File Paths        //File trainData = new File("/tmp/mnist_png/training");        //File testData = new File("/tmp/mnist_png/testing");        // Define the FileSplit(PATH, ALLOWED FORMATS,random)//定义训练测试图片划分        FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS,randNumGen);        FileSplit test = new FileSplit(testData,NativeImageLoader.ALLOWED_FORMATS,randNumGen);        // Extract the parent path as the image label//用上级路径名作为标签        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();        ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker);//图片读取器        // Initialize the record reader        // add a listener, to extract the name        recordReader.initialize(train);//初始化        //recordReader.setListeners(new LogRecordListener());        // DataSet Iterator        DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);//数据迭代器        // Scale pixel values to 0-1        DataNormalization scaler = new ImagePreProcessingScaler(0,1);//规范化器        scaler.fit(dataIter);//收集统计信息        dataIter.setPreProcessor(scaler);//规范到0,1区间        // Build Our Neural Network        log.info("**** Build Model ****");//构建模型和之前的一样了,还是老套路        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()            .seed(rngseed)            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)            .iterations(1)            .learningRate(0.006)            .updater(Updater.NESTEROVS).momentum(0.9)            .regularization(true).l2(1e-4)            .list()            .layer(0, new DenseLayer.Builder()                .nIn(height * width)                .nOut(100)                .activation("relu")                .weightInit(WeightInit.XAVIER)                .build())            .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)                .nIn(100)                .nOut(outputNum)                .activation("softmax")                .weightInit(WeightInit.XAVIER)                .build())            .pretrain(false).backprop(true)            .setInputType(InputType.convolutional(height,width,channels))//setInputType的作用有3个,1.添加卷积层和神经网络层的预处理转换 2.做配置校验 3.如果有必要基于上层网络设置下层的输入数,其实主要就是设置输入层图片大小和过滤器数量            .build();        MultiLayerNetwork model = new MultiLayerNetwork(conf);        // The Score iteration Listener will log        // output to show how well the network is training        model.setListeners(new ScoreIterationListener(10));//设置评分监听器        log.info("*****TRAIN MODEL********");        for(int i = 0; i<numEpochs; i++){//定型网络            model.fit(dataIter);        }        log.info("******EVALUATE MODEL******");        recordReader.reset();//清空读取器配置        // The model trained on the training dataset split        // now that it has trained we evaluate against the        // test data of images the network has not seen        recordReader.initialize(test);//用测试数据初始化        DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);//生成测试迭代数据        scaler.fit(testIter);//获取统计信息        testIter.setPreProcessor(scaler);//规范化        /*        log the order of the labels for later use        In previous versions the label order was consistent, but random        In current verions label order is lexicographic        preserving the RecordReader Labels order is no        longer needed left in for demonstration        purposes        */        log.info(recordReader.getLabels().toString());//打印类别标签        // Create Eval object with 10 possible classes        Evaluation eval = new Evaluation(outputNum);//创建评估器        // Evaluate the network        while(testIter.hasNext()){//评估预测标签和真实标签            DataSet next = testIter.next();            INDArray output = model.output(next.getFeatureMatrix());            // Compare the Feature Matrix from the model            // with the labels from the RecordReader            eval.eval(next.getLabels(),output);        }        log.info(eval.stats());//打印评估结果    }     /*    Everything below here has nothing to do with your RecordReader,    or DataVec, or your Neural Network    The classes downloadData, getMnistPNG(),    and extractTarGz are for downloading and extracting the data     */    private static void downloadData() throws Exception {        //Create directory if required//需要的话创建目录        File directory = new File(DATA_PATH);//这个目录是C:\Users\ADMINI~1\AppData\Local\Temp\dl4j_Mnist\        if(!directory.exists()) directory.mkdir();//如果不存在创建        //Download file:        String archizePath = DATA_PATH + "/mnist_png.tar.gz";//下载路径        File archiveFile = new File(archizePath);//路径文件        String extractedPath = DATA_PATH + "mnist_png";//提取图片        File extractedFile = new File(extractedPath);//图片文件        if( !archiveFile.exists() ){//如果路径文件存在            System.out.println("Starting data download (15MB)...");            getMnistPNG();调用getMnistPNG函数下载文件            //Extract tar.gz file to output directory            extractTarGz(archizePath, DATA_PATH);//提取图片到目录        } else {//如果路径文件存在            //Assume if archive (.tar.gz) exists, then data has already been extracted            System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath());            if( !extractedFile.exists()){//如果路径文件不存在                //Extract tar.gz file to output directory                extractTarGz(archizePath, DATA_PATH);//提取图片到目录            } else {                System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());            }        }    }    private static final int BUFFER_SIZE = 4096;//缓存大小,单位是B    private static void extractTarGz(String filePath, String outputPath) throws IOException {//这相当于解压或者图片        int fileCount = 0;//初始化文件数目录数        int dirCount = 0;        System.out.print("Extracting files");        try(TarArchiveInputStream tais = new TarArchiveInputStream(//解压存档输入包装Gzip压缩输入包装缓存输入包装文件输入流,传文件路径            new GzipCompressorInputStream( new BufferedInputStream( new FileInputStream(filePath))))){            TarArchiveEntry entry;//声明解压内容变量            /** Read the tar entries using the getNextEntry method **///读解压内容使用迭代方法            while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {                //System.out.println("Extracting file: " + entry.getName());                //Create directories as required                if (entry.isDirectory()) {//如果解压是一个文件夹,创建目录,目录数自增                    new File(outputPath + entry.getName()).mkdirs();                    dirCount++;                }else {//如果是文件                    int count;                    byte data[] = new byte[BUFFER_SIZE];//缓冲数组                    FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());//声明文件输出流                    BufferedOutputStream dest = new BufferedOutputStream(fos,BUFFER_SIZE);//载入缓存输出流                    while ((count = tais.read(data, 0, BUFFER_SIZE)) != -1) {//循环读取BUFFER_SIZE大小放到data数组,并返回读取数据大小                        dest.write(data, 0, count);//输出流把缓存数组写到文件                    }                    dest.close();                    fileCount++;//累计文件数量                }                if(fileCount % 1000 == 0) System.out.print(".");//1000个文件输出一个.            }        }        System.out.println("\n" + fileCount + " files and " + dirCount + " directories extracted to: " + outputPath);//最后输出路径下有多少个目录,多少个文件    }    public static void getMnistPNG() throws IOException {        String tmpDirStr = System.getProperty("java.io.tmpdir");//拼下载的文件名        String archizePath = DATA_PATH + "/mnist_png.tar.gz";        if (tmpDirStr == null) {//没有就报错            throw new IOException("System property 'java.io.tmpdir' does specify a tmp dir");        }        String url = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";        File f = new File(archizePath);//下载的文件        File dir = new File(tmpDirStr);//下载目录        if (!f.exists()) {//如果文件不存在            HttpClientBuilder builder = HttpClientBuilder.create();//创建客户端            CloseableHttpClient client = builder.build();            try (CloseableHttpResponse response = client.execute(new HttpGet(url))) {//客户端下载                HttpEntity entity = response.getEntity();//获取内容                if (entity != null) {//如果有内容                    try (FileOutputStream outstream = new FileOutputStream(f)) {//把内容写入输出流                        entity.writeTo(outstream);                        outstream.flush();                        outstream.close();                    }                }            }            System.out.println("Data downloaded to " + f.getAbsolutePath());//打印        } else {            System.out.println("Using existing directory at " + f.getAbsolutePath());        }    }}

0 0
原创粉丝点击