Skip to content

[mlir][polynomial] python bindings #93109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mlir/include/mlir-c/Dialect/Polynomial.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//===-- mlir-c/Dialect/Polynomial.h - C API for Polynomial --------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_C_DIALECT_POLYNOMIAL_H
#define MLIR_C_DIALECT_POLYNOMIAL_H

#include "mlir/CAPI/Wrap.h"
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
Comment on lines +13 to +14
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't include headers from include/mlir inside include/mlir-c.


#ifdef __cplusplus
extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Polynomial, polynomial);

#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
}; \
typedef struct name name
Comment on lines +22 to +26
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including mlir-c/IR.h should bring this macro. Normally we don't have it redefined in dialects.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We define and undef in all usages at the moment. Given the complexity it would seem better to avoid relying on macro def across headers.


DEFINE_C_API_STRUCT(MlirIntMonomial, void);

#undef DEFINE_C_API_STRUCT

DEFINE_C_API_PTR_METHODS(MlirIntMonomial, mlir::polynomial::IntMonomial);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this will compile with C compiler either


MLIR_CAPI_EXPORTED MlirIntMonomial mlirPolynomialGetIntMonomial(int64_t coeff,
uint64_t expo);

MLIR_CAPI_EXPORTED int64_t
mlirPolynomialIntMonomialGetCoefficient(MlirIntMonomial intMonomial);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this purely for Python usage? (getting the feeling its making C shims for feasible C++ accessors)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean... yes? Is there a way around adding all of these? Actually this is exactly why I paused dev on this PR (because it immediately started to look like reproducing the entire class hierarchy in C and then pybind).


MLIR_CAPI_EXPORTED uint64_t
mlirPolynomialIntMonomialGetExponent(MlirIntMonomial intMonomial);

#ifdef __cplusplus
}
#endif

#endif // MLIR_C_DIALECT_POLYNOMIAL_H
102 changes: 102 additions & 0 deletions mlir/lib/Bindings/Python/DialectPolynomial.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//===- DialectPolynomial.cpp - 'polynomial' dialect submodule -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/Polynomial.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"

#include <pybind11/pybind11.h>
#include <vector>

namespace py = pybind11;
using namespace llvm;
using namespace mlir;
using namespace mlir::python::adaptors;

class PyIntMonomial {
public:
PyIntMonomial(MlirIntMonomial intMonomial) : intMonomial(intMonomial) {}
PyIntMonomial(int64_t coeff, uint64_t expo)
: intMonomial(mlirPolynomialGetIntMonomial(coeff, expo)) {}
operator MlirIntMonomial() const { return intMonomial; }
MlirIntMonomial get() { return intMonomial; }

int64_t getCoefficient() {
return mlirPolynomialIntMonomialGetCoefficient(this->get());
}

uint64_t getExponent() {
return mlirPolynomialIntMonomialGetExponent(this->get());
}

private:
MlirIntMonomial intMonomial;
};

#define MLIR_PYTHON_CAPSULE_INT_POLYNOMIAL \
MAKE_MLIR_PYTHON_QUALNAME("dialects.polynomial.IntMonomial._CAPIPtr")

static inline MlirIntMonomial
mlirPythonCapsuleToIntMonomial(PyObject *capsule) {
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_INT_POLYNOMIAL);
MlirIntMonomial intMonomial = {ptr};
return intMonomial;
}

static inline PyObject *
mlirPythonIntMonomialToCapsule(MlirIntMonomial intMonomial) {
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(intMonomial),
MLIR_PYTHON_CAPSULE_INT_POLYNOMIAL, nullptr);
}

static inline bool mlirIntMonomialIsNull(MlirIntMonomial intMonomial) {
return !intMonomial.ptr;
}

namespace pybind11 {
namespace detail {

/// Casts object <-> MlirIntMonomial.
template <>
struct type_caster<MlirIntMonomial> {
PYBIND11_TYPE_CASTER(MlirIntMonomial, _("MlirIntMonomial"));
bool load(handle src, bool) {
py::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToIntMonomial(capsule.ptr());
return !mlirIntMonomialIsNull(value);
}

static handle cast(MlirIntMonomial v, return_value_policy, handle) {
py::object capsule =
py::reinterpret_steal<py::object>(mlirPythonIntMonomialToCapsule(v));
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("dialects.polynomial"))
.attr("IntMonomial")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
}
};
} // namespace detail
} // namespace pybind11

PYBIND11_MODULE(_mlirDialectsPolynomial, m) {
m.doc() = "MLIR Polynomial dialect";

py::class_<PyIntMonomial>(m, "IntMonomial", py::module_local())
.def(py::init<PyIntMonomial &>())
.def(py::init<MlirIntMonomial>())
.def(py::init<int64_t, uint64_t>())
.def_property_readonly("coefficient", &PyIntMonomial::getCoefficient)
.def_property_readonly("exponent", &PyIntMonomial::getExponent)
.def("__str__", [](PyIntMonomial &self) {
return std::string("<")
.append(std::to_string(self.getCoefficient()))
.append(", ")
.append(std::to_string(self.getExponent()))
.append(">");
});
}
10 changes: 10 additions & 0 deletions mlir/lib/CAPI/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ add_mlir_upstream_c_api_library(MLIRCAPIQuant
MLIRQuantDialect
)

add_mlir_upstream_c_api_library(MLIRCAPIPolynomial
Polynomial.cpp

PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRPolynomialDialect
)


add_mlir_upstream_c_api_library(MLIRCAPIOpenMP
OpenMP.cpp

Expand Down
32 changes: 32 additions & 0 deletions mlir/lib/CAPI/Dialect/Polynomial.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===- Polynomial.cpp - C Interface for Polynomial dialect ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/Polynomial.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"

using namespace mlir;

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(polynomial, polynomial,
polynomial::PolynomialDialect)

MlirIntMonomial mlirPolynomialGetIntMonomial(int64_t coeff, uint64_t expo) {
return wrap(new mlir::polynomial::IntMonomial(coeff, expo));
}

int64_t mlirPolynomialIntMonomialGetCoefficient(MlirIntMonomial intMonomial) {
return unwrap(intMonomial)
->getCoefficient()
.getLimitedValue(/*Limit = UINT64_MAX*/);
}

uint64_t mlirPolynomialIntMonomialGetExponent(MlirIntMonomial intMonomial) {
return unwrap(intMonomial)
->getExponent(/*Limit = UINT64_MAX*/)
.getLimitedValue();
}
25 changes: 23 additions & 2 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,22 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformPDLExtensionOps.td
SOURCES
dialects/transform/pdl.py
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/Polynomial.td
SOURCES
dialects/polynomial.py
DIALECT_NAME polynomial)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down Expand Up @@ -537,6 +545,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
MLIRCAPIQuant
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.Polynomial.Pybind
MODULE_NAME _mlirDialectsPolynomial
ADD_TO_PARENT MLIRPythonSources.Dialects.polynomial
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectPolynomial.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPIPolynomial
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
MODULE_NAME _mlirDialectsNVGPU
ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu
Expand Down
14 changes: 14 additions & 0 deletions mlir/python/mlir/dialects/Polynomial.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===-- Polynomial.td - Entry point for Polynomial bind ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_POLYNOMIAL_OPS
#define PYTHON_BINDINGS_POLYNOMIAL_OPS

include "mlir/Dialect/Polynomial/IR/Polynomial.td"

#endif
6 changes: 6 additions & 0 deletions mlir/python/mlir/dialects/polynomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._mlir_libs._mlirDialectsPolynomial import *
from ._polynomial_ops_gen import *
27 changes: 27 additions & 0 deletions mlir/test/python/dialects/polynomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import polynomial


def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f


# CHECK-LABEL: TEST: test_smoke
@constructAndPrintInModule
def test_smoke():
value = Attribute.parse("#polynomial.float_polynomial<0.5 + 1.3e06 x**2>")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking this is the one where you'd rely on the builders rather than the generic parse (e.g., polynomial.FloatPolynomail) so that you'd be able to have a better error experience Python side/make errors less prone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a first pass/attempt/"just make it pass" impl. But what would a builder look like here? Since 0.5 + 1.3e06 x**2 is an arbitrary polynomial..?

res = polynomial.constant(value)
# CHECK: polynomial.constant float<0.5 + 1.3E+6x**2> : <ring = <coefficientType = f32>>
print(res)

int_poly = polynomial.IntMonomial(1, 10)
# CHECK: <1, 10>
print(int_poly)
Loading