5.7. 开发 Custom Passes

在 PopRT 中, Pass 被用于模型转换和优化等各种处理. 可以通过添加 Custom Passes 来拓展 PopRT 的功能.

Custom Passes 的实现/使用需要遵循 PopRT Pass 的一般规范, 详见 Passes.

5.7.1. 实现 Custom Passes

要在 PopRT 中添加 Custom Passes, 需要编写至少一个 Python 文件. 示例代码如下:

Listing 5.12 replace_add_with_sub.py
 1# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
 2import onnx
 3
 4from poprt.passes import Pass, register
 5
 6
 7@register('replace_add_with_sub')
 8class ReplaceAddwithSub(Pass):
 9    """Replace Add with Sub."""
10
11    def __init__(self):
12        super().__init__()
13
14    # define the transform
15    def run_transform(
16        self,
17        graph: onnx.GraphProto,
18        is_main_graph: bool,
19    ) -> onnx.GraphProto:
20        for node in graph.node:
21            if node.op_type == 'Add':
22                node.op_type = 'Sub'
23        return graph
24
25    # define the run method
26    def run(self, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
27        onnx_model.graph.CopyFrom(
28            self.traverse_graph(onnx_model.graph, self.run_transform)
29        )
30        return onnx_model

Download custom_shape_inference.py

如示例代码所示, 实现了一个 ReplaceAddwithSubPass, 该 Pass 的功能是将 ONNX 模型中的所有 Add 节点替换为 Sub 节点. Pass 被注册为 replace_add_with_sub.

Note

一个 Python 文件可以包含多个 Custom Pass, 但是每个 Custom Pass 的注册名字必须是唯一的.

5.7.2. 使用 Custom Passes

可以在 PopRT CLI 或者 Python API 中使用 Custom Passes.

在 PopRT CLI 中使用 Custom Passes

可使用参数 --custom_pass_config 指定包含 Custom Passes 的 Python 文件(多个文件使用 , 隔开). 同时使用 --passes 指定要执行的 Pass.

列出注册的 Custom Passes:

Listing 5.13 load_custom_passes_in_cli.sh
1poprt \
2    --list_all_passes \
3    --custom_pass_config replace_add_with_sub.py |
4    grep replace_add_with_
5
6# replace_add_with_sub

Download custom_shape_inference.py

应用 Custom Passes:

poprt \
    --input_model model.onnx \
    --passes replace_add_with_sub \
    --custom_pass_config replace_add_with_sub.py

Note

  • 通过 CLI 用 Custom Passes 时不支持向 Custom Passes 传递参数.

  • 多个 Custom Pass 之间可能存在依赖关系, PopRT 将会按照 --passes 指定的 Pass 顺序执行.

  • --passes 指定的 Pass 会在 PopRT 默认的一些转换流程完成之后执行. 即输入的 ONNX 模型不是输入的原始模型.

在 Python API 中使用 Custom Passes

在应用代码中导入 Custom Passes 的模块即可注册 Custom Passes.

Listing 5.14 load_custom_passes.py
1import replace_add_with_sub  # noqa
2
3import poprt
4
5assert 'replace_add_with_sub' in poprt.get_registered_passes()

Download custom_shape_inference.py

Custom Pass 使用方式和普通的 Pass 相同. 参考 Passes.