Skip to content

[mlir][python] allow upstream dialect registration #74252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions mlir/examples/standalone/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ add_mlir_python_common_capi_library(StandalonePythonCAPI
RELATIVE_INSTALL_ROOT "../../../.."
DECLARED_SOURCES
StandalonePythonSources
# TODO: Remove this in favor of showing fine grained registration once
# available.
MLIRPythonExtension.RegisterEverything
MLIRPythonSources.Core
)

Expand All @@ -55,9 +52,6 @@ add_mlir_python_modules(StandalonePythonModules
INSTALL_PREFIX "python_packages/standalone/mlir_standalone"
DECLARED_SOURCES
StandalonePythonSources
# TODO: Remove this in favor of showing fine grained registration once
# available.
MLIRPythonExtension.RegisterEverything
MLIRPythonSources
COMMON_CAPI_LINK_LIBS
StandalonePythonCAPI
Expand Down
2 changes: 2 additions & 0 deletions mlir/examples/standalone/test/python/smoketest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from mlir_standalone.ir import *
from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d

add_dialect_to_dialect_registry(get_dialect_registry(), "arith")

with Context():
standalone_d.register_dialect()
module = Module.parse(
Expand Down
41 changes: 41 additions & 0 deletions mlir/include/mlir-c/Dialect/RemainingDialects.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef MLIR_C_REMAINING_DIALECTS_H
#define MLIR_C_REMAINING_DIALECTS_H

#include "mlir-c/IR.h"

#ifdef __cplusplus
extern "C" {
#endif

#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE) \
MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##NAMESPACE##__();

#define FORALL_DIALECTS(_) \
_(acc) \
_(affine) \
_(amx) \
_(arm_neon) \
_(arm_sme) \
_(arm_sve) \
_(bufferization) \
_(complex) \
_(dlti) \
_(emitc) \
_(index) \
_(irdl) \
_(mesh) \
_(spirv) \
_(tosa) \
_(ub) \
_(x86vector)

FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)

#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
#undef FORALL_DIALECTS

#ifdef __cplusplus
}
#endif

#endif // MLIR_C_REMAINING_DIALECTS_H
83 changes: 83 additions & 0 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,32 @@
#include "IRModule.h"
#include "Pass.h"

#include "mlir-c/Dialect/AMDGPU.h"
#include "mlir-c/Dialect/Arith.h"
#include "mlir-c/Dialect/Async.h"
#include "mlir-c/Dialect/ControlFlow.h"
#include "mlir-c/Dialect/Func.h"
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/Dialect/LLVM.h"
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/Dialect/MLProgram.h"
#include "mlir-c/Dialect/Math.h"
#include "mlir-c/Dialect/MemRef.h"
#include "mlir-c/Dialect/NVGPU.h"
#include "mlir-c/Dialect/NVVM.h"
#include "mlir-c/Dialect/OpenMP.h"
#include "mlir-c/Dialect/PDL.h"
#include "mlir-c/Dialect/Quant.h"
#include "mlir-c/Dialect/ROCDL.h"
#include "mlir-c/Dialect/SCF.h"
#include "mlir-c/Dialect/Shape.h"
#include "mlir-c/Dialect/SparseTensor.h"
#include "mlir-c/Dialect/Tensor.h"
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/Dialect/Vector.h"

#include "mlir-c/Dialect/RemainingDialects.h"

namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
Expand Down Expand Up @@ -65,6 +91,63 @@ PYBIND11_MODULE(_mlir, m) {
},
"dialect_class"_a,
"Class decorator for registering a custom Dialect wrapper");
m.def(
"add_dialect_to_dialect_registry",
[](MlirDialectRegistry registry, const std::string &dialectNamespace) {

#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE) \
if (dialectNamespace == #NAMESPACE) { \
mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(), \
registry); \
return; \
}

#define FORALL_DIALECTS(_) \
_(acc) \
_(affine) \
_(amdgpu) \
_(amx) \
_(arith) \
_(arm_neon) \
_(arm_sme) \
_(arm_sve) \
_(async) \
_(bufferization) \
_(cf) \
_(complex) \
_(emitc) \
_(func) \
_(gpu) \
_(index) \
_(irdl) \
_(linalg) \
_(llvm) \
_(math) \
_(memref) \
_(mesh) \
_(ml_program) \
_(nvgpu) \
_(nvvm) \
_(omp) \
_(pdl) \
_(quant) \
_(rocdl) \
_(scf) \
_(shape) \
_(spirv) \
_(tensor) \
_(tosa) \
_(ub) \
_(vector) \
_(x86vector)
FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)

#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
#undef FORALL_DIALECTS
throw std::runtime_error("unknown dialect namespace: " +
dialectNamespace);
},
"dialect_registry"_a, "dialect_namespace"_a);
m.def(
"register_operation",
[](const py::object &dialectClass, bool replace) -> py::cpp_function {
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/CAPI/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,14 @@ add_mlir_upstream_c_api_library(MLIRCAPIVector
MLIRCAPIIR
MLIRVectorDialect
)

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

add_mlir_upstream_c_api_library(MLIRCAPIRemainingDialects
RemainingDialects.cpp

PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC
MLIRCAPIIR
${dialect_libs}
)
53 changes: 53 additions & 0 deletions mlir/lib/CAPI/Dialect/RemainingDialects.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "mlir-c/Dialect/RemainingDialects.h"

#include "mlir/CAPI/Registration.h"
#include "mlir/InitAllDialects.h"

using namespace mlir;

#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE, NAME) \
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NAME, NAMESPACE, \
NAMESPACE::NAME##Dialect)

#define FORALL_DIALECTS(_) \
_(acc, OpenACC) \
_(affine, Affine) \
_(amx, AMX) \
_(arith, Arith) \
_(arm_neon, ArmNeon) \
_(arm_sme, ArmSME) \
_(arm_sve, ArmSVE) \
_(bufferization, Bufferization) \
_(complex, Complex) \
_(emitc, EmitC) \
_(index, Index) \
_(irdl, IRDL) \
_(mesh, Mesh) \
_(spirv, SPIRV) \
_(tosa, Tosa) \
_(ub, UB) \
_(x86vector, X86Vector)

FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)

#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
#undef FORALL_DIALECTS

static void mlirDialectRegistryInsertDLTIDialect(MlirDialectRegistry registry) {
unwrap(registry)->insert<mlir::DLTIDialect>();
}

static MlirDialect mlirContextLoadDLTIDialect(MlirContext context) {
return wrap(unwrap(context)->getOrLoadDialect<mlir::DLTIDialect>());
}

static MlirStringRef mlirDLTIDialectGetNamespace() {
return wrap(mlir::DLTIDialect::getDialectNamespace());
}

MlirDialectHandle mlirGetDialectHandle__dlti__() {
static MlirDialectRegistrationHooks hooks = {
mlirDialectRegistryInsertDLTIDialect, mlirContextLoadDLTIDialect,
mlirDLTIDialectGetNamespace};
return MlirDialectHandle{&hooks};
}
2 changes: 0 additions & 2 deletions mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
#include "mlir/InitAllExtensions.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
Comment on lines -17 to -18
Copy link
Contributor Author

@makslevental makslevental Dec 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just dead code (or, alternatively, an IWYU failure somewhere else...)


void mlirRegisterAllDialects(MlirDialectRegistry registry) {
mlir::registerAllDialects(*unwrap(registry));
Expand Down
23 changes: 23 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,30 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
MLIRCAPIInterfaces

# Dialects
MLIRCAPIAMDGPU
MLIRCAPIArith
MLIRCAPIAsync
MLIRCAPIControlFlow
MLIRCAPIFunc
MLIRCAPIGPU
MLIRCAPILLVM
MLIRCAPILinalg
MLIRCAPIMLProgram
MLIRCAPIMath
MLIRCAPIMemRef
MLIRCAPINVGPU
MLIRCAPINVVM
MLIRCAPIOpenMP
MLIRCAPIPDL
MLIRCAPIQuant
MLIRCAPIROCDL
MLIRCAPISCF
MLIRCAPIShape
MLIRCAPISparseTensor
MLIRCAPITensor
MLIRCAPITransformDialect
MLIRCAPIVector
MLIRCAPIRemainingDialects
Comment on lines 425 to +449
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want "Core" bindings to depend on all dialects...

Copy link
Contributor Author

@makslevental makslevental Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that right now most of the dialects don't have their own pybind modules. I was trying to avoid introducing 20 new modules that just have register_dialect in them. I also couldn't put these in RegisterEverything because the mere presence of that module triggers the omnibus loading.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I need to activate some old caches for this, but I think it was this issue that kept me from doing this the first time around.

If we can get the dependency dag such that RegisterEverything carries the weight of "everything" but is strictly optional, I think that is what we are going for.

)

# This extension exposes an API to register all dialects, extensions, and passes
Expand Down
6 changes: 5 additions & 1 deletion mlir/python/mlir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
from ._mlir_libs._mlir import (
register_type_caster,
register_value_caster,
add_dialect_to_dialect_registry,
)
from ._mlir_libs import get_dialect_registry


Expand Down