我在读pyTorch文档(一)

来源:互联网 发布:mysql索引上创建 编辑:程序博客网 时间:2024/05/22 16:56

Cuda

  • 在Cuda上创建变量的两个方法:

    1. 直接在GPU上创建:x = torch.cuda.FloatTensor(1);
    2. 在CPU上创建然后转移到GPU上:x = torch.FloatTensor(1).cuda();
  • 多GPU使用:

    1. x = torch.FloatTensor(1).cuda(async=True), 通过async=True可以将数据从CPU到GPU的传输与计算重叠,不过当数据量小的时候貌似没什么用;

训练模型保存

  • 1. 只保存和加载模型参数:

    保存:torch.save(model.state_dict(), PATH)
    加载:model = ModelClass(args, * kwargs) + model.load_state_dict(torch.load(PATH))

  • 2. 保存整个模型:

    保存:torch.save(model, PATH)
    加载:model = torch.load(PATH)