Skip to content

[mlir][python] namespace generated enums in python #77830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/lib/TableGen/Attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ bool Attribute::isSymbolRefAttr() const {
return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
}

bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
bool Attribute::isEnumAttr() const {
return isSubClassOf("EnumAttrInfo") || isSubClassOf("EnumAttr");
}

StringRef Attribute::getStorageType() const {
const auto *init = def->getValueInit("storageType");
Expand Down
9 changes: 8 additions & 1 deletion mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -606,12 +606,19 @@ if(MLIR_INCLUDE_TESTS)
"dialects/_python_test_ops_gen.py"
-gen-python-op-bindings
-bind-dialect=python_test)
mlir_tablegen(
"dialects/_python_test_enums_gen.py"
-gen-python-enum-bindings
EXTRA_INCLUDES
"${MLIR_MAIN_SRC_DIR}/test/python")
add_public_tablegen_target(PythonTestDialectPyIncGen)
declare_mlir_python_sources(
MLIRPythonTestSources.Dialects.PythonTest.ops_gen
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest
SOURCES "dialects/_python_test_ops_gen.py")
SOURCES
"dialects/_python_test_ops_gen.py"
"dialects/_python_test_enums_gen.py")

declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension
MODULE_NAME _mlirPythonTest
Expand Down
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._python_test_ops_gen import *
from ._python_test_enums_gen import *
from .._mlir_libs._mlirPythonTest import (
TestAttr,
TestType,
Expand Down
23 changes: 15 additions & 8 deletions mlir/test/mlir-tblgen/enums-python-bindings.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def One : I32EnumAttrCase<"CaseOne", 1, "one">;
def Two : I32EnumAttrCase<"CaseTwo", 2, "two">;
def NegOne : I32EnumAttrCase<"CaseNegOne", -1, "negone">;

def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>;
def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]> {
let cppNamespace = "test";
}
// CHECK-LABEL: class MyEnum(IntEnum):
// CHECK: """An example 32-bit enum"""

Expand All @@ -35,16 +37,20 @@ def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>
// CHECK: return "negone"
// CHECK: raise ValueError("Unknown MyEnum enum entry.")

// CHECK: @register_attribute_builder("MyEnum")
// CHECK: @register_attribute_builder("test_MyEnum")
// CHECK: def _myenum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

def TestMyEnum_Attr : EnumAttr<Test_Dialect, MyEnum, "enum">;
def TestMyEnum_Attr : EnumAttr<Test_Dialect, MyEnum, "enum"> {
let cppNamespace = "test";
}

def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">;

def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>;
def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]> {
let cppNamespace = "test";
}
// CHECK-LABEL: class MyEnum64(IntEnum):
// CHECK: """An example 64-bit enum"""

Expand All @@ -58,7 +64,7 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>
// CHECK: return "two"
// CHECK: raise ValueError("Unknown MyEnum64 enum entry.")

// CHECK: @register_attribute_builder("MyEnum64")
// CHECK: @register_attribute_builder("test_MyEnum64")
// CHECK: def _myenum64(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))

Expand All @@ -70,6 +76,7 @@ def TestBitEnum
]> {
let genSpecializedAttr = 0;
let separator = " | ";
let cppNamespace = "test";
}

def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
Expand All @@ -96,14 +103,14 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
// CHECK: return "other"
// CHECK: raise ValueError("Unknown TestBitEnum enum entry.")

// CHECK: @register_attribute_builder("TestBitEnum")
// CHECK: @register_attribute_builder("test_TestBitEnum")
// CHECK: def _testbitenum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

// CHECK: @register_attribute_builder("TestBitEnum_Attr")
// CHECK: @register_attribute_builder("test_TestBitEnum_Attr")
// CHECK: def _testbitenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)

// CHECK: @register_attribute_builder("TestMyEnum_Attr")
// CHECK: @register_attribute_builder("test_TestMyEnum_Attr")
// CHECK: def _testmyenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)
7 changes: 4 additions & 3 deletions mlir/test/mlir-tblgen/op-python-bindings.td
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: attributes["i32attr"] = (i32attr if (
// CHECK-NEXT: issubclass(type(i32attr), _ods_ir.Attribute) or
// CHECK-NEXT: not _ods_ir.AttrBuilder.contains('I32Attr')
// CHECK-NEXT: _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
// CHECK-NEXT: issubclass(type(i32attr), _ods_ir.Attribute) or
// CHECK-NEXT: not (_ods_ir.AttrBuilder.contains('I32Attr') or _ods_ir.AttrBuilder.contains('I32Attr'))) else
// CHECK-NEXT: (_ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context) if _ods_ir.AttrBuilder.contains('I32Attr')
// CHECK-NEXT: else _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)))
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls)
mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls)
mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs)
mlir_tablegen(lib/PythonTestEnums.h.inc -gen-enum-decls)
mlir_tablegen(lib/PythonTestEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(lib/PythonTestAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(lib/PythonTestAttributes.cpp.inc -gen-attrdef-defs)
mlir_tablegen(lib/PythonTestTypes.h.inc -gen-typedef-decls)
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/python/dialects/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
from mlir.dialects import llvm
from mlir.dialects._llvm_enum_gen import (
_llvm_integeroverflowflagsattr as llvm_integeroverflowflagsattr,
)
from mlir.dialects._python_test_enums_gen import (
_llvm_integeroverflowflagsattr as python_test_integeroverflowflagsattr,
)

test.register_python_test_dialect(get_dialect_registry())

Expand Down Expand Up @@ -543,3 +550,13 @@ def testInferTypeOpInterface():
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
# CHECK: f32
print(two_operands.result.type)


# CHECK-LABEL: TEST: testEnumNamespacing
@run
def testEnumNamespacing():
with Context() as ctx, Location.unknown(ctx):
# CHECK: #llvm.overflow<none>
print(llvm_integeroverflowflagsattr(llvm.IntegerOverflowFlags.none, ctx))
# CHECK: #python_test.overflow<none>
print(python_test_integeroverflowflagsattr(test.IntegerOverflowFlags.none, ctx))
2 changes: 2 additions & 0 deletions mlir/test/python/lib/PythonTestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
#include "PythonTestDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"

#include "PythonTestDialect.cpp.inc"

#include "PythonTestEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "PythonTestAttributes.cpp.inc"

Expand Down
1 change: 1 addition & 0 deletions mlir/test/python/lib/PythonTestDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define GET_OP_CLASSES
#include "PythonTestOps.h.inc"

#include "PythonTestEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "PythonTestAttributes.h.inc"

Expand Down
1 change: 1 addition & 0 deletions mlir/test/python/lit.local.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
if not config.enable_bindings_python:
config.unsupported = True
config.excludes.add("python_test_ops.td")
config.excludes.add("python_test_enums.td")
19 changes: 19 additions & 0 deletions mlir/test/python/python_test_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define PYTHON_TEST_OPS

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"

Expand Down Expand Up @@ -48,6 +49,24 @@ def TestType : TestType<"TestType", "test_type">;

def TestAttr : TestAttr<"TestAttr", "test_attr">;

def IOFnone : I32BitEnumAttrCaseNone<"none">;
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;

def IntegerOverflowFlags : I32BitEnumAttr<
"IntegerOverflowFlags", "",
[IOFnone, IOFnsw, IOFnuw]> {
let separator = ", ";
let cppNamespace = "python_test";
let genSpecializedAttr = 0;
}

// This is intentionally prefixed with LLVM to test for collision in AttrBuilder.
def LLVM_IntegerOverflowFlagsAttr :
EnumAttr<Python_Test_Dialect, IntegerOverflowFlags, "overflow"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// Operation definitions.
//===----------------------------------------------------------------------===//
Expand Down
70 changes: 40 additions & 30 deletions mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,20 @@ static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
/// Emits an attribute builder for the given enum attribute to support automatic
/// conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
static bool emitAttributeBuilderRegistration(const EnumAttr &enumAttr,
raw_ostream &os) {
int64_t bitwidth;
if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
llvm::errs() << "failed to identify bitwidth of "
<< enumAttr.getUnderlyingType();
return true;
}

os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
if (!namespace_.empty())
namespace_ += "_";

os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
enumAttr.getAttrDefName());
os << llvm::formatv("def _{0}(x, context):\n",
enumAttr.getAttrDefName().lower());
Expand All @@ -120,11 +125,34 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
StringRef formatString,
raw_ostream &os) {
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
static bool emitDialectEnumAttributeBuilderRegistration(const llvm::Record &def,
raw_ostream &os) {
const AttrOrTypeDef attr(&def);
StringRef mnemonic = attr.getMnemonic().value();
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
StringRef dialect = attr.getDialect().getName();
std::string formatString;
if (assemblyFormat == "`<` $value `>`")
formatString =
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
else if (assemblyFormat == "$value")
formatString =
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
else {
llvm::errs()
<< "unsupported assembly format for python enum bindings generation";
return true;
}

EnumAttr enumAttr(def);
std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
if (!namespace_.empty())
namespace_ += "_";

os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
enumAttr.getAttrDefName());
os << llvm::formatv("def _{0}(x, context):\n",
enumAttr.getAttrDefName().lower());
os << llvm::formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
Expand All @@ -140,31 +168,13 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
EnumAttr enumAttr(*it);
emitEnumClass(enumAttr, os);
emitAttributeBuilder(enumAttr, os);
}
for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it);
if (!attr.getMnemonic()) {
llvm::errs() << "enum case " << attr
<< " needs mnemonic for python enum bindings generation";
if (emitAttributeBuilderRegistration(enumAttr, os))
return true;
}
StringRef mnemonic = attr.getMnemonic().value();
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
StringRef dialect = attr.getDialect().getName();
if (assemblyFormat == "`<` $value `>`") {
emitDialectEnumAttributeBuilder(
attr.getName(),
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
} else if (assemblyFormat == "$value") {
emitDialectEnumAttributeBuilder(
attr.getName(),
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
} else {
llvm::errs()
<< "unsupported assembly format for python enum bindings generation";
}
for (const auto &it :
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
if (emitDialectEnumAttributeBuilderRegistration(*it, os))
return true;
}
}

return false;
Expand Down
15 changes: 15 additions & 0 deletions mlir/tools/mlir-tblgen/OpGenHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "llvm/Support/Regex.h"
#include "llvm/TableGen/Error.h"

#include <iostream>

using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
Expand Down Expand Up @@ -79,4 +81,17 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
reserved.insert("issubclass");
reserved.insert("type");
return reserved.contains(str);
}

std::string mlir::tblgen::getEnumAttributeNameSpace(const EnumAttr &enumAttr) {
std::string namespace_;
llvm::SmallVector<StringRef> namespaces;
enumAttr.getCppNamespace().ltrim("::").split(namespaces, "::");
if (namespaces[0] == "mlir")
namespace_ = llvm::join(llvm::drop_begin(namespaces), "_");
else
namespace_ = llvm::join(namespaces, "_");
std::transform(namespace_.begin(), namespace_.end(), namespace_.begin(),
tolower);
return namespace_;
}
3 changes: 3 additions & 0 deletions mlir/tools/mlir-tblgen/OpGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_

#include "mlir/TableGen/Attribute.h"
#include "llvm/TableGen/Record.h"
#include <vector>

Expand All @@ -28,6 +29,8 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
bool isPythonReserved(llvm::StringRef str);

std::string getEnumAttributeNameSpace(const EnumAttr &enumAttr);

} // namespace tblgen
} // namespace mlir

Expand Down
Loading