7. Passes

Pass 是 PopRT 中用于优化 ONNX 模型的组件, 实现自定义的 Pass 见 开发 Custom Passes.

7.1. Pass 抽象

PopRT 中实现的 Pass 都继承自 poprt.Pass, 具有统一的接口:

class poprt.Pass(*args, **kwargs)

Abstract Base Class for Passes.

A new Pass could be like:

import onnx
from poprt.passes import register, Pass


@register('dummy_pass')
class Dummy(Pass):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def run(self, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
        print(f"producer_name: {onnx_model.producer_name}")
        return onnx_model
Return type

None

run(onnx_model)

Run Pass, inherited subclasses should override this method.

Parameters

onnx_model (ModelProto) – input onnx model

Returns

the optimized onnx model

Return type

ModelProto

Pass 的注册是自动进行的, 可以通过下面的 CLI 命令列出所有已经注册的 Pass:

poprt --list_all_passes

或者在 Python 中:

import poprt

passes = poprt.get_registered_passes()
print(passes.keys())

使用 poprt.get_pass() 的 Python 接口使用某一个 Pass, 同时可以使用 poprt.PassManager 应用多个 Pass.

poprt.get_pass(name, *args, **kwargs)

Get a Pass by registered name.

Parameters

name (str) – registered name of a Pass

Returns

a Pass instance

Return type

Pass

Example:

import poprt

# get a Pass with parameters
onnx_model = poprt.get_pass('float_to_half', skip_op_types=['Gelu'])(onnx_model)

poprt.get_pass('model_overview')(onnx_model)
class poprt.PassManager(used_passes=[], gather_ir_passes=False)

Manage Passes.

Parameters
  • used_passes (List[Union[str, Pass]]) – passes that will be used

  • gather_ir_passes (bool) – gather onnx ir passes and execute it in one turn.

Return type

None

Example:

import poprt

pm = poprt.PassManager(
    [
        'model_overview',
        'float_to_half',
        poprt.get_pass('model_overview'),
    ]
)

pm.run(onnx_model)
run(onnx_model)

Apply passes to the onnx model.

Parameters

onnx_model (ModelProto) – onnx model that will be optimized

Return type

ModelProto

Note

由于 Pass 模块的结构可能会有变化, 不推荐以下使用方式:

from poprt.passes.float_to_half import Float2Half

Float2Half()(onnx_model)

7.2. PopRT 中注册的 Pass

class poprt.passes.add_checkpoints.AddCheckpoints(checkpoints)

Add intermediate tensor to output.

Parameters

checkpoints (List[str]) –

Return type

None

Registered as Pass: add_checkpoints

class poprt.passes.apply_ir_pass.ApplyIrPass(passes=[])

Apply passes based on onnx IR.

Parameters

passes (List[str]) –

Return type

None

Registered as Pass: apply_ir_pass

class poprt.passes.auto_insert_remap.AutoInsertRemap(remap_mode='after_matmul')

Insert remap after matmul.

This is an experimental feature. There are two different insert mode: after_matmul and before_add. For after_matmul mode, it’s more general but more likely OOM, for before_add mode, it’s target to reduce cycles of attention + mask in transformer- based model.

Parameters

remap_mode (str) –

Return type

None

Registered as Pass: auto_insert_remap

class poprt.passes.workarounds.BatchNormWorkaround(*args, **kwargs)

Workaround for BatchNorm Operator.

Return type

None

Registered as Pass: batchnorm_workaround

class poprt.passes.check_with_fake_data.CheckWithFakeData(origin_model)

Checking model with fake data using onnxruntime.

Parameters

origin_model (ModelProto) –

Return type

None

Registered as Pass: check_with_fake_data

class poprt.passes.const_batch_size.ConstBatchSize(const_batch_size=1)

Convert unknown batch size to a const value.

Parameters

const_batch_size (int) –

Return type

None

Registered as Pass: const_batch_size

class poprt.passes.const_input_shape.ConstInputShape(const_input_shape={})

Convert input shape to const values.

Parameters

const_input_shape (Dict[str, Any]) –

Registered as Pass: const_input_shape

class poprt.passes.constant_folding.ConstantFolding(max_tensor_size=- 1)

Support constant folding.

Parameters

max_tensor_size (int) –

Return type

None

Registered as Pass: constant_folding

class poprt.passes.workarounds.CumSumWorkaround(*args, **kwargs)

Workaround for CumSum Operator.

Return type

None

Registered as Pass: cumsum_workaround

class poprt.passes.eight_bits_io.EightBitsIO

Insert norm operator after input image.

Registered as Pass: eight_bits_io

class poprt.passes.apply_ir_pass.eliminate_deadend(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_deadend

class poprt.passes.apply_ir_pass.eliminate_duplicate_initializer(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_duplicate_initializer

class poprt.passes.apply_ir_pass.eliminate_identity(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_identity

class poprt.passes.apply_ir_pass.eliminate_nop_arithmetic(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_nop_arithmetic

class poprt.passes.apply_ir_pass.eliminate_nop_cast(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_nop_cast

class poprt.passes.apply_ir_pass.eliminate_nop_expand(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_nop_expand

class poprt.passes.apply_ir_pass.eliminate_nop_flatten(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_nop_flatten

class poprt.passes.apply_ir_pass.eliminate_nop_if(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_nop_if

class poprt.passes.apply_ir_pass.eliminate_nop_pad(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_nop_pad

class poprt.passes.apply_ir_pass.eliminate_nop_reshape(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_nop_reshape

class poprt.passes.apply_ir_pass.eliminate_nop_transpose(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_nop_transpose

class poprt.passes.apply_ir_pass.eliminate_unused_initializer(*args, **kwargs)
Return type

None

Registered as Pass: eliminate_unused_initializer

class poprt.passes.erf_gelu_pattern.ErfGeluPattern(*args, **kwargs)

Recognise the pattern of Erf Gelu Op and replace the pattern with Erf Gelu.

Return type

None

Registered as Pass: erf_gelu_pattern

class poprt.passes.apply_ir_pass.extract_constant_to_initializer(*args, **kwargs)
Return type

None

Registered as Pass: extract_constant_to_initializer

class poprt.passes.final_check.FinalCheck(*args, **kwargs)

Final check for dtype and shape of the converted model.

Return type

None

Registered as Pass: final_check

class poprt.passes.workarounds.FloatOpsWorkaround(*args, **kwargs)

Workaround for Operators which are required with float32 / float16 inputs.

Return type

None

Registered as Pass: float_ops_workaround

class poprt.passes.float_to_fp8.Float2FP8(fp8_params=['F143', 'F143', 0, 0], skip_op_names=[], convert_model='fp8', fp8_input_dict=None, fp8_weight_dict=None)

Convert a model from fp32 or fp16 to fp8.

Parameters
  • fp8_params (List[Union[Literal['F143', 'F152'], str]]) – Set parameters to fp8 model, the format is [input_format, weight_format, input_scale, weight_scale]

  • skip_op_names (List[str]) – The Op names which will keep fp32/fp16 in fp8 mode, such as [‘Conv_1’, ‘Conv_2’]

  • convert_model (Literal['fp8', 'fp8_weight']) – Specifies which type the model is converted to, can be set to ‘fp8’ or ‘fp8_weight’

  • fp8_input_dict (Dict[str, int]) – Set parameters for each fp8 input node of fp8 model, if it’s not None, fp8_params will be discarded

  • fp8_weight_dict (Dict[str, int]) – Set parameters for each fp8 weight node of fp8 model, if it’s not None, fp8_params will be discarded

Return type

None

Registered as Pass: float_to_fp8

class poprt.passes.float_to_half.Float2Half(skip_op_types=[])

Convert a model from fp32 to fp16.

Create Float2Half instance.

Parameters

skip_op_types (List[str]) – The Op types which will keep fp32 in fp16 mode.

Returns

A Float2Half instance

Return type

None

Registered as Pass: float_to_half

class poprt.passes.float_to_mixed.Float2Mixed

Convert a model from fp32 to mixed precision.

Return type

None

Registered as Pass: float_to_mixed

class poprt.passes.apply_ir_pass.fuse_bn_into_conv(*args, **kwargs)
Return type

None

Registered as Pass: fuse_bn_into_conv

class poprt.passes.fuse_bn_into_gemm.FuseBnIntoGemm

Fuse BatchNormalization to Matmul/Gemm.

Condition:

condition 1: Matmul/Gemm use initializer condition 2: No multi outputs in Gemm/Matmul condition 3: Initializers used across operaters is not supported

Return type

None

Registered as Pass: fuse_bn_into_gemm

class poprt.passes.fuse_cast_into_onehot.FuseCastIntoOnehot

Fuse Cast into OneHot.

Return type

None

Registered as Pass: fuse_cast_into_onehot

class poprt.passes.apply_ir_pass.fuse_consecutive_cast(*args, **kwargs)
Return type

None

Registered as Pass: fuse_consecutive_cast

class poprt.passes.apply_ir_pass.fuse_consecutive_reshape(*args, **kwargs)
Return type

None

Registered as Pass: fuse_consecutive_reshape

class poprt.passes.apply_ir_pass.fuse_consecutive_squeeze(*args, **kwargs)
Return type

None

Registered as Pass: fuse_consecutive_squeeze

class poprt.passes.apply_ir_pass.fuse_consecutive_transpose(*args, **kwargs)
Return type

None

Registered as Pass: fuse_consecutive_transpose

class poprt.passes.apply_ir_pass.fuse_consecutive_unsqueeze(*args, **kwargs)
Return type

None

Registered as Pass: fuse_consecutive_unsqueeze

class poprt.passes.fused_attention.FusedAttention(*args, **kwargs)

Recognise the pattern of MultiHeadAttention and replace it with Fused MultiHeadAttention. Attention Pattern as below:

Add
|
Reshape  --    --    --
|           \           \
MatMul       MatMul      MatMul
|            |           |
Reshape      Reshape     Reshape
|            |           |
Add          Add         Add
|            |           |
Reshape      Reshape     Reshape

Fused Attention Pattern as below:

Add
|
Concat
|
MatMul
|
Add
|
Reshape
|
Transpose
|
Split
Return type

None

Registered as Pass: fused_attention

class poprt.passes.gelu_pattern.GeluPattern(*args, **kwargs)

Recognise the pattern of Gelu Op and replace the pattern with Gelu.

Return type

None

Registered as Pass: gelu_pattern

class poprt.passes.workarounds.IndicesWorkaround(*args, **kwargs)

Workaround for Gather / GatherElements Operator.

Return type

None

Registered as Pass: indices_workaround

class poprt.passes.insert_attention_mask.InsertAttentionMask(*args, **kwargs)

Replace Reshap-Cast-Sub-Mul with Cast-AttentionMask.

Return type

None

Registered as Pass: insert_attention_mask

class poprt.passes.int64_to_int32.Int64ToInt32(*args, **kwargs)

Transfer int64 to int32.

Return type

None

Registered as Pass: int64_to_int32

class poprt.passes.layer_norm_pattern.LayerNormPattern(*args, **kwargs)

Recognise the pattern of LayerNorm Op and replace the pattern with Reshape + GroupNorm + Reshape.

Return type

None

Registered as Pass: layer_norm_pattern

class poprt.passes.layer_precision_compare.LayerPrecisionCompare(origin_model, data_preprocess=None, options=None, output_dir='./')

Compare the output of conv/matmul/gemm operator of the origin model and the fp8 model.

It will randomly takes a batch of data from the calibration for inference, and then records the output of the origin model and the converted model. We use cosine distance to evaluate the error because it is a normalized number that measures the angle between vectors. The closer the value is to 0, the smaller the error. The log will write to a log file.

Create LayerPrecisionCompare instance.

Parameters
  • data_preprocess (str) – Path of pickle format file for data preprocessing.

  • options (Dict[str, Any]) – options for session.

  • output_dir (str) – The save path of log.

  • origin_model (ModelProto) –

Returns

A LayerPrecisionCompare instance

Return type

None

Registered as Pass: layer_precision_compare

class poprt.passes.manual_sharding.ManualSharding(sharding_info, pipelining_info)

Shard the graph to several subgraphs manually in terms of specific nodes.

Parameters
  • sharding_info (Dict[str, int]) –

  • pipelining_info (Dict[str, int]) –

Return type

None

Registered as Pass: manual_sharding

class poprt.passes.matmul_rotary_embedding.MatmulRotaryEmbedding

Recognise the pattern of element-wised rotary embedding and replace the pattern with equivalent matmul.

Return type

None

Registered as Pass: matmul_rotary_embedding

class poprt.passes.model_overview.ModelOverview(*args, **kwargs)
Return type

None

Registered as Pass: model_overview

class poprt.passes.move_subgraph_initializer.MoveSubgraphInitializer

Move subgraph’s initializers into main graph.

PopART only search initializers from main graph.

Return type

None

Registered as Pass: move_subgraph_initializer

class poprt.passes.overlap_io.OverlapIO

Enable overlap io.

Return type

None

Registered as Pass: overlap_io

class poprt.passes.packed_transformer.PackedTransformer(args)

Recognise the pattern of SelfAttention and replace it with Packed SelfAttention.

Registered as Pass: packed_transformer

class poprt.passes.post_expand.PostExpand(*args, **kwargs)
Return type

None

Registered as Pass: post_expand

class poprt.passes.pre_scale.PreScale(*args, **kwargs)

Pre scale: attention matrix Q to Q/sqrt(d), and remove 1/sqrt(d) node.

Return type

None

Registered as Pass: pre_scale

class poprt.passes.remove_duplicated_initializer.RemoveDuplicatedInitializer

Remove duplicated initializer to save memory.

Return type

None

Registered as Pass: remove_duplicated_initializer

class poprt.passes.remove_initializer_from_input.RemoveInitializerFromInput(*args, **kwargs)

Remove initializer from model inputs.

Model: https://github.com/onnx/models/blob/main/vision/classification/resnet/model/resnet50-v1-7.onnx

Return type

None

Registered as Pass: remove_initializer_from_input

class poprt.passes.remove_input_cast.RemoveInputCast(*args, **kwargs)

Remove input cast: input(fp16)->cast(fp16->int32)->gather to input(int32)->gather.

Return type

None

Registered as Pass: remove_input_cast

class poprt.passes.replace_bn_with_mul_add.ReplaceBNWithMulAdd(*args, **kwargs)

Replace BatchNormalization Op with Mul + Add.

Return type

None

Registered as Pass: replace_bn_with_mul_add

class poprt.passes.replace_celu.ReplaceCelu(*args, **kwargs)

Replace onnx Celu op to graphcore custom op to support it in opset11.

Return type

None

Registered as Pass: replace_celu

class poprt.passes.replace_clip_empty_inputs.ReplaceHardSwish(*args, **kwargs)

Replace Clip Op empty inputs.

Return type

None

Registered as Pass: replace_clip_empty_inputs

class poprt.passes.replace_consecutive_cast_with_notzero.ReplaceConsecuiveCastWithNotZero(*args, **kwargs)

Recognise the pattern of consecutive Cast Ops and replace the pattern with a NotZero Op.

Return type

None

Registered as Pass: replace_consecutive_cast_with_notzero

class poprt.passes.replace_div_with_mul.ReplaceDivWithMul(*args, **kwargs)

Replace Div with Mul if the divisor is constant.

Model: https://github.com/onnx/models/blob/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx

Return type

None

Registered as Pass: replace_div_with_mul

class poprt.passes.apply_ir_pass.replace_einsum_with_matmul(*args, **kwargs)
Return type

None

Registered as Pass: replace_einsum_with_matmul

class poprt.passes.replace_erf_with_erfv2.ReplaceErfWithErfV2(*args, **kwargs)

Replace Erf Op with ErfV2.

ErfV2 is more efficient with bigger error.

Return type

None

Registered as Pass: replace_erf_with_erfv2

class poprt.passes.replace_greater_or_equal.ReplaceGreaterOrEqual(*args, **kwargs)

Replace GreaterOrEqual Op with Less Op and Not Op.

Return type

None

Registered as Pass: replace_greater_or_equal

class poprt.passes.replace_groupnorm_with_fast_norm.ReplaceGroupNormWithFastNorm(*args, **kwargs)

Replace GroupNormalization to FastNorm if datatype is fp16 and num_groups=1.

Return type

None

Registered as Pass: replace_groupnorm_with_fast_norm

class poprt.passes.replace_half_reducemean.ReplaceHalfReduceMean(*args, **kwargs)

Replace ReduceMean Op in fp16 mode with ReduceSum + Mul in case of overflow.

Return type

None

Registered as Pass: replace_half_reducemean

class poprt.passes.replace_hardswish.ReplaceHardSwish(*args, **kwargs)

Replace HardSwish Op with HardSigmoid Op and Mul Op.

Replacement is required for the opset before 14 since HardSwish is only supported 14.

Return type

None

Registered as Pass: replace_hardswish

class poprt.passes.replace_less_or_equal.ReplaceLessOrEqual(*args, **kwargs)

Replace LessOrEqual Op with Less Op and Not Op.

Return type

None

Registered as Pass: replace_less_or_equal

class poprt.passes.replace_nonzero.ReplaceNonZero(*args, **kwargs)

Replace NonZero by ArgMax when the number of nonzero element is known.

Right now only single element is supported, going to support multi elements with TopK.

Return type

None

Registered as Pass: replace_nonzero

class poprt.passes.replace_pow.ReplacePow(*args, **kwargs)

Replace Pow Op with Square Op and Mul Op.

Return type

None

Registered as Pass: replace_pow

class poprt.passes.replace_softmax.ReplaceSoftmax(*args, **kwargs)

Replace Softmax Op with SoftmaxV2 Op when the axis is the lowest dim and the lowest dim is an odd.

Return type

None

Registered as Pass: replace_softmax

class poprt.passes.replace_trilu.ReplaceTrilu(*args, **kwargs)

Replace onnx Trilu op to graphcore custom op to support it in opset14.

Return type

None

Registered as Pass: replace_trilu

class poprt.passes.replace_where_mask.ReplaceWhereMask(*args, **kwargs)

Change attention_mask method from where to add.

Return type

None

Registered as Pass: replace_where_mask

class poprt.passes.replace_where_with_mul_add.ReplaceWhereWithMulAdd

Where(condition, X, Y) = Add(Mul(condition, X), Mul(neg_condition, Y)).

Registered as Pass: replace_where_with_mul_add

class poprt.passes.replace_where_with_wherev2.ReplaceWhereWithWhereV2(*args, **kwargs)

Replace Where Op with WhereV2.

Return type

None

Registered as Pass: replace_where_with_wherev2

class poprt.passes.serialize_matmul.SerializeMatmul(serialize_dict=None)

Enable to serialize Matmul Op to save memory on chip.

Parameters

serialize_dict (Dict) –

Return type

None

Registered as Pass: serialize_matmul

class poprt.passes.serialize_matmul_add.SerializeMatmulAdd(serialize_dict=None)
Parameters

serialize_dict (Dict) –

Return type

None

Registered as Pass: serialize_matmul_add

class poprt.passes.shape_inference.ReplacePow(*args, **kwargs)

Do shape inference.

Return type

None

Registered as Pass: shape_inference

class poprt.passes.workarounds.TopKWorkaround(*args, **kwargs)

Workaround for TopK Op which is required with positive axis.

Return type

None

Registered as Pass: topk_workaround

class poprt.passes.apply_ir_pass.trace_folding(*args, **kwargs)
Return type

None

Registered as Pass: trace_folding

class poprt.passes.apply_ir_pass.unique_name_for_nodes(*args, **kwargs)
Return type

None

Registered as Pass: unique_name_for_nodes