Torch7 教程 Supervised Learning CNN

来源:互联网 发布:苹果手机mac修改器 编辑:程序博客网 时间:2024/06/06 12:40

全部代码放在:https://github.com/guoyilin/CNN_Torch7

在搭建好Torch7之后,我们开始进行监督式Supervised Learning for CNN, Torch7提供了代码和一些说明文件:

http://code.madbits.com/wiki/doku.php?id=tutorial_supervised_1_data 和http://torch.cogbits.com/doc/tutorials_supervised/说的比较详细。

结合http://ufldl.stanford.edu/wiki/index.php/Feature_extraction_using_convolution了解CNN的做法,最关键的是要熟悉http://ufldl.stanford.edu/wiki/index.php/Backpropagation_Algorithm 算法的主要做法。bp算法的目的是为了一次性计算所有的参数导数,该算法利用了chain rule进行error的后向传播。这篇文章写了bp算法: http://neuralnetworksanddeeplearning.com/chap2.html, 写的比较详细。

如果背景不熟悉,可以看看Linear Classification, Neutral Network, SGD算法。

由于该教程使用了torch自己的数据格式,因此如果你要使用自己的数据,需要预先转换下。这里我训练的是图像分类,因此可以使用

https://github.com/clementfarabet/graphicsmagick 进行数据的加载。
如下是加载图像的代码:
[plain] view plaincopy在CODE上查看代码片派生到我的代码片
  1. height = 200  
  2. width = 200  
  3. --see if the file exists  
  4. function file_exists(file)  
  5.   local f = io.open(file, "rb")  
  6.   if f then f:close() end  
  7.   return f ~= nil  
  8. end  
  9.   
  10. function read_file (file)  
  11.   if not file_exists(file) then return {} end  
  12.   lines = {}  
  13.   for line in io.lines(file) do  
  14.     lines[#lines + 1] = line  
  15.   end  
  16.   return lines  
  17. end  
  18.   
  19. -- read all label name. hash them to id.  
  20. labels_id = {}  
  21. label_lines = read_file('labels.txt')  
  22. for i = 1, #label_lines do  
  23.   labels_id[label_lines[i]] = i  
  24. end  
  25.   
  26. -- read train data. iterate train.txt  
  27.   
  28. local train_lines = read_file("train.txt")  
  29. local train_features = torch.Tensor(#train_lines, 3, height, width) -- dimension: sample number, YUV, height, width  
  30. local train_labels = torch.Tensor(#train_lines) -- dimension: sample number  
  31.   
  32. for i = 1, #train_lines do  
  33.   local image = gm.Image("/train_images/" .. train_lines[i])  
  34.   image:size(width, height)  
  35.   img_yuv = image:toTensor('float', 'YUV', 'DHW')  
  36.   --print(img_yuv:size())  
  37.   --print(img_yuv:size())  
  38.   train_features[i] = img_yuv  
  39.   local label_name = train_lines[i]:match("([^,]+)/([^,]+)")  
  40.   train_labels[i] = labels_id[label_name]  
  41.   --print(train_labels[i])  
  42.   if(i % 100 == 0) then  
  43.     print("train data: " .. i)  
  44.   end  
  45. end  
  46.   
  47. trainData = {  
  48.   data = train_features:transpose(3,4),  
  49.   labels = train_labels,  
  50.   --size = function() return #train_lines end  
  51.   size = function() return #train_lines end  
  52. }  
  53.   
  54. -- read test data. iterate test.txt  
  55. local test_lines = read_file("test.txt")  
  56.   
  57. local test_features = torch.Tensor(#test_lines, 3, height, width) -- dimension: sample number, YUV, height, width  
  58. local test_labels = torch.Tensor(#test_lines) -- dimension: sample number  
  59.   
  60. for i = 1, #test_lines do  
  61.   -- if image size is zero, gm.Imge may throw error, we need to dispose it later.  
  62.   local image = gm.Image("test_images/" .. test_lines[i])  
  63.   --print(test_lines[i])  
  64.   
  65.   image:size(width, height)  
  66.   local img_yuv = image:toTensor('float', 'YUV', 'DHW')  
  67.   --print(img_yuv:size())  
  68.   test_features[i] = img_yuv  
  69.   local label_name = test_lines[i]:match("([^,]+)/([^,]+)")  
  70.   test_labels[i] = labels_id[label_name]  
  71.   --print(test_labels[i])  
  72.   if(i % 100 == 0) then  
  73.     print("test data: " .. i)  
  74.   end  
  75. end  
  76.   
  77. testData = {  
  78.   data = test_features:transpose(3,4),  
  79.   labels = test_labels,  
  80.   --size = function() return #test_lines end  
  81.   size = function() return #test_lines end  
  82. }  
  83. trsize = #train_lines  
  84. tesize = #test_lines  

由于图像的大小从32*32变成了200*200, 因此需要修改相应的model中的每一层的大小。
假定其他层没有变化,最后一层需要修改:
[plain] view plaincopy在CODE上查看代码片派生到我的代码片
  1. -- stage 3 : standard 2-layer neural network  
  2.  model:add(nn.Reshape(nstates[2]*47*47))  
  3.  model:add(nn.Linear(nstates[2]*47*47, nstates[3]))  
  4.  model:add(nn.Tanh())  
  5.  model:add(nn.Linear(nstates[3], noutputs))  

版权声明:本文为博主原创文章,未经博主允许不得转载。

0 0
原创粉丝点击
热门问题 老师的惩罚 人脸识别 我在镇武司摸鱼那些年 重生之率土为王 我在大康的咸鱼生活 盘龙之生命进化 天生仙种 凡人之先天五行 春回大明朝 姑娘不必设防,我是瞎子 20个月宝宝发烧流鼻血怎么办 60天宝宝老鼻塞怎么办 10个月宝宝头被撞到流鼻血怎么办 狗狗受凉吐了怎么办 狗狗咳嗽流鼻涕一直不好怎么办 宝宝感冒咳嗽流鼻涕发烧怎么办 狗狗感冒咳嗽流鼻涕怎么办 9岁儿童咳嗽鼻塞怎么办 三个月大的狗狗流鼻涕怎么办 3个月小狗干呕流鼻涕怎么办 狗狗流鼻涕怎么办有浓 小狗狗感冒了怎么办呢 狗狗一直擤鼻涕怎么办 小狗感冒流黄鼻涕怎么办 六个月宝宝鼻塞流鼻涕怎么办 小狗吃太多吐了怎么办 狗狗晕车一直吐怎么办 狗狗已经晕车了怎么办 狗狗得犬瘟怎么办 泰迪坐车吐了怎么办 小孩感冒流鼻涕带血怎么办 孩子鼻子流鼻涕有血丝怎么办 鼻子过敏流鼻涕有血丝怎么办 孕妇感冒头痛鼻涕带血怎么办 孕妇感冒鼻塞鼻涕带血怎么办 孕晚期感冒流鼻涕打喷嚏怎么办 孕晚期感冒鼻塞流鼻涕怎么办 怀孕初期鼻涕一直流怎么办 孩子一直流鼻水怎么办 9个月宝宝流鼻涕怎么办 8个月婴儿流鼻涕怎么办 3岁宝宝鼻塞咳嗽怎么办 又感冒又咳嗽了怎么办 鼻塞有一个月了怎么办 感冒一直流清水鼻涕怎么办 孩子受凉流清水鼻涕怎么办 一遇冷空气就打喷嚏流鼻涕怎么办 打喷嚏鼻塞流清鼻涕怎么办 哺乳期鼻子不通气有鼻涕怎么办 宝宝热伤风流清鼻涕怎么办 哺乳期感冒流鼻涕怎么办最有效