python caffe 在师兄的代码上修改成自己风格的代码

来源:互联网 发布:淘宝刷评论免费送衣服 编辑:程序博客网 时间:2024/04/29 23:50

首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667

(部分)代码分为:

1 train_net.py

复制代码
 1 #import some module 2 import time 3 import os 4 import numpy as np 5 import sys 6 import cv2 7 sys.path.append("/home/wang/Downloads/caffe-master/python") 8 import caffe 9 #from prepare_data import DataConfig10 #from data_config import DataConfig11 12 #configure GPU mode13 ''' uncommend below line to use gpu '''14 caffe.set_mode_gpu()15 16 # about dataset17 ##dataset = Dataset('/home/wang/Downloads/object/extract/')18 ##dataset = dataset.Split('train')19 ##data_config = DataConfig(dataset)20 ##data_config.SetBatchSize(256)21 data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/'22 23 24 25 #configure solve.prototxt26 solver = caffe.SGDSolver('models/solver.prototxt')27 28 # load pretrain model29 print('load pretrain model')30 solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel')31 32 solver.net.layers[0].SetDataConfig(data_config)33 34 for i in range(1, 10000):35     # Make one SGD update36     solver.step(5)37     if i % 100 == 0:38         solver.net.save('tmp.caffemodel')39         ''' TODO:  test code '''  
复制代码

2 test_net.py

复制代码
 1 #import setup 2 import time 3 import os 4 import random 5 import sys 6 sys.path.append("/home/wang/Downloads/caffe-master/python") 7 import caffe 8 import cv2 9 import numpy as np10 import random11 12 13 from utils import PrepareImage14 #from dataset import Dataset15 from test_data import test_data_pre 16 17 test_num_once=1018 19 20 ''' uncommend below line to use gpu '''21 # caffe.set_mode_gpu()22 23 # dataset24 #dataset = Dataset('/home/wang/Downloads/object/extract/')25 #dataset = dataset.Split('test')26 27 # load net28 net = caffe.Net('models/deploy.prototxt', caffe.TEST)29 30 31 # load train model32 print('load pretrain model')33 net.copy_from('tmp.caffemodel')34 35 #test all samples one by one36 data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/'37 #(imgPaths, gt_label) = dataset[int(random.random()*num_obj)]38 (imgPaths, gt_label)=test_data_pre(data_pre) 39 num_img = len(imgPaths)40 correct_num=041 for idx in range(num_img):42     img = cv2.imread(imgPaths[idx])43     img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)44     tmp_img = img.copy() # for display45     img = PrepareImage(img, (227, 227))46     net.blobs['data'].reshape(test_num_once, 3, 227, 227)47     net.blobs['data'].data[...] = img48     #net.blobs['data'].data[i,:,:,:] = img49     net.forward()50     score = net.blobs['cls_prob'].data51     if score.argmax()==gt_label[idx]:52         correct_num=correct_num+153     if idx%100==0:54         print("Please wait some minutes...")55 correct_rate=correct_num*1.0/num_img56 print('The correct rate is :',correct_rate)57 58 59     
复制代码

3 test_data.py

复制代码
 1 import os 2 import numpy as np 3 from random import randint 4 import cv2 5 from utils import PrepareImage,CatImage 6 #class data: 7 #path should be /home/ 8 def test_data_pre(path): 9     img_list=[]10     image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2'))  11     label = np.zeros(image_num, dtype=np.float32)  12 13     i=014     for idf in range(3): 15         idf_str=str(idf)16         path1=path+idf_str17         tmp_path=os.listdir(path1)18         for idi in range(len(tmp_path)):   19             img_path=path1+'/'+tmp_path[idi] 20             img_list.append(img_path)21             label[i]=idf22             i=i+123     return ( img_list,label)
复制代码

 

4 pre_data.py

复制代码
 1 import os 2 import numpy as np 3 from random import randint 4 import cv2 5 from utils import PrepareImage,CatImage 6 #class data: 7 #path should be /home/ 8 def prepare_data(path,batchsize): 9     #tmp_path=os.listdir(path)10     img_list=[]11     label = np.zeros(batchsize, dtype=np.float32)12     for i in range(batchsize): 13         #randomly select one file14         idf=randint(0,2)15         idf_str=str(idf)16         path1=path+idf_str17         tmp_path=os.listdir(path1)18         19         #randomly select one image    20         idi=randint(0,len(tmp_path)-1)21         #img = cv2.imread(imgPaths[idx])22         img_path=path1+'/'+tmp_path[idi]23         img=cv2.imread(img_path)24 25         img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)26         flip = randint(0, 1)>027         if flip > 0:28             img = img[:, ::-1, :] # flip left to right29  30         img=PrepareImage(img, (227,227))31         img_list.append(img)32         label[i]=idf33     imgData = CatImage(img_list)34     return (imgData,label)
复制代码

5 utils.py

复制代码
 1 import os 2 import cv2 3 import numpy as np 4  5 def PrepareImage(im, size): 6     im = cv2.resize(im, (size[0], size[1])) 7     im = im.transpose(2, 0, 1) 8     im = im.astype(np.float32, copy=False) 9     return im10 11 def CatImage(im_list):12     max_shape = np.array([im.shape for im in im_list]).max(axis=0)13     blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)14     # set to mean value15     blob[:, 0, :, :] = 102.980116     blob[:, 1, :, :] = 115.946517     blob[:, 2, :, :] = 122.7717 18     for i, im in enumerate(im_list):19         blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im20     return blob
复制代码

6 layer/data_layer.py

复制代码
 1 import caffe 2 import numpy as np 3  4 #import data_config 5 #import prepare_data 6 from pre_data import prepare_data 7  8 class DataLayer(caffe.Layer): 9 10     def SetDataConfig(self, data_config):11         self._data_config = data_config12 13     def GetDataConfig(self):14         return self._data_config15 16     def setup(self, bottom, top):17         # data blob18         top[0].reshape(1, 3, 227, 227)19         #top[0].reshape(1, 3, 34, 44)20         # label type21         top[1].reshape(1, 1)22 23     def reshape(self, bootom, top):24         pass25 26     def forward(self, bottom, top):27         #(imgs, label) = self._data_config.next()28         path=self.GetDataConfig()29         (imgs,label)=prepare_data(path,128)30         (N, C, W, H) = imgs.shape31         # image data32         top[0].reshape(N, C, W, H)33         top[0].data[...] = imgs34         # object type label35         top[1].reshape(N)36         top[1].data[...] = label37 38     def backward(self, top, propagate_down, bottom):39         pass
复制代码

7 layer/__init__.py

import data_layer

还有一些caffe中经典的东西没放进来。

代码和数据:

 

原创粉丝点击