Skip to content

[mlir][python] allow upstream dialect registration #74252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Dec 3, 2023

This (currently just a sketch) PR allows users that do not build the MLIRPythonExtension.RegisterEverything target to still register upstream dialects from Python (discussed in #74245). It uses the get_dialect_registry API introduced in #72488 and adds/exposes MlirDialectHandles for all upstream dialects (i.e., fills in those that weren't already generating).

Right now the API is lame - pass a string that's associated with the dialect (the dialect namespace) and the MlirDialectHandle is fetched/gotten/materialized internally (by calling the appropriate C API). Ignoring whether we should keep it like this, the basic idea is captured in the m.def:

  m.def(
      "add_dialect_to_dialect_registry",
      [](MlirDialectRegistry registry, const std::string &dialectNamespace) {

#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE)                      \
  if (dialectNamespace == #NAMESPACE) {                                        \
    mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(),    \
                                   registry);                                  \
    return;                                                                    \
  }

where (currently!) the user would do ir.add_dialect_to_dialect_registry(get_dialect_registry(), "arith") and achieve the desired effect i.e., the next ir.Context() will include the arith dialect in its dialect registry. At least I believe that's the desired effect (gleaned from #74245). So before I do more manicuring on the API I just want to make sure this is what everyone wants.

Also I'm not sure how to test this upstream because it's not like I can turn off MLIRPythonExtension.RegisterEverything for the core bindings. There is the Standalone example that conveniently a TODO about moving to finer grained registration so maybe now is the time to check off that TODO?

This is now tested by removing MLIRPythonExtension.RegisterEverything from the standalone example and calling add_dialect_to_dialect_registry(get_dialect_registry(), "arith") explicitly (the smoketest has a module with an arith op in it).

cc @hawkinsp

@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2023

@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

This (currently just a sketch) PR allows users that do not build the MLIRPythonExtension.RegisterEverything target to still register upstream dialects from Python (discussed in #74245). It uses the get_dialect_registry API introduced in #72488 and adds/exposes MlirDialectHandles for all upstream dialects.

Right now the API is lame - pass a string that's associated with the dialect (the dialect namespace) and the MlirDialectHandle is fetched/gotten/materialized internally (by calling the appropriate C API). Ignoring whether we should keep it like this, the basic idea is captured in the m.def:

  m.def(
      "add_dialect_to_dialect_registry",
      [](MlirDialectRegistry registry, const std::string &dialectNamespace) {

#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE)                      \
  if (dialectNamespace == #NAMESPACE) {                                        \
    mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(),    \
                                   registry);                                  \
    return;                                                                    \
  }

Before I do more manicuring on the API I just want to make sure this is what everyone wants. Also I'm not sure how to test this upstream because it's not like I can turn off MLIRPythonExtension.RegisterEverything for the core bindings. There is the Standalone example that conveniently a TODO about moving to finer grained registration so maybe now is the time to check off that TODO?


Full diff: https://github.com/llvm/llvm-project/pull/74252.diff

3 Files Affected:

  • (modified) mlir/include/mlir-c/RegisterEverything.h (+26)
  • (modified) mlir/lib/Bindings/Python/MainModule.cpp (+82)
  • (modified) mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp (+32-2)
diff --git a/mlir/include/mlir-c/RegisterEverything.h b/mlir/include/mlir-c/RegisterEverything.h
index ea2ea86449727..f894419ecb1e4 100644
--- a/mlir/include/mlir-c/RegisterEverything.h
+++ b/mlir/include/mlir-c/RegisterEverything.h
@@ -31,6 +31,32 @@ MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context);
 /// Register all compiler passes of MLIR.
 MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(void);
 
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE)                      \
+  MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##NAMESPACE##__();
+
+#define FORALL_DIALECTS(_)                                                     \
+  _(acc)                                                                       \
+  _(affine)                                                                    \
+  _(amx)                                                                       \
+  _(arm_neon)                                                                  \
+  _(arm_sme)                                                                   \
+  _(arm_sve)                                                                   \
+  _(bufferization)                                                             \
+  _(complex)                                                                   \
+  _(emitc)                                                                     \
+  _(index)                                                                     \
+  _(irdl)                                                                      \
+  _(mesh)                                                                      \
+  _(spirv)                                                                     \
+  _(tosa)                                                                      \
+  _(ub)                                                                        \
+  _(x86vector)
+
+FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 17272472ccca4..4605ccc0f2935 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,6 +12,31 @@
 #include "IRModule.h"
 #include "Pass.h"
 
+#include "mlir-c/Dialect/AMDGPU.h"
+#include "mlir-c/Dialect/Arith.h"
+#include "mlir-c/Dialect/Async.h"
+#include "mlir-c/Dialect/ControlFlow.h"
+#include "mlir-c/Dialect/Func.h"
+#include "mlir-c/Dialect/GPU.h"
+#include "mlir-c/Dialect/LLVM.h"
+#include "mlir-c/Dialect/Linalg.h"
+#include "mlir-c/Dialect/MLProgram.h"
+#include "mlir-c/Dialect/Math.h"
+#include "mlir-c/Dialect/MemRef.h"
+#include "mlir-c/Dialect/NVGPU.h"
+#include "mlir-c/Dialect/NVVM.h"
+#include "mlir-c/Dialect/OpenMP.h"
+#include "mlir-c/Dialect/PDL.h"
+#include "mlir-c/Dialect/Quant.h"
+#include "mlir-c/Dialect/ROCDL.h"
+#include "mlir-c/Dialect/SCF.h"
+#include "mlir-c/Dialect/Shape.h"
+#include "mlir-c/Dialect/SparseTensor.h"
+#include "mlir-c/Dialect/Tensor.h"
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir-c/Dialect/Vector.h"
+#include "mlir-c/RegisterEverything.h"
+
 namespace py = pybind11;
 using namespace mlir;
 using namespace py::literals;
@@ -65,6 +90,63 @@ PYBIND11_MODULE(_mlir, m) {
       },
       "dialect_class"_a,
       "Class decorator for registering a custom Dialect wrapper");
+  m.def(
+      "add_dialect_to_dialect_registry",
+      [](MlirDialectRegistry registry, const std::string &dialectNamespace) {
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE)                      \
+  if (dialectNamespace == #NAMESPACE) {                                        \
+    mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(),    \
+                                   registry);                                  \
+    return;                                                                    \
+  }
+
+#define FORALL_DIALECTS(_)                                                     \
+  _(acc)                                                                       \
+  _(affine)                                                                    \
+  _(amdgpu)                                                                    \
+  _(amx)                                                                       \
+  _(arith)                                                                     \
+  _(arm_neon)                                                                  \
+  _(arm_sme)                                                                   \
+  _(arm_sve)                                                                   \
+  _(async)                                                                     \
+  _(bufferization)                                                             \
+  _(cf)                                                                        \
+  _(complex)                                                                   \
+  _(emitc)                                                                     \
+  _(func)                                                                      \
+  _(gpu)                                                                       \
+  _(index)                                                                     \
+  _(irdl)                                                                      \
+  _(linalg)                                                                    \
+  _(llvm)                                                                      \
+  _(math)                                                                      \
+  _(memref)                                                                    \
+  _(mesh)                                                                      \
+  _(ml_program)                                                                \
+  _(nvgpu)                                                                     \
+  _(nvvm)                                                                      \
+  _(omp)                                                                       \
+  _(pdl)                                                                       \
+  _(quant)                                                                     \
+  _(rocdl)                                                                     \
+  _(scf)                                                                       \
+  _(shape)                                                                     \
+  _(spirv)                                                                     \
+  _(tensor)                                                                    \
+  _(tosa)                                                                      \
+  _(ub)                                                                        \
+  _(vector)                                                                    \
+  _(x86vector)
+        FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+        throw std::runtime_error("unknown dialect namespace: " +
+                                 dialectNamespace);
+      },
+      "dialect_registry"_a, "dialect_namespace"_a);
   m.def(
       "register_operation",
       [](const py::object &dialectClass, bool replace) -> py::cpp_function {
diff --git a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
index c1c4a418b2552..debebe58ad64f 100644
--- a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
+++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
@@ -9,13 +9,43 @@
 #include "mlir-c/RegisterEverything.h"
 
 #include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Registration.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/InitAllDialects.h"
 #include "mlir/InitAllExtensions.h"
 #include "mlir/InitAllPasses.h"
 #include "mlir/Target/LLVMIR/Dialect/All.h"
-#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+
+using namespace mlir;
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE, NAME)                \
+  MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NAME, NAMESPACE,                       \
+                                        NAMESPACE::NAME##Dialect)
+
+#define FORALL_DIALECTS(_)                                                     \
+  _(acc, OpenACC)                                                              \
+  _(affine, Affine)                                                            \
+  _(amx, AMX)                                                                  \
+  _(arith, Arith)                                                              \
+  _(arm_neon, ArmNeon)                                                         \
+  _(arm_sme, ArmSME)                                                           \
+  _(arm_sve, ArmSVE)                                                           \
+  _(bufferization, Bufferization)                                              \
+  _(complex, Complex)                                                          \
+  _(emitc, EmitC)                                                              \
+  _(mlir, DLTI)                                                                \
+  _(index, Index)                                                              \
+  _(irdl, IRDL)                                                                \
+  _(mesh, Mesh)                                                                \
+  _(spirv, SPIRV)                                                              \
+  _(tosa, Tosa)                                                                \
+  _(ub, UB)                                                                    \
+  _(x86vector, X86Vector)
+
+FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
 
 void mlirRegisterAllDialects(MlirDialectRegistry registry) {
   mlir::registerAllDialects(*unwrap(registry));

Comment on lines -17 to -18
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
Copy link
Contributor Author

@makslevental makslevental Dec 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just dead code (or, alternatively, an IWYU failure somewhere else...)

@makslevental makslevental requested a review from jpienaar December 3, 2023 21:45
@makslevental makslevental force-pushed the upstream_dialect_registration branch from f732927 to bc9e216 Compare December 3, 2023 22:14
@makslevental makslevental force-pushed the upstream_dialect_registration branch 2 times, most recently from f53d4ff to bb78956 Compare December 4, 2023 00:45
@makslevental makslevental force-pushed the upstream_dialect_registration branch from bb78956 to 719695d Compare December 4, 2023 01:30
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the intent here, seems to be the usual confusion with dialect registration and loading. Does this intend to register the dialects, i.e. make them available to the parser? Does this rather intend to load the dialects into the context, i.e. allow one to construct objects from these dialects? For the latter, I'd look at what downstreams already do with various load_<downstream>_dialects(context) functions and reuse that, modulo some homogenization, for upstream dialects.

As a vague design suggestion, I would find it convenient to have from mlir.dialects import foo to have foo loaded in whatever context I create afterwards and from mlir.dialects.foo import parser to have foo registered (but not loaded) in whatever context I create afterwards + have a mechanism to override this behavior should we need a clean context for some reason. But maybe it's too magic for other folks.

Comment on lines 425 to +449
# Dialects
MLIRCAPIAMDGPU
MLIRCAPIArith
MLIRCAPIAsync
MLIRCAPIControlFlow
MLIRCAPIFunc
MLIRCAPIGPU
MLIRCAPILLVM
MLIRCAPILinalg
MLIRCAPIMLProgram
MLIRCAPIMath
MLIRCAPIMemRef
MLIRCAPINVGPU
MLIRCAPINVVM
MLIRCAPIOpenMP
MLIRCAPIPDL
MLIRCAPIQuant
MLIRCAPIROCDL
MLIRCAPISCF
MLIRCAPIShape
MLIRCAPISparseTensor
MLIRCAPITensor
MLIRCAPITransformDialect
MLIRCAPIVector
MLIRCAPIRemainingDialects
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want "Core" bindings to depend on all dialects...

Copy link
Contributor Author

@makslevental makslevental Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that right now most of the dialects don't have their own pybind modules. I was trying to avoid introducing 20 new modules that just have register_dialect in them. I also couldn't put these in RegisterEverything because the mere presence of that module triggers the omnibus loading.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I need to activate some old caches for this, but I think it was this issue that kept me from doing this the first time around.

If we can get the dependency dag such that RegisterEverything carries the weight of "everything" but is strictly optional, I think that is what we are going for.

@makslevental
Copy link
Contributor Author

makslevental commented Dec 4, 2023

Does this intend to register the dialects, i.e. make them available to the parser? Does this rather intend to load the dialects into the context, i.e. allow one to construct objects from these dialects?

For upstream dialects, registration is sufficient because the auto-loading mechanisms are infallible. That's why I proposed this version rather than the one alluded to in the issue (in terms of loading) because this version can indeed be a side-effecting import, which as you say is very fine UX.

mechanism to override this behavior should we need a clean context for some reason

That's doable since now the default registry is an exposed global.

@makslevental
Copy link
Contributor Author

@hawkinsp just pinging again to get your feedback.

@stellaraccident
Copy link
Contributor

As a vague design suggestion, I would find it convenient to have from mlir.dialects import foo to have foo loaded in whatever context I create afterwards and from mlir.dialects.foo import parser to have foo registered (but not loaded) in whatever context I create afterwards + have a mechanism to override this behavior should we need a clean context for some reason. But maybe it's too magic for other folks.

I think the issue is that we use this in various jit and multi-client flows where "after" has no defined meaning -- it is all concurrent. I'm certainly not against doing something more ergonomic for this class of things but would like to avoid non-deterministic behavior in concurrent situations.

@nicolasvasilache
Copy link
Contributor

I've been studying this a bit and digging deeper I find the current situation unfortunate and inconsistent.

Atm it seems we have a :

  1. there is a @_ods_cext.register_dialect that seems to be a registration decorator, it seems related to https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/MainModule.cpp#L59 (?) but I am not sure what to do with it
  2. the current register the world and/or link the world does not seem appealing.
  3. context.get_dialect_descriptor is wrapping getOrLoadDialect (why the name change btw?) and requires pre-registration.
  4. context.allow_unregistered_dialects, I am sure many people are using this blanket solution ..

I also see that only a tiny subset of dialects have a dedicated module (in mlir/lib/Bindings/Python/DialectXXX.cpp) making things inconsistent.
Could we just add a module per dialect with a single register_dialect(bool load = false) function ?
This would also provide a location for grounding all future dialect module things and all per-dialect pybind things.

I imagine this could be tablegen'd but a low-tech solution can already improve things significantly.

@makslevental
Copy link
Contributor Author

makslevental commented Dec 20, 2023

there is a @_ods_cext.register_dialect that seems to be a registration decorator, it seems related to https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/MainModule.cpp#L59 (?) but I am not sure what to do with it

This is related to "loading" dialects when it comes to the python bindings i.e., how op.operation.opview materializes one of classes in the generated bindings. That decorator itself doesn't do much except decouple the dialect name from generated module name (arith vs _arith_ops_gen.py).

the current register the world and/or link the world does not seem appealing.

Linking is a rabbit hole that I've explored pretty extensively. The issue (as I see it) is you want to enable dialects (upstream and downstream) to implement bindings posthoc i.e., in addition to whatever builtin dialects are currently handled by the _mlir extension. The current theory/approach on how to accomplish this necessitates many/several separate C extensions to be built, which leads to this asymmetry where upstream dialects are registered automatically/wholesale and downstream require explicit register_MYDIALECT_dialect(ctx, /*load*/true) call, where register_MYDIALECT_dialect is exposed through their C extension. Note that the separate modules in upstream for transform and pdl and quant don't have anything to do with registration but are about exposing various types.

If the question is about piece-meal registration then this PR could be the vehicle but I don't think sprawling out into ~20 separate C extension modules (dumped into _mlir_libs) is the answer. It wouldn't be hard to turn off register_all and just have ~20 separate register_arith_dialect, register_scf_dialect, etc.

If the question is really about linking i.e. people want to build bindings that don't have any object code for affine but do for scf that's a much much harder task because right now all that code doesn't live in the C extension but actually in the AggregateCAPI omnibus shlib which is opt-in from each dialect's perspective (that's what ENABLE_AGGREGATION on add_mlir_public_c_api_library does). So by the time you get to building the bindings that shlib is already fixed.

context.get_dialect_descriptor is wrapping getOrLoadDialect (why the name change btw?) and requires pre-registration.

I guess your issue is you'd like for there to be a something like context.register_and_load_and_get_dialect_descriptor(dialect_name)? Or just context.register_and_load(dialect_name)? The problem is you can't do that unless you just somewhere have a map from dialect_name -> MlirDialectHandle for all the dialects you're interested in. That's certainly possible for upstream dialects but isn't for downstream (where you have to actually generate/sign up for mlirGetDialectHandle__YOURDIALECT__). And then doing it just for upstream would again further enshrine the asymmetry. If these were generated by default then we could actually put a global map somewhere in the AggregateCAPI and add to it.

And to be honest, we should just start generating these. We've talked about generating C API for a while now and so we should just do it and this is a reasonable enough testcase.

context.allow_unregistered_dialects, I am sure many people are using this blanket solution

You can actually go further - if you turn on allow_unregistered_dialects you can just copy-paste generated bindings from somewhere and they'll generate legal "unregistered IR" that'll then (assuming you copy-pasted recently) verify in a context that does have the dialects registered - I've done this for triton when I want to generate triton IR using my bindings but I don't want to build/link triton (so I just pass textual IR to their package/module/C extension).

I also see that only a tiny subset of dialects have a dedicated module (in mlir/lib/Bindings/Python/DialectXXX.cpp) making things inconsistent. Could we just add a module per dialect with a single register_dialect(bool load = false) function ? This would also provide a location for grounding all future dialect module things and all per-dialect pybind things.

You could but you're not going to save yourself any linked baggage as I mentioned above.

I imagine this could be tablegen'd but a low-tech solution can already improve things significantly.

Then you'll end up in a world where some people want to write their own C extensions (in order to pick and choose which types/attributes etc to expose) and some people don't. Not a serious barrier except I cringe at the CMake shenaningans to line up all the object files.

After writing all this out I'm basically in the same place I was when I opened this PR: what exactly do people want???. I think if someone (@nicolasvasilache or @hawkinsp) could articulate what their goal is it would go a long way to deciding how to proceed. Just "fixing asymmetry" is a nice goal but not great motivation to rearchitect :)

@stellaraccident
Copy link
Contributor

stellaraccident commented Dec 20, 2023

This part of the system was, frankly, never finished. Every stakeholder wanted something else and the desire to not end up in a linking-the-world-required state in the upstream world made this hard.

I think the solution may be relatively simple: Treat "RegisterEverything" as something that everyone will want one of -- but make it so that everyone can have their own. Right now "RegisterEverything" is hard coded to be literally all upstream dialects, but what if there was just an easy way to, given a list of dialects, generate the registration handles and statically populate the name->handle map. Then upstream uses that rule to make its "RegisterEverything". Downstreams generate their own C code for that using the same generator, etc. I'd much prefer one C implementation file that needed to be generated for a list of dialects vs the current state that it trends to where there are tons of little pieces that have to fit together.

If we had that, it could replace the ad-hoc _site_initialize stuff for most users, and the way upstream does it would be the same as downstreams. Could ultimately be one CMake target to generate the RegisterEverything.c.

Edit:

Or we go even simpler. If you want your very own RegisterEverything in your project, then you just put a cpp file in your bindings that, well, registers everything you want. One function exported and that is it. And if you have that, the rest just works. If you can't be bothered, you just copy the one upstream has, add the one or two lines you need, put it next to your CMakeLists and done.

Basically, instead of trying to factor RegisterEverything out, let's just make that the norm but make it controllable.

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking this as Request Changes because I think this PR does a good job of showing us why we have a problem -- but we need to think more on the solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants