Skip to content

Commit 719695d

Browse files
committed
[mlir][python] allow upstream dialect registration
1 parent 21a9c7e commit 719695d

File tree

9 files changed

+218
-9
lines changed

9 files changed

+218
-9
lines changed

mlir/examples/standalone/python/CMakeLists.txt

-6
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ add_mlir_python_common_capi_library(StandalonePythonCAPI
4040
RELATIVE_INSTALL_ROOT "../../../.."
4141
DECLARED_SOURCES
4242
StandalonePythonSources
43-
# TODO: Remove this in favor of showing fine grained registration once
44-
# available.
45-
MLIRPythonExtension.RegisterEverything
4643
MLIRPythonSources.Core
4744
)
4845

@@ -55,9 +52,6 @@ add_mlir_python_modules(StandalonePythonModules
5552
INSTALL_PREFIX "python_packages/standalone/mlir_standalone"
5653
DECLARED_SOURCES
5754
StandalonePythonSources
58-
# TODO: Remove this in favor of showing fine grained registration once
59-
# available.
60-
MLIRPythonExtension.RegisterEverything
6155
MLIRPythonSources
6256
COMMON_CAPI_LINK_LIBS
6357
StandalonePythonCAPI

mlir/examples/standalone/test/python/smoketest.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from mlir_standalone.ir import *
44
from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d
55

6+
add_dialect_to_dialect_registry(get_dialect_registry(), "arith")
7+
68
with Context():
79
standalone_d.register_dialect()
810
module = Module.parse(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#ifndef MLIR_C_REMAINING_DIALECTS_H
2+
#define MLIR_C_REMAINING_DIALECTS_H
3+
4+
#include "mlir-c/IR.h"
5+
6+
#ifdef __cplusplus
7+
extern "C" {
8+
#endif
9+
10+
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE) \
11+
MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##NAMESPACE##__();
12+
13+
#define FORALL_DIALECTS(_) \
14+
_(acc) \
15+
_(affine) \
16+
_(amx) \
17+
_(arm_neon) \
18+
_(arm_sme) \
19+
_(arm_sve) \
20+
_(bufferization) \
21+
_(complex) \
22+
_(dlti) \
23+
_(emitc) \
24+
_(index) \
25+
_(irdl) \
26+
_(mesh) \
27+
_(spirv) \
28+
_(tosa) \
29+
_(ub) \
30+
_(x86vector)
31+
32+
FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
33+
34+
#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
35+
#undef FORALL_DIALECTS
36+
37+
#ifdef __cplusplus
38+
}
39+
#endif
40+
41+
#endif // MLIR_C_REMAINING_DIALECTS_H

mlir/lib/Bindings/Python/MainModule.cpp

+83
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,32 @@
1212
#include "IRModule.h"
1313
#include "Pass.h"
1414

15+
#include "mlir-c/Dialect/AMDGPU.h"
16+
#include "mlir-c/Dialect/Arith.h"
17+
#include "mlir-c/Dialect/Async.h"
18+
#include "mlir-c/Dialect/ControlFlow.h"
19+
#include "mlir-c/Dialect/Func.h"
20+
#include "mlir-c/Dialect/GPU.h"
21+
#include "mlir-c/Dialect/LLVM.h"
22+
#include "mlir-c/Dialect/Linalg.h"
23+
#include "mlir-c/Dialect/MLProgram.h"
24+
#include "mlir-c/Dialect/Math.h"
25+
#include "mlir-c/Dialect/MemRef.h"
26+
#include "mlir-c/Dialect/NVGPU.h"
27+
#include "mlir-c/Dialect/NVVM.h"
28+
#include "mlir-c/Dialect/OpenMP.h"
29+
#include "mlir-c/Dialect/PDL.h"
30+
#include "mlir-c/Dialect/Quant.h"
31+
#include "mlir-c/Dialect/ROCDL.h"
32+
#include "mlir-c/Dialect/SCF.h"
33+
#include "mlir-c/Dialect/Shape.h"
34+
#include "mlir-c/Dialect/SparseTensor.h"
35+
#include "mlir-c/Dialect/Tensor.h"
36+
#include "mlir-c/Dialect/Transform.h"
37+
#include "mlir-c/Dialect/Vector.h"
38+
39+
#include "mlir-c/Dialect/RemainingDialects.h"
40+
1541
namespace py = pybind11;
1642
using namespace mlir;
1743
using namespace py::literals;
@@ -65,6 +91,63 @@ PYBIND11_MODULE(_mlir, m) {
6591
},
6692
"dialect_class"_a,
6793
"Class decorator for registering a custom Dialect wrapper");
94+
m.def(
95+
"add_dialect_to_dialect_registry",
96+
[](MlirDialectRegistry registry, const std::string &dialectNamespace) {
97+
98+
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE) \
99+
if (dialectNamespace == #NAMESPACE) { \
100+
mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(), \
101+
registry); \
102+
return; \
103+
}
104+
105+
#define FORALL_DIALECTS(_) \
106+
_(acc) \
107+
_(affine) \
108+
_(amdgpu) \
109+
_(amx) \
110+
_(arith) \
111+
_(arm_neon) \
112+
_(arm_sme) \
113+
_(arm_sve) \
114+
_(async) \
115+
_(bufferization) \
116+
_(cf) \
117+
_(complex) \
118+
_(emitc) \
119+
_(func) \
120+
_(gpu) \
121+
_(index) \
122+
_(irdl) \
123+
_(linalg) \
124+
_(llvm) \
125+
_(math) \
126+
_(memref) \
127+
_(mesh) \
128+
_(ml_program) \
129+
_(nvgpu) \
130+
_(nvvm) \
131+
_(omp) \
132+
_(pdl) \
133+
_(quant) \
134+
_(rocdl) \
135+
_(scf) \
136+
_(shape) \
137+
_(spirv) \
138+
_(tensor) \
139+
_(tosa) \
140+
_(ub) \
141+
_(vector) \
142+
_(x86vector)
143+
FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
144+
145+
#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
146+
#undef FORALL_DIALECTS
147+
throw std::runtime_error("unknown dialect namespace: " +
148+
dialectNamespace);
149+
},
150+
"dialect_registry"_a, "dialect_namespace"_a);
68151
m.def(
69152
"register_operation",
70153
[](const py::object &dialectClass, bool replace) -> py::cpp_function {

mlir/lib/CAPI/Dialect/CMakeLists.txt

+11
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,14 @@ add_mlir_upstream_c_api_library(MLIRCAPIVector
224224
MLIRCAPIIR
225225
MLIRVectorDialect
226226
)
227+
228+
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
229+
230+
add_mlir_upstream_c_api_library(MLIRCAPIRemainingDialects
231+
RemainingDialects.cpp
232+
233+
PARTIAL_SOURCES_INTENDED
234+
LINK_LIBS PUBLIC
235+
MLIRCAPIIR
236+
${dialect_libs}
237+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include "mlir-c/Dialect/RemainingDialects.h"
2+
3+
#include "mlir/CAPI/Registration.h"
4+
#include "mlir/InitAllDialects.h"
5+
6+
using namespace mlir;
7+
8+
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE, NAME) \
9+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NAME, NAMESPACE, \
10+
NAMESPACE::NAME##Dialect)
11+
12+
#define FORALL_DIALECTS(_) \
13+
_(acc, OpenACC) \
14+
_(affine, Affine) \
15+
_(amx, AMX) \
16+
_(arith, Arith) \
17+
_(arm_neon, ArmNeon) \
18+
_(arm_sme, ArmSME) \
19+
_(arm_sve, ArmSVE) \
20+
_(bufferization, Bufferization) \
21+
_(complex, Complex) \
22+
_(emitc, EmitC) \
23+
_(index, Index) \
24+
_(irdl, IRDL) \
25+
_(mesh, Mesh) \
26+
_(spirv, SPIRV) \
27+
_(tosa, Tosa) \
28+
_(ub, UB) \
29+
_(x86vector, X86Vector)
30+
31+
FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
32+
33+
#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
34+
#undef FORALL_DIALECTS
35+
36+
static void mlirDialectRegistryInsertDLTIDialect(MlirDialectRegistry registry) {
37+
unwrap(registry)->insert<mlir::DLTIDialect>();
38+
}
39+
40+
static MlirDialect mlirContextLoadDLTIDialect(MlirContext context) {
41+
return wrap(unwrap(context)->getOrLoadDialect<mlir::DLTIDialect>());
42+
}
43+
44+
static MlirStringRef mlirDLTIDialectGetNamespace() {
45+
return wrap(mlir::DLTIDialect::getDialectNamespace());
46+
}
47+
48+
MlirDialectHandle mlirGetDialectHandle__dlti__() {
49+
static MlirDialectRegistrationHooks hooks = {
50+
mlirDialectRegistryInsertDLTIDialect, mlirContextLoadDLTIDialect,
51+
mlirDLTIDialectGetNamespace};
52+
return MlirDialectHandle{&hooks};
53+
}

mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
#include "mlir/InitAllExtensions.h"
1515
#include "mlir/InitAllPasses.h"
1616
#include "mlir/Target/LLVMIR/Dialect/All.h"
17-
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
18-
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
1917

2018
void mlirRegisterAllDialects(MlirDialectRegistry registry) {
2119
mlir::registerAllDialects(*unwrap(registry));

mlir/python/CMakeLists.txt

+23
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,30 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
423423
MLIRCAPIInterfaces
424424

425425
# Dialects
426+
MLIRCAPIAMDGPU
427+
MLIRCAPIArith
428+
MLIRCAPIAsync
429+
MLIRCAPIControlFlow
426430
MLIRCAPIFunc
431+
MLIRCAPIGPU
432+
MLIRCAPILLVM
433+
MLIRCAPILinalg
434+
MLIRCAPIMLProgram
435+
MLIRCAPIMath
436+
MLIRCAPIMemRef
437+
MLIRCAPINVGPU
438+
MLIRCAPINVVM
439+
MLIRCAPIOpenMP
440+
MLIRCAPIPDL
441+
MLIRCAPIQuant
442+
MLIRCAPIROCDL
443+
MLIRCAPISCF
444+
MLIRCAPIShape
445+
MLIRCAPISparseTensor
446+
MLIRCAPITensor
447+
MLIRCAPITransformDialect
448+
MLIRCAPIVector
449+
MLIRCAPIRemainingDialects
427450
)
428451

429452
# This extension exposes an API to register all dialects, extensions, and passes

mlir/python/mlir/ir.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
from ._mlir_libs._mlir.ir import *
66
from ._mlir_libs._mlir.ir import _GlobalDebug
7-
from ._mlir_libs._mlir import register_type_caster, register_value_caster
7+
from ._mlir_libs._mlir import (
8+
register_type_caster,
9+
register_value_caster,
10+
add_dialect_to_dialect_registry,
11+
)
812
from ._mlir_libs import get_dialect_registry
913

1014

0 commit comments

Comments
 (0)