5.5. Packing
5.5.1. 背景
当前 IPU 只支持静态图, 模型的输入 shape 需要是固定的, 动态 shape 会导致模型重新编译. 但在实际应用中, 尤其是自然语言处理类型的应用, 模型输入 sequence length 往往是动态的. 这种情况下, 常规的处理方法是将这些变长数据都先 pad 到 max sequence length, 然后再输入到模型. 然而这种方法会带来很多无效计算, 导致算力的实际利用率低下. 在 IPU 上, 可以使用 Packing 来支持 dynamic sequence length, 提高算力利用率.
5.5.2. Packing 及 Unpacking
这里通过例子来说明什么是 Packing 及 Unpacking. 假设模型输入长度最大是 8, batch size 是 4, 当前有 7 个不同长度的 batch size 为 1 的 request, 长度从 1 到 7, 0 表示 pad 的无效数据, 则 Packing 及 Unpacking 如下图所示:

Fig. 5.6 Packing 及 Unpacking
5.5.3. Transformer-based NLP Models
自 2017 年被提出以来, Transformer 结构应用领域不断扩展, 从最初的 NLP 扩展到今天的 ASR/CV/DLRM 等领域. Transformer 包含 Encoder 和 Decoder 部分, 本文只关注 Encoder 部分. Transformer Encoder 结构如下图所示:

Fig. 5.7 Transformer Encoder
以 Bert 为例, Transformer Encoder 的输入 shape 通常为 (batch_size, seq_len, hidden_size). 在 Encoder 中, 除 Multi-Head Attention 模块外, 其它模块的计算都只在最后一个维度进行, 因此针对这些模块, 可以通过 Packing 减少无效计算; 而 Multi-Head Attention 模块因为需要计算 token 之间的相关性, 在不修改 mask 的情况下, 必须在 Unpacking 之后进行计算, 在 Multi-Head Attention 计算完成之后重新 Packing. 计算流程可以用如下伪代码表示:
packed_input from host
activation = packed_input
for encoer in encoders:
Unpacking
Attention
Packing
Add & LayerNorm
Feed-Forward
Add & LayerNorm
Update activation
Unpacking
unpacked_output to host
5.5.4. 如何使用 Packing
本节以 Bert-Base-Squad 为例进行说明, 本文使用的 OS 为 Ubuntu 20.04, Python 3.8.15. 本文完整示例参考 packed_bert_example .
下载模型
在下载模型之前需要先安装依赖包, 命令如下:
pip install torch==1.10.0
pip install transformers[onnx]==4.25.1
下载模型的命令如下:
python -m transformers.onnx --model=csarron/bert-base-uncased-squad-v1 . --feature question-answering
转换模型
通过上面命令下载的模型, 输入中不包含 position_ids, 而在 IPU 上使用 Packing 的时候, 需要首先在 host 端将输入进行 Pack, 因此需要将 position_ids 加到模型的输入上. 代码如下:
1# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2import argparse
3import copy
4import os
5
6import onnx
7
8# Download model from huggingface
9# - python -m transformers.onnx --model=csarron/bert-base-uncased-squad-v1 . --feature question-answering
10# reference: https://huggingface.co/csarron/bert-base-uncased-squad-v1
11
12
13if __name__ == '__main__':
14 parser = argparse.ArgumentParser(description='Preprocess Bert-Squad Model')
15 parser.add_argument(
16 '--input_model', type=str, default='', help='path of input model'
17 )
18 args = parser.parse_args()
19
20 if not os.path.exists(args.input_model):
21 parser.print_usage()
22 raise FileNotFoundError(f'Unable to find model : {args.input_model}')
23
24 model = onnx.load(args.input_model)
25
26 # for packed bert, we need to export position_ids to model's input
27 # step 1: remove unneed node
28 rm_node_names = [
29 'Shape_7',
30 'Gather_9',
31 'Add_11',
32 'Unsqueeze_12',
33 'Slice_14',
34 'Constant_8',
35 'Constant_10',
36 'Constant_13',
37 ]
38 rm_nodes = []
39 for node in model.graph.node:
40 if node.name in rm_node_names:
41 rm_nodes.append(node)
42
43 assert len(rm_node_names) == len(rm_nodes)
44
45 for node in rm_nodes:
46 model.graph.node.remove(node)
47
48 # step 2: add position_ids to model's input
49 position_ids = copy.deepcopy(model.graph.input[0])
50 position_ids.name = 'position_ids'
51 model.graph.input.append(position_ids)
52
53 for node in model.graph.node:
54 if node.op_type == 'Gather' and node.name == 'Gather_18':
55 node.input[1] = position_ids.name
56
57 print(f'Save preprocessed model to bert_base_squad_pos.onnx')
58 onnx.save(model, 'bert_base_squad_pos.onnx')
生成不使用 packing 模型:
poprt \
--input_model squad_bert_base_pos.onnx \
--output_model squad_bert_base_bs16_sl256.onnx \
--precision fp16 \
--input_shape input_ids=16,256 attention_mask=16,256 token_type_ids=16,256 position_ids=16,256
生成 packing 模型:
poprt \
--input_model squad_bert_base_pos.onnx \
--output_model squad_bert_base_bs16_sl256_pack.onnx \
--precision fp16 \
--input_shape input_ids=16,256 attention_mask=16,256 token_type_ids=16,256 position_ids=16,256 \
--pack_args max_valid_num=40 segment_max_size=256
其中, max_valid_num 用于指定 Unpacking 之后的最大 batch size, segment_max_size 表示最大的长度.
运行模型
运行模型的命令如下:
python packed_bert_example.py \
--model_with_packing squad_bert_base_bs16_sl256_pack.onnx \
--model_without_packing squad_bert_base_bs16_sl256.onnx
完整的代码如下:
1# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2import argparse
3import csv
4import os
5import queue
6import random
7import sys
8import tempfile
9import threading
10import time
11
12from multiprocessing.pool import ThreadPool
13from queue import Queue
14
15import numpy as np
16import packing_utils
17
18from sklearn.metrics import mean_absolute_error
19
20from poprt import runtime
21from poprt.backend import get_session
22
23np.random.seed(2023)
24INPUT_IDS = "input_ids"
25POSITION_IDS = "position_ids"
26ATTENTION_MASK = "attention_mask"
27TOKEN_TYPE_IDS = "token_type_ids"
28UNPACK_INFO = "unpack_info"
29OUTPUT2 = "start_logits"
30OUTPUT1 = "end_logits"
31
32
33class BertInputs(object):
34 def __init__(
35 self,
36 input_ids,
37 attention_mask,
38 token_type_ids,
39 position_ids,
40 unpack_info,
41 input_len,
42 ):
43 self.input_ids = input_ids
44 self.attention_mask = attention_mask
45 self.token_type_ids = token_type_ids
46 self.position_ids = position_ids
47 self.input_len = input_len
48 self.unpack_info = unpack_info
49
50
51def get_synthetic_data(args):
52 input_len = np.random.normal(
53 args.avg_seq_len, args.avg_seq_len, size=args.dataset_size
54 ).astype(np.int32)
55 input_len = np.clip(input_len, 1, args.max_seq_len)
56
57 datasets = []
58 for s_len in input_len:
59 input_ids = np.random.randint(0, args.emb_size, (s_len)).astype(np.int32)
60
61 attention_mask = np.ones(s_len).astype(np.int32)
62 token_type_ids = np.random.randint(0, 2, (s_len)).astype(np.int32)
63
64 position_ids = np.arange(s_len).astype(np.int32)
65 unpack_info = np.zeros(args.max_valid_num).astype(np.int32)
66
67 feature = BertInputs(
68 input_ids, attention_mask, token_type_ids, position_ids, unpack_info, s_len
69 )
70 datasets.append(feature)
71
72 return datasets
73
74
75def dump_results(model_name, results):
76 fieldnames = [OUTPUT1, OUTPUT2]
77 filename = os.path.basename(model_name)[:-4] + 'csv'
78 with open(filename, 'w') as f:
79 writer = csv.DictWriter(f, fieldnames=fieldnames)
80 for result in results:
81 dict_name2list = {
82 OUTPUT1: result[OUTPUT1],
83 OUTPUT2: result[OUTPUT2],
84 }
85 writer.writerow(dict_name2list)
86
87
88## create batched inputs and pad samples to max_seq_len
89def padding_data(datasets, index, args):
90 feed_dicts = {}
91 feed_dicts[INPUT_IDS] = np.zeros(
92 (args.batch_size, args.max_seq_len), dtype=np.int32
93 )
94 feed_dicts[ATTENTION_MASK] = np.zeros(
95 (args.batch_size, args.max_seq_len), dtype=np.int32
96 )
97 feed_dicts[POSITION_IDS] = np.zeros(
98 (args.batch_size, args.max_seq_len), dtype=np.int32
99 )
100 feed_dicts[TOKEN_TYPE_IDS] = np.zeros(
101 (args.batch_size, args.max_seq_len), dtype=np.int32
102 )
103
104 for i in range(args.batch_size):
105 input_len = datasets[index].input_len
106 feed_dicts[INPUT_IDS][i][:input_len] = datasets[index].input_ids
107 feed_dicts[ATTENTION_MASK][i][:input_len] = datasets[index].attention_mask
108 feed_dicts[POSITION_IDS][i][:input_len] = datasets[index].position_ids
109 feed_dicts[TOKEN_TYPE_IDS][i][:input_len] = datasets[index].token_type_ids
110 index = index + 1
111 return feed_dicts
112
113
114# online pack, samples feeded to IPU can reach to maximum num of batches in each running turn
115def run_packing_model_with_pack_runner_unpack_repack(args, datasets):
116 tmpdir = tempfile.TemporaryDirectory()
117 # export popef for PackRunner
118 get_session(
119 args.model_with_packing_unpack_repack,
120 1,
121 "poprt",
122 output_dir=tmpdir.name,
123 export_popef=True,
124 ).load()
125 config = runtime.PackRunnerConfig(
126 timeout_microseconds=args.timeout_microseconds,
127 # max_valid_num=args.max_valid_num,
128 # dynamic_input_name=args.dynamic_input_name,
129 )
130
131 popef_path = tmpdir.name + '/executable.popef'
132 # popef_path = "/popconverter/examples/packed_bert_example/executable.popef"
133 pack_runner = runtime.Runner(popef_path, config)
134
135 result_queue = queue.Queue()
136 results = []
137 start_time = time.time()
138 for i in range(args.dataset_size):
139 feed_dicts = {
140 INPUT_IDS: datasets[i].input_ids,
141 ATTENTION_MASK: datasets[i].attention_mask,
142 TOKEN_TYPE_IDS: datasets[i].token_type_ids,
143 POSITION_IDS: datasets[i].position_ids,
144 # unpack_info should be hidden from user in the future
145 UNPACK_INFO: np.zeros(args.max_valid_num).astype(np.int32),
146 }
147 out_dict = {
148 OUTPUT1: np.zeros([args.max_seq_len]).astype(np.float16),
149 OUTPUT2: np.zeros([args.max_seq_len]).astype(np.float16),
150 }
151 future = pack_runner.execute_async(feed_dicts, out_dict)
152 result_queue.put((future, out_dict))
153 result_queue.put((None, None))
154 while True:
155 future, out_dict = result_queue.get()
156 if future == None:
157 break
158 future.wait()
159 results.append(out_dict)
160 end_time = time.time()
161
162 tput = args.dataset_size / (end_time - start_time)
163 latency_ms = (end_time - start_time) / args.dataset_size
164 print(
165 f"[Pack Online Unpack Repack] Throughput: {tput} samples/s, Latency : {latency_ms * 1000} ms"
166 )
167
168 if args.dump_results:
169 dump_results(
170 "online_unpack_repack" + args.model_with_packing_unpack_repack, results
171 )
172
173 tmpdir.cleanup()
174 return results
175
176
177# offline pack, samples feeded to IPU can reach to maximum num of batches in each running turn
178# model with pack / unpack ops
179def run_packing_model_with_model_runner(args, datasets, model_path, across_rows):
180 run_queue = queue.Queue()
181 start_time = time.time()
182 index = 0
183 for i in range(0, args.dataset_size):
184 transfer = packing_utils.pack_data(
185 datasets,
186 index,
187 args.batch_size,
188 seq_len=256,
189 max_valid_num=args.max_valid_num,
190 segment_num=1,
191 across_rows=across_rows,
192 )
193
194 run_queue.put(transfer)
195 index = transfer.count
196 if index == args.dataset_size:
197 break
198 run_queue.put(None)
199 duration_of_packing = time.time() - start_time
200 mean_latency_of_padding_us = duration_of_packing * 1e6 / args.dataset_size
201
202 print(f"Mean latency of packing data: {mean_latency_of_padding_us} us/sam")
203 print(f"Total latency of packing data: {duration_of_packing} s")
204
205 sess = get_session(model_path, 1, "poprt").load()
206
207 pool = ThreadPool(processes=1)
208
209 def execute(feed_dicts, valid_num):
210 outputs = sess.run([OUTPUT1, OUTPUT2], feed_dicts)
211 res = []
212 if across_rows:
213 for i in range(valid_num):
214 res1 = outputs[0][i].copy().tolist()
215 res2 = outputs[1][i].copy().tolist()
216 res.append({OUTPUT1: res1, OUTPUT2: res2})
217 else:
218 outlen = len(outputs[0][0])
219 for index in range(len(feed_dicts[ATTENTION_MASK])):
220 start = 0
221 arr = np.array(feed_dicts[ATTENTION_MASK][index])
222 while start < outlen and arr[start] > 0:
223 arr = arr - 1
224 zero_num = len(arr) - np.count_nonzero(arr)
225 out1 = [0] * outlen
226 out2 = [0] * outlen
227 out1[:zero_num] = outputs[0][index][start : start + zero_num]
228 out2[:zero_num] = outputs[1][index][start : start + zero_num]
229 res.append({OUTPUT1: out1, OUTPUT2: out2})
230 start += zero_num
231 return res
232
233 asy_results = []
234
235 total_start_time = time.time()
236 while True:
237 input_data = run_queue.get()
238 if input_data is None:
239 break
240
241 feed_dicts = {
242 INPUT_IDS: input_data.data[INPUT_IDS],
243 ATTENTION_MASK: input_data.data[ATTENTION_MASK],
244 TOKEN_TYPE_IDS: input_data.data[TOKEN_TYPE_IDS],
245 POSITION_IDS: input_data.data[POSITION_IDS],
246 # unpack_info should be hidden from user in the future
247 UNPACK_INFO: input_data.unpack_info,
248 }
249 if not across_rows:
250 feed_dicts.pop(UNPACK_INFO)
251
252 valid_num = len(input_data.specs)
253 async_result = pool.apply_async(execute, (feed_dicts, valid_num))
254 asy_results.append(async_result)
255
256 results = []
257 for asy in asy_results:
258 for res in asy.get():
259 results.append(res)
260 total_end_time = time.time()
261
262 tput = len(results) / (total_end_time - total_start_time)
263 latency = (total_end_time - total_start_time) / len(results)
264 if across_rows:
265 print(
266 f"[Pack Offline Unpack Repack] Throughput: {tput} samples/s, Latency: {latency*1000} ms"
267 )
268 else:
269 print(
270 f"[Pack Offline AttentionMask] Throughput: {tput} samples/s, Latency: {latency*1000} ms"
271 )
272
273 if args.dump_results:
274 dump_results("offline_" + model_path, results)
275
276 return results
277
278
279# online pack, samples feeded to IPU can reach to maximum num of batches in each running turn
280# model only add AttentionMask op in this mode
281def run_packing_model_with_pack_runner_attention_mask(args, datasets, algo):
282 tmpdir = tempfile.TemporaryDirectory()
283 # export popef for PackRunner
284 get_session(
285 args.model_with_packing_attention_mask,
286 1,
287 "poprt",
288 output_dir=tmpdir.name,
289 export_popef=True,
290 ).load()
291 config = runtime.PackRunnerConfig(
292 timeout_microseconds=args.timeout_microseconds,
293 max_valid_num=args.max_valid_num,
294 dynamic_input_name=args.dynamic_input_name,
295 )
296
297 if algo == "next_fit":
298 config.algorithm = runtime.PackAlgorithm.next_fit
299 else:
300 config.algorithm = runtime.PackAlgorithm.first_fit
301
302 config.enable_input_single_row_mode("attention_mask")
303 popef_path = tmpdir.name + '/executable.popef'
304 # popef_path = "/popconverter/examples/packed_bert_example/executable.popef"
305 pack_runner = runtime.Runner(popef_path, config)
306
307 result_queue = queue.Queue()
308 results = []
309 start_time = time.time()
310 for i in range(args.dataset_size):
311 feed_dicts = {
312 INPUT_IDS: datasets[i].input_ids,
313 ATTENTION_MASK: datasets[i].attention_mask,
314 TOKEN_TYPE_IDS: datasets[i].token_type_ids,
315 # position_ids is an optional input for first_fit/next_fit mode
316 # POSITION_IDS: datasets[i].position_ids,
317 }
318 out_dict = {
319 OUTPUT1: np.zeros([args.max_seq_len]).astype(np.float16),
320 OUTPUT2: np.zeros([args.max_seq_len]).astype(np.float16),
321 }
322 future = pack_runner.execute_async(feed_dicts, out_dict)
323 result_queue.put((future, out_dict))
324 result_queue.put((None, None))
325 while True:
326 future, out_dict = result_queue.get()
327 if future == None:
328 break
329 future.wait()
330 results.append(out_dict)
331 end_time = time.time()
332
333 tput = args.dataset_size / (end_time - start_time)
334 latency_ms = (end_time - start_time) / args.dataset_size
335 print(
336 f"[Pack Online AttentionMask({algo})] Throughput: {tput} samples/s, Latency : {latency_ms * 1000} ms"
337 )
338
339 if args.dump_results:
340 dump_results(
341 "online_attention_mask_"
342 + algo
343 + "_"
344 + args.model_with_packing_attention_mask,
345 results,
346 )
347
348 tmpdir.cleanup()
349 return results
350
351
352def latency_distribuion_with_pack_runner_attention_mask(args, datasets, algo):
353 tmpdir = tempfile.TemporaryDirectory()
354 # export popef for PackRunner
355 get_session(
356 args.model_with_packing_attention_mask,
357 1,
358 "poprt",
359 output_dir=tmpdir.name,
360 export_popef=True,
361 ).load()
362 config = runtime.PackRunnerConfig(
363 timeout_microseconds=args.timeout_microseconds,
364 max_valid_num=args.max_valid_num,
365 dynamic_input_name=args.dynamic_input_name,
366 )
367
368 if algo == "next_fit":
369 config.algorithm = runtime.PackAlgorithm.next_fit
370 else:
371 config.algorithm = runtime.PackAlgorithm.first_fit
372
373 config.enable_input_single_row_mode("attention_mask")
374 popef_path = tmpdir.name + '/executable.popef'
375 # popef_path = "/popconverter/examples/packed_bert_example/executable.popef"
376 pack_runner = runtime.Runner(popef_path, config)
377
378 sample_num = args.batch_size * args.iterations
379 clients = int(args.batch_size * 3.5)
380 count_percent = 0.6
381
382 q = Queue()
383
384 def perf_count(model_runner, iteration):
385 durations = []
386 for i in range(sample_num):
387 start_time = time.time()
388 random.randint(0, sample_num)
389 feed_dicts = {
390 INPUT_IDS: datasets[i].input_ids,
391 ATTENTION_MASK: datasets[i].attention_mask,
392 TOKEN_TYPE_IDS: datasets[i].token_type_ids,
393 }
394 out_dict = {
395 OUTPUT1: np.zeros([args.max_seq_len]).astype(np.float16),
396 OUTPUT2: np.zeros([args.max_seq_len]).astype(np.float16),
397 }
398 pack_runner.execute(feed_dicts, out_dict)
399 end_time = time.time()
400 durations.append((start_time, end_time))
401 # remove first and last example's time counter
402 ignored_samples = int(sample_num * (1 - count_percent) / 2)
403 durations = durations[ignored_samples:-ignored_samples]
404 q.put(durations, timeout=10)
405
406 thp = [
407 threading.Thread(target=perf_count, args=(pack_runner, args.iterations))
408 for _ in range(clients)
409 ]
410 for t in thp:
411 t.start()
412 for t in thp:
413 t.join()
414
415 durations_from_th = []
416 while not q.empty():
417 durations_from_th += q.get()
418 max_timestamp = max(y for _, y in durations_from_th)
419 min_timestamp = min(x for x, _ in durations_from_th)
420 clients * (sample_num * count_percent) / (max_timestamp - min_timestamp)
421 times_range = [y - x for x, y in durations_from_th]
422
423 times_range.sort()
424 tail_latency = round(times_range[int(len(times_range) * 0.99)] * 1000, 2)
425 avg_latency = round(sum(times_range) / len(times_range) * 1000, 2)
426
427 print(f"Average Latency: {avg_latency}ms, P99 latency: {tail_latency}ms.")
428 return tail_latency, avg_latency
429
430
431# no pack, padding each line with 0 if input length is not long enough.
432# samples num equals to batch at every running turn
433def run_original_model_with_model_runner(args, datasets):
434 run_queue = queue.Queue()
435 start_time = time.time()
436 for i in range(0, args.dataset_size, args.batch_size):
437 feed_dicts = padding_data(datasets, i, args)
438 run_queue.put((args.batch_size, feed_dicts))
439 run_queue.put((0, None))
440 duration_of_padding_s = time.time() - start_time
441
442 mean_latency_of_padding_us = duration_of_padding_s * 1e6 / args.dataset_size
443 print(f"Mean latency of padding data: {mean_latency_of_padding_us} us/sam")
444 print(f"Total latency of padding data: {duration_of_padding_s} s")
445
446 sess = get_session(args.model_without_packing, 1, "poprt").load()
447
448 asy_results = []
449
450 def execute(feed_dicts, valid_num):
451 outputs = sess.run([OUTPUT1, OUTPUT2], feed_dicts)
452 res = []
453 for i in range(valid_num):
454 res1 = outputs[0][i].copy().tolist()
455 res2 = outputs[1][i].copy().tolist()
456 res.append({OUTPUT1: res1, OUTPUT2: res2})
457 return res
458
459 # execute
460 pool = ThreadPool(processes=1)
461 total_start_time = time.time()
462 while True:
463 valid_num, feed_dicts = run_queue.get()
464 if feed_dicts is None:
465 break
466 async_result = pool.apply_async(execute, (feed_dicts, valid_num))
467 asy_results.append(async_result)
468 results = []
469 for asy in asy_results:
470 for res in asy.get():
471 results.append(res)
472 total_end_time = time.time()
473
474 tput = len(results) / (total_end_time - total_start_time)
475 latency = (total_end_time - total_start_time) / len(results)
476
477 if args.dump_results:
478 dump_results("original_" + args.model_without_packing, results)
479
480 print(f"[Original] Throughput: {tput} samples/s, Latency: {latency *1000} ms")
481
482 return results
483
484
485def calculate_mae(expected_results, output_results, datasets, enable_debug):
486 assert len(datasets) == len(expected_results)
487 assert len(datasets) == len(output_results)
488 maes = []
489 zipped_data = zip(datasets, expected_results, output_results)
490 for i, (data, expected, output) in enumerate(zipped_data):
491 np.testing.assert_equal(len(expected), len(output))
492 input_len = data.input_len
493 output_1_mae = mean_absolute_error(
494 expected[OUTPUT1][:input_len], output[OUTPUT1][:input_len]
495 )
496 output_2_mae = mean_absolute_error(
497 expected[OUTPUT2][:input_len], output[OUTPUT2][:input_len]
498 )
499 maes.append([i, output_1_mae, output_2_mae])
500
501 k = 10 if len(datasets) > 10 else len(datasets)
502
503 def print_topk(k, out_name, out_index):
504 for i in range(1, k + 1):
505 print(f"Sample: {maes[-i][0]}, {out_name} mae : {maes[-i][out_index]}")
506
507 if enable_debug:
508 maes.sort(key=lambda e: e[1])
509 print(f"\n***** Top {k} mae of output: {OUTPUT1} *****")
510 print_topk(k, OUTPUT1, 1)
511
512 maes.sort(key=lambda e: e[2])
513 print(f"\n***** Top {k} mae of output: {OUTPUT2} *****")
514 print_topk(k, OUTPUT2, 2)
515
516 print(f"{OUTPUT1} average mae: {np.mean(maes,axis=0)[1]}")
517 print(f"{OUTPUT2} average mae: {np.mean(maes,axis=0)[2]}")
518
519
520def main():
521 parser = argparse.ArgumentParser(description='packed bert-base-squad')
522 parser.add_argument(
523 '--avg_seq_len', type=int, default=128, help='average sequence length of input'
524 )
525 parser.add_argument(
526 '--batch_size', type=int, default=16, help='batch size of model'
527 )
528 parser.add_argument('--dump_results', action='store_true', help='dump results')
529 parser.add_argument(
530 '--dynamic_input_name', type=str, default=INPUT_IDS, help='dynamic input name'
531 )
532 parser.add_argument(
533 '--emb_size', type=int, default=30522, help='word embedding table size'
534 )
535 parser.add_argument(
536 '--enable_debug', action='store_true', help='enable output debug info'
537 )
538 parser.add_argument(
539 '--iterations', type=int, default=100, help='number of batches to run'
540 )
541 parser.add_argument(
542 '--max_seq_len', type=int, default=256, help='max sequence length of input'
543 )
544 parser.add_argument(
545 '--max_valid_num', type=int, default=40, help='max valid num for pack'
546 )
547 parser.add_argument(
548 '--model_without_packing', help='model without pack, unpack, repack op'
549 )
550 parser.add_argument(
551 '--model_with_packing_unpack_repack',
552 help='model with pack, unpack, repack op converted by PopRT',
553 )
554 parser.add_argument(
555 '--model_with_packing_attention_mask',
556 help='model with AttentionMask op converted by PopRT',
557 )
558 parser.add_argument(
559 '--timeout_microseconds',
560 type=int,
561 default=15000,
562 help='timeout in microseconds',
563 )
564
565 args = parser.parse_args()
566 args.dataset_size = args.iterations * args.batch_size
567
568 # generate synthetic dataset
569 datasets = get_synthetic_data(args)
570 original_result = run_original_model_with_model_runner(args, datasets)
571
572 offline_pack_result_unpack_repack = run_packing_model_with_model_runner(
573 args, datasets, args.model_with_packing_unpack_repack, True
574 )
575 online_pack_result_unpack_repack = run_packing_model_with_pack_runner_unpack_repack(
576 args, datasets
577 )
578
579 offline_pack_result_attention_mask = run_packing_model_with_model_runner(
580 args, datasets, args.model_with_packing_attention_mask, False
581 )
582 online_pack_result_attention_mask_first_fit = (
583 run_packing_model_with_pack_runner_attention_mask(args, datasets, "first_fit")
584 )
585 online_pack_result_attention_mask_next_fit = (
586 run_packing_model_with_pack_runner_attention_mask(args, datasets, "next_fit")
587 )
588 latency_distribuion_with_pack_runner_attention_mask(args, datasets, "first_fit")
589
590 # compare the results
591 print("\nCompare results between original and online pack(with unpack repack)")
592 calculate_mae(
593 original_result, online_pack_result_unpack_repack, datasets, args.enable_debug
594 )
595 print("\nCompare results between offline and online pack with unpack repack op")
596 calculate_mae(
597 offline_pack_result_unpack_repack,
598 online_pack_result_unpack_repack,
599 datasets,
600 args.enable_debug,
601 )
602
603 print(
604 "\nCompare results between original and online_first_fit pack with attention_mask op"
605 )
606 calculate_mae(
607 original_result,
608 online_pack_result_attention_mask_first_fit,
609 datasets,
610 args.enable_debug,
611 )
612 print(
613 "\nCompare results between original and online_next_fit pack with attention_mask op"
614 )
615 calculate_mae(
616 original_result,
617 online_pack_result_attention_mask_next_fit,
618 datasets,
619 args.enable_debug,
620 )
621
622 print(
623 "\nCompare results between offline and online_next_fit pack with attenttion_mask op"
624 )
625 calculate_mae(
626 offline_pack_result_attention_mask,
627 online_pack_result_attention_mask_next_fit,
628 datasets,
629 args.enable_debug,
630 )
631
632
633if __name__ == "__main__":
634 sys.exit(main())
运行完成后, 将输出类似如下信息:
[Original] Throughput: 1860.9792005501781 samples/s, Latency: 0.5373515188694 ms
....
[Pack Offline] Throughput: 2830.8140869025283 samples/s, Latency: 0.3532552719116211 ms
....
[Pack Online] Throughput: 2782.587696947809 samples/s, Latency : 0.3593777120113373 ms
....