From 719695dcf53084c4d439f4de53d6c05c80cb32e9 Mon Sep 17 00:00:00 2001 From: max Date: Sun, 3 Dec 2023 14:45:16 -0600 Subject: [PATCH] [mlir][python] allow upstream dialect registration --- .../examples/standalone/python/CMakeLists.txt | 6 -- .../standalone/test/python/smoketest.py | 2 + .../mlir-c/Dialect/RemainingDialects.h | 41 +++++++++ mlir/lib/Bindings/Python/MainModule.cpp | 83 +++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 11 +++ mlir/lib/CAPI/Dialect/RemainingDialects.cpp | 53 ++++++++++++ .../RegisterEverything/RegisterEverything.cpp | 2 - mlir/python/CMakeLists.txt | 23 +++++ mlir/python/mlir/ir.py | 6 +- 9 files changed, 218 insertions(+), 9 deletions(-) create mode 100644 mlir/include/mlir-c/Dialect/RemainingDialects.h create mode 100644 mlir/lib/CAPI/Dialect/RemainingDialects.cpp diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt index a8c43827a5a37..014d6061f7f0f 100644 --- a/mlir/examples/standalone/python/CMakeLists.txt +++ b/mlir/examples/standalone/python/CMakeLists.txt @@ -40,9 +40,6 @@ add_mlir_python_common_capi_library(StandalonePythonCAPI RELATIVE_INSTALL_ROOT "../../../.." DECLARED_SOURCES StandalonePythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything MLIRPythonSources.Core ) @@ -55,9 +52,6 @@ add_mlir_python_modules(StandalonePythonModules INSTALL_PREFIX "python_packages/standalone/mlir_standalone" DECLARED_SOURCES StandalonePythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything MLIRPythonSources COMMON_CAPI_LINK_LIBS StandalonePythonCAPI diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py index 08e08cbd2fe24..6e82a91e0bfc7 100644 --- a/mlir/examples/standalone/test/python/smoketest.py +++ b/mlir/examples/standalone/test/python/smoketest.py @@ -3,6 +3,8 @@ from mlir_standalone.ir import * from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d +add_dialect_to_dialect_registry(get_dialect_registry(), "arith") + with Context(): standalone_d.register_dialect() module = Module.parse( diff --git a/mlir/include/mlir-c/Dialect/RemainingDialects.h b/mlir/include/mlir-c/Dialect/RemainingDialects.h new file mode 100644 index 0000000000000..e98f084798cd6 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/RemainingDialects.h @@ -0,0 +1,41 @@ +#ifndef MLIR_C_REMAINING_DIALECTS_H +#define MLIR_C_REMAINING_DIALECTS_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE) \ + MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##NAMESPACE##__(); + +#define FORALL_DIALECTS(_) \ + _(acc) \ + _(affine) \ + _(amx) \ + _(arm_neon) \ + _(arm_sme) \ + _(arm_sve) \ + _(bufferization) \ + _(complex) \ + _(dlti) \ + _(emitc) \ + _(index) \ + _(irdl) \ + _(mesh) \ + _(spirv) \ + _(tosa) \ + _(ub) \ + _(x86vector) + +FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_) + +#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_ +#undef FORALL_DIALECTS + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_REMAINING_DIALECTS_H diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 17272472ccca4..dc062244b828c 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -12,6 +12,32 @@ #include "IRModule.h" #include "Pass.h" +#include "mlir-c/Dialect/AMDGPU.h" +#include "mlir-c/Dialect/Arith.h" +#include "mlir-c/Dialect/Async.h" +#include "mlir-c/Dialect/ControlFlow.h" +#include "mlir-c/Dialect/Func.h" +#include "mlir-c/Dialect/GPU.h" +#include "mlir-c/Dialect/LLVM.h" +#include "mlir-c/Dialect/Linalg.h" +#include "mlir-c/Dialect/MLProgram.h" +#include "mlir-c/Dialect/Math.h" +#include "mlir-c/Dialect/MemRef.h" +#include "mlir-c/Dialect/NVGPU.h" +#include "mlir-c/Dialect/NVVM.h" +#include "mlir-c/Dialect/OpenMP.h" +#include "mlir-c/Dialect/PDL.h" +#include "mlir-c/Dialect/Quant.h" +#include "mlir-c/Dialect/ROCDL.h" +#include "mlir-c/Dialect/SCF.h" +#include "mlir-c/Dialect/Shape.h" +#include "mlir-c/Dialect/SparseTensor.h" +#include "mlir-c/Dialect/Tensor.h" +#include "mlir-c/Dialect/Transform.h" +#include "mlir-c/Dialect/Vector.h" + +#include "mlir-c/Dialect/RemainingDialects.h" + namespace py = pybind11; using namespace mlir; using namespace py::literals; @@ -65,6 +91,63 @@ PYBIND11_MODULE(_mlir, m) { }, "dialect_class"_a, "Class decorator for registering a custom Dialect wrapper"); + m.def( + "add_dialect_to_dialect_registry", + [](MlirDialectRegistry registry, const std::string &dialectNamespace) { + +#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE) \ + if (dialectNamespace == #NAMESPACE) { \ + mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(), \ + registry); \ + return; \ + } + +#define FORALL_DIALECTS(_) \ + _(acc) \ + _(affine) \ + _(amdgpu) \ + _(amx) \ + _(arith) \ + _(arm_neon) \ + _(arm_sme) \ + _(arm_sve) \ + _(async) \ + _(bufferization) \ + _(cf) \ + _(complex) \ + _(emitc) \ + _(func) \ + _(gpu) \ + _(index) \ + _(irdl) \ + _(linalg) \ + _(llvm) \ + _(math) \ + _(memref) \ + _(mesh) \ + _(ml_program) \ + _(nvgpu) \ + _(nvvm) \ + _(omp) \ + _(pdl) \ + _(quant) \ + _(rocdl) \ + _(scf) \ + _(shape) \ + _(spirv) \ + _(tensor) \ + _(tosa) \ + _(ub) \ + _(vector) \ + _(x86vector) + FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_) + +#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_ +#undef FORALL_DIALECTS + throw std::runtime_error("unknown dialect namespace: " + + dialectNamespace); + }, + "dialect_registry"_a, "dialect_namespace"_a); m.def( "register_operation", [](const py::object &dialectClass, bool replace) -> py::cpp_function { diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index d815eba48d9b9..de0c4b9ac2478 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -224,3 +224,14 @@ add_mlir_upstream_c_api_library(MLIRCAPIVector MLIRCAPIIR MLIRVectorDialect ) + +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + +add_mlir_upstream_c_api_library(MLIRCAPIRemainingDialects + RemainingDialects.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + ${dialect_libs} +) diff --git a/mlir/lib/CAPI/Dialect/RemainingDialects.cpp b/mlir/lib/CAPI/Dialect/RemainingDialects.cpp new file mode 100644 index 0000000000000..a35814376a3eb --- /dev/null +++ b/mlir/lib/CAPI/Dialect/RemainingDialects.cpp @@ -0,0 +1,53 @@ +#include "mlir-c/Dialect/RemainingDialects.h" + +#include "mlir/CAPI/Registration.h" +#include "mlir/InitAllDialects.h" + +using namespace mlir; + +#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE, NAME) \ + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NAME, NAMESPACE, \ + NAMESPACE::NAME##Dialect) + +#define FORALL_DIALECTS(_) \ + _(acc, OpenACC) \ + _(affine, Affine) \ + _(amx, AMX) \ + _(arith, Arith) \ + _(arm_neon, ArmNeon) \ + _(arm_sme, ArmSME) \ + _(arm_sve, ArmSVE) \ + _(bufferization, Bufferization) \ + _(complex, Complex) \ + _(emitc, EmitC) \ + _(index, Index) \ + _(irdl, IRDL) \ + _(mesh, Mesh) \ + _(spirv, SPIRV) \ + _(tosa, Tosa) \ + _(ub, UB) \ + _(x86vector, X86Vector) + +FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_) + +#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_ +#undef FORALL_DIALECTS + +static void mlirDialectRegistryInsertDLTIDialect(MlirDialectRegistry registry) { + unwrap(registry)->insert(); +} + +static MlirDialect mlirContextLoadDLTIDialect(MlirContext context) { + return wrap(unwrap(context)->getOrLoadDialect()); +} + +static MlirStringRef mlirDLTIDialectGetNamespace() { + return wrap(mlir::DLTIDialect::getDialectNamespace()); +} + +MlirDialectHandle mlirGetDialectHandle__dlti__() { + static MlirDialectRegistrationHooks hooks = { + mlirDialectRegistryInsertDLTIDialect, mlirContextLoadDLTIDialect, + mlirDLTIDialectGetNamespace}; + return MlirDialectHandle{&hooks}; +} diff --git a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp index c1c4a418b2552..767e7631de17d 100644 --- a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp +++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp @@ -14,8 +14,6 @@ #include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Target/LLVMIR/Dialect/All.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" void mlirRegisterAllDialects(MlirDialectRegistry registry) { mlir::registerAllDialects(*unwrap(registry)); diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 55731943fb78d..456cf5f205cc5 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -423,7 +423,30 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MLIRCAPIInterfaces # Dialects + MLIRCAPIAMDGPU + MLIRCAPIArith + MLIRCAPIAsync + MLIRCAPIControlFlow MLIRCAPIFunc + MLIRCAPIGPU + MLIRCAPILLVM + MLIRCAPILinalg + MLIRCAPIMLProgram + MLIRCAPIMath + MLIRCAPIMemRef + MLIRCAPINVGPU + MLIRCAPINVVM + MLIRCAPIOpenMP + MLIRCAPIPDL + MLIRCAPIQuant + MLIRCAPIROCDL + MLIRCAPISCF + MLIRCAPIShape + MLIRCAPISparseTensor + MLIRCAPITensor + MLIRCAPITransformDialect + MLIRCAPIVector + MLIRCAPIRemainingDialects ) # This extension exposes an API to register all dialects, extensions, and passes diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 6d21da3b4179f..d46134b24416e 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,7 +4,11 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug -from ._mlir_libs._mlir import register_type_caster, register_value_caster +from ._mlir_libs._mlir import ( + register_type_caster, + register_value_caster, + add_dialect_to_dialect_registry, +) from ._mlir_libs import get_dialect_registry