-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir] Make C/Python ExecutionEngine constructible with an Operation. #86329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir] Make C/Python ExecutionEngine constructible with an Operation. #86329
Conversation
This continues the long deprivileging of mlir.ir.Module as having any semantic meaning. Given the potential for silent/deadly failures by changing a C API signature, I added a new C API entrypoint with a new name and marked the original as deprecated. The `ExecutionEngine()` constructor was extended to accept either a `Module` or an `Operation`, so there should be no user-level API breakage. Test was added to verify. Python ExecutionEngine tests were modernized to use `Operation.parse` and explicit outer modules.
@llvm/pr-subscribers-mlir-execution-engine @llvm/pr-subscribers-mlir Author: Stella Laurenzo (stellaraccident) ChangesThis continues the long deprivileging of mlir.ir.Module as having any semantic meaning. Given the potential for silent/deadly failures by changing a C API signature, I added a new C API entrypoint with a new name and marked the original as deprecated. The Python ExecutionEngine tests were modernized to use Full diff: https://github.com/llvm/llvm-project/pull/86329.diff 5 Files Affected:
diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h
index 99cddc5c2598d4..311451a029181b 100644
--- a/mlir/include/mlir-c/ExecutionEngine.h
+++ b/mlir/include/mlir-c/ExecutionEngine.h
@@ -42,8 +42,15 @@ DEFINE_C_API_STRUCT(MlirExecutionEngine, void);
/// that will be loaded are specified via `numPaths` and `sharedLibPaths`
/// respectively.
/// TODO: figure out other options.
+MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreateFromOp(
+ MlirOperation op, int optLevel, int numPaths,
+ const MlirStringRef *sharedLibPaths, bool enableObjectDump);
+
+// Deprecated variant which takes an MlirModule instead of an operation.
+// This is being preserved as of 2024-Mar for short term consistency and should
+// be dropped soon.
MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(
- MlirModule op, int optLevel, int numPaths,
+ MlirModule module, int optLevel, int numPaths,
const MlirStringRef *sharedLibPaths, bool enableObjectDump);
/// Destroy an ExecutionEngine instance.
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index b3df30583fc963..9ed5ee80f97f8b 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -71,15 +71,34 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) {
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
py::class_<PyExecutionEngine>(m, "ExecutionEngine", py::module_local())
- .def(py::init<>([](MlirModule module, int optLevel,
+ .def(py::init<>([](py::object operation_or_module, int optLevel,
const std::vector<std::string> &sharedLibPaths,
bool enableObjectDump) {
+ // Manually type cast from either a Module or Operation. The
+ // automatic type casters do not handle such cascades well,
+ // so be explicit.
+ py::object capsule = mlirApiObjectToCapsule(operation_or_module);
+ MlirOperation module_op =
+ mlirPythonCapsuleToOperation(capsule.ptr());
+ if (mlirOperationIsNull(module_op)) {
+ // If null, then a PyErr_Set has set an exception, which we must
+ // clear.
+ PyErr_Clear();
+ MlirModule mod = mlirPythonCapsuleToModule(capsule.ptr());
+ if (mlirModuleIsNull(mod)) {
+ throw py::type_error(
+ "ExecutionEngine expects a Module or Operation");
+ }
+ module_op = mlirModuleGetOperation(mod);
+ }
+
llvm::SmallVector<MlirStringRef, 4> libPaths;
for (const std::string &path : sharedLibPaths)
libPaths.push_back({path.c_str(), path.length()});
MlirExecutionEngine executionEngine =
- mlirExecutionEngineCreate(module, optLevel, libPaths.size(),
- libPaths.data(), enableObjectDump);
+ mlirExecutionEngineCreateFromOp(
+ module_op, optLevel, libPaths.size(), libPaths.data(),
+ enableObjectDump);
if (mlirExecutionEngineIsNull(executionEngine))
throw std::runtime_error(
"Failure while creating the ExecutionEngine.");
diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
index 507be9171d328d..8bd7e8b354f341 100644
--- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
@@ -20,9 +20,18 @@
using namespace mlir;
extern "C" MlirExecutionEngine
-mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
+mlirExecutionEngineCreate(MlirModule module, int optLevel, int numPaths,
const MlirStringRef *sharedLibPaths,
bool enableObjectDump) {
+ return mlirExecutionEngineCreateFromOp(mlirModuleGetOperation(module),
+ optLevel, numPaths, sharedLibPaths,
+ enableObjectDump);
+}
+
+extern "C" MlirExecutionEngine
+mlirExecutionEngineCreateFromOp(MlirOperation op, int optLevel, int numPaths,
+ const MlirStringRef *sharedLibPaths,
+ bool enableObjectDump) {
static bool initOnce = [] {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm
@@ -104,9 +113,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
void *sym) {
unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
llvm::orc::SymbolMap symbolMap;
- symbolMap[interner(unwrap(name))] =
- { llvm::orc::ExecutorAddr::fromPtr(sym),
- llvm::JITSymbolFlags::Exported };
+ symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym),
+ llvm::JITSymbolFlags::Exported};
return symbolMap;
});
}
diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
index 893dab8a431fd1..c32b5db13241c0 100644
--- a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
@@ -4,7 +4,7 @@
# * Relative imports for cross-module references.
# * Add __all__
-from typing import List, Sequence
+from typing import List, Sequence,Union
from ._mlir import ir as _ir
@@ -13,7 +13,7 @@ __all__ = [
]
class ExecutionEngine:
- def __init__(self, module: _ir.Module, opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ...
+ def __init__(self, module: Union[_ir.Operation, _ir.Module], opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ...
def _CAPICreate(self) -> object: ...
def _testing_release(self) -> None: ...
def dump_to_object_file(self, file_name: str) -> None: ...
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e8b47007a8907d..647e6667b69a34 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -21,17 +21,46 @@ def run(f):
assert Context._get_live_count() == 0
-# Verify capsule interop.
-# CHECK-LABEL: TEST: testCapsule
-def testCapsule():
+# Verify capsule interop for passing an Operation.
+# CHECK-LABEL: TEST: testAcceptsOperation
+def testAcceptsOperation():
+ with Context():
+ module = Operation.parse(
+ r"""
+builtin.module {
+llvm.func @none() {
+llvm.return
+}
+}
+ """
+ )
+ execution_engine = ExecutionEngine(module)
+ execution_engine_capsule = execution_engine._CAPIPtr
+ # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
+ log(repr(execution_engine_capsule))
+ execution_engine._testing_release()
+ execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
+ # CHECK: _mlirExecutionEngine.ExecutionEngine
+ log(repr(execution_engine1))
+
+
+run(testAcceptsOperation)
+
+
+# Verify capsule interop for passing a Module.
+# CHECK-LABEL: TEST: testAcceptsModule
+def testAcceptsModule():
with Context():
module = Module.parse(
r"""
+builtin.module {
llvm.func @none() {
- llvm.return
+llvm.return
+}
}
"""
)
+ print("MODULE:", type(module))
execution_engine = ExecutionEngine(module)
execution_engine_capsule = execution_engine._CAPIPtr
# CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
@@ -42,7 +71,7 @@ def testCapsule():
log(repr(execution_engine1))
-run(testCapsule)
+run(testAcceptsModule)
# Test invalid ExecutionEngine creation
@@ -50,9 +79,11 @@ def testCapsule():
def testInvalidModule():
with Context():
# Builtin function
- module = Module.parse(
+ module = Operation.parse(
r"""
+ builtin.module {
func.func @foo() { return }
+ }
"""
)
# CHECK: Got RuntimeError: Failure while creating the ExecutionEngine.
@@ -69,7 +100,7 @@ def lowerToLLVM(module):
pm = PassManager.parse(
"builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)"
)
- pm.run(module.operation)
+ pm.run(module)
return module
@@ -77,10 +108,12 @@ def lowerToLLVM(module):
# CHECK-LABEL: TEST: testInvokeVoid
def testInvokeVoid():
with Context():
- module = Module.parse(
+ module = Operation.parse(
r"""
+builtin.module {
func.func @void() attributes { llvm.emit_c_interface } {
return
+}
}
"""
)
@@ -96,11 +129,13 @@ def testInvokeVoid():
# CHECK-LABEL: TEST: testInvokeFloatAdd
def testInvokeFloatAdd():
with Context():
- module = Module.parse(
+ module = Operation.parse(
r"""
+builtin.module {
func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
%add = arith.addf %arg0, %arg1 : f32
return %add : f32
+}
}
"""
)
@@ -129,13 +164,15 @@ def callback(a, b):
with Context():
# The module just forwards to a runtime function known as "some_callback_into_python".
- module = Module.parse(
+ module = Operation.parse(
r"""
+builtin.module {
func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
%resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
return %resf : f32
}
func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
+}
"""
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -168,13 +205,15 @@ def callback(a):
with Context():
# The module just forwards to a runtime function known as "some_callback_into_python".
- module = Module.parse(
+ module = Operation.parse(
r"""
+builtin.module {
func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
return
}
func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
+}
"""
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -221,13 +260,15 @@ def callback(a):
with Context():
# The module just forwards to a runtime function known as "some_callback_into_python".
- module = Module.parse(
+ module = Operation.parse(
r"""
+builtin.module {
func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
return
}
func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
+}
"""
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -262,8 +303,9 @@ def callback(a):
with Context():
# The module takes a subview of the argument memref and calls the callback with it
- module = Module.parse(
+ module = Operation.parse(
r"""
+builtin.module {
func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
%base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
@@ -272,6 +314,7 @@ def callback(a):
return
}
func.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
+}
"""
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -301,8 +344,9 @@ def callback(a):
with Context():
# The module takes a subview of the argument memref, casts it to an unranked memref and
# calls the callback with it.
- module = Module.parse(
+ module = Operation.parse(
r"""
+builtin.module {
func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
%base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
@@ -311,6 +355,7 @@ def callback(a):
return
}
func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface}
+}
"""
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -330,9 +375,9 @@ def callback(a):
# CHECK-LABEL: TEST: testMemrefAdd
def testMemrefAdd():
with Context():
- module = Module.parse(
+ module = Operation.parse(
"""
- module {
+ builtin.module {
func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
%0 = arith.constant 0 : index
%1 = memref.load %arg0[%0] : memref<1xf32>
@@ -372,9 +417,9 @@ def testMemrefAdd():
# CHECK-LABEL: TEST: testF16MemrefAdd
def testF16MemrefAdd():
with Context():
- module = Module.parse(
+ module = Operation.parse(
"""
- module {
+ builtin.module {
func.func @main(%arg0: memref<1xf16>,
%arg1: memref<1xf16>,
%arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
@@ -422,9 +467,9 @@ def testF16MemrefAdd():
# CHECK-LABEL: TEST: testComplexMemrefAdd
def testComplexMemrefAdd():
with Context():
- module = Module.parse(
+ module = Operation.parse(
"""
- module {
+ builtin.module {
func.func @main(%arg0: memref<1xcomplex<f64>>,
%arg1: memref<1xcomplex<f64>>,
%arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
@@ -472,9 +517,9 @@ def testComplexMemrefAdd():
# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
def testComplexUnrankedMemrefAdd():
with Context():
- module = Module.parse(
+ module = Operation.parse(
"""
- module {
+ builtin.module {
func.func @main(%arg0: memref<*xcomplex<f32>>,
%arg1: memref<*xcomplex<f32>>,
%arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
@@ -525,9 +570,9 @@ def testComplexUnrankedMemrefAdd():
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
def testDynamicMemrefAdd2D():
with Context():
- module = Module.parse(
+ module = Operation.parse(
"""
- module {
+ builtin.module {
func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
@@ -589,9 +634,9 @@ def testDynamicMemrefAdd2D():
# CHECK-LABEL: TEST: testSharedLibLoad
def testSharedLibLoad():
with Context():
- module = Module.parse(
+ module = Operation.parse(
"""
- module {
+ builtin.module {
func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
%c0 = arith.constant 0 : index
%cst42 = arith.constant 42.0 : f32
@@ -640,9 +685,9 @@ def testSharedLibLoad():
# CHECK-LABEL: TEST: testNanoTime
def testNanoTime():
with Context():
- module = Module.parse(
+ module = Operation.parse(
"""
- module {
+ builtin.module {
func.func @main() attributes { llvm.emit_c_interface } {
%now = call @nanoTime() : () -> i64
%memref = memref.alloca() : memref<1xi64>
@@ -686,9 +731,9 @@ def testDumpToObjectFile():
try:
with Context():
- module = Module.parse(
+ module = Operation.parse(
"""
- module {
+ builtin.module {
func.func @main() attributes { llvm.emit_c_interface } {
return
}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
✅ With the latest revision this PR passed the Python code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good in principle/"depriviledging" SGTM
py::object capsule = mlirApiObjectToCapsule(operation_or_module); | ||
MlirOperation module_op = | ||
mlirPythonCapsuleToOperation(capsule.ptr()); | ||
if (mlirOperationIsNull(module_op)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing ... I'd have expected this to be a check if not-null.
module = Operation.parse( | ||
r""" | ||
builtin.module { | ||
llvm.func @none() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: indent body
llvm.func @none() { | ||
llvm.return | ||
llvm.return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
} | ||
""" | ||
) | ||
print("MODULE:", type(module)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leftover debugging?
This continues the long deprivileging of
mlir.ir.Module
as having any semantic meaning. Given the potential for silent/deadly failures by changing a C API signature, I added a new C API entrypoint with a new name and marked the original as deprecated.The Python
ExecutionEngine()
constructor was extended to accept either aModule
or anOperation
, so there should be no user-level API breakage. Test was added to verify.Python ExecutionEngine tests were modernized to use
Operation.parse
and explicit outer modules.