5.2. 使用 Overlap IO
IPU 在执行推理模型时, 一般分成 3 个阶段:
Load: 从 Host 拷贝输入数据到 IPU
Compute: 模型计算
Store: 从 IPU 拷贝结果数据到 Host
这三个阶段是串行执行的, 也就是说在 Load/Store 阶段传输数据时, IPU 的计算资源是闲置状态. 这在一些输入输出数据比较大的模型中, 整个模型的性能是受 IO 限制的. 对于这种模型, 开启 Overlap IO 能够使计算阶段和 IO 阶段重叠起来, 提高 IPU 的计算资源利用率.
5.2.1. 原理
Overlap IO 的原理是通过把 IPU 的片上所有 Tile 划分成两组, 即 Compute Tiles 和 IO Tiles, Compute Tiles 专门处理计算, 而 IO Tiles 专门负责与 Host 之间进行数据拷贝. 这样, 对于一个计算流来说, Load, Compute, Store 三个阶段组成了一个三级的 pipeline, 从而使计算和 IO 重叠起来, 提高 IPU 计算资源的利用率.
5.2.2. 配置 IO Tiles
Overlap IO 的开启只需要设置一个参数, 即 IO Tiles 的数量. 可以调整 IO Tiles 的数量来优化传输的吞吐量. 要计算 IO Tiles 的数量, 可以用所有输入输出的 Tensor 大小之和除以每个 Tile 可用的 SRAM 大小, 然后四舍五入到下一个 2 的幂次方.
通过 PopRT CLI 中
--num_io_tiles
来配置 IO Tiles:
poprt \
--input_model model.onnx \
--export_popef \
--output_dir model \
--num_io_tiles 128
通过
poprt.compiler.CompilerOptions
API 来配置 IO Tiles:
opts = poprt.compiler.CompilerOptions()
opts.num_io_tiles = 128
5.2.3. 调试
通过 PopVision Graph Analyser 工具, 可以观察 IO 和 Compute 是否重叠, 从而来判断 OverlapIO 是否生效, 以及通过调整 IO Tiles 的数量来优化模型的性能.
5.2.4. 并发请求
由于通过 Overlap IO 把推理的 3 个阶段组成了一个三段式的流水线, 因此, 为了能维持流水线运行下去, 必须要有足够的并发数据喂给 IPU. 通过多线程的方式给 IPU 并发的喂数据, 至少需要启动 3 个线程.
5.2.5. 示例
下面是一个简单的 OverlapIO 的 example code:
1# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
2import argparse
3import datetime
4import threading
5
6import numpy as np
7import onnx
8
9from onnx import helper
10
11from poprt import runtime
12from poprt.compiler import Compiler, CompilerOptions
13from poprt.runtime import RuntimeConfig
14
15'''
16PopRT use OverlapInnerLoop strategy as default exchange strategy.
17There are two loops in the main program: outer loop and inner loop.
18Each batch data needs to be processed in three pipeline stages: load/compute/store.
19Therefore, in order to enable the pipeline to run normally, at least three threads
20are required to feed data to the pipeline at the same time.
21==============================================================
22OverlapInnerLoop:
23- Boxes denote subgraphs / subgraph Ops / loops
24- Inputs/outputs are loop carried in order
25
26.- outer loop ----------------------------------------.
27| .- inner loop -. |
28| load - compute - | - store | |
29| load - | - compute -- | - store |
30| | load ----- | - compute - store |
31| '--------------' |
32'-----------------------------------------------------'
33 ^^^^^^^ ^^^^^^^ ^^^^^^^
34 overlap overlap overlap
35
36==============================================================
37'''
38
39
40def compile(model: onnx.ModelProto, args):
41 """Compile ONNX to PopEF."""
42 model_bytes = model.SerializeToString()
43 outputs = [o.name for o in model.graph.output]
44
45 options = CompilerOptions()
46 options.batches_per_step = args.batches_per_step
47 options.num_io_tiles = args.num_io_tiles
48
49 executable = Compiler.compile(model_bytes, outputs, options)
50 return executable
51
52
53def run(executable, args):
54 """Run PopEF."""
55 # Create model runner
56 config = RuntimeConfig()
57 config.timeout_ns = datetime.timedelta(microseconds=0)
58 # Create model runner
59 model_runner = runtime.ModelRunner(executable, config)
60
61 inputs_info = model_runner.get_model_inputs()
62 outputs_info = model_runner.get_model_outputs()
63
64 # Run in multiple threads
65 def execute(bps, inputs_info, outputs_info):
66 inputs = {}
67 outputs = {}
68
69 for input in inputs_info:
70 inputs[input.name] = np.random.uniform(0, 1, input.shape).astype(
71 input.numpy_data_type()
72 )
73 for output in outputs_info:
74 outputs[output.name] = np.zeros(
75 output.shape, dtype=output.numpy_data_type()
76 )
77
78 # To correctly generate the popvision report, iteration must be a
79 # multiple of batches_per_step and greater than 2 * batches_per_step
80 # There are 3 threads, so the total number feed into IPU is 3 * iteration
81 iteration = bps
82 for _ in range(iteration):
83 model_runner.execute(inputs, outputs)
84
85 threads = []
86 num_threads = 3
87 print(f"Run PopEF with {num_threads} threads.")
88 for _ in range(num_threads):
89 threads.append(
90 threading.Thread(
91 target=execute, args=(args.batches_per_step, inputs_info, outputs_info)
92 )
93 )
94
95 for t in threads:
96 t.start()
97
98 for t in threads:
99 t.join()
100 print(f"Complete.")
101
102
103def default_model():
104 TensorProto = onnx.TensorProto
105
106 nodes = []
107 num_matmuls = 4
108 nodes.append(helper.make_node("Expand", ["input", "shape"], ["Act0"]))
109 for i in range(num_matmuls):
110 nodes.append(helper.make_node("MatMul", [f"Act{i}", "Weight"], [f"Act{i+1}"]))
111 nodes.append(
112 helper.make_node("ReduceMean", [f"Act{num_matmuls}"], ["output"], axes=[0, 1])
113 )
114
115 graph = helper.make_graph(
116 nodes,
117 "matmul_test",
118 [
119 helper.make_tensor_value_info("input", TensorProto.FLOAT, (256, 256)),
120 ],
121 [helper.make_tensor_value_info("output", TensorProto.FLOAT, (256, 256))],
122 [
123 helper.make_tensor(
124 "shape",
125 TensorProto.INT64,
126 [4],
127 np.array([4, 4, 256, 256], dtype=np.int64),
128 ),
129 helper.make_tensor(
130 "Weight",
131 TensorProto.FLOAT,
132 (4, 4, 256, 256),
133 np.random.randn(4, 4, 256, 256),
134 ),
135 ],
136 )
137 opset_imports = [helper.make_opsetid("", 11)]
138 original_model = helper.make_model(graph, opset_imports=opset_imports)
139 return original_model
140
141
142if __name__ == '__main__':
143 parser = argparse.ArgumentParser(
144 description='Convert onnx model and run it on IPU.'
145 )
146 parser.add_argument(
147 '--batches_per_step',
148 type=int,
149 default=16,
150 help="The number of on-chip loop count.",
151 )
152 parser.add_argument(
153 '--num_io_tiles',
154 type=int,
155 default=192,
156 help="The number of IO tiles.",
157 )
158 args = parser.parse_args()
159 model = default_model()
160 exec = compile(model, args)
161 run(exec, args)