使用pytorch预训练模型分类与特征提取
来源:互联网 发布:非广播多路访问网络 编辑:程序博客网 时间:2024/06/05 23:01
pytorch应该是深度学习框架里面比较好使用的了,相比于tensorflow,mxnet。可能在用户上稍微少一点,有的时候出问题不好找文章。下面就使用pytorch预训练模型做分类和特征提取,pytorch文档可以参考:pytorch docs , 模型是imagenet2012训练的标签可参考:imagenet2012 labels ,模型预测的下标按从上到下,起始(n01440764)为0
#encoding=utf-8import osimport numpy as npimport torchimport torch.nnimport torchvision.models as modelsfrom torch.autograd import Variable import torch.cudaimport torchvision.transforms as transformsfrom PIL import Imageimg_to_tensor = transforms.ToTensor()def make_model(): resmodel=models.resnet34(pretrained=True) resmodel.cuda()#将模型从CPU发送到GPU,如果没有GPU则删除该行 return resmodel#分类def inference(resmodel,imgpath): resmodel.eval()#必需,否则预测结果是错误的 img=Image.open(imgpath) img=img.resize((224,224)) tensor=img_to_tensor(img) tensor=tensor.resize_(1,3,224,224) tensor=tensor.cuda()#将数据发送到GPU,数据和模型在同一个设备上运行 result=resmodel(Variable(tensor)) result_npy=result.data.cpu().numpy()#将结果传到CPU,并转换为numpy格式 max_index=np.argmax(result_npy[0]) return max_index #特征提取def extract_feature(resmodel,imgpath): resmodel.fc=torch.nn.LeakyReLU(0.1) resmodel.eval() img=Image.open(imgpath) img=img.resize((224,224)) tensor=img_to_tensor(img) tensor=tensor.resize_(1,3,224,224) tensor=tensor.cuda() result=resmodel(Variable(tensor)) result_npy=result.data.cpu().numpy() return result_npy[0] if __name__=="__main__": model=make_model() imgpath='path_to_img/xxx.jpg' print inference(model,imgpath) print extract_feature(model, imgpath)
阅读全文
0 0
- 使用pytorch预训练模型分类与特征提取
- 使用mxnet的预训练模型(pretrained model)分类与特征提取
- 【caffe:从一个预训练模型中提取特征】
- 使用keras预训练VGG16模型参数分类图像并提取特征
- pytorch训练imagenet分类
- pytorch 使用预训练层
- pytorch 如何加载部分预训练模型
- 使用opensmile提取音频的特征,得到特征向量,并扔进libsvm中进行分类训练测试
- Tensorflow保存模型,恢复模型,使用训练好的模型进行预测和提取中间输出(特征)
- Tensorflow保存模型,恢复模型,使用训练好的模型进行预测和提取中间输出(特征)【转】
- caffe学习笔记2_用一个预训练模型提取特征
- caffe练习实例(3)——用预训练模型提取特征
- caffe提取已训练好模型的特征
- 代码笔记:caffereid利用训练好的模型提取特征
- caffe根据训练出的模型提取特征
- caffe 用训练好的模型提取图片特征(使用自带classify.py和classifier.py)
- PyTorch学习之路(level1)——训练一个图像分类模型
- 使用Keras预训练模型ResNet50进行图像分类
- Oozie安装时放置Mysql驱动包的总结(网上最全)
- post发布表单判断
- memcached win64位服务端安装和java客户端实例
- Android学习三、SurfaceView的学习
- java笔记-多线程join用法
- 使用pytorch预训练模型分类与特征提取
- js获取文件的后缀名方法
- Paint MaskFilter类进行处理、颜色RGB的滤镜处理
- CSDN日报20170601 ——《程序猿职业生涯的迷惘与野望》
- 后代选择器与子选择器
- Python例题8-3~8-4 T恤
- C#学习笔记(八)—–LinqToSql和Entity Framework(上)
- Java程序运行机制
- 记一次oracle创建一个新数据库,并导入正式环境数据库备份的dmp包过程