caffe的简单python脚本

来源:互联网 发布:js闪动文字 编辑:程序博客网 时间:2024/06/07 04:05

caffe的简单python脚本

#-*-coding:utf-8-*-import caffeimport matplotlib.pyplot as plt#网络构建caffe.set_mode_cpu() # 设置caffe为cpu模式,也可设成gpu模式model_def = 'my_alexnet_test.prototxt' model_weights =  'models_rist1/lrcn_four_words_iter_10000.caffemodel'net = caffe.Net(model_def,      # 定义模型结构                 model_weights,  # 包含模型训练权重                caffe.TEST)     # 使用测试模式(训练中不能执行dropout)#读取图像image = caffe.io.load_image('./images/COCO_train2014_000000246146.jpg') #图像处理transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})transformer.set_transpose('data', (2,0,1))    # 将图像通道数设置为outermost的维数transformer.set_raw_scale('data', 255)        # 像素值从[0,1]变换为[0,255]transformer.set_channel_swap('data', (2,1,0)) # 交换通道,RGB->BGR#transformer.set_mean('data', channel_mean)transformed_image = transformer.preprocess('data', image) #图像可视化plt.imshow(image)plt.show()#前向net.blobs['data'].data[...] = transformed_imageoutput = net.forward() net.forward()# 循环打印每一层名字和相应维度#for layer_name, blob in net.blobs.iteritems():#    print layer_name + '\t' + str(blob.data.shape)#打印输出#for layer_name, param in net.params.iteritems():#    print layer_name + '\t' + str(param[0].data.shape), str(param[1].data.shape) #print '**********'#filters = net.params['conv1'][0].data#打印层的输出print transformed_imageprint '**********'#print transformed_image#feat = net.blobs['conv1'].data[0, :36]#print feat 


原创粉丝点击