pytorch 使用预训练层

来源:互联网 发布:什么是软件危机 编辑:程序博客网 时间:2024/06/06 02:52

pytorch 使用预训练层

将其他地方训练好的网络,用到新的网络里面

  • pytorch 使用预训练层
    • 加载预训练网络
    • 加载新网络
    • 更新新网络参数


加载预训练网络

1.原先已经训练好一个网络 AutoEncoder_FC()
2.首先加载该网络,读取其存储的参数
3.设置一个参数集

cnnpre = AutoEncoder_FC()cnnpre.load_state_dict(torch.load('autoencoder_FC.pkl')['state_dict'])cnnpre_dict =cnnpre.state_dict()

加载新网络

1.设置新的网络
2.设置新网络参数集

cnn= AutoEncoder()cnn_dict = cnn.state_dict()

更新新网络参数

1.将两个参数集比对,存在的网络参数保留
2.使用保留下的参数更新新网络参数集
3.加载新网络参数集到新网络中

cnnpre_dict = {k: v for k, v in cnnpre_dict.items() if k in cnn_dict}cnn_dict.update(cnnpre_dict)cnn.load_state_dict(cnn_dict)
原创粉丝点击