_rebuild_tensor_v2?pytorch版本间模型兼容性脱坑实践

最近使用Pytorch 0.4.0 进行模型训练,之后使用一个转模型的工具时,报了一个错,就是标题里面的_rebuild_tensor_v2相关的错误。最后发现是本地使用的pytorch的版本是0.3.0,和0.4.0模型上不兼容。各论坛上的解决方案都是说pytorch版本不向后兼容,建议升级pytorch。无奈我这里不方便升级pytorch版本。那么问题就来了,有没有什么不需要修改pytorch源码,或是不升级pytorch,又能让老版本的pytorch读取新版本模型的方案呢?

当然是有的,而且工作量很小。

一、Pytorch模型存储和读取的流程

首先,我们使用pytorch存储模型会使用 torch.save 这个函数,直接将模型的state_dict()保存下来。类似下面的代码:

torch.save({
    'state_dict': model.state_dict(),
    'other': other_data
  },
  'model.pth'
)

读取参数的代码也十分简单:

model = Model()

with open('model.pth', 'rb') as fp:
    param = torch.load(fp)

model.load_state_dict(param['state_dict'])

而低版本的pytorch就是在 load_state_dict 这里报了错。

二、State Dict

我们首先要知道,model.state_dict()的返回值究竟是什么。

这里我直接给出结论:

model.state_dict() 的返回值是一个 collections.OrderedDict 对象,它的键是一个字符串,它的值是Tensor的对象。所以造成兼容性问题的其实是Tensor对象的不兼容。

那么是不是可以将Tensor转化成一个新的非Pytorch内置的数据类型呢?这样就可以避免兼容性问题。

numpy.ndarray 就是我们需要的中间态。

三、模型转换

首先,我们需要将state_dict的参数转换成 numpy.ndarray 保存下来。这里使用高版本的pytorch。

import torch
from collections import OrderedDict
import pickle

with open('model.pth', 'rb') as fp:
    param = torch.load(fp)

state_dict = param['state\_dict']
numpy_state_dict = OrderedDict()

for key, tensor in state_dict.items():
    numpy_state_dict[key] = tensor.cpu().numpy()

with open('state_dict.pic', 'wb') as fp:
    pickle.dump(numpy_state_dict, fp)

之后,用低版本的pytorch载入这个numpy的state_dict

import pickle
import torch

with open('numpy_state_dict.pic', 'rb') as fp:
    state_dict = pickle.load(fp)

# numpy.ndarray -> tensor
for key, ndarr in state_dict.items():
    state_dict[key] = torch.Tensor(ndarr)

model = Model()
model.load_state_dict(state_dict)

四、总结

对于这个问题,还有很多的解决方案,这里是比较简单的一种。

PS. 这是目前为止,写的最快的一篇博客了。。。

转载请注明出处,谢谢!