5.9. 开发 Custom Transforms

Transform 的概念来源于 PopART 框架, 是图优化的一种手段. 不同于 Pattern 在 PopART IR 上对 OP 做匹配, Transforms 在 PopART IR 上进行整图级别的优化, 以更复杂的方式改变 PopART IR.

例如 Transform SubgraphOutline 的作用是将重复的 Ops 结构提取为新的 Graphs 中, 并用 CallOps 调用以节省内存.

本教程帮助用户了解如何实现 Custom PopART Transforms 并在 PopRT 中使用.

在阅读本教程之前, 需要首先了解以下主题:

5.9.1. 实现 Custom PopART Transform

要在 PopART 中创建 Custom Transform, 需要编写至少一个 C++ 文件.

示例代码如下, 这个例子用于打印当前的序列化 IR, 可以将它编译成独立的动态库并在使用 PopRT 的时候链接:

Listing 5.16 ir_serialise_transform.cpp
 1// Copyright (c) 2022 Graphcore Ltd. All rights reserved.
 2
 3#include <iostream>
 4#include <string>
 5#include <popart/graph.hpp>
 6#include <popart/ir.hpp>
 7#include <popart/op.hpp>
 8#include <popart/transforms/transform.hpp>
 9
10namespace popart {
11class Graph;
12
13class IrSerialise : public Transform {
14public:
15  static std::size_t id();
16
17  IrSerialise() : Transform() {}
18  virtual ~IrSerialise() override {}
19
20  virtual bool apply(Graph &graph) const final;
21
22  virtual std::size_t getId() const final { return id(); }
23
24  virtual std::string getName() const final { return "IrSerialise"; }
25};
26
27std::size_t IrSerialise::id() { return typeid(IrSerialise).hash_code(); }
28
29bool IrSerialise::apply(Graph &graph) const {
30  const auto &ir = graph.getIr();
31  std::stringstream ss;
32  ir.serialise(Ir::SerialiseFormat::JSON, ss);
33  const auto modelStr = ss.str();
34  std::cout << "SerializedIr : " << std::endl;
35  std::cout << modelStr << std::endl;
36  return true;
37}
38
39namespace {
40bool init = Transform::registerTransform(new IrSerialise);
41}
42
43} // namespace popart

为了实现用户自定义的 PopART Custom Transform, 需要继承 popart::Transform 并覆盖或实现主要方法:

  • apply() 实现 IR 转换以及其他功能

  • getId() 唯一的 Transform ID

  • getName() 定义 Custom Transform 的名称, 需要避免与 PopART 已有的 Transform 名称冲突

  • registerTransform() 向 PopART 注册 Custom Transform

可以参考 /popart/willow/src/transforms 包含的 PopART 中默认的 Transform.

5.9.2. 在 PopRT 中使用 Custom Transform

到目前为止, 已经完成了这个 Custom Transform 的编写. 接下来需要做的是在 PopRT 中使用它.

需要将 Custom Transform 源码编译成独立的动态链接库并在使用 PopRT 的时候链接, 编译的命令示例如下:

g++ \
    -std=c++14 \
    -fPIC \
    -O3 \
    -DONNX_NAMESPACE=onnx \
    ir_serialise_transform.cpp \
    -shared \
    -lpopart \
    -o custom_transform.so

然后可以通过 PopRT CLI 的 --custom_library_so_paths 参数来链接包含 Custom Transforms 的动态库:

poprt  --custom_library_so_paths path/to/shared/library

由于 Transform 以比较复杂的方式改变 PopART IR, 所以需要按照预定义的顺序执行, 通常在编写 Transform 之前就需要考虑应该把它放哪一个执行位置.

Transform 的执行顺序分为几个阶段, PopART 允许在每个阶段完成后的 checkpoint 中调用用户自定义的 Custom Transform.

预定义的 checkpoint 有:

  • Fwd0: 从 ONNX Lowering 到 PopART IR 后的初始 IR

  • Fwd1: 在 pre-alias patterns 被应用到 FWD0 之后

  • Bwd0: 在 backward pass 之后

  • Prealias: 在 pre-alias patterns 被应用到 BWD0 之后

  • MainLoops: 在应用 MainLoops transform 之后

  • Final: 所有 Transform 被应用的最终 IR 之后

参阅 /popart/willow/src/popart/ir.cpp

通过 PopRT CLI 的 --compiler_options 参数来配置 Custom Transform, 示例如下

poprt \
    --input_model model.onnx \
    --output_model model_export.onnx \
    --export_popef \
    --output_dir model \
    --custom_library_so_paths build/custom_transforms.so \
    --compiler_options custom_transform_applier_settings="{'Fwd0': ['IrSerialise']}"