7. Passes
Pass 是 PopRT 中用于优化 ONNX 模型的组件, 实现自定义的 Pass 见 开发 Custom Passes.
7.1. Pass 抽象
PopRT 中实现的 Pass 都继承自 poprt.Pass
, 具有统一的接口:
- class poprt.Pass(*args, **kwargs)
Abstract Base Class for Passes.
A new Pass could be like:
import onnx from poprt.passes import register, Pass @register('dummy_pass') class Dummy(Pass): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def run(self, onnx_model: onnx.ModelProto) -> onnx.ModelProto: print(f"producer_name: {onnx_model.producer_name}") return onnx_model
- Return type
None
- run(onnx_model)
Run Pass, inherited subclasses should override this method.
- Parameters
onnx_model (ModelProto) – input onnx model
- Returns
the optimized onnx model
- Return type
ModelProto
Pass 的注册是自动进行的, 可以通过下面的 CLI 命令列出所有已经注册的 Pass:
poprt --list_all_passes
或者在 Python 中:
import poprt
passes = poprt.get_registered_passes()
print(passes.keys())
使用 poprt.get_pass()
的 Python 接口使用某一个 Pass, 同时可以使用 poprt.PassManager
应用多个 Pass.
- poprt.get_pass(name, *args, **kwargs)
Get a Pass by registered name.
- Parameters
name (str) – registered name of a Pass
- Returns
a Pass instance
- Return type
Example:
import poprt # get a Pass with parameters onnx_model = poprt.get_pass('float_to_half', skip_op_types=['Gelu'])(onnx_model) poprt.get_pass('model_overview')(onnx_model)
- class poprt.PassManager(used_passes=[], gather_ir_passes=False)
Manage Passes.
- Parameters
used_passes (List[Union[str, Pass]]) – passes that will be used
gather_ir_passes (bool) – gather onnx ir passes and execute it in one turn.
- Return type
None
Example:
import poprt pm = poprt.PassManager( [ 'model_overview', 'float_to_half', poprt.get_pass('model_overview'), ] ) pm.run(onnx_model)
- run(onnx_model)
Apply passes to the onnx model.
- Parameters
onnx_model (ModelProto) – onnx model that will be optimized
- Return type
ModelProto
Note
由于 Pass 模块的结构可能会有变化, 不推荐以下使用方式:
from poprt.passes.float_to_half import Float2Half
Float2Half()(onnx_model)
7.2. PopRT 中注册的 Pass
- class poprt.passes.add_checkpoints.AddCheckpoints(checkpoints)
Add intermediate tensor to output.
- Parameters
checkpoints (List[str]) –
- Return type
None
Registered as Pass: add_checkpoints
- class poprt.passes.apply_ir_pass.ApplyIrPass(passes=[])
Apply passes based on onnx IR.
- Parameters
passes (List[str]) –
- Return type
None
Registered as Pass: apply_ir_pass
- class poprt.passes.auto_insert_remap.AutoInsertRemap(remap_mode='after_matmul')
Insert remap after matmul.
This is an experimental feature. There are two different insert mode: after_matmul and before_add. For after_matmul mode, it’s more general but more likely OOM, for before_add mode, it’s target to reduce cycles of attention + mask in transformer- based model.
- Parameters
remap_mode (str) –
- Return type
None
Registered as Pass: auto_insert_remap
- class poprt.passes.workarounds.BatchNormWorkaround(*args, **kwargs)
Workaround for BatchNorm Operator.
- Return type
None
Registered as Pass: batchnorm_workaround
- class poprt.passes.check_with_fake_data.CheckWithFakeData(origin_model)
Checking model with fake data using onnxruntime.
- Parameters
origin_model (ModelProto) –
- Return type
None
Registered as Pass: check_with_fake_data
- class poprt.passes.const_batch_size.ConstBatchSize(const_batch_size=1)
Convert unknown batch size to a const value.
- Parameters
const_batch_size (int) –
- Return type
None
Registered as Pass: const_batch_size
- class poprt.passes.const_input_shape.ConstInputShape(const_input_shape={})
Convert input shape to const values.
- Parameters
const_input_shape (Dict[str, Any]) –
Registered as Pass: const_input_shape
- class poprt.passes.constant_folding.ConstantFolding(max_tensor_size=- 1)
Support constant folding.
- Parameters
max_tensor_size (int) –
- Return type
None
Registered as Pass: constant_folding
- class poprt.passes.workarounds.CumSumWorkaround(*args, **kwargs)
Workaround for CumSum Operator.
- Return type
None
Registered as Pass: cumsum_workaround
- class poprt.passes.eight_bits_io.EightBitsIO
Insert norm operator after input image.
Registered as Pass: eight_bits_io
- class poprt.passes.apply_ir_pass.eliminate_deadend(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_deadend
- class poprt.passes.apply_ir_pass.eliminate_duplicate_initializer(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_duplicate_initializer
- class poprt.passes.apply_ir_pass.eliminate_identity(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_identity
- class poprt.passes.apply_ir_pass.eliminate_nop_arithmetic(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_nop_arithmetic
- class poprt.passes.apply_ir_pass.eliminate_nop_cast(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_nop_cast
- class poprt.passes.apply_ir_pass.eliminate_nop_expand(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_nop_expand
- class poprt.passes.apply_ir_pass.eliminate_nop_flatten(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_nop_flatten
- class poprt.passes.apply_ir_pass.eliminate_nop_if(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_nop_if
- class poprt.passes.apply_ir_pass.eliminate_nop_pad(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_nop_pad
- class poprt.passes.apply_ir_pass.eliminate_nop_reshape(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_nop_reshape
- class poprt.passes.apply_ir_pass.eliminate_nop_transpose(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_nop_transpose
- class poprt.passes.apply_ir_pass.eliminate_unused_initializer(*args, **kwargs)
- Return type
None
Registered as Pass: eliminate_unused_initializer
- class poprt.passes.erf_gelu_pattern.ErfGeluPattern(*args, **kwargs)
Recognise the pattern of Erf Gelu Op and replace the pattern with Erf Gelu.
- Return type
None
Registered as Pass: erf_gelu_pattern
- class poprt.passes.apply_ir_pass.extract_constant_to_initializer(*args, **kwargs)
- Return type
None
Registered as Pass: extract_constant_to_initializer
- class poprt.passes.final_check.FinalCheck(*args, **kwargs)
Final check for dtype and shape of the converted model.
- Return type
None
Registered as Pass: final_check
- class poprt.passes.workarounds.FloatOpsWorkaround(*args, **kwargs)
Workaround for Operators which are required with float32 / float16 inputs.
- Return type
None
Registered as Pass: float_ops_workaround
- class poprt.passes.float_to_fp8.Float2FP8(fp8_params=['F143', 'F143', 0, 0], skip_op_names=[], convert_model='fp8', fp8_input_dict=None, fp8_weight_dict=None)
Convert a model from fp32 or fp16 to fp8.
- Parameters
fp8_params (List[Union[Literal['F143', 'F152'], str]]) – Set parameters to fp8 model, the format is [input_format, weight_format, input_scale, weight_scale]
skip_op_names (List[str]) – The Op names which will keep fp32/fp16 in fp8 mode, such as [‘Conv_1’, ‘Conv_2’]
convert_model (Literal['fp8', 'fp8_weight']) – Specifies which type the model is converted to, can be set to ‘fp8’ or ‘fp8_weight’
fp8_input_dict (Dict[str, int]) – Set parameters for each fp8 input node of fp8 model, if it’s not None,
fp8_params
will be discardedfp8_weight_dict (Dict[str, int]) – Set parameters for each fp8 weight node of fp8 model, if it’s not None,
fp8_params
will be discarded
- Return type
None
Registered as Pass: float_to_fp8
- class poprt.passes.float_to_half.Float2Half(skip_op_types=[])
Convert a model from fp32 to fp16.
Create Float2Half instance.
- Parameters
skip_op_types (List[str]) – The Op types which will keep fp32 in fp16 mode.
- Returns
A Float2Half instance
- Return type
None
Registered as Pass: float_to_half
- class poprt.passes.float_to_mixed.Float2Mixed
Convert a model from fp32 to mixed precision.
- Return type
None
Registered as Pass: float_to_mixed
- class poprt.passes.apply_ir_pass.fuse_bn_into_conv(*args, **kwargs)
- Return type
None
Registered as Pass: fuse_bn_into_conv
- class poprt.passes.fuse_bn_into_gemm.FuseBnIntoGemm
Fuse BatchNormalization to Matmul/Gemm.
- Condition:
condition 1: Matmul/Gemm use initializer condition 2: No multi outputs in Gemm/Matmul condition 3: Initializers used across operaters is not supported
- Return type
None
Registered as Pass: fuse_bn_into_gemm
- class poprt.passes.fuse_cast_into_onehot.FuseCastIntoOnehot
Fuse Cast into OneHot.
- Return type
None
Registered as Pass: fuse_cast_into_onehot
- class poprt.passes.apply_ir_pass.fuse_consecutive_cast(*args, **kwargs)
- Return type
None
Registered as Pass: fuse_consecutive_cast
- class poprt.passes.apply_ir_pass.fuse_consecutive_reshape(*args, **kwargs)
- Return type
None
Registered as Pass: fuse_consecutive_reshape
- class poprt.passes.apply_ir_pass.fuse_consecutive_squeeze(*args, **kwargs)
- Return type
None
Registered as Pass: fuse_consecutive_squeeze
- class poprt.passes.apply_ir_pass.fuse_consecutive_transpose(*args, **kwargs)
- Return type
None
Registered as Pass: fuse_consecutive_transpose
- class poprt.passes.apply_ir_pass.fuse_consecutive_unsqueeze(*args, **kwargs)
- Return type
None
Registered as Pass: fuse_consecutive_unsqueeze
- class poprt.passes.fused_attention.FusedAttention(*args, **kwargs)
Recognise the pattern of MultiHeadAttention and replace it with Fused MultiHeadAttention. Attention Pattern as below:
Add | Reshape -- -- -- | \ \ MatMul MatMul MatMul | | | Reshape Reshape Reshape | | | Add Add Add | | | Reshape Reshape Reshape
Fused Attention Pattern as below:
Add | Concat | MatMul | Add | Reshape | Transpose | Split
- Return type
None
Registered as Pass: fused_attention
- class poprt.passes.gelu_pattern.GeluPattern(*args, **kwargs)
Recognise the pattern of Gelu Op and replace the pattern with Gelu.
- Return type
None
Registered as Pass: gelu_pattern
- class poprt.passes.workarounds.IndicesWorkaround(*args, **kwargs)
Workaround for Gather / GatherElements Operator.
- Return type
None
Registered as Pass: indices_workaround
- class poprt.passes.insert_attention_mask.InsertAttentionMask(*args, **kwargs)
Replace Reshap-Cast-Sub-Mul with Cast-AttentionMask.
- Return type
None
Registered as Pass: insert_attention_mask
- class poprt.passes.int64_to_int32.Int64ToInt32(*args, **kwargs)
Transfer int64 to int32.
- Return type
None
Registered as Pass: int64_to_int32
- class poprt.passes.layer_norm_pattern.LayerNormPattern(*args, **kwargs)
Recognise the pattern of LayerNorm Op and replace the pattern with Reshape + GroupNorm + Reshape.
- Return type
None
Registered as Pass: layer_norm_pattern
- class poprt.passes.layer_precision_compare.LayerPrecisionCompare(origin_model, data_preprocess=None, options=None, output_dir='./')
Compare the output of conv/matmul/gemm operator of the origin model and the fp8 model.
It will randomly takes a batch of data from the calibration for inference, and then records the output of the origin model and the converted model. We use cosine distance to evaluate the error because it is a normalized number that measures the angle between vectors. The closer the value is to 0, the smaller the error. The log will write to a log file.
Create LayerPrecisionCompare instance.
- Parameters
data_preprocess (str) – Path of pickle format file for data preprocessing.
options (Dict[str, Any]) – options for session.
output_dir (str) – The save path of log.
origin_model (ModelProto) –
- Returns
A LayerPrecisionCompare instance
- Return type
None
Registered as Pass: layer_precision_compare
- class poprt.passes.manual_sharding.ManualSharding(sharding_info, pipelining_info)
Shard the graph to several subgraphs manually in terms of specific nodes.
- Parameters
sharding_info (Dict[str, int]) –
pipelining_info (Dict[str, int]) –
- Return type
None
Registered as Pass: manual_sharding
- class poprt.passes.matmul_rotary_embedding.MatmulRotaryEmbedding
Recognise the pattern of element-wised rotary embedding and replace the pattern with equivalent matmul.
- Return type
None
Registered as Pass: matmul_rotary_embedding
- class poprt.passes.model_overview.ModelOverview(*args, **kwargs)
- Return type
None
Registered as Pass: model_overview
- class poprt.passes.move_subgraph_initializer.MoveSubgraphInitializer
Move subgraph’s initializers into main graph.
PopART only search initializers from main graph.
- Return type
None
Registered as Pass: move_subgraph_initializer
- class poprt.passes.overlap_io.OverlapIO
Enable overlap io.
- Return type
None
Registered as Pass: overlap_io
- class poprt.passes.packed_transformer.PackedTransformer(args)
Recognise the pattern of SelfAttention and replace it with Packed SelfAttention.
Registered as Pass: packed_transformer
- class poprt.passes.post_expand.PostExpand(*args, **kwargs)
- Return type
None
Registered as Pass: post_expand
- class poprt.passes.pre_scale.PreScale(*args, **kwargs)
Pre scale: attention matrix Q to Q/sqrt(d), and remove 1/sqrt(d) node.
- Return type
None
Registered as Pass: pre_scale
- class poprt.passes.remove_duplicated_initializer.RemoveDuplicatedInitializer
Remove duplicated initializer to save memory.
- Return type
None
Registered as Pass: remove_duplicated_initializer
- class poprt.passes.remove_initializer_from_input.RemoveInitializerFromInput(*args, **kwargs)
Remove initializer from model inputs.
Model: https://github.com/onnx/models/blob/main/vision/classification/resnet/model/resnet50-v1-7.onnx
- Return type
None
Registered as Pass: remove_initializer_from_input
- class poprt.passes.remove_input_cast.RemoveInputCast(*args, **kwargs)
Remove input cast: input(fp16)->cast(fp16->int32)->gather to input(int32)->gather.
- Return type
None
Registered as Pass: remove_input_cast
- class poprt.passes.replace_bn_with_mul_add.ReplaceBNWithMulAdd(*args, **kwargs)
Replace BatchNormalization Op with Mul + Add.
- Return type
None
Registered as Pass: replace_bn_with_mul_add
- class poprt.passes.replace_celu.ReplaceCelu(*args, **kwargs)
Replace onnx Celu op to graphcore custom op to support it in opset11.
- Return type
None
Registered as Pass: replace_celu
- class poprt.passes.replace_clip_empty_inputs.ReplaceHardSwish(*args, **kwargs)
Replace Clip Op empty inputs.
- Return type
None
Registered as Pass: replace_clip_empty_inputs
- class poprt.passes.replace_consecutive_cast_with_notzero.ReplaceConsecuiveCastWithNotZero(*args, **kwargs)
Recognise the pattern of consecutive Cast Ops and replace the pattern with a NotZero Op.
- Return type
None
Registered as Pass: replace_consecutive_cast_with_notzero
- class poprt.passes.replace_div_with_mul.ReplaceDivWithMul(*args, **kwargs)
Replace Div with Mul if the divisor is constant.
Model: https://github.com/onnx/models/blob/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx
- Return type
None
Registered as Pass: replace_div_with_mul
- class poprt.passes.apply_ir_pass.replace_einsum_with_matmul(*args, **kwargs)
- Return type
None
Registered as Pass: replace_einsum_with_matmul
- class poprt.passes.replace_erf_with_erfv2.ReplaceErfWithErfV2(*args, **kwargs)
Replace Erf Op with ErfV2.
ErfV2 is more efficient with bigger error.
- Return type
None
Registered as Pass: replace_erf_with_erfv2
- class poprt.passes.replace_greater_or_equal.ReplaceGreaterOrEqual(*args, **kwargs)
Replace GreaterOrEqual Op with Less Op and Not Op.
- Return type
None
Registered as Pass: replace_greater_or_equal
- class poprt.passes.replace_groupnorm_with_fast_norm.ReplaceGroupNormWithFastNorm(*args, **kwargs)
Replace GroupNormalization to FastNorm if datatype is fp16 and num_groups=1.
- Return type
None
Registered as Pass: replace_groupnorm_with_fast_norm
- class poprt.passes.replace_half_reducemean.ReplaceHalfReduceMean(*args, **kwargs)
Replace ReduceMean Op in fp16 mode with ReduceSum + Mul in case of overflow.
- Return type
None
Registered as Pass: replace_half_reducemean
- class poprt.passes.replace_hardswish.ReplaceHardSwish(*args, **kwargs)
Replace HardSwish Op with HardSigmoid Op and Mul Op.
Replacement is required for the opset before 14 since HardSwish is only supported 14.
- Return type
None
Registered as Pass: replace_hardswish
- class poprt.passes.replace_less_or_equal.ReplaceLessOrEqual(*args, **kwargs)
Replace LessOrEqual Op with Less Op and Not Op.
- Return type
None
Registered as Pass: replace_less_or_equal
- class poprt.passes.replace_nonzero.ReplaceNonZero(*args, **kwargs)
Replace NonZero by ArgMax when the number of nonzero element is known.
Right now only single element is supported, going to support multi elements with TopK.
- Return type
None
Registered as Pass: replace_nonzero
- class poprt.passes.replace_pow.ReplacePow(*args, **kwargs)
Replace Pow Op with Square Op and Mul Op.
- Return type
None
Registered as Pass: replace_pow
- class poprt.passes.replace_softmax.ReplaceSoftmax(*args, **kwargs)
Replace Softmax Op with SoftmaxV2 Op when the axis is the lowest dim and the lowest dim is an odd.
- Return type
None
Registered as Pass: replace_softmax
- class poprt.passes.replace_trilu.ReplaceTrilu(*args, **kwargs)
Replace onnx Trilu op to graphcore custom op to support it in opset14.
- Return type
None
Registered as Pass: replace_trilu
- class poprt.passes.replace_where_mask.ReplaceWhereMask(*args, **kwargs)
Change attention_mask method from where to add.
- Return type
None
Registered as Pass: replace_where_mask
- class poprt.passes.replace_where_with_mul_add.ReplaceWhereWithMulAdd
Where(condition, X, Y) = Add(Mul(condition, X), Mul(neg_condition, Y)).
Registered as Pass: replace_where_with_mul_add
- class poprt.passes.replace_where_with_wherev2.ReplaceWhereWithWhereV2(*args, **kwargs)
Replace Where Op with WhereV2.
- Return type
None
Registered as Pass: replace_where_with_wherev2
- class poprt.passes.serialize_matmul.SerializeMatmul(serialize_dict=None)
Enable to serialize Matmul Op to save memory on chip.
- Parameters
serialize_dict (Dict) –
- Return type
None
Registered as Pass: serialize_matmul
- class poprt.passes.serialize_matmul_add.SerializeMatmulAdd(serialize_dict=None)
- Parameters
serialize_dict (Dict) –
- Return type
None
Registered as Pass: serialize_matmul_add
- class poprt.passes.shape_inference.ReplacePow(*args, **kwargs)
Do shape inference.
- Return type
None
Registered as Pass: shape_inference
- class poprt.passes.workarounds.TopKWorkaround(*args, **kwargs)
Workaround for TopK Op which is required with positive axis.
- Return type
None
Registered as Pass: topk_workaround
- class poprt.passes.apply_ir_pass.trace_folding(*args, **kwargs)
- Return type
None
Registered as Pass: trace_folding
- class poprt.passes.apply_ir_pass.unique_name_for_nodes(*args, **kwargs)
- Return type
None
Registered as Pass: unique_name_for_nodes