Skip to content

Commit 184e9fe

Browse files
committed
namespace generated enums in python
1 parent 5b9be0e commit 184e9fe

File tree

4 files changed

+72
-35
lines changed

4 files changed

+72
-35
lines changed

mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp

+38-28
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,14 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
105105
return true;
106106
}
107107

108-
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
108+
llvm::SmallVector<StringRef> namespaces;
109+
enumAttr.getStorageType().ltrim("::").split(namespaces, "::");
110+
namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
111+
std::string namespace_ = getAttributeNameSpace(namespaces);
112+
if (!namespace_.empty())
113+
namespace_ += "_";
114+
115+
os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
109116
enumAttr.getAttrDefName());
110117
os << llvm::formatv("def _{0}(x, context):\n",
111118
enumAttr.getAttrDefName().lower());
@@ -120,11 +127,33 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
120127
/// Emits an attribute builder for the given dialect enum attribute to support
121128
/// automatic conversion between enum values and attributes in Python. Returns
122129
/// `false` on success, `true` on failure.
123-
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
124-
StringRef formatString,
130+
static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
125131
raw_ostream &os) {
126-
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
127-
os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
132+
StringRef mnemonic = attr.getMnemonic().value();
133+
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
134+
StringRef dialect = attr.getDialect().getName();
135+
std::string formatString;
136+
if (assemblyFormat == "`<` $value `>`")
137+
formatString =
138+
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
139+
else if (assemblyFormat == "$value")
140+
formatString =
141+
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
142+
else {
143+
llvm::errs()
144+
<< "unsupported assembly format for python enum bindings generation";
145+
return true;
146+
}
147+
148+
llvm::SmallVector<StringRef> namespaces;
149+
attr.getStorageNamespace().ltrim("::").split(namespaces, "::");
150+
std::string namespace_ = getAttributeNameSpace(namespaces);
151+
if (!namespace_.empty())
152+
namespace_ += "_";
153+
154+
os << llvm::formatv("@register_attribute_builder(\"{0}{1}\")\n", namespace_,
155+
attr.getName());
156+
os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
128157
os << llvm::formatv(" return "
129158
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
130159
formatString);
@@ -142,29 +171,10 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
142171
emitEnumClass(enumAttr, os);
143172
emitAttributeBuilder(enumAttr, os);
144173
}
145-
for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
146-
AttrOrTypeDef attr(&*it);
147-
if (!attr.getMnemonic()) {
148-
llvm::errs() << "enum case " << attr
149-
<< " needs mnemonic for python enum bindings generation";
150-
return true;
151-
}
152-
StringRef mnemonic = attr.getMnemonic().value();
153-
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
154-
StringRef dialect = attr.getDialect().getName();
155-
if (assemblyFormat == "`<` $value `>`") {
156-
emitDialectEnumAttributeBuilder(
157-
attr.getName(),
158-
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
159-
} else if (assemblyFormat == "$value") {
160-
emitDialectEnumAttributeBuilder(
161-
attr.getName(),
162-
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
163-
} else {
164-
llvm::errs()
165-
<< "unsupported assembly format for python enum bindings generation";
166-
return true;
167-
}
174+
for (const auto &it :
175+
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
176+
const AttrOrTypeDef attr(&*it);
177+
return emitDialectEnumAttributeBuilder(attr, os);
168178
}
169179

170180
return false;

mlir/tools/mlir-tblgen/OpGenHelpers.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,16 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
7979
reserved.insert("issubclass");
8080
reserved.insert("type");
8181
return reserved.contains(str);
82+
}
83+
84+
std::string
85+
mlir::tblgen::getAttributeNameSpace(llvm::SmallVector<StringRef> namespaces) {
86+
std::string namespace_;
87+
if (namespaces[0] == "mlir")
88+
namespace_ = llvm::join(llvm::drop_begin(namespaces), "_");
89+
else
90+
namespace_ = llvm::join(namespaces, "_");
91+
std::transform(namespace_.begin(), namespace_.end(), namespace_.begin(),
92+
tolower);
93+
return namespace_;
8294
}

mlir/tools/mlir-tblgen/OpGenHelpers.h

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
2828
/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
2929
bool isPythonReserved(llvm::StringRef str);
3030

31+
std::string
32+
getAttributeNameSpace(llvm::SmallVector<llvm::StringRef> namespaces);
33+
3134
} // namespace tblgen
3235
} // namespace mlir
3336

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

+19-7
Original file line numberDiff line numberDiff line change
@@ -529,27 +529,31 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
529529
/// Template for attribute builder from raw input in the operation builder.
530530
/// {0} is the builder argument name;
531531
/// {1} is the attribute builder from raw;
532-
/// {2} is the attribute builder from raw.
532+
/// {2} is the attribute builder from raw;
533+
/// {3} is the attribute's dialect.
533534
/// Use the value the user passed in if either it is already an Attribute or
534535
/// there is no method registered to make it an Attribute.
535536
constexpr const char *initAttributeWithBuilderTemplate =
536537
R"Py(attributes["{1}"] = ({0} if (
537538
issubclass(type({0}), _ods_ir.Attribute) or
538-
not _ods_ir.AttrBuilder.contains('{2}')) else
539-
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
539+
not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
540+
(_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
541+
else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
540542

541543
/// Template for attribute builder from raw input for optional attribute in the
542544
/// operation builder.
543545
/// {0} is the builder argument name;
544546
/// {1} is the attribute builder from raw;
545-
/// {2} is the attribute builder from raw.
547+
/// {2} is the attribute builder from raw;
548+
/// {3} is the attribute's dialect.
546549
/// Use the value the user passed in if either it is already an Attribute or
547550
/// there is no method registered to make it an Attribute.
548551
constexpr const char *initOptionalAttributeWithBuilderTemplate =
549552
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
550553
issubclass(type({0}), _ods_ir.Attribute) or
551-
not _ods_ir.AttrBuilder.contains('{2}')) else
552-
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
554+
not (_ods_ir.AttrBuilder.contains('{3}') or _ods_ir.AttrBuilder.contains('{2}{3}'))) else
555+
(_ods_ir.AttrBuilder.get('{3}')({0}, context=_ods_context) if _ods_ir.AttrBuilder.contains('{3}')
556+
else _ods_ir.AttrBuilder.contains('{2}{3}')({0}, context=_ods_context))))Py";
553557

554558
constexpr const char *initUnitAttributeTemplate =
555559
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -677,11 +681,19 @@ populateBuilderLinesAttr(const Operator &op,
677681
continue;
678682
}
679683

684+
llvm::SmallVector<StringRef> namespaces;
685+
attribute->attr.getStorageType().ltrim("::").split(namespaces, "::");
686+
namespaces = llvm::SmallVector<StringRef>{llvm::drop_end(namespaces)};
687+
std::string namespace_ = getAttributeNameSpace(namespaces);
688+
if (!namespace_.empty())
689+
namespace_ += "_";
690+
680691
builderLines.push_back(llvm::formatv(
681692
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
682693
? initOptionalAttributeWithBuilderTemplate
683694
: initAttributeWithBuilderTemplate,
684-
argNames[i], attribute->name, attribute->attr.getAttrDefName()));
695+
argNames[i], attribute->name, namespace_,
696+
attribute->attr.getAttrDefName()));
685697
}
686698
}
687699

0 commit comments

Comments
 (0)