python caffe 在师兄的代码上修改成自己风格的代码
来源:互联网 发布:淘宝刷评论免费送衣服 编辑:程序博客网 时间:2024/04/29 23:50
首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667
(部分)代码分为:
1 train_net.py
1 #import some module 2 import time 3 import os 4 import numpy as np 5 import sys 6 import cv2 7 sys.path.append("/home/wang/Downloads/caffe-master/python") 8 import caffe 9 #from prepare_data import DataConfig10 #from data_config import DataConfig11 12 #configure GPU mode13 ''' uncommend below line to use gpu '''14 caffe.set_mode_gpu()15 16 # about dataset17 ##dataset = Dataset('/home/wang/Downloads/object/extract/')18 ##dataset = dataset.Split('train')19 ##data_config = DataConfig(dataset)20 ##data_config.SetBatchSize(256)21 data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/'22 23 24 25 #configure solve.prototxt26 solver = caffe.SGDSolver('models/solver.prototxt')27 28 # load pretrain model29 print('load pretrain model')30 solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel')31 32 solver.net.layers[0].SetDataConfig(data_config)33 34 for i in range(1, 10000):35 # Make one SGD update36 solver.step(5)37 if i % 100 == 0:38 solver.net.save('tmp.caffemodel')39 ''' TODO: test code '''
2 test_net.py
1 #import setup 2 import time 3 import os 4 import random 5 import sys 6 sys.path.append("/home/wang/Downloads/caffe-master/python") 7 import caffe 8 import cv2 9 import numpy as np10 import random11 12 13 from utils import PrepareImage14 #from dataset import Dataset15 from test_data import test_data_pre 16 17 test_num_once=1018 19 20 ''' uncommend below line to use gpu '''21 # caffe.set_mode_gpu()22 23 # dataset24 #dataset = Dataset('/home/wang/Downloads/object/extract/')25 #dataset = dataset.Split('test')26 27 # load net28 net = caffe.Net('models/deploy.prototxt', caffe.TEST)29 30 31 # load train model32 print('load pretrain model')33 net.copy_from('tmp.caffemodel')34 35 #test all samples one by one36 data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/'37 #(imgPaths, gt_label) = dataset[int(random.random()*num_obj)]38 (imgPaths, gt_label)=test_data_pre(data_pre) 39 num_img = len(imgPaths)40 correct_num=041 for idx in range(num_img):42 img = cv2.imread(imgPaths[idx])43 img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)44 tmp_img = img.copy() # for display45 img = PrepareImage(img, (227, 227))46 net.blobs['data'].reshape(test_num_once, 3, 227, 227)47 net.blobs['data'].data[...] = img48 #net.blobs['data'].data[i,:,:,:] = img49 net.forward()50 score = net.blobs['cls_prob'].data51 if score.argmax()==gt_label[idx]:52 correct_num=correct_num+153 if idx%100==0:54 print("Please wait some minutes...")55 correct_rate=correct_num*1.0/num_img56 print('The correct rate is :',correct_rate)57 58 59
3 test_data.py
1 import os 2 import numpy as np 3 from random import randint 4 import cv2 5 from utils import PrepareImage,CatImage 6 #class data: 7 #path should be /home/ 8 def test_data_pre(path): 9 img_list=[]10 image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2')) 11 label = np.zeros(image_num, dtype=np.float32) 12 13 i=014 for idf in range(3): 15 idf_str=str(idf)16 path1=path+idf_str17 tmp_path=os.listdir(path1)18 for idi in range(len(tmp_path)): 19 img_path=path1+'/'+tmp_path[idi] 20 img_list.append(img_path)21 label[i]=idf22 i=i+123 return ( img_list,label)
4 pre_data.py
1 import os 2 import numpy as np 3 from random import randint 4 import cv2 5 from utils import PrepareImage,CatImage 6 #class data: 7 #path should be /home/ 8 def prepare_data(path,batchsize): 9 #tmp_path=os.listdir(path)10 img_list=[]11 label = np.zeros(batchsize, dtype=np.float32)12 for i in range(batchsize): 13 #randomly select one file14 idf=randint(0,2)15 idf_str=str(idf)16 path1=path+idf_str17 tmp_path=os.listdir(path1)18 19 #randomly select one image 20 idi=randint(0,len(tmp_path)-1)21 #img = cv2.imread(imgPaths[idx])22 img_path=path1+'/'+tmp_path[idi]23 img=cv2.imread(img_path)24 25 img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)26 flip = randint(0, 1)>027 if flip > 0:28 img = img[:, ::-1, :] # flip left to right29 30 img=PrepareImage(img, (227,227))31 img_list.append(img)32 label[i]=idf33 imgData = CatImage(img_list)34 return (imgData,label)
5 utils.py
1 import os 2 import cv2 3 import numpy as np 4 5 def PrepareImage(im, size): 6 im = cv2.resize(im, (size[0], size[1])) 7 im = im.transpose(2, 0, 1) 8 im = im.astype(np.float32, copy=False) 9 return im10 11 def CatImage(im_list):12 max_shape = np.array([im.shape for im in im_list]).max(axis=0)13 blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)14 # set to mean value15 blob[:, 0, :, :] = 102.980116 blob[:, 1, :, :] = 115.946517 blob[:, 2, :, :] = 122.7717 18 for i, im in enumerate(im_list):19 blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im20 return blob
6 layer/data_layer.py
1 import caffe 2 import numpy as np 3 4 #import data_config 5 #import prepare_data 6 from pre_data import prepare_data 7 8 class DataLayer(caffe.Layer): 9 10 def SetDataConfig(self, data_config):11 self._data_config = data_config12 13 def GetDataConfig(self):14 return self._data_config15 16 def setup(self, bottom, top):17 # data blob18 top[0].reshape(1, 3, 227, 227)19 #top[0].reshape(1, 3, 34, 44)20 # label type21 top[1].reshape(1, 1)22 23 def reshape(self, bootom, top):24 pass25 26 def forward(self, bottom, top):27 #(imgs, label) = self._data_config.next()28 path=self.GetDataConfig()29 (imgs,label)=prepare_data(path,128)30 (N, C, W, H) = imgs.shape31 # image data32 top[0].reshape(N, C, W, H)33 top[0].data[...] = imgs34 # object type label35 top[1].reshape(N)36 top[1].data[...] = label37 38 def backward(self, top, propagate_down, bottom):39 pass
7 layer/__init__.py
import data_layer
还有一些caffe中经典的东西没放进来。
代码和数据:
阅读全文
0 0
- python caffe 在师兄的代码上修改成自己风格的代码
- Python 的代码风格
- sublime text编辑器修改python代码的缩进设风格
- 在Cocoapods上发布自己的代码
- Python--良好的代码风格
- 在caffe上跑自己的数据
- 在caffe上跑自己的数据
- 在caffe上跑自己的数据
- 在caffe上跑自己的数据
- 在caffe上跑自己的数据
- 在caffe上跑自己的数据
- 在caffe上跑自己的数据
- 自己修改的MBProgressHUD 代码
- 代码笔记:caffe-reid自己增加的caffe.proto
- 自己的代码风格——代码注释
- 在工程中查找自己修改的所有代码
- 将自己写的Python代码打包放到PyPI上
- Github:在Github上创建自己的代码仓库
- 客户端xshell连接linux中vim不能正常使用小键盘的问题
- 对Fragment加载的理解
- 2018届校招面试知识点
- logistic回归
- 一 蓝牙概述
- python caffe 在师兄的代码上修改成自己风格的代码
- 解决jetty运行ClassNotFoundException
- 第一个ssm项目
- 超酷超炫特效动画
- MQ消息架构设计一(到底什么时候该使用MQ?)
- leetcode 160. Intersection of Two Linked Lists
- (三)u-boot启动流程分析(C语言部分board_f.c)
- linux deamon函数使用方法说明
- 矩阵的物理意义