|
12 | 12 | #include "IRModule.h"
|
13 | 13 | #include "Pass.h"
|
14 | 14 |
|
| 15 | +#include "mlir-c/Dialect/AMDGPU.h" |
| 16 | +#include "mlir-c/Dialect/Arith.h" |
| 17 | +#include "mlir-c/Dialect/Async.h" |
| 18 | +#include "mlir-c/Dialect/ControlFlow.h" |
| 19 | +#include "mlir-c/Dialect/Func.h" |
| 20 | +#include "mlir-c/Dialect/GPU.h" |
| 21 | +#include "mlir-c/Dialect/LLVM.h" |
| 22 | +#include "mlir-c/Dialect/Linalg.h" |
| 23 | +#include "mlir-c/Dialect/MLProgram.h" |
| 24 | +#include "mlir-c/Dialect/Math.h" |
| 25 | +#include "mlir-c/Dialect/MemRef.h" |
| 26 | +#include "mlir-c/Dialect/NVGPU.h" |
| 27 | +#include "mlir-c/Dialect/NVVM.h" |
| 28 | +#include "mlir-c/Dialect/OpenMP.h" |
| 29 | +#include "mlir-c/Dialect/PDL.h" |
| 30 | +#include "mlir-c/Dialect/Quant.h" |
| 31 | +#include "mlir-c/Dialect/ROCDL.h" |
| 32 | +#include "mlir-c/Dialect/SCF.h" |
| 33 | +#include "mlir-c/Dialect/Shape.h" |
| 34 | +#include "mlir-c/Dialect/SparseTensor.h" |
| 35 | +#include "mlir-c/Dialect/Tensor.h" |
| 36 | +#include "mlir-c/Dialect/Transform.h" |
| 37 | +#include "mlir-c/Dialect/Vector.h" |
| 38 | + |
| 39 | +#include "mlir-c/Dialect/RemainingDialects.h" |
| 40 | + |
15 | 41 | namespace py = pybind11;
|
16 | 42 | using namespace mlir;
|
17 | 43 | using namespace py::literals;
|
@@ -65,6 +91,63 @@ PYBIND11_MODULE(_mlir, m) {
|
65 | 91 | },
|
66 | 92 | "dialect_class"_a,
|
67 | 93 | "Class decorator for registering a custom Dialect wrapper");
|
| 94 | + m.def( |
| 95 | + "add_dialect_to_dialect_registry", |
| 96 | + [](MlirDialectRegistry registry, const std::string &dialectNamespace) { |
| 97 | + |
| 98 | +#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE) \ |
| 99 | + if (dialectNamespace == #NAMESPACE) { \ |
| 100 | + mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(), \ |
| 101 | + registry); \ |
| 102 | + return; \ |
| 103 | + } |
| 104 | + |
| 105 | +#define FORALL_DIALECTS(_) \ |
| 106 | + _(acc) \ |
| 107 | + _(affine) \ |
| 108 | + _(amdgpu) \ |
| 109 | + _(amx) \ |
| 110 | + _(arith) \ |
| 111 | + _(arm_neon) \ |
| 112 | + _(arm_sme) \ |
| 113 | + _(arm_sve) \ |
| 114 | + _(async) \ |
| 115 | + _(bufferization) \ |
| 116 | + _(cf) \ |
| 117 | + _(complex) \ |
| 118 | + _(emitc) \ |
| 119 | + _(func) \ |
| 120 | + _(gpu) \ |
| 121 | + _(index) \ |
| 122 | + _(irdl) \ |
| 123 | + _(linalg) \ |
| 124 | + _(llvm) \ |
| 125 | + _(math) \ |
| 126 | + _(memref) \ |
| 127 | + _(mesh) \ |
| 128 | + _(ml_program) \ |
| 129 | + _(nvgpu) \ |
| 130 | + _(nvvm) \ |
| 131 | + _(omp) \ |
| 132 | + _(pdl) \ |
| 133 | + _(quant) \ |
| 134 | + _(rocdl) \ |
| 135 | + _(scf) \ |
| 136 | + _(shape) \ |
| 137 | + _(spirv) \ |
| 138 | + _(tensor) \ |
| 139 | + _(tosa) \ |
| 140 | + _(ub) \ |
| 141 | + _(vector) \ |
| 142 | + _(x86vector) |
| 143 | + FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_) |
| 144 | + |
| 145 | +#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_ |
| 146 | +#undef FORALL_DIALECTS |
| 147 | + throw std::runtime_error("unknown dialect namespace: " + |
| 148 | + dialectNamespace); |
| 149 | + }, |
| 150 | + "dialect_registry"_a, "dialect_namespace"_a); |
68 | 151 | m.def(
|
69 | 152 | "register_operation",
|
70 | 153 | [](const py::object &dialectClass, bool replace) -> py::cpp_function {
|
|
0 commit comments