最近使用Pytorch 0.4.0 进行模型训练,之后使用一个转模型的工具时,报了一个错,就是标题里面的
_rebuild_tensor_v2
相关的错误。最后发现是本地使用的pytorch的版本是0.3.0,和0.4.0模型上不兼容。各论坛上的解决方案都是说pytorch版本不向后兼容,建议升级pytorch。无奈我这里不方便升级pytorch版本。那么问题就来了,有没有什么不需要修改pytorch源码,或是不升级pytorch,又能让老版本的pytorch读取新版本模型的方案呢?
当然是有的,而且工作量很小。
一、Pytorch模型存储和读取的流程
首先,我们使用pytorch存储模型会使用 torch.save
这个函数,直接将模型的state_dict()
保存下来。类似下面的代码:
torch.save({
'state_dict': model.state_dict(),
'other': other_data
},
'model.pth'
)
读取参数的代码也十分简单:
model = Model()
with open('model.pth', 'rb') as fp:
param = torch.load(fp)
model.load_state_dict(param['state_dict'])
而低版本的pytorch就是在 load_state_dict
这里报了错。
二、State Dict
我们首先要知道,model.state_dict()
的返回值究竟是什么。
这里我直接给出结论:
model.state_dict()
的返回值是一个 collections.OrderedDict
对象,它的键是一个字符串,它的值是Tensor的对象。所以造成兼容性问题的其实是Tensor对象的不兼容。
那么是不是可以将Tensor转化成一个新的非Pytorch内置的数据类型呢?这样就可以避免兼容性问题。
numpy.ndarray
就是我们需要的中间态。
三、模型转换
首先,我们需要将state_dict的参数转换成 numpy.ndarray
保存下来。这里使用高版本的pytorch。
import torch
from collections import OrderedDict
import pickle
with open('model.pth', 'rb') as fp:
param = torch.load(fp)
state_dict = param['state\_dict']
numpy_state_dict = OrderedDict()
for key, tensor in state_dict.items():
numpy_state_dict[key] = tensor.cpu().numpy()
with open('state_dict.pic', 'wb') as fp:
pickle.dump(numpy_state_dict, fp)
之后,用低版本的pytorch载入这个numpy的state_dict
。
import pickle
import torch
with open('numpy_state_dict.pic', 'rb') as fp:
state_dict = pickle.load(fp)
# numpy.ndarray -> tensor
for key, ndarr in state_dict.items():
state_dict[key] = torch.Tensor(ndarr)
model = Model()
model.load_state_dict(state_dict)
四、总结
对于这个问题,还有很多的解决方案,这里是比较简单的一种。
PS. 这是目前为止,写的最快的一篇博客了。。。
转载请注明出处,谢谢!