pytorch在cpu上加载预先训练好的GPU模型以及GPU上加载CPU上训练的Model
2018-12-18 10:24
1056 查看
有时候我们在CPU上训练的模型,因为一些原因,切换到GPU上,或者在GPU上训练的模型,因为条件限制,切换到CPU上。 GPU上训练模型时,将权重加载到CPU的最佳方式是什么?今天我们来讨论一下:
提取模型到指定的设备
从官方文档中我们可以看到如下方法
torch.load('tensors.pt') # 把所有的张量加载到CPU中 torch.load('tensors.pt', map_location=lambda storage, loc: storage) # 把所有的张量加载到GPU 1中 torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # 把张量从GPU 1 移动到 GPU 0 torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
在cpu上加载预先训练好的GPU模型,有一种强制所有GPU张量在CPU中的方式:
torch.load('my_file.pt', map_location=lambda storage, loc: storage)
Q:上述代码只有在模型在一个GPU上训练时才起作用。如果我在多个GPU上训练我的模型,保存它,然后尝试在CPU上加载,我得到这个错误:KeyError: ‘unexpected key “module.conv1.weight” in state_dict’ 如何解决?
您可能已经使用模型保存了模型nn.DataParallel,该模型将模型存储在该模型中module,而现在您正试图加载模型DataParallel。您可以nn.DataParallel在网络中暂时添加一个加载目的,也可以加载权重文件,创建一个没有module前缀的新的有序字典,然后加载它。
参考:
# original saved file with DataParallel state_dict = torch.load('myfile.pth.tar') # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params model.load_state_dict(new_state_dict)
相关文章推荐
- pytorch 分布式训练GPU模型转CPU
- tensorflow-gpu 和cpu使用训练ssd模型感想(显卡内存不足解决办法)
- pytorch多GPU训练以及多线程加载数据
- tensorflow: 保存和加载模型, 参数;以及使用预训练参数方法
- PyTorch加载预训练模型的问题
- tensorflow 在cpu的环境中无法导入gpu训练好的模型(Make sure the device specification refers to a valid device.)
- Tensorflow实战学习(十六)【CNN实现、数据集、TFRecord、加载图像、模型、训练、调试】
- 【PyTorch图像语义分割】4. 使用训练好的模型测试
- (2) 用DPM(Deformable Part Model,voc-release4.01)算法在INRIA数据集上训练自己的人体检測模型
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TF Saver 保存/加载训练好模型(网络+参数)的那些事儿
- tensorflow 使用多块GPU同时训练多个模型
- TensorFlow在训练模型时指定GPU进行训练
- 用Deformable Part Model(DPM)voc-release3.1训练自己的模型
- windows下可运行的mat转xml,VOC-release4.01 DPM训练的model(mat)转为OpenCV latentsvm可以加载的model(xml)
- OpenAI推新程序包:GPU适应十倍大模型仅需增加20%训练时间
- Tensorflow加载预训练模型和保存模型
- Torch load model from gpu to cpu, so can convert to pytorch
- 用tensorflow框架和Mnist手写字体,训练cnn模型以及测试一张手写字体