Skip to content

Commit 104bfc8

Browse files
committed
[mlir][polynomial] python bindings
1 parent c98a799 commit 104bfc8

File tree

8 files changed

+142
-2
lines changed

8 files changed

+142
-2
lines changed
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===-- mlir-c/Dialect/Polynomial.h - C API for LLVM --------------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef MLIR_C_DIALECT_POLYNOMIAL_H
11+
#define MLIR_C_DIALECT_POLYNOMIAL_H
12+
13+
#include "mlir-c/IR.h"
14+
15+
#ifdef __cplusplus
16+
extern "C" {
17+
#endif
18+
19+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Polynomial, polynomial);
20+
21+
#ifdef __cplusplus
22+
}
23+
#endif
24+
25+
#endif // MLIR_C_DIALECT_POLYNOMIAL_H
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- DialectPolynomial.cpp - 'polynomial' dialect submodule -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/Polynomial.h"
10+
#include "mlir-c/IR.h"
11+
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
13+
#include <pybind11/pybind11.h>
14+
#include <vector>
15+
16+
namespace py = pybind11;
17+
using namespace llvm;
18+
using namespace mlir;
19+
using namespace mlir::python::adaptors;
20+
21+
PYBIND11_MODULE(_mlirDialectsPolynomial, m) {
22+
m.doc() = "MLIR Polynomial dialect";
23+
}

mlir/lib/CAPI/Dialect/CMakeLists.txt

+10
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,16 @@ add_mlir_upstream_c_api_library(MLIRCAPIQuant
225225
MLIRQuantDialect
226226
)
227227

228+
add_mlir_upstream_c_api_library(MLIRCAPIPolynomial
229+
Polynomial.cpp
230+
231+
PARTIAL_SOURCES_INTENDED
232+
LINK_LIBS PUBLIC
233+
MLIRCAPIIR
234+
MLIRPolynomialDialect
235+
)
236+
237+
228238
add_mlir_upstream_c_api_library(MLIRCAPIOpenMP
229239
OpenMP.cpp
230240

mlir/lib/CAPI/Dialect/Polynomial.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//===- Polynomial.cpp - C Interface for Polynomial dialect
2+
//--------------------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "mlir-c/Dialect/Polynomial.h"
11+
#include "mlir/CAPI/Registration.h"
12+
#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
13+
14+
using namespace mlir;
15+
16+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(polynomial, polynomial,
17+
polynomial::PolynomialDialect)

mlir/python/CMakeLists.txt

+23-2
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,22 @@ declare_mlir_dialect_python_bindings(
162162
GEN_ENUM_BINDINGS)
163163

164164
declare_mlir_dialect_extension_python_bindings(
165-
ADD_TO_PARENT MLIRPythonSources.Dialects
166-
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
165+
ADD_TO_PARENT MLIRPythonSources.Dialects
166+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
167167
TD_FILE dialects/TransformPDLExtensionOps.td
168168
SOURCES
169169
dialects/transform/pdl.py
170170
DIALECT_NAME transform
171171
EXTENSION_NAME transform_pdl_extension)
172172

173+
declare_mlir_dialect_python_bindings(
174+
ADD_TO_PARENT MLIRPythonSources.Dialects
175+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
176+
TD_FILE dialects/Polynomial.td
177+
SOURCES
178+
dialects/polynomial.py
179+
DIALECT_NAME polynomial)
180+
173181
declare_mlir_dialect_python_bindings(
174182
ADD_TO_PARENT MLIRPythonSources.Dialects
175183
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -537,6 +545,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
537545
MLIRCAPIQuant
538546
)
539547

548+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Polynomial.Pybind
549+
MODULE_NAME _mlirDialectsPolynomial
550+
ADD_TO_PARENT MLIRPythonSources.Dialects.polynomial
551+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
552+
SOURCES
553+
DialectPolynomial.cpp
554+
PRIVATE_LINK_LIBS
555+
LLVMSupport
556+
EMBED_CAPI_LINK_LIBS
557+
MLIRCAPIIR
558+
MLIRCAPIPolynomial
559+
)
560+
540561
declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
541562
MODULE_NAME _mlirDialectsNVGPU
542563
ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- PolynomialOps.td - Entry point for PolynomialOps bind ------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_POLYNOMIAL_OPS
10+
#define PYTHON_BINDINGS_POLYNOMIAL_OPS
11+
12+
include "mlir/Dialect/Polynomial/IR/Polynomial.td"
13+
14+
#endif
+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from .._mlir_libs._mlirDialectsPolynomial import *
6+
from ._polynomial_ops_gen import *
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import polynomial
5+
6+
7+
def constructAndPrintInModule(f):
8+
print("\nTEST:", f.__name__)
9+
with Context(), Location.unknown():
10+
module = Module.create()
11+
with InsertionPoint(module.body):
12+
f()
13+
print(module)
14+
return f
15+
16+
17+
# CHECK-LABEL: TEST: test_smoke
18+
@constructAndPrintInModule
19+
def test_smoke():
20+
value = Attribute.parse("#polynomial.float_polynomial<0.5 + 1.3e06 x**2>")
21+
output = Type.parse("!polynomial.polynomial<ring=<coefficientType=f32>>")
22+
res = polynomial.constant(output, value)
23+
# CHECK: polynomial.constant {value = #polynomial.float_polynomial<0.5 + 1.3E+6x**2>} : <ring = <coefficientType = f32>>
24+
print(res)

0 commit comments

Comments
 (0)