5.1. Passes
Pass 是 PopRT 中用于优化 ONNX 模型的组件, 实现自定义的 Pass 见 Custom Passes.
5.1.1. Pass 抽象
PopRT 中实现的 Pass 都继承自 poprt.Pass
, 具有统一的接口:
- class poprt.Pass(*args, **kwargs)
Abstract Base Class for Passes.
A new Pass could be like:
import onnx from poprt.passes import register, Pass @register('dummy_pass') class Dummy(Pass): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def run(self, onnx_model: onnx.ModelProto) -> onnx.ModelProto: print(f"producer_name: {onnx_model.producer_name}") return onnx_model
- Return type
None
- run(onnx_model)
Run Pass, inherited subclasses should override this method.
- Parameters
onnx_model (ModelProto) – input onnx model
- Returns
the optimized onnx model
- Return type
ModelProto
Pass 的注册是自动进行的, 可以通过下面的 CLI 命令列出所有已经注册的 Pass:
poprt --list_all_passes
或者在 Python 中:
import poprt
passes = poprt.get_registered_passes()
print(passes.keys())
使用 poprt.get_pass()
的 Python 接口使用某一个 Pass(可在 Built-in passes 查看 PopRT 已注册的 Passes), 同时可以使用 poprt.PassManager
应用多个 Pass.
- poprt.get_pass(name, *args, **kwargs)
Get a Pass by registered name.
- Parameters
name (str) – registered name of a Pass
- Returns
a Pass instance
- Return type
Example:
import poprt # get a Pass with parameters onnx_model = poprt.get_pass('float_to_half', skip_op_types=['Gelu'])(onnx_model) poprt.get_pass('model_overview')(onnx_model)
- class poprt.PassManager(used_passes=[], gather_ir_passes=False)
Manage Passes.
- Parameters
used_passes (List[Union[str, Pass]]) – passes that will be used
gather_ir_passes (bool) – gather onnx ir passes and execute it in one turn.
- Return type
None
Example:
import poprt pm = poprt.PassManager( [ 'model_overview', 'float_to_half', poprt.get_pass('model_overview'), ] ) pm.run(onnx_model)
- run(onnx_model)
Apply passes to the onnx model.
- Parameters
onnx_model (ModelProto) – onnx model that will be optimized
- Return type
ModelProto
Note
由于 Pass 模块的结构可能会有变化, 不推荐以下使用方式:
from poprt.passes.float_to_half import Float2Half
Float2Half()(onnx_model)