-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][python] namespace generated enums in python #77830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
239e37d
to
c9ffacc
Compare
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesA recent PR broke enum bindings generation because of collision of attrs with the same name across dialects: #77211 (comment). So we need to namespace these now. In the current form of the PR, this is a breaking change (anyone that is supplying their own attribute builders won't have the Will add/update tests shortly. Full diff: https://github.com/llvm/llvm-project/pull/77830.diff 2 Files Affected:
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index f4ced0803772ed..d5f36ff3bc0fd2 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -105,7 +105,8 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
return true;
}
- os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
+ os << llvm::formatv("@register_attribute_builder(\"{0}_{1}\")\n",
+ enumAttr.getDialect().getName(),
enumAttr.getAttrDefName());
os << llvm::formatv("def _{0}(x, context):\n",
enumAttr.getAttrDefName().lower());
@@ -120,11 +121,26 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
- StringRef formatString,
+static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
raw_ostream &os) {
- os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
- os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
+ StringRef mnemonic = attr.getMnemonic().value();
+ std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
+ StringRef dialect = attr.getDialect().getName();
+ std::string formatString;
+ if (assemblyFormat == "`<` $value `>`") {
+ formatString =
+ llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
+ } else if (assemblyFormat == "$value") {
+ formatString =
+ llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
+ } else {
+ llvm::errs()
+ << "unsupported assembly format for python enum bindings generation";
+ return true;
+ }
+ os << llvm::formatv("@register_attribute_builder(\"{0}_{1}\")\n",
+ attr.getDialect().getName(), attr.getName());
+ os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
os << llvm::formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
@@ -142,29 +158,10 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
emitEnumClass(enumAttr, os);
emitAttributeBuilder(enumAttr, os);
}
- for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
- AttrOrTypeDef attr(&*it);
- if (!attr.getMnemonic()) {
- llvm::errs() << "enum case " << attr
- << " needs mnemonic for python enum bindings generation";
- return true;
- }
- StringRef mnemonic = attr.getMnemonic().value();
- std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
- StringRef dialect = attr.getDialect().getName();
- if (assemblyFormat == "`<` $value `>`") {
- emitDialectEnumAttributeBuilder(
- attr.getName(),
- llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
- } else if (assemblyFormat == "$value") {
- emitDialectEnumAttributeBuilder(
- attr.getName(),
- llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
- } else {
- llvm::errs()
- << "unsupported assembly format for python enum bindings generation";
- return true;
- }
+ for (const auto &it :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+ const AttrOrTypeDef attr(&*it);
+ return emitDialectEnumAttributeBuilder(attr, os);
}
return false;
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e7..de343df1c434fa 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -529,27 +529,29 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for attribute builder from raw input in the operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
-/// {2} is the attribute builder from raw.
+/// {2} is the attribute builder from raw;
+/// {3} is the attribute's dialect.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
R"Py(attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
- not _ods_ir.AttrBuilder.contains('{2}')) else
- _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+ not _ods_ir.AttrBuilder.contains('{2}_{3}')) else
+ _ods_ir.AttrBuilder.get('{2}_{3}')({0}, context=_ods_context)))Py";
/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
-/// {2} is the attribute builder from raw.
+/// {2} is the attribute builder from raw;
+/// {3} is the attribute's dialect.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
- not _ods_ir.AttrBuilder.contains('{2}')) else
- _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+ not _ods_ir.AttrBuilder.contains('{2}_{3}')) else
+ _ods_ir.AttrBuilder.get('{2}_{3}')({0}, context=_ods_context)))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -681,7 +683,8 @@ populateBuilderLinesAttr(const Operator &op,
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
- argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+ argNames[i], attribute->name, attribute->attr.getAttrDefName(),
+ attribute->attr.getDialect().getName()));
}
}
|
As discussed offline, I like uniform. But overall this looks like a good step. |
c9ffacc
to
184e9fe
Compare
0229db7
to
9e5daf8
Compare
a2288e5
to
d973eac
Compare
d973eac
to
07c3931
Compare
A recent PR broke enum bindings generation because of collision of attrs with the same name across dialects: #77211 (comment). So we need to namespace these now. In the current form of the PR, this is a breaking change (anyone that is supplying their own attribute builders won't have the
<dialect>_attr
prefix). I can rewrite to just check for both<dialect>_attr
and justattr
in theAttrBuilder
queries but I don't know what people think is best.Will add/update tests shortly.
cc @Hardcode84
update:
Instead of disambiguating/namespacing by dialect mnemonic (what I originally planned) I use the real thing:
cppNamespace
. That means generated registered builders now look like this:Note in
"llvm_LLVM_IntegerOverflowFlagsAttr"
the lowercasellvm
is from this PR and the uppercaseLLVM
is from the def name in tablegen. The corresponding getters in the generated ops looks like thiswhere the action is here
in order to be backwards compatible for people that would have their own attribute builders registered (for their own attributes). Doing some kind of explicit "warning: this is deprecated" might be possible but seems like overengineering - I propose we inform people on discourse and then remove in a quarter (3 months).
The added test simulates the breakage that prompted this PR (overlapping attribute named
IntegerOverflowFlagsAttr
).