Skip to content

[mlir][vector] Use DenseI64ArrayAttr in vector.multi_reduction #102637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 10, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Aug 9, 2024

This prevents some unnecessary conversions to/from int64_t and IntegerAttr.

This prevents some unnecessary conversions to/from int64_t and
IntegerAttr.
@llvmbot
Copy link
Member

llvmbot commented Aug 9, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Benjamin Maxwell (MacDue)

Changes

This prevents some unnecessary conversions to/from int64_t and IntegerAttr.


Full diff: https://github.com/llvm/llvm-project/pull/102637.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+3-3)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+3-4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+1-4)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 925eb80dbe71ec..b96f5c2651bce5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -286,7 +286,7 @@ def Vector_MultiDimReductionOp :
     Arguments<(ins Vector_CombiningKindAttr:$kind,
                    AnyVector:$source,
                    AnyType:$acc,
-                   I64ArrayAttr:$reduction_dims)>,
+                   DenseI64ArrayAttr:$reduction_dims)>,
     Results<(outs AnyType:$dest)> {
   let summary = "Multi-dimensional reduction operation";
   let description = [{
@@ -325,8 +325,8 @@ def Vector_MultiDimReductionOp :
 
     SmallVector<bool> getReductionMask() {
       SmallVector<bool> res(getSourceVectorType().getRank(), false);
-      for (auto ia : getReductionDims().getAsRange<IntegerAttr>())
-        res[ia.getInt()] = true;
+      for (int64_t dim : getReductionDims())
+        res[dim] = true;
       return res;
     }
     static SmallVector<bool> getReductionMask(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ab4485c37e5e7f..60b4f93a53ad43 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -445,8 +445,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   for (const auto &en : llvm::enumerate(reductionMask))
     if (en.value())
       reductionDims.push_back(en.index());
-  build(builder, result, kind, source, acc,
-        builder.getI64ArrayAttr(reductionDims));
+  build(builder, result, kind, source, acc, reductionDims);
 }
 
 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
@@ -467,8 +466,8 @@ LogicalResult MultiDimReductionOp::verify() {
   Type inferredReturnType;
   auto sourceScalableDims = getSourceVectorType().getScalableDims();
   for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
-    if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
-          return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
+    if (!llvm::any_of(getReductionDims(), [&](int64_t dim) {
+          return dim == static_cast<int64_t>(it.index());
         })) {
       targetShape.push_back(it.value());
       scalableDims.push_back(sourceScalableDims[it.index()]);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index ac576ed0b4f097..716da55ba09aec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -67,10 +67,7 @@ class InnerOuterDimReductionConversion
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
 
     // Separate reduction and parallel dims
-    auto reductionDimsRange =
-        multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
-    auto reductionDims = llvm::to_vector<4>(llvm::map_range(
-        reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
     llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
                                                   reductionDims.end());
     int64_t reductionSize = reductionDims.size();

@MacDue
Copy link
Member Author

MacDue commented Aug 9, 2024

This is the final one of these tidy ups (alongside #101850)

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@MacDue MacDue merged commit 5f26497 into llvm:main Aug 10, 2024
8 checks passed
@MacDue MacDue deleted the dense_multi_reduction branch August 10, 2024 13:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants