AI Compilation Introduction
This article introduces the basic concepts of AI compilation, including its workflow, and its applications such as PyTorch, TVM, and MLIR.
AI Compilation Introduction
AI Compilation Designs
AI Compilers transform a computation graph into low-level code (e.g. LLVM-IR, SPIR-V, MLIR) which can be interpreted or further compiled for a target device. An AI Compiler typically consists of the following parts.
Figure 1: AI Compiler Routine
There are three levels of optimization in AI compilers:
- Graph-Level Optimization.
-
The computation graph is optimized by operating on the graph structure, such as removing redundant nodes (operators), fusing nodes, and applying graph-level transformations.
- Operator-Level Optimization.
-
An operator can be optimized by applying transformations to the for-loops, such as loop unrolling, loop tiling, and loop fusion.
- ISA-Level Optimization.
-
The backend generates code that leverages special instructions on the target device, such as SIMD instructions, tensor instructions, and other hardware accelerator-specific instructions.
Great Articles
- Awesome Tensor Compilers: https://github.com/merrymercy/awesome-tensor-compilers
PyTorch Infrastructure
By default, PyTorch uses its eager execution mode to describe a model, where the computation graph is dynamically constructed and executed, and is referred to as a dynamic graph.
PyTorch provides several ways to convert the dynamic graph to a static graph for AI compilation.
TorchScript
TorchScript (deprecated) is a method to convert the dynamic graph to a static graph using a Python-like language.
1 | import torch |
This outputs
1 | my_cell: |
Trace (
torch.jit.trace
) is a way to convert the dynamic graph to a static graph by tracing the model with a set of inputs. However, some control flows cannot be traced. For example,class MyDecisionGate(torch.nn.Module): def forward(self, x): if x.sum() > 0: return x else: return -x
Script (
torch.jit.script
) is a way to convert the dynamic graph to a static graph by using a Python-like language. It can handle control flows, but it requires the model to be written in a subset of Python.
Torch FX
Torch FX is a tool for modifying the computational
graph. Some optimizations, such as operator fusion, can be applied to
the graph using this tool. Also, Torch FX supports
torch.fx.Graph
, which is a way to represent the computation
graph in a more flexible way (it is like a call graph!).
1 | import torch |
This outputs
1 | opcode name target args kwargs |
Torch Compile and Torch Dynamo
Torch Compile (
torch.compile
) provides a high-level API to compile PyTorch models into static graphs just in time. It uses Torch Dynamo as the backend to generate the static graph.Torch Dynamo, different from TorchScript and Torch FX, uses JIT to compile the bytecode of the model to a static graph. When some operations in the computational graph generations are hard to compile without further information, Torch Dynamo could break the whole generation into several smaller graphs and compile them one by one, leaving the barrier unchanged.
Torch Dynamo provides a way to compile the model to Torch FX Graph, under the circumstance that NO barriers (as mentioned above) are met.
1 | import torch |
Great Articles
- PyTorch Official Tutorial: https://pytorch.org/tutorials/
TVM (Tensor Virtual Machine)
Concepts
TVM compiles deep learning models (Computation Graph) into various hardware device instruction sets (with minimal runtime). It receives models from PyTorch, TensorFlow, ONNX etc., and compiles them for various target devices.
Figure 2: TVM working flow

Figure 3: TVM compilation and runtime
Running Example
Given a simple MLP model (here we use TVM frontend to describe the model; in practice, one can use PyTorch or other frameworks to define it):1
1 | import tvm |
Then we manually instruct TVM to generate an IRModule.
1 | mod, param_spec = MLPModel().export_tvm( |
The module is in the following form, showing the entire computation graph:
1 | # from tvm.script import ir as I |
Then optimization passes are applied to the IRModule in a pipelined
manner. A library dispatch can be applied, and auto-tuning can be used
to find the best configuration for the target device. For example, the
following example fuses nn.Linear
and nn.ReLU
and rewrites them into a call_dps_packed
function for
CUBLAS library.
1 | # Import cublas pattern |
And the IRModule is transformed into the following form:
1 | # from tvm.script import ir as I |
Additionally, auto-tuning (using a machine learning algorithm to find optimal solutions) is available:
1 | device = tvm.cuda(0) |
Finally, after extensive optimization, the compilation continues and the final code is generated. A TVM runtime supports loads the code and runs it on the target device.
1 | import numpy as np |
Great Articles
TVM official documentation: https://tvm.apache.org/docs
Tutorial of TVM (in Chinese): https://zhuanlan.zhihu.com/p/532873577
MLIR (Multi-Level Intermediate Representation)
Concepts
To support DSLs (including those for AI compilation) and represent many application/device-specific objects, a single centralized IR - LLVM IR
- is insufficient. MLIR is designed to address this problem. It has a discrete, hierarchical and extensible definition, where many dialects can be defined to represent different levels of abstraction and different domains.

Figure 4: MLIR Dialects
Great Articles
MLIR learning path (in Chinese): https://www.zhihu.com/question/435109274
MLIR inspiration (in Chinese): https://www.lei.chat/zh/posts/compilers-and-irs-llvm-ir-spirv-and-mlir
How TVM is different from MLIR: https://stackoverflow.com/questions/65288033/how-tvm-is-different-from-mlir
Example from https://tvm.apache.org/docs/âŠī¸