MLIR:マルチレベル中間表現
はじめに
MLIR(Multi-Level Intermediate Representation)は、LLVMプロジェクトの一部として開発された新しいコンパイラ基盤です。機械学習コンパイラから汎用言語まで、幅広い用途に対応する柔軟なフレームワークを提供します。
MLIRの概要
MLIRは、異なる抽象度の中間表現を統一的に扱えるフレームワークです。従来のコンパイラは単一の中間表現を使用していましたが、MLIRでは複数レベルの中間表現を階層的に扱えます。
設計目標
- 拡張性: 新しい方言(Dialect)の容易な追加
- 再利用性: 共通パスとインフラの活用
- 表現力: 高レベルから低レベルまで表現可能
- 統一性: 全レベルで共通のインフラを使用
LLVM IRとの違い
| 特性 | LLVM IR | MLIR |
|---|---|---|
| 抽象レベル | 低レベルのみ | 複数レベル対応 |
| 拡張性 | 限定的 | 高い(Dialect) |
| 用途 | 汎用コンパイラ | 機械学習、DSL、汎用 |
| 型システム | 固定 | 拡張可能 |
Dialect(方言)
MLIRの中核概念はDialectです。各Dialectは特定の抽象度や領域に特化した操作と型を定義します。
主要なDialect
| Dialect | 用途 |
|---|---|
builtin | 組み込み操作と型 |
func | 関数定義と呼び出し |
arith | 算術演算 |
memref | メモリ参照 |
scf | 標準制御フロー |
cf | 制御フロー(低レベル) |
linalg | 線形代数 |
tensor | テンソル操作 |
vector | ベクトル演算 |
affine | アフィン変換 |
gpu | GPU操作 |
llvm | LLVM IRとの連携 |
spirv | SPIR-V操作 |
機械学習関連のDialect
| Dialect | 用途 |
|---|---|
tosa | Tensor Operator Set Architecture |
mhlo | MLIR-HLO(XLA) |
stablehlo | StableHLO |
iree_linalg_ext | IREE拡張 |
MLIRの基本構造
モジュール構造
HLJSMLIR
module {
func.func @main() -> i32 {
%c42 = arith.constant 42 : i32
func.return %c42 : i32
}
}
関数定義
HLJSMLIR
module {
func.func @add(%a: i32, %b: i32) -> i32 {
%result = arith.addi %a, %b : i32
func.return %result : i32
}
func.func @multiply(%a: i32, %b: i32) -> i32 {
%result = arith.muli %a, %b : i32
func.return %result : i32
}
}
制御フロー(scf)
HLJSMLIR
module {
func.func @sum_to_n(%n: i32) -> i32 {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%sum = scf.for %i = %c0 to %n step %c1
iter_args(%acc = %c0) -> (i32) {
%new_acc = arith.addi %acc, %i : i32
scf.yield %new_acc : i32
}
func.return %sum : i32
}
}
条件分岐(scf.if)
HLJSMLIR
module {
func.func @abs(%x: i32) -> i32 {
%c0 = arith.constant 0 : i32
%is_negative = arith.cmpi slt, %x, %c0 : i32
%result = scf.if %is_negative -> i32 {
%neg = arith.subi %c0, %x : i32
scf.yield %neg : i32
} else {
scf.yield %x : i32
}
func.return %result : i32
}
}
メモリ操作(memref)
HLJSMLIR
module {
func.func @array_sum(%arr: memref<10xi32>) -> i32 {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c10 = arith.constant 10 : i32
%sum = scf.for %i = %c0 to %c10 step %c1
iter_args(%acc = %c0) -> (i32) {
%val = memref.load %arr[%i] : memref<10xi32>
%new_acc = arith.addi %acc, %val : i32
scf.yield %new_acc : i32
}
func.return %sum : i32
}
}
テンソル操作
HLJSMLIR
module {
func.func @tensor_add(%a: tensor<4xi32>, %b: tensor<4xi32>) -> tensor<4xi32> {
%result = arith.addi %a, %b : tensor<4xi32>
func.return %result : tensor<4xi32>
}
func.func @tensor_sum(%t: tensor<4xi32>) -> i32 {
%c0 = arith.constant 0 : i32
%sum = "tensor.reduce"(%t, %c0) ({
^bb0(%arg0: i32, %arg1: i32):
%add = arith.addi %arg0, %arg1 : i32
tensor.yield %add : i32
}) {dimensions = [0]} : (tensor<4xi32>, i32) -> i32
func.return %sum : i32
}
}
線形代数(linalg)
行列乗算
HLJSMLIR
module {
func.func @matmul(%A: memref<64x64xf32>,
%B: memref<64x64xf32>,
%C: memref<64x64xf32>) {
linalg.matmul ins(%A, %B: memref<64x64xf32>, memref<64x64xf32>)
outs(%C: memref<64x64xf32>)
func.return
}
}
汎用演算
HLJSMLIR
module {
func.func @elementwise_add(%A: memref<64xf32>,
%B: memref<64xf32>,
%C: memref<64xf32>) {
linalg.generic
{ indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)> ],
iterator_types = ["parallel"] }
ins(%A, %B : memref<64xf32>, memref<64xf32>)
outs(%C : memref<64xf32>) {
^bb0(%a: f32, %b: f32, %c: f32):
%sum = arith.addf %a, %b : f32
linalg.yield %sum : f32
}
func.return
}
}
Lowering(低レベル化)
MLIRでは、高レベルのDialectから低レベルのDialectへ段階的に変換(Lowering)します。
Loweringパイプライン
高レベル Dialect
(tosa, mhlo, stablehlo)
↓
中間 Dialect (linalg, tensor)
↓
制御フロー Dialect (scf, affine)
↓
メモリ Dialect (memref, vector)
↓
低レベル Dialect (llvm)
↓
LLVM IR
mlir-optによるLowering
HLJSBASH
# 検証のみ
mlir-opt input.mlir --verify-diagnostics
# TOSAからlinalgへの変換
mlir-opt input.mlir --tosa-to-linalg -o output.mlir
# linalgからループへの変換
mlir-opt input.mlir --convert-linalg-to-loops -o output.mlir
# scfからcfへの変換
mlir-opt input.mlir --convert-scf-to-cf -o output.mlir
# memrefからllvmへの変換
mlir-opt input.mlir --convert-memref-to-llvm -o output.mlir
# funcからllvmへの変換
mlir-opt input.mlir --convert-func-to-llvm -o output.mlir
# LLVM IRへの変換
mlir-opt input.mlir --convert-arith-to-llvm \
--convert-scf-to-cf \
--convert-cf-to-llvm \
--convert-func-to-llvm \
--convert-memref-to-llvm \
--reconcile-unrealized-casts | \
mlir-translate --mlir-to-llvmir -o output.ll
完全なLoweringパイプラインの例
HLJSBASH
mlir-opt matmul.mlir \
--linalg-bufferize \
--linalg-promote-subviews \
--convert-linalg-to-loops \
--convert-scf-to-cf \
--convert-memref-to-llvm \
--convert-arith-to-llvm \
--convert-func-to-llvm \
--convert-cf-to-llvm \
--reconcile-unrealized-casts \
-o lowered.mlir
カスタムDialectの作成
Dialect定義(ODS)
HLJSTABLEGEN
// MyDialect.td
def My_Dialect : Dialect {
let name = "my";
let summary = "My custom dialect";
let description = [{
A custom dialect for demonstration purposes.
}];
let cppNamespace = "::my";
}
def My_AddOp : My_Op<"add"> {
let summary = "Addition operation";
let arguments = (ins My_Type:$a, My_Type:$b);
let results = (outs My_Type:$result);
let assemblyFormat = "$a `,` $b attr-dict `:` type($result)";
}
C++実装
HLJSCPP
// MyDialect.cpp
namespace mlir {
namespace my {
void MyDialect::initialize() {
addOperations<
>();
addTypes<MyType>();
}
LogicalResult MyAddOp::verify() {
Type lhsType = getLhs().getType();
Type rhsType = getRhs().getType();
Type resultType = getResult().getType();
if (lhsType != rhsType || lhsType != resultType)
return emitOpError("operand and result types must match");
return success();
}
} // namespace my
} // namespace mlir
Passの実装
Dialect変換Pass
HLJSCPP
namespace mlir {
namespace my {
struct MyToLLVMLoweringPass
: public PassWrapper<MyToLLVMLoweringPass, OperationPass<ModuleOp>> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
ConversionTarget target(*context);
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
LLVMTypeConverter typeConverter(context);
RewritePatternSet patterns(context);
populateMyToLLVMConversionPatterns(patterns, typeConverter);
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
};
std::unique_ptr<Pass> createMyToLLVMLoweringPass() {
return std::make_unique<MyToLLVMLoweringPass>();
}
} // namespace my
} // namespace mlir
Passの登録
HLJSCPP
void registerMyPasses() {
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return mlir::my::createMyToLLVMLoweringPass();
});
}
GPU向けの変換
GPU Dialectの使用
HLJSMLIR
module {
func.func @gpu_add(%a: memref<1024xf32>,
%b: memref<1024xf32>,
%c: memref<1024xf32>) {
%c0 = arith.constant 0 : index
%c1024 = arith.constant 1024 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1024, %grid_y = %c0, %grid_z = %c0)
threads(%tx, %ty, %tz) in (%block_x = %c1024, %block_y = %c0, %block_z = %c0) {
%idx = arith.addi %bx, %tx : index
%va = memref.load %a[%idx] : memref<1024xf32>
%vb = memref.load %b[%idx] : memref<1024xf32>
%vc = arith.addf %va, %vb : f32
memref.store %vc, %c[%idx] : memref<1024xf32>
gpu.terminator
}
func.return
}
}
GPUへの変換
HLJSBASH
mlir-opt input.mlir \
--convert-linalg-to-loops \
--convert-scf-to-cf \
--convert-cf-to-llvm \
--convert-arith-to-llvm \
--convert-func-to-llvm \
--convert-memref-to-llvm \
--gpu-kernel-outlining \
--convert-gpu-to-nvvm \
--convert-nvvm-to-llvm \
-o output.mlir
実践:mlir-optとmlir-translate
検証と解析
HLJSBASH
# 構文検証
mlir-opt input.mlir --verify-diagnostics
# 統計情報の表示
mlir-opt input.mlir --mlir-print-op-stats
# IRの正規化
mlir-opt input.mlir --canonicalize -o canonicalized.mlir
# CSE(共通部分式除去)
mlir-opt input.mlir --cse -o optimized.mlir
# インライン化
mlir-opt input.mlir --inline -o inlined.mlir
可視化
HLJSBASH
# 制御フローグラフの可視化
mlir-opt input.mlir --view-op-graph
# DOT形式での出力
mlir-opt input.mlir --dump-op-graph -o graph.dot
dot -Tpng graph.dot -o graph.png
LLVM IRへの変換
HLJSBASH
# MLIRからLLVM IR
mlir-translate --mlir-to-llvmir input.mlir -o output.ll
# LLVM IRからオブジェクトファイル
llc output.ll -o output.o
# 実行ファイルの作成
clang output.o -o program
実際のプロジェクトでの活用
IREE(Machine Learning Compiler)
IREEはMLIRを使用した機械学習コンパイラです。
HLJSBASH
# IREEでのコンパイル
iree-compile model.mlir -o module.vmfb
# 実行
iree-run-module --module=module.vmfb --function=main
MLIR-AIE(AMD AI Engine)
HLJSBASH
# AIE向けの変換
mlir-opt aie.mlir --aie-standard-lowering -o aie_lowered.mlir
まとめ
MLIRは、コンパイラ開発に新たなパラダイムをもたらしています。
MLIRの利点
- 統一フレームワーク: 異なる抽象度を同じインフラで扱える
- 再利用性: 共通の最適化と変換を共有
- 拡張性: 新しいDialectの追加が容易
- 生産性: DSLやドメイン固有の最適化が容易
学習リソース
- MLIR公式ドキュメント: https://mlir.llvm.org/
- MLIR Language Reference: https://mlir.llvm.org/docs/LangRef/
- MLIR Tutorials: https://mlir.llvm.org/docs/Tutorials/
MLIRは現在、機械学習コンパイラ、ドメイン固有言語、そして従来のコンパイラ開発において急速に採用されています。その柔軟性と拡張性は、将来のコンパイラ技術の基盤となるでしょう。