torch

来源:互联网 发布:怎么查找淘宝店铺 编辑:程序博客网 时间:2024/05/17 22:44
require 'paths';require 'nn';---Load TrainSetpaths.filep("/home/xuhang/torch/myfiles/mydata/cifar10torchsmall.zip"); trainset = torch.load('/home/xuhang/torch/myfiles/mydata/cifar10-train.t7');testset = torch.load('/home/xuhang/torch/myfiles/mydata/cifar10-test.t7');classes = {'airplane', 'automobile', 'bird', 'cat',           'deer', 'dog', 'frog', 'horse', 'ship', 'truck'};---Add size() function and Tensor index operator setmetatable(trainset,     {__index = function(t, i)                     return {t.data[i], t.label[i]}                 end});trainset.data = trainset.data:double() function trainset:size()     return self.data:size(1) end---Normalize datamean = {}stdv = {}for i=1,3 do    mean[i] = trainset.data[{ {}, {i}, {}, {}  }]:mean()    print('Channel ' .. i .. ', Mean: ' .. mean[i])    trainset.data[{ {}, {i}, {}, {}  }]:add(-mean[i])    stdv[i] = trainset.data[{ {}, {i}, {}, {}  }]:std()    print('Channel ' .. i .. ', Standard Deviation:' .. stdv[i])    trainset.data[{ {}, {i}, {}, {}  }]:div(stdv[i])endnet = nn.Sequential()--change 1 channel to 3 channels--net:add(nn.SpatialConvolution(1, 6, 5, 5))net:add(nn.SpatialConvolution(3, 6, 5, 5)) net:add(nn.ReLU())                       net:add(nn.SpatialMaxPooling(2,2,2,2))     net:add(nn.SpatialConvolution(6, 16, 5, 5))net:add(nn.ReLU())                       net:add(nn.SpatialMaxPooling(2,2,2,2))net:add(nn.View(16*5*5))                    net:add(nn.Linear(16*5*5, 120))         net:add(nn.ReLU())                       net:add(nn.Linear(120, 84))net:add(nn.ReLU())                       net:add(nn.Linear(84, 10))                  net:add(nn.LogSoftMax()) criterion = nn.ClassNLLCriterion();trainer = nn.StochasticGradient(net, criterion)trainer.learningRate = 0.001trainer.maxIteration = 5trainer:train(trainset)

//test

testset.data=testset.data:double();for i=1,3 do    testset.data[{ {},{i},{},{} }]:add(-mean[i])    testset.data[{ {},{i},{},{} }]:div(stdv[i])endprint(classes[testset.label[100]])itorch.image(testset.data[100])predicted=net:forward(testset.data[100])print(predicted:exp())--gailv,label=torch.sort(predicted,true)print (gailv[1])print (label[1])
0 0
原创粉丝点击