我试图让一些现有的 pytorch 模型支持 TorchScript jit 编译器,但我遇到了非原始类型成员的问题。
这个小例子说明了这个问题:
import torch
@torch.jit.script
class Factory(object):
def __init__(self):
pass
def create(self, x: float) -> torch.Tensor:
return torch.tensor([x])
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.factory: Factory = Factory()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.factory.create(0)
mod = torch.jit.script(Foo())
运行时,jit编译报错
RuntimeError:
module has no attribute 'factory':
at example.py:17:15
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.factory.create(0)
~~~~~~~~~~~~ <--- HERE
我已经测试过 Factory
类可用于 forward
方法内的 jit,但是当我将它存储为成员时它不承认它。为什么是这样?有什么办法可以让 jit 编译器将这种成员保存到编译后的模块中?
最佳答案
这是 PyTorch 中的一个错误,在您发布问题后很快就解决了:https://discuss.pytorch.org/t/jit-scripted-attributes-inside-module/60645 , https://github.com/pytorch/pytorch/issues/27495 .
更新 PyTorch 应该可以解决这个问题。
https://stackoverflow.com/questions/58998441/
相关文章:
node.js - Sequelize findAndCountAll 并计算包含的模型
r - 代码相当于 RStudio 查看器 Pane 中的 'broom' 图标?
python - Jupyter Lab 交互图像展示 : issue with widgets a
react-native - 在 bazel 构建中运行 react-native cli
flutter - 将 mapEventToState 与传入流一起使用的最佳实践?
git - Devops Azure 没有足够的钩子(Hook)权限(对于协作者)
reactjs - React 中的 forwardingRef 与回调 refs 有什么区别?