Skip to content

Commit 0229db7

Browse files
committed
working
1 parent 184e9fe commit 0229db7

13 files changed

+102
-30
lines changed

mlir/lib/TableGen/Attribute.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ bool Attribute::isSymbolRefAttr() const {
5353
return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
5454
}
5555

56-
bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
56+
bool Attribute::isEnumAttr() const {
57+
return isSubClassOf("EnumAttrInfo") || isSubClassOf("EnumAttr");
58+
}
5759

5860
StringRef Attribute::getStorageType() const {
5961
const auto *init = def->getValueInit("storageType");

mlir/python/CMakeLists.txt

+8-1
Original file line numberDiff line numberDiff line change
@@ -606,12 +606,19 @@ if(MLIR_INCLUDE_TESTS)
606606
"dialects/_python_test_ops_gen.py"
607607
-gen-python-op-bindings
608608
-bind-dialect=python_test)
609+
mlir_tablegen(
610+
"dialects/_python_test_enums_gen.py"
611+
-gen-python-enum-bindings
612+
EXTRA_INCLUDES
613+
"${MLIR_MAIN_SRC_DIR}/test/python")
609614
add_public_tablegen_target(PythonTestDialectPyIncGen)
610615
declare_mlir_python_sources(
611616
MLIRPythonTestSources.Dialects.PythonTest.ops_gen
612617
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
613618
ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest
614-
SOURCES "dialects/_python_test_ops_gen.py")
619+
SOURCES
620+
"dialects/_python_test_ops_gen.py"
621+
"dialects/_python_test_enums_gen.py")
615622

616623
declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension
617624
MODULE_NAME _mlirPythonTest

mlir/python/mlir/dialects/python_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._python_test_ops_gen import *
6+
from ._python_test_enums_gen import *
67
from .._mlir_libs._mlirPythonTest import (
78
TestAttr,
89
TestType,

mlir/test/python/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls)
33
mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs)
44
mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls)
55
mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs)
6+
mlir_tablegen(lib/PythonTestEnums.h.inc -gen-enum-decls)
7+
mlir_tablegen(lib/PythonTestEnums.cpp.inc -gen-enum-defs)
68
mlir_tablegen(lib/PythonTestAttributes.h.inc -gen-attrdef-decls)
79
mlir_tablegen(lib/PythonTestAttributes.cpp.inc -gen-attrdef-defs)
810
mlir_tablegen(lib/PythonTestTypes.h.inc -gen-typedef-decls)

mlir/test/python/dialects/python_test.py

+17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
import mlir.dialects.python_test as test
66
import mlir.dialects.tensor as tensor
77
import mlir.dialects.arith as arith
8+
from mlir.dialects import llvm
9+
from mlir.dialects._llvm_enum_gen import (
10+
_llvm_integeroverflowflagsattr as llvm_integeroverflowflagsattr,
11+
)
12+
from mlir.dialects._python_test_enums_gen import (
13+
_llvm_integeroverflowflagsattr as python_test_integeroverflowflagsattr,
14+
)
815

916
test.register_python_test_dialect(get_dialect_registry())
1017

@@ -543,3 +550,13 @@ def testInferTypeOpInterface():
543550
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
544551
# CHECK: f32
545552
print(two_operands.result.type)
553+
554+
555+
# CHECK-LABEL: TEST: testEnumNamespacing
556+
@run
557+
def testEnumNamespacing():
558+
with Context() as ctx, Location.unknown(ctx):
559+
# CHECK: #llvm.overflow<none>
560+
print(llvm_integeroverflowflagsattr(llvm.IntegerOverflowFlags.none, ctx))
561+
# CHECK: #python_test.overflow<none>
562+
print(python_test_integeroverflowflagsattr(test.IntegerOverflowFlags.none, ctx))

mlir/test/python/lib/PythonTestDialect.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
#include "PythonTestDialect.h"
1010
#include "mlir/IR/DialectImplementation.h"
1111
#include "mlir/IR/OpImplementation.h"
12+
#include "llvm/ADT/StringExtras.h"
1213
#include "llvm/ADT/TypeSwitch.h"
1314

1415
#include "PythonTestDialect.cpp.inc"
1516

17+
#include "PythonTestEnums.cpp.inc"
1618
#define GET_ATTRDEF_CLASSES
1719
#include "PythonTestAttributes.cpp.inc"
1820

mlir/test/python/lib/PythonTestDialect.h

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#define GET_OP_CLASSES
2020
#include "PythonTestOps.h.inc"
2121

22+
#include "PythonTestEnums.h.inc"
2223
#define GET_ATTRDEF_CLASSES
2324
#include "PythonTestAttributes.h.inc"
2425

mlir/test/python/lit.local.cfg

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
22
if not config.enable_bindings_python:
33
config.unsupported = True
44
config.excludes.add("python_test_ops.td")
5+
config.excludes.add("python_test_enums.td")

mlir/test/python/python_test_ops.td

+19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define PYTHON_TEST_OPS
1111

1212
include "mlir/IR/AttrTypeBase.td"
13+
include "mlir/IR/EnumAttr.td"
1314
include "mlir/IR/OpBase.td"
1415
include "mlir/Interfaces/InferTypeOpInterface.td"
1516

@@ -48,6 +49,24 @@ def TestType : TestType<"TestType", "test_type">;
4849

4950
def TestAttr : TestAttr<"TestAttr", "test_attr">;
5051

52+
def IOFnone : I32BitEnumAttrCaseNone<"none">;
53+
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
54+
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
55+
56+
def IntegerOverflowFlags : I32BitEnumAttr<
57+
"IntegerOverflowFlags", "",
58+
[IOFnone, IOFnsw, IOFnuw]> {
59+
let separator = ", ";
60+
let cppNamespace = "python_test";
61+
let genSpecializedAttr = 0;
62+
}
63+
64+
// This is intentionally prefixed with LLVM to test for collision in AttrBuilder.
65+
def LLVM_IntegerOverflowFlagsAttr :
66+
EnumAttr<Python_Test_Dialect, IntegerOverflowFlags, "overflow"> {
67+
let assemblyFormat = "`<` $value `>`";
68+
}
69+
5170
//===----------------------------------------------------------------------===//
5271
// Operation definitions.
5372
//===----------------------------------------------------------------------===//

mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp

+30-16
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,16 @@ static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
9797
/// Emits an attribute builder for the given enum attribute to support automatic
9898
/// conversion between enum values and attributes in Python. Returns
9999
/// `false` on success, `true` on failure.
100-
static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
100+
static bool emitAttributeBuilderRegistration(const EnumAttr &enumAttr,
101+
raw_ostream &os) {
101102
int64_t bitwidth;
102103
if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
103104
llvm::errs() << "failed to identify bitwidth of "
104105
<< enumAttr.getUnderlyingType();
105106
return true;
106107
}
107108

108-
llvm::SmallVector<StringRef> namespaces;
109-
enumAttr.getStorageType().ltrim("::").split(namespaces, "::");
110-
namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
111-
std::string namespace_ = getAttributeNameSpace(namespaces);
109+
std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
112110
if (!namespace_.empty())
113111
namespace_ += "_";
114112

@@ -127,8 +125,9 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
127125
/// Emits an attribute builder for the given dialect enum attribute to support
128126
/// automatic conversion between enum values and attributes in Python. Returns
129127
/// `false` on success, `true` on failure.
130-
static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
131-
raw_ostream &os) {
128+
static bool emitDialectEnumAttributeBuilderRegistration(const llvm::Record &def,
129+
raw_ostream &os) {
130+
const AttrOrTypeDef attr(&def);
132131
StringRef mnemonic = attr.getMnemonic().value();
133132
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
134133
StringRef dialect = attr.getDialect().getName();
@@ -145,15 +144,15 @@ static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
145144
return true;
146145
}
147146

148-
llvm::SmallVector<StringRef> namespaces;
149-
attr.getStorageNamespace().ltrim("::").split(namespaces, "::");
150-
std::string namespace_ = getAttributeNameSpace(namespaces);
147+
EnumAttr enumAttr(def);
148+
std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
151149
if (!namespace_.empty())
152150
namespace_ += "_";
153151

154152
os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
155-
attr.getName());
156-
os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
153+
enumAttr.getAttrDefName());
154+
os << llvm::formatv("def _{0}(x, context):\n",
155+
enumAttr.getAttrDefName().lower());
157156
os << llvm::formatv(" return "
158157
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
159158
formatString);
@@ -164,17 +163,32 @@ static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
164163
/// `false` on success, `true` on failure.
165164
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
166165
raw_ostream &os) {
166+
llvm::SmallDenseSet<StringRef> alreadyEmitted;
167167
os << fileHeader;
168+
for (const auto &it :
169+
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
170+
EnumAttr *enumAttr;
171+
for (const auto &value : it->getValues())
172+
if (value.getType()->getAsString() == "EnumAttrInfo")
173+
enumAttr = new EnumAttr(value.getValue()->getRecordKeeper().getDef(
174+
value.getValue()->getAsString()));
175+
if (enumAttr) {
176+
emitEnumClass(*enumAttr, os);
177+
alreadyEmitted.insert(enumAttr->getEnumClassName());
178+
}
179+
}
168180
for (auto &it :
169181
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
170182
EnumAttr enumAttr(*it);
171-
emitEnumClass(enumAttr, os);
172-
emitAttributeBuilder(enumAttr, os);
183+
if (!alreadyEmitted.contains(enumAttr.getEnumClassName()))
184+
emitEnumClass(enumAttr, os);
185+
if (emitAttributeBuilderRegistration(enumAttr, os))
186+
return true;
173187
}
174188
for (const auto &it :
175189
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
176-
const AttrOrTypeDef attr(&*it);
177-
return emitDialectEnumAttributeBuilder(attr, os);
190+
if (emitDialectEnumAttributeBuilderRegistration(*it, os))
191+
return true;
178192
}
179193

180194
return false;

mlir/tools/mlir-tblgen/OpGenHelpers.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,16 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
8181
return reserved.contains(str);
8282
}
8383

84-
std::string
85-
mlir::tblgen::getAttributeNameSpace(llvm::SmallVector<StringRef> namespaces) {
84+
std::string mlir::tblgen::getEnumAttributeNameSpace(const EnumAttr &enumAttr) {
85+
llvm::SmallVector<StringRef> namespaces;
86+
if (enumAttr.getCppNamespace().empty() &&
87+
enumAttr.getBaseAttr().isEnumAttr()) {
88+
EnumAttr(enumAttr.getBaseAttr().getDef())
89+
.getCppNamespace()
90+
.ltrim("::")
91+
.split(namespaces, "::");
92+
} else
93+
enumAttr.getCppNamespace().ltrim("::").split(namespaces, "::");
8694
std::string namespace_;
8795
if (namespaces[0] == "mlir")
8896
namespace_ = llvm::join(llvm::drop_begin(namespaces), "_");

mlir/tools/mlir-tblgen/OpGenHelpers.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
1414
#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
1515

16+
#include "mlir/TableGen/Attribute.h"
1617
#include "llvm/TableGen/Record.h"
1718
#include <vector>
1819

@@ -28,8 +29,7 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
2829
/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
2930
bool isPythonReserved(llvm::StringRef str);
3031

31-
std::string
32-
getAttributeNameSpace(llvm::SmallVector<llvm::StringRef> namespaces);
32+
std::string getEnumAttributeNameSpace(const EnumAttr &enumAttr);
3333

3434
} // namespace tblgen
3535
} // namespace mlir

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -530,30 +530,30 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
530530
/// {0} is the builder argument name;
531531
/// {1} is the attribute builder from raw;
532532
/// {2} is the attribute builder from raw;
533-
/// {3} is the attribute's dialect.
533+
/// {3} is the attribute's fully qualified namespace.
534534
/// Use the value the user passed in if either it is already an Attribute or
535535
/// there is no method registered to make it an Attribute.
536536
constexpr const char *initAttributeWithBuilderTemplate =
537537
R"Py(attributes["{1}"] = ({0} if (
538538
issubclass(type({0}), _ods_ir.Attribute) or
539539
not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
540540
(_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
541-
else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
541+
else _ods_ir.AttrBuilder.get('{2}{3}')({0}, context=_ods_context))))Py";
542542

543543
/// Template for attribute builder from raw input for optional attribute in the
544544
/// operation builder.
545545
/// {0} is the builder argument name;
546546
/// {1} is the attribute builder from raw;
547547
/// {2} is the attribute builder from raw;
548-
/// {3} is the attribute's dialect.
548+
/// {3} is the attribute's fully qualified namespace.
549549
/// Use the value the user passed in if either it is already an Attribute or
550550
/// there is no method registered to make it an Attribute.
551551
constexpr const char *initOptionalAttributeWithBuilderTemplate =
552552
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
553553
issubclass(type({0}), _ods_ir.Attribute) or
554554
not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
555555
(_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
556-
else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
556+
else _ods_ir.AttrBuilder.get('{2}{3}')({0}, context=_ods_context))))Py";
557557

558558
constexpr const char *initUnitAttributeTemplate =
559559
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -681,10 +681,8 @@ populateBuilderLinesAttr(const Operator &op,
681681
continue;
682682
}
683683

684-
llvm::SmallVector<StringRef> namespaces;
685-
attribute->attr.getStorageType().ltrim("::").split(namespaces, "::");
686-
namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
687-
std::string namespace_ = getAttributeNameSpace(namespaces);
684+
EnumAttr enumAttr(attribute->attr.getDef());
685+
std::string namespace_ = getEnumAttributeNameSpace(enumAttr);
688686
if (!namespace_.empty())
689687
namespace_ += "_";
690688

0 commit comments

Comments
 (0)