Torch7学习笔记[2] ---神经网络的主体框架

来源:互联网 发布:servlet修改表单数据 编辑:程序博客网 时间:2024/05/18 12:43

参考资料:
https://github.com/soumith/cvpr2015/blob/master/Deep%20Learning%20with%20Torch.ipynb
将整个框架分为以下几个模块:
1、options设置
2、train、test预处理以及读取
3、net结构以及criterion建立
4、train设置
5、test设置
6、保存中间结果以及断点开始(待完善)
y以上每个功能模块单独由一个文件完成,整个结构分为7个文件
main.lua
opt.lua
dataloder.lua
model.lua
train.lua
test.lua
checkpont.lua(待完善)

require 'torch'require 'nn'require 'optim'local DataLoder = require 'dataloder'  --load the dataloder.lualocal opts = require 'opt'local Model = require 'model'local Test = require 'test'local checkpoints = require 'checkpoint'local Trainer = require 'train'torch.setdefaulttensortype = ('torch.FloatTensor')  --torch.setnumthreads(1)torch.manualSeed(opt.manualSeed)cutorch.manualSeedAll(opt.manualSeed)local opt = opts.parse(arg)  --load the optionslocal trainset,testset = DataLoder.creat(opt) --load the datasetlocal model,criterion = Model.setup(opt) --load the model,criterionif(opt.type == 'cuda')   then  --turn on gpu:model-criterion-data-label    model = model:cuda()     criterion = criterion:cuda()    trainset.data = trainset.data:cuda()    trainset.label = trainset.label:cuda()    testset.data = testset.data:cuda()    testset.label = testset.label:cuda()endfunction trainset:size() --prepare for training     return self.data:size(1) endlocal trainer = Trainer(model,criterion,opt)bestModel = falsefor epoch = 1,opt.max_epoch do    local current_error = trainer:train(epoch,trainset)    --save the current station    --checkpoints.save(epoch, model, trainer.optimState, bestModel, opt)endlocal correct_rate = Test.run(opt,testset,model)print(correct_rate)

运行程序时,直接在文件所在目录终端执行:th main.lua 即可运行程序。若需改变options,例如gpu运行:th main.lua –type cuda

0 0
原创粉丝点击