初识caffe之python mnist训练

来源:互联网 发布:织梦58公司 编辑:程序博客网 时间:2024/05/17 01:50

1、下载数据

caffe源码中提供了下载mnist数据的脚本get_mnist.sh,但是因为是windows环境下,脚本无法运行,只能手动下载。进入caffe_root\data\mnist目录,打开get_mnist.sh脚本,如下:

#!/usr/bin/env sh
# This scripts downloads the mnist data and unzips it.


DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"


echo "Downloading..."


for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
do
    if [ ! -e $fname ]; then
        wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz
        gunzip ${fname}.gz
    fi
done

根据脚本提供的线索现在数据即可。

不过这样的数据是不能够直接使用的,需要将他们转换为lmdb或leveldb格式才可以导入caffe进行训练和识别,那么如何转换数据格式呢?下面来解决这个问题:

在路径caffe_root\examples\mnist路径下面有一个脚本 create_mnist.sh。它就是用来将原始数据转换为lmdb或leveldb的。因为我用的是转lmdb格式,而且caffe默认的也是lmdb格式,因此这里就以lmdb格式为例子进行描述。

同样的,因为create_mnist.sh为shell脚本,不能在windows下 运行,因此得想办法解决。我的办法是打开看看,脚本内容如下:

#!/usr/bin/env sh
# This script converts the mnist data into lmdb/leveldb format,
# depending on the value assigned to $BACKEND.
set -e


EXAMPLE=examples/mnist
DATA=data/mnist
BUILD=build/examples/mnist


BACKEND="lmdb"


echo "Creating ${BACKEND}..."


rm -rf $EXAMPLE/mnist_train_${BACKEND}
rm -rf $EXAMPLE/mnist_test_${BACKEND}


$BUILD/convert_mnist_data.bin $DATA/train-images-idx3-ubyte \
  $DATA/train-labels-idx1-ubyte $EXAMPLE/mnist_train_${BACKEND} --backend=${BACKEND}
$BUILD/convert_mnist_data.bin $DATA/t10k-images-idx3-ubyte \
  $DATA/t10k-labels-idx1-ubyte $EXAMPLE/mnist_test_${BACKEND} --backend=${BACKEND}


echo "Done."

容易看出,该shell脚本是调用了convert_mnist_data.bin应用来进行转换的,而这个应用同样不能在windows下运行。为了解决这样问题,我找到了生成convert_mnist_data.bin的源码,即在caffe_root\examples\mnist目录下的convert_mnist_data.cpp程序。程序很长,要读懂并不简单,好的是我们并不需要完全读懂它,只要注意到注释会用就好了:

// Usage:
//    convert_mnist_data [FLAGS] input_image_file input_label_file
//                        output_db_file

这两行注释很好的解释了函数的使用方法。既然知道了调用方法我们就要开始调用它生成lmdb数据了:

要运行它就得先编译,为了它再去配工程 显然是不可取的。其实在编译caffe的时候它已经被便宜过了的。就在caffe_root\build\install\bin目录下面,有一个convert_mnist_data.ext可执行文件。下面就来生产lmdb数据:

cmd进入caffe_root目录,执行如下命令:

build\install\bin\convert_mnist_data.exe caffe_root\\data\\mnist\\train-images.idx3-ubyte  caffe_root\\data\\mnist\\train-images.idx1-ubyte caffe_root\\examples\\mnist\\mnist_train_lmdb

build\install\bin\convert_mnist_data.exe caffe_root\\data\\mnist\\t10k-images.idx3-ubyte  caffe_root\\data\\mnist\\t10k-images.idx1-ubyte caffe_root\\examples\\mnist\\mnist_test_lmdb

如上两行命令分别生产了训练和测试数据。

以下的工作就是需要进行训练和测试了。训练和测试我是根据博客http://www.cnblogs.com/linyuanzhou/p/6012231.html来进行的,主要的代码就是:

import caffecaffe.set_mode_cpu()solver = caffe.SGDSolver('examples/mnist/lenet_solver.prototxt')solver.solve()

代码虽然只有四句话,坑可是不少啊。主要的问题是因为,linux路径和windos路径不兼容的问题。需要修改包含路径的代码。第一个要改的就是上面代码段中的第三行,需要将

'examples/mnist/lenet_solver.prototxt'改为caffe_root\\examples\\mnist\\lenet_solver.prototxt
然后进入
lenet_solver.prototxt文件,发现里面有net: "examples/mnist/lenet_train_test.prototxt"需要该为net: "D:\\caffe_windows_cpu\\examples\\mnist\\lenet_train_test.prototxt"
显然它调用了lenet_train_test.prototxt文件,再进入lenet_train_test.prototxt文件,将里面source路径进行更改。分别为训练路径该为:D:\\caffe_windows_cpu\\examples\\mnist\\mnist_train_lmdb
测试路径改为:D:\\caffe_windows_cpu\\examples\\mnist\\mnist_test_lmdb
至此,配置基本完成,可以进行训练了。进入python环境,输入如下代码,即可以开始训练:
import caffecaffe.set_mode_cpu()solver = caffe.SGDSolver('D:\\caffe_windows_cpu\\examples\\mnist\\lenet_solver.prototxt')solver.solve()
训练完成可以看到,准确率为0.99左右。

0 0
原创粉丝点击