5.7. 开发 Custom ONNX Passes

Pass 是 PopRT 相对核心的一个模块, 目前模型转换和优化都是以 Pass 的形式实现的, Pass 需要按照正确的顺序才能实现预期的功能. 目前 Pass 支持以 C++ 或者 Python 实现, 本教程将讲解如何以 Python 实现一个 Pass 并在 PopRT 使用.

在阅读本教程之前, 用户需要对以下内容有一定的了解:

Note

  • ONNX

  • ONNX Python API

  • 所有 Custom ONNX Pass 均在 PopRT 内置转换流程之后运行. 换句话说, Custom ONNX Pass 针对的是 PopRT 使用默认或自选参数转换后得到的 ONNX 模型, 而非原模型.

  • 建议在编写 Custom ONNX Pass 使用 PopRT 转换得到 Custom ONNX Pass 的目标 ONNX 模型, 针对目标模型编写 Custom ONNX Pass.

  • 多个 Custom ONNX Pass 要注意相互之间的顺序依赖.

5.7.1. 实现 Custom ONNX Pass

要在 PopRT 中实现一个 Custom ONNX Pass, 需要编写至少一个 Python 文件. 示例代码如下:

import onnx
from poprt.base_pass import Pass, register


@register('replace_add_with_sub')
class ReplaceAddwithSub(Pass):
    '''Replace Add with Sub.'''

    def __init__(self):
        super().__init__()

    # must implement __call__ method, onnx in onnx out
    def __call__(self, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
        for node in onnx_model.graph.node:
            if node.op_type == 'Add':
                node.op_type = 'Sub'

        return onnx_model


# optional implement, return instance of ReplaceAddwithSub
def get_pass():
    return ReplaceAddwithSub()

实现 Custom ONNX Pass, 通常需要实现以下三部分:

  • 继承 Pass 并实现其 __call__ 方法. __call__ 方法遵循 ONNX in ONNX out 的原则. 在 __call__ 方法里, 可以实现想要的功能.

  • 注册 PassPassManager.

  • 实现 get_pass 方法, 这部分是可选的, 取决于使用哪种方法在 PopRT 里执行 Custom ONNX Pass. 后面会介绍使能 Custom ONNX Pass 的方法.

5.7.2. 在 PopRT 中使用 Custom ONNX Pass

目前有两种方法可以在 PopRT 中使用 Custom ONNX Pass.

方法一: 通过 --custom_pass_config 指定

python -m poprt.cli \
    --input_model single_add.onnx \
    --input_shape input_ids=4,384 \
    --custom_pass_config examples/custom_pass/replace_add_with_sub.py

Note

  • 以这种方法执行的 Custom ONNX Pass 必须实现 get_pass 这个方法.

  • --custom_pass_config 可指定多个 Python 文件, Pass 的调用顺序和文件顺序相同. 例如:

    --custom_pass_config test_pass.py,test_pass_2.py,test_pass_3.py
    
  • 每个 Python 文件仅支持包含一个 Pass.

方法二: 通过将 Python 文件放在 PopRT 安装目录下

这种方法不需要实现 get_pass 方法, 但是还需要执行以下步骤:

  1. 将 Python 文件放置到 poprt/passes 目录下, 可以通过以下方法获取目录:

import poprt

print(poprt.passes.__path__)
  1. poprt/passes/__init__.py 文件中 import 自己实现的 Pass, 示例代码如下:

from .replace_add_with_sub import ReplaceAddwithSub
  1. 在执行完以上步骤后, 可以通过 --passes 来执行刚加入的 Pass, 示例如下:

python -m poprt.cli \
    --input_model single_add.onnx \
    --passes replace_add_with_sub

Note

  • Pass 的调用顺序和 --passes 指定的顺序相同.

  • Python 文件支持放入多个 Pass, 每个 Pass 均需要加入 poprt/passes/__init__.py.