5.7. 开发 Custom Passes
在 PopRT 中, Pass 被用于模型转换和优化等各种处理. 可以通过添加 Custom Passes 来拓展 PopRT 的功能.
Custom Passes 的实现/使用需要遵循 PopRT Pass 的一般规范, 详见 Passes.
5.7.1. 实现 Custom Passes
要在 PopRT 中添加 Custom Passes, 需要编写至少一个 Python 文件. 示例代码如下:
1# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
2import onnx
3
4from poprt.passes import Pass, register
5
6
7@register('replace_add_with_sub')
8class ReplaceAddwithSub(Pass):
9 """Replace Add with Sub."""
10
11 def __init__(self):
12 super().__init__()
13
14 # define the transform
15 def run_transform(
16 self,
17 graph: onnx.GraphProto,
18 is_main_graph: bool,
19 ) -> onnx.GraphProto:
20 for node in graph.node:
21 if node.op_type == 'Add':
22 node.op_type = 'Sub'
23 return graph
24
25 # define the run method
26 def run(self, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
27 onnx_model.graph.CopyFrom(
28 self.traverse_graph(onnx_model.graph, self.run_transform)
29 )
30 return onnx_model
Download custom_shape_inference.py
如示例代码所示, 实现了一个 ReplaceAddwithSub
的 Pass
, 该 Pass
的功能是将 ONNX 模型中的所有 Add
节点替换为 Sub
节点.
Pass
被注册为 replace_add_with_sub
.
Note
一个 Python 文件可以包含多个 Custom Pass, 但是每个 Custom Pass 的注册名字必须是唯一的.
5.7.2. 使用 Custom Passes
可以在 PopRT CLI 或者 Python API 中使用 Custom Passes.
在 PopRT CLI 中使用 Custom Passes
可使用参数 --custom_pass_config
指定包含 Custom Passes 的 Python 文件(多个文件使用 ,
隔开). 同时使用 --passes
指定要执行的 Pass.
列出注册的 Custom Passes:
1poprt \
2 --list_all_passes \
3 --custom_pass_config replace_add_with_sub.py |
4 grep replace_add_with_
5
6# replace_add_with_sub
Download custom_shape_inference.py
应用 Custom Passes:
poprt \
--input_model model.onnx \
--passes replace_add_with_sub \
--custom_pass_config replace_add_with_sub.py
Note
通过 CLI 用 Custom Passes 时不支持向 Custom Passes 传递参数.
多个 Custom Pass 之间可能存在依赖关系, PopRT 将会按照
--passes
指定的 Pass 顺序执行.--passes
指定的 Pass 会在 PopRT 默认的一些转换流程完成之后执行. 即输入的 ONNX 模型不是输入的原始模型.
在 Python API 中使用 Custom Passes
在应用代码中导入 Custom Passes 的模块即可注册 Custom Passes.
1import replace_add_with_sub # noqa
2
3import poprt
4
5assert 'replace_add_with_sub' in poprt.get_registered_passes()
Download custom_shape_inference.py
Custom Pass 使用方式和普通的 Pass 相同. 参考 Passes.