Skip to content

[mlir][python] namespace generated enums in python #77830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Jan 11, 2024

A recent PR broke enum bindings generation because of collision of attrs with the same name across dialects: #77211 (comment). So we need to namespace these now. In the current form of the PR, this is a breaking change (anyone that is supplying their own attribute builders won't have the <dialect>_attr prefix). I can rewrite to just check for both <dialect>_attr and just attr in the AttrBuilder queries but I don't know what people think is best.

Will add/update tests shortly.

cc @Hardcode84

update:

Instead of disambiguating/namespacing by dialect mnemonic (what I originally planned) I use the real thing: cppNamespace. That means generated registered builders now look like this:

@register_attribute_builder("llvm_LLVM_IntegerOverflowFlagsAttr")
def _llvm_integeroverflowflagsattr(x, context):
    return _ods_ir.Attribute.parse(f'#llvm.overflow<{str(x)}>', context=context)

Note in "llvm_LLVM_IntegerOverflowFlagsAttr" the lowercase llvm is from this PR and the uppercase LLVM is from the def name in tablegen. The corresponding getters in the generated ops looks like this

if overflowFlags is not None:
    attributes["overflowFlags"] = (
        overflowFlags
        if (
            issubclass(type(overflowFlags), _ods_ir.Attribute)
            or not (
                _ods_ir.AttrBuilder.contains("LLVM_IntegerOverflowFlagsAttr")
                or _ods_ir.AttrBuilder.contains(
                    "llvm_LLVM_IntegerOverflowFlagsAttr"
                )
            )
        )
        else (
            _ods_ir.AttrBuilder.get("LLVM_IntegerOverflowFlagsAttr")(
                overflowFlags, context=_ods_context
            )
            if _ods_ir.AttrBuilder.contains("LLVM_IntegerOverflowFlagsAttr")
            else _ods_ir.AttrBuilder.get("llvm_LLVM_IntegerOverflowFlagsAttr")(
                overflowFlags, context=_ods_context
            )
        )
    )

where the action is here

_ods_ir.AttrBuilder.get("LLVM_IntegerOverflowFlagsAttr")(
    overflowFlags, context=_ods_context
)
if _ods_ir.AttrBuilder.contains("LLVM_IntegerOverflowFlagsAttr")
else _ods_ir.AttrBuilder.get("llvm_LLVM_IntegerOverflowFlagsAttr")(
    overflowFlags, context=_ods_context
)

in order to be backwards compatible for people that would have their own attribute builders registered (for their own attributes). Doing some kind of explicit "warning: this is deprecated" might be possible but seems like overengineering - I propose we inform people on discourse and then remove in a quarter (3 months).

The added test simulates the breakage that prompted this PR (overlapping attribute named IntegerOverflowFlagsAttr).

@makslevental makslevental force-pushed the fix_enum_bindings branch 3 times, most recently from 239e37d to c9ffacc Compare January 11, 2024 22:37
@makslevental makslevental marked this pull request as ready for review January 11, 2024 22:42
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jan 11, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 11, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

A recent PR broke enum bindings generation because of collision of attrs with the same name across dialects: #77211 (comment). So we need to namespace these now. In the current form of the PR, this is a breaking change (anyone that is supplying their own attribute builders won't have the &lt;dialect&gt;_attr prefix. I can rewrite to just check for both &lt;dialect&gt;_attr and just attr in the AttrBuilder queries but I don't know what people think is best.

Will add/update tests shortly.


Full diff: https://github.com/llvm/llvm-project/pull/77830.diff

2 Files Affected:

  • (modified) mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp (+25-28)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+10-7)
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index f4ced0803772ed..d5f36ff3bc0fd2 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -105,7 +105,8 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
     return true;
   }
 
-  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
+  os << llvm::formatv("@register_attribute_builder(\"{0}_{1}\")\n",
+                      enumAttr.getDialect().getName(),
                       enumAttr.getAttrDefName());
   os << llvm::formatv("def _{0}(x, context):\n",
                       enumAttr.getAttrDefName().lower());
@@ -120,11 +121,26 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
 /// Emits an attribute builder for the given dialect enum attribute to support
 /// automatic conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
-                                            StringRef formatString,
+static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
                                             raw_ostream &os) {
-  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
-  os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
+  StringRef mnemonic = attr.getMnemonic().value();
+  std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
+  StringRef dialect = attr.getDialect().getName();
+  std::string formatString;
+  if (assemblyFormat == "`<` $value `>`") {
+    formatString =
+        llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
+  } else if (assemblyFormat == "$value") {
+    formatString =
+        llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
+  } else {
+    llvm::errs()
+        << "unsupported assembly format for python enum bindings generation";
+    return true;
+  }
+  os << llvm::formatv("@register_attribute_builder(\"{0}_{1}\")\n",
+                      attr.getDialect().getName(), attr.getName());
+  os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
   os << llvm::formatv("    return "
                       "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
                       formatString);
@@ -142,29 +158,10 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
     emitEnumClass(enumAttr, os);
     emitAttributeBuilder(enumAttr, os);
   }
-  for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
-    AttrOrTypeDef attr(&*it);
-    if (!attr.getMnemonic()) {
-      llvm::errs() << "enum case " << attr
-                   << " needs mnemonic for python enum bindings generation";
-      return true;
-    }
-    StringRef mnemonic = attr.getMnemonic().value();
-    std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
-    StringRef dialect = attr.getDialect().getName();
-    if (assemblyFormat == "`<` $value `>`") {
-      emitDialectEnumAttributeBuilder(
-          attr.getName(),
-          llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
-    } else if (assemblyFormat == "$value") {
-      emitDialectEnumAttributeBuilder(
-          attr.getName(),
-          llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
-    } else {
-      llvm::errs()
-          << "unsupported assembly format for python enum bindings generation";
-      return true;
-    }
+  for (const auto &it :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+    const AttrOrTypeDef attr(&*it);
+    return emitDialectEnumAttributeBuilder(attr, os);
   }
 
   return false;
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e7..de343df1c434fa 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -529,27 +529,29 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
 /// Template for attribute builder from raw input in the operation builder.
 ///   {0} is the builder argument name;
 ///   {1} is the attribute builder from raw;
-///   {2} is the attribute builder from raw.
+///   {2} is the attribute builder from raw;
+///   {3} is the attribute's dialect.
 /// Use the value the user passed in if either it is already an Attribute or
 /// there is no method registered to make it an Attribute.
 constexpr const char *initAttributeWithBuilderTemplate =
     R"Py(attributes["{1}"] = ({0} if (
     issubclass(type({0}), _ods_ir.Attribute) or
-    not _ods_ir.AttrBuilder.contains('{2}')) else
-      _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+    not _ods_ir.AttrBuilder.contains('{2}_{3}')) else
+      _ods_ir.AttrBuilder.get('{2}_{3}')({0}, context=_ods_context)))Py";
 
 /// Template for attribute builder from raw input for optional attribute in the
 /// operation builder.
 ///   {0} is the builder argument name;
 ///   {1} is the attribute builder from raw;
-///   {2} is the attribute builder from raw.
+///   {2} is the attribute builder from raw;
+///   {3} is the attribute's dialect.
 /// Use the value the user passed in if either it is already an Attribute or
 /// there is no method registered to make it an Attribute.
 constexpr const char *initOptionalAttributeWithBuilderTemplate =
     R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
         issubclass(type({0}), _ods_ir.Attribute) or
-        not _ods_ir.AttrBuilder.contains('{2}')) else
-          _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+        not _ods_ir.AttrBuilder.contains('{2}_{3}')) else
+          _ods_ir.AttrBuilder.get('{2}_{3}')({0}, context=_ods_context)))Py";
 
 constexpr const char *initUnitAttributeTemplate =
     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -681,7 +683,8 @@ populateBuilderLinesAttr(const Operator &op,
         attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
             ? initOptionalAttributeWithBuilderTemplate
             : initAttributeWithBuilderTemplate,
-        argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+        argNames[i], attribute->name, attribute->attr.getAttrDefName(),
+        attribute->attr.getDialect().getName()));
   }
 }
 

@jpienaar
Copy link
Member

As discussed offline, I like uniform. But overall this looks like a good step.

@llvmbot llvmbot added the mlir:python MLIR Python bindings label Jan 12, 2024
@makslevental makslevental force-pushed the fix_enum_bindings branch 3 times, most recently from 0229db7 to 9e5daf8 Compare January 13, 2024 00:49
@makslevental makslevental force-pushed the fix_enum_bindings branch 4 times, most recently from a2288e5 to d973eac Compare January 13, 2024 02:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants