jupyter notebook test

来源:互联网 发布:php web上传 编辑:程序博客网 时间:2024/05/01 18:01
import os
import numpy as np
import scipy.io as sio
import Image
caffe_root = '../../../../'
import sys
sys.path.insert(0, caffe_root + 'python')
import caffe
data_root = '../../../../data/A_resize/'
test_dir = data_root+'test_data/images/'
gt_dir=data_root+'test_data/ground_truth/'
pathdir=os.listdir(test_dir)
error_all=0
error=np.zeros((1,len(pathdir)),dtype=np.float64)
pred=np.zeros((1,len(pathdir)),dtype=np.float64)
gt=np.zeros((1,len(pathdir)),dtype=np.float64)
i=0
pa=[]
for allpath in pathdir:
  img_file_name =allpath
  pa.append(allpath)
  img = Image.open(test_dir+img_file_name)
  img = np.array(img, dtype=np.float32)
  if len(img.shape)==2 :
    img2=np.empty((img.shape[0],img.shape[1],3),dtype=img.dtype)
    img2[:,:,0]=img
    img2[:,:,1]=img
    img2[:,:,2]=img
    img=img2
  img = img[:,:,::-1]
  img = img.transpose((2,0,1))
  caffe.set_mode_gpu()
  caffe.set_device(0)
  model_root = './'
  net = caffe.Net(model_root+'deploy.prototxt','./network_iter_4320000.caffemodel', caffe.TEST)
  # shape for input (data blob is N x C x H x W), set data
  net.blobs['data'].reshape(1, *img.shape)
  net.blobs['data'].data[...] = img
  # run net and take argmax for prediction
  net.forward()
  pred[0,i]=np.sum(net.blobs['score'].data[0][0,:,:])
  label_name='GT_'+img_file_name[:-3]+'mat'
  label_path=os.path.join(gt_dir,label_name)
  label=sio.loadmat(label_path)
  gt[0,i]=label['image_info'][0,0]['number']
  error[0,i]=abs(gt[0,i]-pred[0,i])
  error_all=error_all+error[0,i]
  i=i+1
0 0
原创粉丝点击