AI Compilation Introduction

This article introduces the basic concepts of AI compilation, including its workflow, and its applications such as PyTorch, TVM, and MLIR.

AI Compilation Introduction

AI Compilation Designs

AI Compilers transform a computation graph into low-level code (e.g. LLVM-IR, SPIR-V, MLIR) which can be interpreted or further compiled for a target device. An AI Compiler typically consists of the following parts.

Figure 1: AI Compiler Routine

There are three levels of optimization in AI compilers:

Graph-Level Optimization.

The computation graph is optimized by operating on the graph structure, such as removing redundant nodes (operators), fusing nodes, and applying graph-level transformations.

Operator-Level Optimization.

An operator can be optimized by applying transformations to the for-loops, such as loop unrolling, loop tiling, and loop fusion.

ISA-Level Optimization.

The backend generates code that leverages special instructions on the target device, such as SIMD instructions, tensor instructions, and other hardware accelerator-specific instructions.

Great Articles

PyTorch Infrastructure

By default, PyTorch uses its eager execution mode to describe a model, where the computation graph is dynamically constructed and executed, and is referred to as a dynamic graph.

PyTorch provides several ways to convert the dynamic graph to a static graph for AI compilation.

TorchScript

TorchScript (deprecated) is a method to convert the dynamic graph to a static graph using a Python-like language.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch

class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
print("my_cell: \n", my_cell, "\n", my_cell(x, h), "\n", my_cell.graph)

# Way 1: tracing the model
traced_cell = torch.jit.trace(my_cell, (x, h))
print("\n\ntraced_cell: \n", traced_cell, "\n", traced_cell(x, h), "\n", traced_cell.graph)

# Way 2: scripting the model
scripted_cell = torch.jit.script(my_cell)
print("\n\nscripted_cell: \n", scripted_cell, "\n", scripted_cell(x, h), "\n", scripted_cell.graph)

This outputs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
my_cell:
MyCell(
(linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[0.0229, 0.3737, 0.6648, 0.7164],
[0.6495, 0.6735, 0.4847, 0.3583],
[0.2359, 0.4631, 0.9131, 0.4446]], grad_fn=<TanhBackward0>), tensor([[0.0229, 0.3737, 0.6648, 0.7164],
[0.6495, 0.6735, 0.4847, 0.3583],
[0.2359, 0.4631, 0.9131, 0.4446]], grad_fn=<TanhBackward0>))


traced_cell:
MyCell(
original_name=MyCell
(linear): Linear(original_name=Linear)
)
(tensor([[0.0229, 0.3737, 0.6648, 0.7164],
[0.6495, 0.6735, 0.4847, 0.3583],
[0.2359, 0.4631, 0.9131, 0.4446]], grad_fn=<TanhBackward0>), tensor([[0.0229, 0.3737, 0.6648, 0.7164],
[0.6495, 0.6735, 0.4847, 0.3583],
[0.2359, 0.4631, 0.9131, 0.4446]], grad_fn=<TanhBackward0>))
graph(%self.1 : __torch__.MyCell,
%x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
%h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
%20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
%11 : int = prim::Constant[value=1]() # /pwd/main.py:9:0
%12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /pwd/main.py:90
%13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /pwd/main.py:9:0
%14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
return (%14)



scripted_cell:
RecursiveScriptModule(
original_name=MyCell
(linear): RecursiveScriptModule(original_name=Linear)
)
(tensor([[0.0229, 0.3737, 0.6648, 0.7164],
[0.6495, 0.6735, 0.4847, 0.3583],
[0.2359, 0.4631, 0.9131, 0.4446]], grad_fn=<TanhBackward0>), tensor([[0.0229, 0.3737, 0.6648, 0.7164],
[0.6495, 0.6735, 0.4847, 0.3583],
[0.2359, 0.4631, 0.9131, 0.4446]], grad_fn=<TanhBackward0>))
graph(%self : __torch__.___torch_mangle_3.MyCell,
%x.1 : Tensor,
%h.1 : Tensor):
%7 : int = prim::Constant[value=1]()
%linear : __torch__.torch.nn.modules.linear.___torch_mangle_2.Linear = prim::GetAttr[name="linear"](%self)
%5 : Tensor = prim::CallMethod[name="forward"](%linear, %x.1) # /pwd/main.py:9:27
%8 : Tensor = aten::add(%5, %h.1, %7) # /pwd/main.py:9:27
%new_h.1 : Tensor = aten::tanh(%8) # /pwd/main.py:9:16
%12 : (Tensor, Tensor) = prim::TupleConstruct(%new_h.1, %new_h.1)
return (%12)

  • Trace (torch.jit.trace) is a way to convert the dynamic graph to a static graph by tracing the model with a set of inputs. However, some control flows cannot be traced. For example,

    class MyDecisionGate(torch.nn.Module):
       def forward(self, x):
           if x.sum() > 0:
               return x
           else:
               return -x
  • Script (torch.jit.script) is a way to convert the dynamic graph to a static graph by using a Python-like language. It can handle control flows, but it requires the model to be written in a subset of Python.

Torch FX

Torch FX is a tool for modifying the computational graph. Some optimizations, such as operator fusion, can be applied to the graph using this tool. Also, Torch FX supports torch.fx.Graph, which is a way to represent the computation graph in a more flexible way (it is like a call graph!).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.fx


class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

def forward(self, x):
return torch.topk(
torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3
)


m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

This outputs

1
2
3
4
5
6
7
8
9
10
opcode         name           target                                                   args                kwargs
------------- ------------- ------------------------------------------------------- ------------------ -----------
placeholder x x () {}
get_attr linear_weight linear.weight () {}
call_function add <built-in function add> (x, linear_weight) {}
call_module linear linear (add,) {}
call_method relu relu (linear,) {}
call_function sum_1 <built-in method sum of type object at 0x7dc9e439af60> (relu,) {'dim': -1}
call_function topk <built-in method topk of type object at 0x7dc9e439af60> (sum_1, 3) {}
output output output (topk,) {}

Torch Compile and Torch Dynamo

  • Torch Compile (torch.compile) provides a high-level API to compile PyTorch models into static graphs just in time. It uses Torch Dynamo as the backend to generate the static graph.

  • Torch Dynamo, different from TorchScript and Torch FX, uses JIT to compile the bytecode of the model to a static graph. When some operations in the computational graph generations are hard to compile without further information, Torch Dynamo could break the whole generation into several smaller graphs and compile them one by one, leaving the barrier unchanged.

  • Torch Dynamo provides a way to compile the model to Torch FX Graph, under the circumstance that NO barriers (as mentioned above) are met.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

@torch.compile
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(my_cell(x, h))

Great Articles

TVM (Tensor Virtual Machine)

Concepts

TVM compiles deep learning models (Computation Graph) into various hardware device instruction sets (with minimal runtime). It receives models from PyTorch, TensorFlow, ONNX etc., and compiles them for various target devices.

Figure 2: TVM working flow

Figure 3: TVM compilation and runtime

Running Example

Given a simple MLP model (here we use TVM frontend to describe the model; in practice, one can use PyTorch or other frameworks to define it):1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import tvm
from tvm import relax
from tvm.relax.frontend import nn


class MLPModel(nn.Module):
def __init__(self):
super(MLPModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)

def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x

Then we manually instruct TVM to generate an IRModule.

1
2
3
4
mod, param_spec = MLPModel().export_tvm(
spec={"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod.show()

The module is in the following form, showing the entire computation graph:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv

Then optimization passes are applied to the IRModule in a pipelined manner. A library dispatch can be applied, and auto-tuning can be used to find the best configuration for the target device. For example, the following example fuses nn.Linear and nn.ReLU and rewrites them into a call_dps_packed function for CUBLAS library.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Import cublas pattern
import tvm.relax.backend.cuda.cublas as _cublas


# Define a new pass for CUBLAS dispatch
@tvm.transform.module_pass(opt_level=0, name="CublasDispatch")
class CublasDispatch:
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
# Get interested patterns
patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]
# Note in real-world cases, we usually get all patterns
# patterns = relax.backend.get_patterns_with_prefix("cublas")

# Fuse ops by patterns and then run codegen
mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)
mod = relax.transform.RunCodegen()(mod)
return mod


mod = CublasDispatch()(mod)
mod.show()

And the IRModule is transformed into the following form:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(lv, permute_dims1, out_dtype="void")
gv: R.Tensor((1, 10), dtype="float32") = matmul1
R.output(gv)
return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

Additionally, auto-tuning (using a machine learning algorithm to find optimal solutions) is available:

1
2
3
4
5
6
7
8
9
10
11
12
13
device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
trials = 2000
with target, tempfile.TemporaryDirectory() as tmp_dir:
mod = tvm.ir.transform.Sequential(
[
relax.get_pipeline("zero"),
relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
]
)(mod)

mod.show()

Finally, after extensive optimization, the compilation continues and the final code is generated. A TVM runtime supports loads the code and runs it on the target device.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np

# Generate optimized code (which can be saved and loaded)
target = tvm.target.Target("llvm")
ex = relax.build(mod, target)

# ------

# Run VM (at the device side)
device = tvm.cpu()
vm = relax.VirtualMachine(ex, device)

data = np.random.rand(1, 784).astype("float32")
tvm_data = tvm.nd.array(data, device=device)
params = [np.random.rand(*param.shape).astype("float32") for _, param in param_spec]
params = [tvm.nd.array(param, device=device) for param in params]
print(vm["forward"](tvm_data, *params).numpy())

Great Articles

MLIR (Multi-Level Intermediate Representation)

Concepts

To support DSLs (including those for AI compilation) and represent many application/device-specific objects, a single centralized IR - LLVM IR

  • is insufficient. MLIR is designed to address this problem. It has a discrete, hierarchical and extensible definition, where many dialects can be defined to represent different levels of abstraction and different domains.

Figure 4: MLIR Dialects

Great Articles


  1. Example from https://tvm.apache.org/docs/↩ī¸Ž