TenserRT(三)PYTORCH 转 ONNX 详解
第三章:PyTorch 转 ONNX 详解 — mmdeploy 0.12.0 文档
torch.onnx — PyTorch 2.0 documentation
torch.onnx.export
细解
计算图导出方法
TorchScript是一种序列化和优化PyTorch模型的格式,将torch.nn.Module模型转换为TorchScript的torch.jit.ScriptModule模型,也是一种中间表示。
torch.onnx.export中使用的模型实际上是torch.jit.ScriptModule。
将torch.nn.Module转化为TorchScript模型(导出计算图)有两种模式:跟踪(trace)和脚本化(script)。
torch.onnx.export输入一个torch.nn.Module,默认会使用跟踪(trace)的方法导出。
import torchclass Model(torch.nn.Module):def __init__(self, n):super().__init__()self.n = nself.conv = torch.nn.Conv2d(3, 3, 3)def forward(self, x):for i in range(self.n):#控制输入张量被卷积的次数x = self.conv(x)return xmodels = [Model(2), Model(3)]# n=2和n=3的模型
model_names = ['model_2', 'model_3']for model, model_name in zip(models, model_names):dummy_input = torch.rand(1, 3, 10, 10)dummy_output = model(dummy_input)model_trace = torch.jit.trace(model, dummy_input)model_script = torch.jit.script(model)#torch.onnx.export默认使用trace,所有不需要先trace# 跟踪法与直接 torch.onnx.export(model, ...)等价# torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx', example_outputs=dummy_output, opset_version = 11)torch.onnx.export(model, dummy_input,f'{model_name}_trace.onnx', example_outputs=dummy_output, opset_version = 11)# 脚本化必须先调用 torch.jit.sciprttorch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx', example_outputs=dummy_output)# 如果是先运行了torch.jit.script,将模型转化成TorchScript,则export函数不需要再运行一遍# 如果输入不是TorchScript,则export需要运行一遍模型# dummy_input和dummy_output表示输入输出张量的数据类型和形状
跟踪法trace中,不同的n得到的ONNX模型结构是不一样的。
脚本法script中,Loop节点表示循环,不同的n可以有相同的结构。
推理引擎对静态图支持更好,不需要显式的将PyTorch模型转换为TorchScript,直接使用torch.onnx.export跟踪法导出即可。
虽然在代码中没有直接将trace的脚本作为export输入,但是可以通过trace来定位export问题是否出现在trace中。
参数讲解
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,input_names=None, output_names=None, aten=False, export_raw_ir=False,operator_export_type=None, opset_version=None, _retain_param_name=True,do_constant_folding=True, example_outputs=None, strip_doc_string=True,dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None,enable_onnx_checker=True, use_external_data_format=False):
- 模型(model):必选
- 输入(args):必选
- 导出的 onnx 文件名(f):必选
- 模型中是否保存权重(export_params):一般模型结构和模型权重放在一个文件里存储,所以默认是true,如果是在不同的框架间传递模型,而不是用于部署,则设置为false。
- 输入/输出张量名称(input_names, output_names):推理引擎一般都需要通过“名称-张量值”的数据对来输入数据,并根据输出张量的名称来获取输出数据,保证ONNX和推理引擎中使用同一套名称。
- opset_version:ONNX算子集版本。
- dynamic_axes:指定输入输出张量的哪些维度是动态的,为了追求效率,ONNX默认所有参与运算的张量都是静态的(张量的形状不发生改变)。可以显式的指明输入输出张量的哪几个维度的大小是可变的
import torchclass Model(torch.nn.Module):def __init__(self):super().__init__()#继承父类构造函数中self.conv = torch.nn.Conv2d(3, 3, 3)def forward(self, x):x = self.conv(x)return xmodel = Model()
dummy_input = torch.rand(1, 3, 10, 10)
model_names = ['model_static.onnx',
'model_dynamic_0.onnx',
'model_dynamic_23.onnx']# dynamic_axes_0 = {#第0维动态
# 'in' : [0],
# 'out' : [0]
# }
dynamic_axes_0 = {'in' : {0: 'batch'},'out' : {0: 'batch'}
}
dynamic_axes_23 = {#第2、3维动态'in' : [2, 3],'out' : [2, 3]
}torch.onnx.export(model, dummy_input, model_names[0], input_names=['in'], output_names=['out'])#没有动态维度
torch.onnx.export(model, dummy_input, model_names[1], input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_0)#第0维动态
torch.onnx.export(model, dummy_input, model_names[2], input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_23)#第2、3维动态
# ONNX 要求每个动态维度都有一个名字,直接这样写会引出一条UserWarning,警告我们通过列表方式设置动态维度的话,系统会自动为它们分配名字
import onnxruntime
import numpy as nporigin_tensor = np.random.rand(1, 3, 10, 10).astype(np.float32)
mult_batch_tensor = np.random.rand(2, 3, 10, 10).astype(np.float32)
big_tensor = np.random.rand(1, 3, 20, 20).astype(np.float32)inputs = [origin_tensor, mult_batch_tensor, big_tensor]
exceptions = dict()
model_names = ['model_static.onnx',#批量或者维度增加就会出错
'model_dynamic_0.onnx',#维度增加就会出错
'model_dynamic_23.onnx']#批量增加就会出错
for model_name in model_names:for i, input in enumerate(inputs):try:ort_session = onnxruntime.InferenceSession(model_name)ort_inputs = {'in': input}ort_session.run(['out'], ort_inputs)#只有在设置了对应的动态维度后才不会出错except Exception as e:exceptions[(i, model_name)] = eprint(f'Input[{i}] on model {model_name} error.')print(exceptions[(1, 'model_static.onnx')])else:print(f'Input[{i}] on model {model_name} succeed.')
使用技巧
torch.onnx.is_in_onnx_export():PyTorch推理时不运行,但是在执行torch.onnx.export()时为真。
import torchclass Model(torch.nn.Module):def __init__(self):super().__init__()self.conv = torch.nn.Conv2d(3, 3, 3)def forward(self, x):x = self.conv(x)if torch.onnx.is_in_onnx_export():# 仅在模型导出时把输出张量的数值限制在[0,1]之间#可以在代码中添加和模型部署相关的逻辑x = torch.clip(x, 0, 1)return x
利用中断张量跟踪的操作
import torch
class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):#item、for、list等方法都会导致ONNX模型不太正确x = x * x[0].item()#跟踪法会把某些取决于输入的中间结果变成常量# .item()把torch中的张量转换成普通的Python遍历return x, torch.Tensor([i for i in x])#遍历torch张量,并用一个列表新建一个torch张量。model = Model()
dummy_input = torch.rand(10)
torch.onnx.export(model, dummy_input, 'a.onnx')
涉及到张量与普通变量转换的逻辑都会导致最终ONNX模型不太正确。
利用这个性质,在保证正确性的前提下令模型中间结果变成常量。
这个技巧尝尝用于模型的静态化上,即令模型中所有张量形状都变成常量。
使用张量为输入(PyTorch版本 < 1.9.0)
PyTorch 对 ONNX 的算子支持
如果torch.onnx.export()正常执行后,另一个容易出现的问题就是算子不兼容。
在转换普通torch.nn.Module模型时:
- Pytorch利用跟踪法执行前向推理,把遇到的算子整合成计算图;
- Pytorch把遇到的算子翻译成ONNX定义的算子。
算子翻译的过程可能遇到的情况:
- 算子可以一对一翻译成ONNX算子。
- 算子没有一对一的ONNX算子,被翻译成一个或多个ONNX算子。
- 算子没有翻译成ONNX的规则。
ONNX 算子文档
onnx/Operators.md at main · onnx/onnx · GitHub
算子变更表格(算子名,算子变更版本号opset_version),第一次变更的版本号,表示算子第一次被支持,且第一个改动记录可以知道当前算子集中该算子的定义规则。
表格中的链接可以说明该算子的输入输出参数规定使用示例。
PyTorch 对 ONNX 算子的映射
pytorch/torch/onnx at master · pytorch/pytorch · GitHub
symbloic_opset{n}.py表示pytorch对应的ONNX算子集版本。
在vscode中限定在torch/onnx文件夹搜索对应算子
按照调用逻辑直接跳转到
@_onnx_symbolic("aten::upsample_bicubic2d",decorate=[_apply_params("upsample_bicubic2d", 4, "cubic")],
)
->@_beartype.beartype
def _interpolate(name: str, dim: int, interpolate_mode: str):return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)->@_beartype.beartype
def _interpolate_helper(name, dim, interpolate_mode):@quantized_args(True, False, False)def symbolic_fn(g, input, output_size, *args):...return symbolic_fn
symbolic_fn中插值算子被映射成多个ONNX算子,一个g.op对应ONNX
return g.op("Resize",input,empty_roi,empty_scales,output_size,coordinate_transformation_mode_s=coordinate_transformation_mode,cubic_coeff_a_f=-0.75, # only valid when mode="cubic"mode_s=interpolate_mode, # nearest, linear, or cubicnearest_mode_s="floor",) # only valid when mode="nearest"
查找对应的ONNXonnx/Operators.md at main · onnx/onnx · GitHub resize算子定义,可以知道对应参数含义。
查询PyTorch到ONNX的映射关系,然后在torch.onnx.export()的opset_version设定一个版本号,然后去PyTorch符号表文件里去查。如果没有对应算子,就需要考虑用其他算子替代,或者自定义算子。
总结
- 跟踪法和脚本化在导出待控制语句的计算图时有什么区别。
- torch.onnx.export()中如何设置input_names, output_names, dynamic_axes。
- 使用torch.onnx.is_in_onnx_export()来使得模型在转换到ONNX时有不同的行为。
- 查询ONNX 算子文档。
- 查询ONNX算子对PyTorch算子支持情况。
- 查询ONNX算子对PyTorch算子使用方式。