5.7. 开发 Custom ONNX Passes
Pass 是 PopRT 相对核心的一个模块, 目前模型转换和优化都是以 Pass 的形式实现的, Pass 需要按照正确的顺序才能实现预期的功能. 目前 Pass 支持以 C++ 或者 Python 实现, 本教程将讲解如何以 Python 实现一个 Pass 并在 PopRT 使用.
在阅读本教程之前, 用户需要对以下内容有一定的了解:
Note
所有 Custom ONNX Pass 均在 PopRT 内置转换流程之后运行. 换句话说, Custom ONNX Pass 针对的是 PopRT 使用默认或自选参数转换后得到的 ONNX 模型, 而非原模型.
建议在编写 Custom ONNX Pass 使用 PopRT 转换得到 Custom ONNX Pass 的目标 ONNX 模型, 针对目标模型编写 Custom ONNX Pass.
多个 Custom ONNX Pass 要注意相互之间的顺序依赖.
5.7.1. 实现 Custom ONNX Pass
要在 PopRT 中实现一个 Custom ONNX Pass, 需要编写至少一个 Python 文件. 示例代码如下:
import onnx
from poprt.base_pass import Pass, register
@register('replace_add_with_sub')
class ReplaceAddwithSub(Pass):
'''Replace Add with Sub.'''
def __init__(self):
super().__init__()
# must implement __call__ method, onnx in onnx out
def __call__(self, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
for node in onnx_model.graph.node:
if node.op_type == 'Add':
node.op_type = 'Sub'
return onnx_model
# optional implement, return instance of ReplaceAddwithSub
def get_pass():
return ReplaceAddwithSub()
实现 Custom ONNX Pass, 通常需要实现以下三部分:
继承
Pass
并实现其__call__
方法.__call__
方法遵循ONNX in ONNX out
的原则. 在__call__
方法里, 可以实现想要的功能.注册
Pass
到PassManager
.实现
get_pass
方法, 这部分是可选的, 取决于使用哪种方法在 PopRT 里执行 Custom ONNX Pass. 后面会介绍使能 Custom ONNX Pass 的方法.
5.7.2. 在 PopRT 中使用 Custom ONNX Pass
目前有两种方法可以在 PopRT 中使用 Custom ONNX Pass.
方法一: 通过 --custom_pass_config
指定
python -m poprt.cli \
--input_model single_add.onnx \
--input_shape input_ids=4,384 \
--custom_pass_config examples/custom_pass/replace_add_with_sub.py
Note
以这种方法执行的 Custom ONNX Pass 必须实现
get_pass
这个方法.--custom_pass_config
可指定多个 Python 文件,Pass
的调用顺序和文件顺序相同. 例如:--custom_pass_config test_pass.py,test_pass_2.py,test_pass_3.py
每个 Python 文件仅支持包含一个
Pass
.
方法二: 通过将 Python 文件放在 PopRT 安装目录下
这种方法不需要实现 get_pass
方法, 但是还需要执行以下步骤:
将 Python 文件放置到
poprt/passes
目录下, 可以通过以下方法获取目录:
import poprt print(poprt.passes.__path__)
在
poprt/passes/__init__.py
文件中import
自己实现的Pass
, 示例代码如下:
from .replace_add_with_sub import ReplaceAddwithSub
在执行完以上步骤后, 可以通过
--passes
来执行刚加入的Pass
, 示例如下:
python -m poprt.cli \ --input_model single_add.onnx \ --passes replace_add_with_sub
Note
Pass
的调用顺序和--passes
指定的顺序相同.Python 文件支持放入多个
Pass
, 每个Pass
均需要加入poprt/passes/__init__.py
.