-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][vector] Use DenseI64ArrayAttr
in vector.multi_reduction
#102637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This prevents some unnecessary conversions to/from int64_t and IntegerAttr.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Benjamin Maxwell (MacDue) ChangesThis prevents some unnecessary conversions to/from int64_t and IntegerAttr. Full diff: https://github.com/llvm/llvm-project/pull/102637.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 925eb80dbe71ec..b96f5c2651bce5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -286,7 +286,7 @@ def Vector_MultiDimReductionOp :
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVector:$source,
AnyType:$acc,
- I64ArrayAttr:$reduction_dims)>,
+ DenseI64ArrayAttr:$reduction_dims)>,
Results<(outs AnyType:$dest)> {
let summary = "Multi-dimensional reduction operation";
let description = [{
@@ -325,8 +325,8 @@ def Vector_MultiDimReductionOp :
SmallVector<bool> getReductionMask() {
SmallVector<bool> res(getSourceVectorType().getRank(), false);
- for (auto ia : getReductionDims().getAsRange<IntegerAttr>())
- res[ia.getInt()] = true;
+ for (int64_t dim : getReductionDims())
+ res[dim] = true;
return res;
}
static SmallVector<bool> getReductionMask(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ab4485c37e5e7f..60b4f93a53ad43 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -445,8 +445,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
for (const auto &en : llvm::enumerate(reductionMask))
if (en.value())
reductionDims.push_back(en.index());
- build(builder, result, kind, source, acc,
- builder.getI64ArrayAttr(reductionDims));
+ build(builder, result, kind, source, acc, reductionDims);
}
OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
@@ -467,8 +466,8 @@ LogicalResult MultiDimReductionOp::verify() {
Type inferredReturnType;
auto sourceScalableDims = getSourceVectorType().getScalableDims();
for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
- if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
- return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
+ if (!llvm::any_of(getReductionDims(), [&](int64_t dim) {
+ return dim == static_cast<int64_t>(it.index());
})) {
targetShape.push_back(it.value());
scalableDims.push_back(sourceScalableDims[it.index()]);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index ac576ed0b4f097..716da55ba09aec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -67,10 +67,7 @@ class InnerOuterDimReductionConversion
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Separate reduction and parallel dims
- auto reductionDimsRange =
- multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
- auto reductionDims = llvm::to_vector<4>(llvm::map_range(
- reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
reductionDims.end());
int64_t reductionSize = reductionDims.size();
|
This is the final one of these tidy ups (alongside #101850) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
This prevents some unnecessary conversions to/from int64_t and IntegerAttr.