使用pytorch预训练模型分类与特征提取

来源:互联网 发布:非广播多路访问网络 编辑:程序博客网 时间:2024/06/05 23:01

    pytorch应该是深度学习框架里面比较好使用的了,相比于tensorflow,mxnet。可能在用户上稍微少一点,有的时候出问题不好找文章。下面就使用pytorch预训练模型做分类和特征提取,pytorch文档可以参考:pytorch docs  , 模型是imagenet2012训练的标签可参考:imagenet2012 labels  ,模型预测的下标按从上到下,起始(n01440764)为0

   

#encoding=utf-8import osimport numpy as npimport torchimport torch.nnimport torchvision.models as modelsfrom torch.autograd import Variable import torch.cudaimport torchvision.transforms as transformsfrom PIL import Imageimg_to_tensor = transforms.ToTensor()def make_model():    resmodel=models.resnet34(pretrained=True)    resmodel.cuda()#将模型从CPU发送到GPU,如果没有GPU则删除该行    return resmodel#分类def inference(resmodel,imgpath):    resmodel.eval()#必需,否则预测结果是错误的        img=Image.open(imgpath)    img=img.resize((224,224))    tensor=img_to_tensor(img)        tensor=tensor.resize_(1,3,224,224)    tensor=tensor.cuda()#将数据发送到GPU,数据和模型在同一个设备上运行                result=resmodel(Variable(tensor))    result_npy=result.data.cpu().numpy()#将结果传到CPU,并转换为numpy格式    max_index=np.argmax(result_npy[0])        return max_index    #特征提取def extract_feature(resmodel,imgpath):    resmodel.fc=torch.nn.LeakyReLU(0.1)    resmodel.eval()        img=Image.open(imgpath)    img=img.resize((224,224))    tensor=img_to_tensor(img)        tensor=tensor.resize_(1,3,224,224)    tensor=tensor.cuda()                result=resmodel(Variable(tensor))    result_npy=result.data.cpu().numpy()        return result_npy[0]    if __name__=="__main__":    model=make_model()    imgpath='path_to_img/xxx.jpg'    print inference(model,imgpath)    print extract_feature(model, imgpath)    


阅读全文
0 0
原创粉丝点击