5.1. Passes

Pass 是 PopRT 中用于优化 ONNX 模型的组件, 实现自定义的 Pass 见 Custom Passes.

5.1.1. Pass 抽象

PopRT 中实现的 Pass 都继承自 poprt.Pass, 具有统一的接口:

class poprt.Pass(*args, **kwargs)

Abstract Base Class for Passes.

A new Pass could be like:

import onnx
from poprt.passes import register, Pass


@register('dummy_pass')
class Dummy(Pass):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def run(self, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
        print(f"producer_name: {onnx_model.producer_name}")
        return onnx_model
Return type

None

run(onnx_model)

Run Pass, inherited subclasses should override this method.

Parameters

onnx_model (ModelProto) – input onnx model

Returns

the optimized onnx model

Return type

ModelProto

Pass 的注册是自动进行的, 可以通过下面的 CLI 命令列出所有已经注册的 Pass:

poprt --list_all_passes

或者在 Python 中:

import poprt

passes = poprt.get_registered_passes()
print(passes.keys())

使用 poprt.get_pass() 的 Python 接口使用某一个 Pass(可在 Built-in passes 查看 PopRT 已注册的 Passes), 同时可以使用 poprt.PassManager 应用多个 Pass.

poprt.get_pass(name, *args, **kwargs)

Get a Pass by registered name.

Parameters

name (str) – registered name of a Pass

Returns

a Pass instance

Return type

Pass

Example:

import poprt

# get a Pass with parameters
onnx_model = poprt.get_pass('float_to_half', skip_op_types=['Gelu'])(onnx_model)

poprt.get_pass('model_overview')(onnx_model)
class poprt.PassManager(used_passes=[], gather_ir_passes=False)

Manage Passes.

Parameters
  • used_passes (List[Union[str, Pass]]) – passes that will be used

  • gather_ir_passes (bool) – gather onnx ir passes and execute it in one turn.

Return type

None

Example:

import poprt

pm = poprt.PassManager(
    [
        'model_overview',
        'float_to_half',
        poprt.get_pass('model_overview'),
    ]
)

pm.run(onnx_model)
run(onnx_model)

Apply passes to the onnx model.

Parameters

onnx_model (ModelProto) – onnx model that will be optimized

Return type

ModelProto

Note

由于 Pass 模块的结构可能会有变化, 不推荐以下使用方式:

from poprt.passes.float_to_half import Float2Half

Float2Half()(onnx_model)