3. 快速开始
本章将以 ONNX 模型为例, 讲述如何通过 PopRT 进行模型转换和快速部署.
Note
3.1. 主要参数介绍
--input_model
(必要)原 ONNX 模型路径及名称.
--show
(可选)仅打印模型的输入输出信息.
--input_shape
(可选)指定模型输入 shape.
--output_model
(可选)转换后的模型名称, 存放路径默认是当前目录. 如果不指定该参数, 将按默认名称保存转换后的模型.
--precision
(可选)指定转换后的模型精度. 如果不指定该参数, 将默认使用原模型精度.
--run
(可选)使用随机数运行转换后的模型.
--export_popef
(可选)导出编译生成的 PopEF.
Note
--input_shape
仅支持指定输入 shape 中的可变维度, 不能改变已知维度大小.IPU 仅支持静态图, 如果模型的输入 shape 可变, 必须使用
--input_shape
指定输入 shape.更多配置请参考 使用 PopRT.
3.2. 转换并运行模型
本节将以 ONNX model zoo 中的 BERT-Squad 模型为例, 讲述如何进行模型转换.
3.2.1. 下载 ONNX 模型
wget https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-12.onnx
3.2.2. 获取 ONNX 模型输入输出信息
poprt \
--input_model bertsquad-12.onnx \
--show
# 模型输入输出信息
2023-01-06 02:59:32,897 INFO cli.py:327] Input unique_ids_raw_output___9:0: dtype - int64, shape - [0]
2023-01-06 02:59:32,897 INFO cli.py:327] Input segment_ids:0: dtype - int64, shape - [0, 256]
2023-01-06 02:59:32,897 INFO cli.py:327] Input input_mask:0: dtype - int64, shape - [0, 256]
2023-01-06 02:59:32,898 INFO cli.py:327] Input input_ids:0: dtype - int64, shape - [0, 256]
2023-01-06 02:59:32,898 INFO cli.py:334] Output unstack:1: dtype - float32, shape - [0, 256]
2023-01-06 02:59:32,898 INFO cli.py:334] Output unstack:0: dtype - float32, shape - [0, 256]
2023-01-06 02:59:32,898 INFO cli.py:334] Output unique_ids:0: dtype - int64, shape - [0]
3.2.3. 指定输入 shape
根据原模型的输入输出信息, 原模型的输入 shape 是可变的, 需要使用 --input_shape
指定输入 shape, 该示例以 batch size = 2 为例.
Note
原模型输入是 int64, 由于 IPU 不支持 int64, 转换的模型输入将变为 int32.
poprt \
--input_model bertsquad-12.onnx \
--input_shape unique_ids_raw_output___9:0=2 segment_ids:0=2,256 input_mask:0=2,256 input_ids:0=2,256 \
--output_model bertsquad-12_fp32_bs2.onnx
3.2.4. 指定模型精度
根据原模型的输入输出信息, 原模型的精度是 fp32 的, 该示例以转换为 fp16 为例.
poprt \
--input_model bertsquad-12.onnx \
--input_shape unique_ids_raw_output___9:0=2 segment_ids:0=2,256 input_mask:0=2,256 input_ids:0=2,256 \
--output_model bertsquad-12_fp16_bs2.onnx \
--precision fp16
3.2.5. 运行模型
Note
--run
可用于快速验证转换后的模型是否可以正常编译并在 IPU 上运行.该示例中展示的数据并不代表最优性能, 仅通过默认转换流程进行演示.
# 转换并编译运行 fp32 的 ONNX 模型
poprt \
--input_model bertsquad-12.onnx \
--input_shape unique_ids_raw_output___9:0=2 segment_ids:0=2,256 input_mask:0=2,256 input_ids:0=2,256 \
--output_model bertsquad-12_fp32_bs2.onnx \
--run
# 随机数运行结果
2023-01-06 05:58:14,209 INFO cli.py:452] Bs: 2
2023-01-06 05:58:14,209 INFO cli.py:455] Latency: 4.58ms
2023-01-06 05:58:14,210 INFO cli.py:456] Tput: 436
# 转换并编译运行 fp16 的 ONNX 模型
poprt \
--input_model bertsquad-12.onnx \
--input_shape unique_ids_raw_output___9:0=2 segment_ids:0=2,256 input_mask:0=2,256 input_ids:0=2,256 \
--output_model bertsquad-12_fp16_bs2.onnx \
--precision fp16 \
--run
# 随机数运行结果
2023-01-06 06:00:59,283 INFO cli.py:452] Bs: 2
2023-01-06 06:00:59,283 INFO cli.py:455] Latency: 2.23ms
2023-01-06 06:00:59,283 INFO cli.py:456] Tput: 896
3.2.6. 导出 PopEF
导出 PopEF 可用于离线部署.
Note
--export_popef
使用默认名称executable.popef
保存 PopEF 文件, 需要注意多次导出会被覆盖.
# 转换并编译导出 fp32 的 PopEF
poprt \
--input_model bertsquad-12.onnx \
--input_shape unique_ids_raw_output___9:0=2 segment_ids:0=2,256 input_mask:0=2,256 input_ids:0=2,256 \
--export_popef
# 转换并编译导出 fp16 的 PopEF
poprt \
--input_model bertsquad-12.onnx \
--input_shape unique_ids_raw_output___9:0=2 segment_ids:0=2,256 input_mask:0=2,256 input_ids:0=2,256 \
--precision fp16 \
--export_popef
3.3. 快速部署
本节将以上节中转换得到的 ONNX 模型 和 PopEF 为例, 讲述如何使用 PopRT Python API 进行快速部署.
3.3.1. 运行导出的 PopEF
from poprt import runtime
import numpy as np
# 创建 ModelRunner 实例, 并加载 PopEF 文件
runner = runtime.ModelRunner('executable.popef')
# 获取模型输出信息
outputs = runner.get_model_outputs()
# 创建输入输出数据
input_dict = {
"unique_ids_raw_output___9:0": np.ones([2]).astype(np.int32),
"segment_ids:0": np.ones([2, 256]).astype(np.int32),
"input_mask:0": np.ones([2, 256]).astype(np.int32),
"input_ids:0": np.ones([2, 256]).astype(np.int32),
}
output_dict = {x.name: np.zeros(x.shape).astype(x.numpy_data_type()) for x in outputs}
# 运行 PopEF
runner.execute(input_dict, output_dict)
print(output_dict)
3.3.2. 运行转换后的 ONNX 模型
除了将 ONNX 模型编译成 PopEF 文件, 也可以直接通过 PopRT 的 Compiler 来直接编译转换后新生成的 ONNX 文件, 并通过 ModelRunner 运行编译后生成的 PopEF 实例.
import onnx
import numpy as np
from poprt import runtime
from poprt.compiler import Compiler
# 导入 ONNX 模型
model = onnx.load("bertsquad-12_fp32_bs2.onnx")
model_bytes = model.SerializeToString()
output_names = [o.name for o in model.graph.output]
# 编译 ONNX 并生成 PopEF 实例
executable = Compiler.compile(model_bytes, output_names)
# 创建 ModelRunner 实例, 并加载 PopEF 实例
runner = runtime.ModelRunner(executable)
# 获取模型输出信息
outputs = runner.get_model_outputs()
# 创建输入输出数据
input_dict = {
"unique_ids_raw_output___9:0": np.ones([2]).astype(np.int32),
"segment_ids:0": np.ones([2, 256]).astype(np.int32),
"input_mask:0": np.ones([2, 256]).astype(np.int32),
"input_ids:0": np.ones([2, 256]).astype(np.int32),
}
output_dict = {x.name: np.zeros(x.shape).astype(x.numpy_data_type()) for x in outputs}
# 运行 PopEF
runner.execute(input_dict, output_dict)
print(output_dict)
3.4. Python API 示例
下面是完全通过 Python API 对 ONNX 模型进行 convert, compile, run 的示例:
1# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
2import argparse
3import time
4
5from typing import Dict
6
7import numpy as np
8import onnx
9
10from onnx import helper
11
12from poprt import Pass, runtime
13from poprt.compiler import Compiler, CompilerOptions
14from poprt.converter import Converter
15
16RuntimeInput = Dict[str, np.ndarray]
17
18
19def convert(model_proto: onnx.ModelProto, args) -> onnx.ModelProto:
20 """Convert ONNX model to a new optimized ONNX model."""
21 converter = Converter(convert_version=11, precision='fp16')
22 converted_model = converter.convert(model_proto)
23 # Add other passes here
24 converted_model = Pass.get_pass('int64_to_int32')(converted_model)
25 converted_model = Pass.get_pass('gelu_pattern')(converted_model)
26
27 return converted_model
28
29
30def compile(model: onnx.ModelProto, args):
31 """Compile ONNX to PopEF."""
32 model_bytes = model.SerializeToString()
33 outputs = [o.name for o in model.graph.output]
34
35 options = CompilerOptions()
36 options.num_ipus = 1
37 options.ipu_version = runtime.DeviceManager().ipu_hardware_version()
38 options.batches_per_step = args.batches_per_step
39 options.partials_type = 'half'
40 executable = Compiler.compile(model_bytes, outputs, options)
41
42 return executable
43
44
45def run_synchronous(
46 model_runner: runtime.ModelRunner,
47 input: RuntimeInput,
48 output: RuntimeInput,
49 iterations: int,
50) -> None:
51 deltas = []
52 sess_start = time.time()
53 for _ in range(iterations):
54 start = time.time()
55 model_runner.execute(input, output)
56 end = time.time()
57 deltas.append(end - start)
58 sess_end = time.time()
59
60 latency = sum(deltas) / len(deltas) * 1000
61 print(f'Latency : {latency:.3f}ms')
62 avg_sess_time = (sess_end - sess_start) / iterations * 1000
63 print(f'Synchorous avg Session Time : {avg_sess_time:.3f}ms')
64
65
66def run_asynchronous(
67 model_runner: runtime.ModelRunner,
68 input: RuntimeInput,
69 output: RuntimeInput,
70 iterations: int,
71) -> None:
72 # precreate multiple numbers of outputs
73 async_inputs = [input] * iterations
74 async_outputs = [output] * iterations
75 futures = []
76
77 sess_start = time.time()
78 for i in range(iterations):
79 f = model_runner.executeAsync(async_inputs[i], async_outputs[i])
80 futures.append(f)
81
82 # waits all execution ends
83 for i, future in enumerate(futures):
84 future.wait()
85 sess_end = time.time()
86
87 avg_sess_time = (sess_end - sess_start) / iterations * 1000
88 print(f'Asyncronous avg Session Time : {avg_sess_time:.3f}ms')
89
90
91def run(executable, args):
92 """Run PopEF."""
93 # Create model runner
94 model_runner = runtime.ModelRunner(executable)
95
96 # fix random number generation
97 np.random.seed(2022)
98
99 # Prepare inputs and outpus
100 inputs = {}
101 inputs_info = model_runner.get_model_inputs()
102 for input in inputs_info:
103 inputs[input.name] = np.random.uniform(0, 1, input.shape).astype(
104 input.numpy_data_type()
105 )
106
107 outputs = {}
108 outputs_info = model_runner.get_model_outputs()
109 for output in outputs_info:
110 outputs[output.name] = np.zeros(output.shape, dtype=output.numpy_data_type())
111
112 # Run
113 # To correctly generate the popvision report, iteration must be a
114 # multiple of batches_per_step and greater than 2 * batches_per_step
115 iteration = args.batches_per_step * 10
116
117 # warm up device
118 for _ in range(10):
119 model_runner.execute(inputs, outputs)
120
121 run_synchronous(model_runner, inputs, outputs, iteration)
122 run_asynchronous(model_runner, inputs, outputs, iteration)
123
124
125def default_model():
126 TensorProto = onnx.TensorProto
127 matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
128 add = helper.make_node("Add", ["Z", "A"], ["B"])
129 graph = helper.make_graph(
130 [matmul, add],
131 "test",
132 [
133 helper.make_tensor_value_info("X", TensorProto.FLOAT, (4, 4, 8, 16)),
134 helper.make_tensor_value_info("Y", TensorProto.FLOAT, (4, 4, 16, 8)),
135 helper.make_tensor_value_info("A", TensorProto.FLOAT, (4, 4, 8, 8)),
136 ],
137 [helper.make_tensor_value_info("B", TensorProto.FLOAT, (4, 4, 8, 8))],
138 )
139 opset_imports = [helper.make_opsetid("", 11)]
140 original_model = helper.make_model(graph, opset_imports=opset_imports)
141 return original_model
142
143
144if __name__ == '__main__':
145 parser = argparse.ArgumentParser(
146 description='Convert onnx model and run it on IPU.'
147 )
148 parser.add_argument('--onnx_model', type=str, help="Full path of the onnx model.")
149 parser.add_argument(
150 '--batches_per_step',
151 type=int,
152 default=100,
153 help="The number of on-chip loop count.",
154 )
155 parser.add_argument('--popef', type=str, help="Full path of the popef file")
156 args = parser.parse_args()
157
158 if args.popef:
159 run(args.popef, args)
160 else:
161 if not args.onnx_model:
162 print("No onnx model provided, run default model.")
163 model = default_model()
164 else:
165 print(f"Run onnx model {args.onnx_model}")
166 model = onnx.load(args.onnx_model)
167
168 converted_model = convert(model, args)
169 executable = compile(converted_model, args)
170 run(executable, args)