low-layer

MLIR:マルチレベル中間表現

MLIRの基本概念からDialect、Lowering、実践的な活用方法まで詳しく解説します。

7 min read

MLIR:マルチレベル中間表現

はじめに

MLIR(Multi-Level Intermediate Representation)は、LLVMプロジェクトの一部として開発された新しいコンパイラ基盤です。機械学習コンパイラから汎用言語まで、幅広い用途に対応する柔軟なフレームワークを提供します。

MLIRの概要

MLIRは、異なる抽象度の中間表現を統一的に扱えるフレームワークです。従来のコンパイラは単一の中間表現を使用していましたが、MLIRでは複数レベルの中間表現を階層的に扱えます。

設計目標

  • 拡張性: 新しい方言(Dialect)の容易な追加
  • 再利用性: 共通パスとインフラの活用
  • 表現力: 高レベルから低レベルまで表現可能
  • 統一性: 全レベルで共通のインフラを使用

LLVM IRとの違い

特性LLVM IRMLIR
抽象レベル低レベルのみ複数レベル対応
拡張性限定的高い(Dialect)
用途汎用コンパイラ機械学習、DSL、汎用
型システム固定拡張可能

Dialect(方言)

MLIRの中核概念はDialectです。各Dialectは特定の抽象度や領域に特化した操作と型を定義します。

主要なDialect

Dialect用途
builtin組み込み操作と型
func関数定義と呼び出し
arith算術演算
memrefメモリ参照
scf標準制御フロー
cf制御フロー(低レベル)
linalg線形代数
tensorテンソル操作
vectorベクトル演算
affineアフィン変換
gpuGPU操作
llvmLLVM IRとの連携
spirvSPIR-V操作

機械学習関連のDialect

Dialect用途
tosaTensor Operator Set Architecture
mhloMLIR-HLO(XLA)
stablehloStableHLO
iree_linalg_extIREE拡張

MLIRの基本構造

モジュール構造

HLJSMLIR
module { func.func @main() -> i32 { %c42 = arith.constant 42 : i32 func.return %c42 : i32 } }

関数定義

HLJSMLIR
module { func.func @add(%a: i32, %b: i32) -> i32 { %result = arith.addi %a, %b : i32 func.return %result : i32 } func.func @multiply(%a: i32, %b: i32) -> i32 { %result = arith.muli %a, %b : i32 func.return %result : i32 } }

制御フロー(scf)

HLJSMLIR
module { func.func @sum_to_n(%n: i32) -> i32 { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %sum = scf.for %i = %c0 to %n step %c1 iter_args(%acc = %c0) -> (i32) { %new_acc = arith.addi %acc, %i : i32 scf.yield %new_acc : i32 } func.return %sum : i32 } }

条件分岐(scf.if)

HLJSMLIR
module { func.func @abs(%x: i32) -> i32 { %c0 = arith.constant 0 : i32 %is_negative = arith.cmpi slt, %x, %c0 : i32 %result = scf.if %is_negative -> i32 { %neg = arith.subi %c0, %x : i32 scf.yield %neg : i32 } else { scf.yield %x : i32 } func.return %result : i32 } }

メモリ操作(memref)

HLJSMLIR
module { func.func @array_sum(%arr: memref<10xi32>) -> i32 { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %c10 = arith.constant 10 : i32 %sum = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %c0) -> (i32) { %val = memref.load %arr[%i] : memref<10xi32> %new_acc = arith.addi %acc, %val : i32 scf.yield %new_acc : i32 } func.return %sum : i32 } }

テンソル操作

HLJSMLIR
module { func.func @tensor_add(%a: tensor<4xi32>, %b: tensor<4xi32>) -> tensor<4xi32> { %result = arith.addi %a, %b : tensor<4xi32> func.return %result : tensor<4xi32> } func.func @tensor_sum(%t: tensor<4xi32>) -> i32 { %c0 = arith.constant 0 : i32 %sum = "tensor.reduce"(%t, %c0) ({ ^bb0(%arg0: i32, %arg1: i32): %add = arith.addi %arg0, %arg1 : i32 tensor.yield %add : i32 }) {dimensions = [0]} : (tensor<4xi32>, i32) -> i32 func.return %sum : i32 } }

線形代数(linalg)

行列乗算

HLJSMLIR
module { func.func @matmul(%A: memref<64x64xf32>, %B: memref<64x64xf32>, %C: memref<64x64xf32>) { linalg.matmul ins(%A, %B: memref<64x64xf32>, memref<64x64xf32>) outs(%C: memref<64x64xf32>) func.return } }

汎用演算

HLJSMLIR
module { func.func @elementwise_add(%A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) { linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)> ], iterator_types = ["parallel"] } ins(%A, %B : memref<64xf32>, memref<64xf32>) outs(%C : memref<64xf32>) { ^bb0(%a: f32, %b: f32, %c: f32): %sum = arith.addf %a, %b : f32 linalg.yield %sum : f32 } func.return } }

Lowering(低レベル化)

MLIRでは、高レベルのDialectから低レベルのDialectへ段階的に変換(Lowering)します。

Loweringパイプライン

高レベル Dialect (tosa, mhlo, stablehlo) ↓ 中間 Dialect (linalg, tensor) ↓ 制御フロー Dialect (scf, affine) ↓ メモリ Dialect (memref, vector) ↓ 低レベル Dialect (llvm) ↓ LLVM IR

mlir-optによるLowering

HLJSBASH
# 検証のみ mlir-opt input.mlir --verify-diagnostics # TOSAからlinalgへの変換 mlir-opt input.mlir --tosa-to-linalg -o output.mlir # linalgからループへの変換 mlir-opt input.mlir --convert-linalg-to-loops -o output.mlir # scfからcfへの変換 mlir-opt input.mlir --convert-scf-to-cf -o output.mlir # memrefからllvmへの変換 mlir-opt input.mlir --convert-memref-to-llvm -o output.mlir # funcからllvmへの変換 mlir-opt input.mlir --convert-func-to-llvm -o output.mlir # LLVM IRへの変換 mlir-opt input.mlir --convert-arith-to-llvm \ --convert-scf-to-cf \ --convert-cf-to-llvm \ --convert-func-to-llvm \ --convert-memref-to-llvm \ --reconcile-unrealized-casts | \ mlir-translate --mlir-to-llvmir -o output.ll

完全なLoweringパイプラインの例

HLJSBASH
mlir-opt matmul.mlir \ --linalg-bufferize \ --linalg-promote-subviews \ --convert-linalg-to-loops \ --convert-scf-to-cf \ --convert-memref-to-llvm \ --convert-arith-to-llvm \ --convert-func-to-llvm \ --convert-cf-to-llvm \ --reconcile-unrealized-casts \ -o lowered.mlir

カスタムDialectの作成

Dialect定義(ODS)

HLJSTABLEGEN
// MyDialect.td def My_Dialect : Dialect { let name = "my"; let summary = "My custom dialect"; let description = [{ A custom dialect for demonstration purposes. }]; let cppNamespace = "::my"; } def My_AddOp : My_Op<"add"> { let summary = "Addition operation"; let arguments = (ins My_Type:$a, My_Type:$b); let results = (outs My_Type:$result); let assemblyFormat = "$a `,` $b attr-dict `:` type($result)"; }

C++実装

HLJSCPP
// MyDialect.cpp #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { namespace my { void MyDialect::initialize() { addOperations< #define GET_OP_LIST #include "MyOps.cpp.inc" >(); addTypes<MyType>(); } LogicalResult MyAddOp::verify() { Type lhsType = getLhs().getType(); Type rhsType = getRhs().getType(); Type resultType = getResult().getType(); if (lhsType != rhsType || lhsType != resultType) return emitOpError("operand and result types must match"); return success(); } } // namespace my } // namespace mlir

Passの実装

Dialect変換Pass

HLJSCPP
#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace my { struct MyToLLVMLoweringPass : public PassWrapper<MyToLLVMLoweringPass, OperationPass<ModuleOp>> { void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); ConversionTarget target(*context); target.addLegalDialect<LLVM::LLVMDialect>(); target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(); LLVMTypeConverter typeConverter(context); RewritePatternSet patterns(context); populateMyToLLVMConversionPatterns(patterns, typeConverter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } }; std::unique_ptr<Pass> createMyToLLVMLoweringPass() { return std::make_unique<MyToLLVMLoweringPass>(); } } // namespace my } // namespace mlir

Passの登録

HLJSCPP
#include "mlir/InitAllPasses.h" void registerMyPasses() { mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> { return mlir::my::createMyToLLVMLoweringPass(); }); }

GPU向けの変換

GPU Dialectの使用

HLJSMLIR
module { func.func @gpu_add(%a: memref<1024xf32>, %b: memref<1024xf32>, %c: memref<1024xf32>) { %c0 = arith.constant 0 : index %c1024 = arith.constant 1024 : index gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1024, %grid_y = %c0, %grid_z = %c0) threads(%tx, %ty, %tz) in (%block_x = %c1024, %block_y = %c0, %block_z = %c0) { %idx = arith.addi %bx, %tx : index %va = memref.load %a[%idx] : memref<1024xf32> %vb = memref.load %b[%idx] : memref<1024xf32> %vc = arith.addf %va, %vb : f32 memref.store %vc, %c[%idx] : memref<1024xf32> gpu.terminator } func.return } }

GPUへの変換

HLJSBASH
mlir-opt input.mlir \ --convert-linalg-to-loops \ --convert-scf-to-cf \ --convert-cf-to-llvm \ --convert-arith-to-llvm \ --convert-func-to-llvm \ --convert-memref-to-llvm \ --gpu-kernel-outlining \ --convert-gpu-to-nvvm \ --convert-nvvm-to-llvm \ -o output.mlir

実践:mlir-optとmlir-translate

検証と解析

HLJSBASH
# 構文検証 mlir-opt input.mlir --verify-diagnostics # 統計情報の表示 mlir-opt input.mlir --mlir-print-op-stats # IRの正規化 mlir-opt input.mlir --canonicalize -o canonicalized.mlir # CSE(共通部分式除去) mlir-opt input.mlir --cse -o optimized.mlir # インライン化 mlir-opt input.mlir --inline -o inlined.mlir

可視化

HLJSBASH
# 制御フローグラフの可視化 mlir-opt input.mlir --view-op-graph # DOT形式での出力 mlir-opt input.mlir --dump-op-graph -o graph.dot dot -Tpng graph.dot -o graph.png

LLVM IRへの変換

HLJSBASH
# MLIRからLLVM IR mlir-translate --mlir-to-llvmir input.mlir -o output.ll # LLVM IRからオブジェクトファイル llc output.ll -o output.o # 実行ファイルの作成 clang output.o -o program

実際のプロジェクトでの活用

IREE(Machine Learning Compiler)

IREEはMLIRを使用した機械学習コンパイラです。

HLJSBASH
# IREEでのコンパイル iree-compile model.mlir -o module.vmfb # 実行 iree-run-module --module=module.vmfb --function=main

MLIR-AIE(AMD AI Engine)

HLJSBASH
# AIE向けの変換 mlir-opt aie.mlir --aie-standard-lowering -o aie_lowered.mlir

まとめ

MLIRは、コンパイラ開発に新たなパラダイムをもたらしています。

MLIRの利点

  1. 統一フレームワーク: 異なる抽象度を同じインフラで扱える
  2. 再利用性: 共通の最適化と変換を共有
  3. 拡張性: 新しいDialectの追加が容易
  4. 生産性: DSLやドメイン固有の最適化が容易

学習リソース

MLIRは現在、機械学習コンパイラ、ドメイン固有言語、そして従来のコンパイラ開発において急速に採用されています。その柔軟性と拡張性は、将来のコンパイラ技術の基盤となるでしょう。