-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][vector] Relax constraints on shape_cast #136587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
The shape_cast operation casts from a source vector to a target vector, | ||
retaining the element type and number of elements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition, the number of scalable dimensions needs to be preserved. We should be verifying that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Is something like <3x[4]xf32> -> <2x[2]x3xf32> valid, or do I need to check more than just the number?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized this was already in the verifier, so logic unchanged w.r.t. scalable dims. Added a comment in the .td.
This approach makes more sense to me — thanks! I don't see a need to "special-case" vector.shape_cast; this PR aligns more closely with how I’ve been thinking about it. In fact, I wasn’t even aware of the restrictions this is removing - and I suspect Diego wasn’t either, based on his comment. Given the feedback so far, I’d suggest closing #135855 and marking this PR as ready for review instead - unless I’m missing something? 🤔 |
Yeah I think I prefer this design too. Plus, simpler is better! |
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: James Newling (newling) Changes
This PR also adds a simple canonicalizer, so that
becomes
Full diff: https://github.com/llvm/llvm-project/pull/136587.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d7518943229ea..4d49e52b21563 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,8 @@ def Vector_ShapeCastOp :
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
- The shape_cast operation casts between an n-D source vector shape and
- a k-D result vector shape (the element type remains the same).
-
- If reducing rank (n > k), result dimension sizes must be a product
- of contiguous source dimension sizes.
- If expanding rank (n < k), source dimensions must factor into a
- contiguous sequence of destination dimension sizes.
- Each source dim is expanded (or contiguous sequence of source dims combined)
- in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
- sequence of result dims (or a single result dim), in result dimension list
- order (i.e. 0 <= j < k). The product of all source dimension sizes and all
- result dimension sizes must match.
+ Casts to a vector with the same number of elements, element type, and
+ number of scalable dimensions.
It is currently assumed that this operation does not require moving data,
and that it will be folded away before lowering vector operations.
@@ -2265,15 +2255,13 @@ def Vector_ShapeCastOp :
2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM
is supported in that particular case, for now.
- Example:
+ Examples:
```mlir
- // Example casting to a lower vector rank.
- %1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32>
-
- // Example casting to a higher vector rank.
- %3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32>
+ %1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32>
+ // with 2 scalable dimensions (number of which must be preserved).
+ %3 = vector.shape_cast %2 : vector<[2]x3x[4]xi8> to vector<3x[1]x[8]xi8>
```
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 368259b38b153..732a5d21a4b87 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5505,127 +5505,67 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
- unsigned rankA = a.size();
- unsigned rankB = b.size();
- assert(rankA < rankB);
-
- auto isOne = [](int64_t v) { return v == 1; };
-
- // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
- // casted to a 0-d vector.
- if (rankA == 0 && llvm::all_of(b, isOne))
- return true;
-
- unsigned i = 0;
- unsigned j = 0;
- while (i < rankA && j < rankB) {
- int64_t dimA = a[i];
- int64_t dimB = 1;
- while (dimB < dimA && j < rankB)
- dimB *= b[j++];
- if (dimA != dimB)
- break;
- ++i;
+LogicalResult ShapeCastOp::verify() {
- // Handle the case when trailing dimensions are of size 1.
- // Include them into the contiguous sequence.
- if (i < rankA && llvm::all_of(a.slice(i), isOne))
- i = rankA;
- if (j < rankB && llvm::all_of(b.slice(j), isOne))
- j = rankB;
- }
+ VectorType sourceType = getSourceVectorType();
+ VectorType resultType = getResultVectorType();
- return i == rankA && j == rankB;
-}
+ // Check that element type is preserved
+ if (sourceType.getElementType() != resultType.getElementType())
+ return emitOpError("has different source and result element types");
-static LogicalResult verifyVectorShapeCast(Operation *op,
- VectorType sourceVectorType,
- VectorType resultVectorType) {
- // Check that element type is the same.
- if (sourceVectorType.getElementType() != resultVectorType.getElementType())
- return op->emitOpError("source/result vectors must have same element type");
- auto sourceShape = sourceVectorType.getShape();
- auto resultShape = resultVectorType.getShape();
-
- // Check that product of source dim sizes matches product of result dim sizes.
- int64_t sourceDimProduct = std::accumulate(
- sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
- int64_t resultDimProduct = std::accumulate(
- resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
- if (sourceDimProduct != resultDimProduct)
- return op->emitOpError("source/result number of elements must match");
-
- // Check that expanding/contracting rank cases.
- unsigned sourceRank = sourceVectorType.getRank();
- unsigned resultRank = resultVectorType.getRank();
- if (sourceRank < resultRank) {
- if (!isValidShapeCast(sourceShape, resultShape))
- return op->emitOpError("invalid shape cast");
- } else if (sourceRank > resultRank) {
- if (!isValidShapeCast(resultShape, sourceShape))
- return op->emitOpError("invalid shape cast");
+ // Check that number of elements is preserved
+ int64_t sourceNElms = sourceType.getNumElements();
+ int64_t resultNElms = resultType.getNumElements();
+ if (sourceNElms != resultNElms) {
+ return emitOpError() << "has different number of elements at source ("
+ << sourceNElms << ") and result (" << resultNElms
+ << ")";
}
// Check that (non-)scalability is preserved
- int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
- int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
+ int64_t sourceNScalableDims = sourceType.getNumScalableDims();
+ int64_t resultNScalableDims = resultType.getNumScalableDims();
if (sourceNScalableDims != resultNScalableDims)
- return op->emitOpError("different number of scalable dims at source (")
- << sourceNScalableDims << ") and result (" << resultNScalableDims
- << ")";
- sourceVectorType.getNumDynamicDims();
-
- return success();
-}
-
-LogicalResult ShapeCastOp::verify() {
- auto sourceVectorType =
- llvm::dyn_cast_or_null<VectorType>(getSource().getType());
- auto resultVectorType =
- llvm::dyn_cast_or_null<VectorType>(getResult().getType());
-
- // Check if source/result are of vector type.
- if (sourceVectorType && resultVectorType)
- return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
+ return emitOpError() << "has different number of scalable dims at source ("
+ << sourceNScalableDims << ") and result ("
+ << resultNScalableDims << ")";
return success();
}
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
+ VectorType resultType = getType();
+
// No-op shape cast.
- if (getSource().getType() == getType())
+ if (getSource().getType() == resultType)
return getSource();
- VectorType resultType = getType();
-
- // Canceling shape casts.
+ // Y = shape_cast(shape_cast(X)))
+ // -> X, if X and Y have same type
+ // -> shape_cast(X) otherwise.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
- // Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();
if (resultType == srcType)
return otherOp.getSource();
- if (srcType.getRank() < resultType.getRank()) {
- if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
- return {};
- } else if (srcType.getRank() > resultType.getRank()) {
- if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
- return {};
- } else {
- return {};
- }
setOperand(otherOp.getSource());
return getResult();
}
- // Cancelling broadcast and shape cast ops.
+ // Y = shape_cast(broadcast(X))
+ // -> X, if X and Y have same type, else
+ // -> shape_cast(X) if X is a vector and the broadcast preserves
+ // number of elements.
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
if (bcastOp.getSourceType() == resultType)
return bcastOp.getSource();
+ if (auto bcastSrcType = dyn_cast<VectorType>(bcastOp.getSourceType())) {
+ if (bcastSrcType.getNumElements() == resultType.getNumElements()) {
+ setOperand(bcastOp.getSource());
+ return getResult();
+ }
+ }
}
// shape_cast(constant) -> constant
diff --git a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
index ae2b5393ca449..60ad54bf5c370 100644
--- a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
@@ -26,8 +26,7 @@ func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
// CHECK-NEXT: vector.insert {{.*}}[1]
// CHECK-NEXT: vector.insert {{.*}}[2]
// CHECK-NEXT: vector.insert {{.*}}[3]
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
+ // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<8x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
return %0 : vector<8x4xf32>
}
@@ -54,8 +53,7 @@ func.func @transpose021_1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32>
// CHECK-NEXT: vector.insert {{.*}}[1]
// CHECK-NEXT: vector.insert {{.*}}[2]
// CHECK-NEXT: vector.insert {{.*}}[3]
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32>
+ // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<1x8x4xf32>
%0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32>
return %0 : vector<1x8x4xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2d365ac2b4287..04d8e613d4156 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -950,10 +950,9 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
// -----
-// CHECK-LABEL: dont_fold_expand_collapse
-// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-// CHECK: return %[[B]] : vector<8x8xf32>
+// CHECK-LABEL: fold_expand_collapse
+// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<8x8xf32>
+// CHECK: return %[[A]] : vector<8x8xf32>
func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
%0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
%1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
@@ -973,6 +972,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
// -----
+// CHECK-LABEL: func @fold_count_preserving_broadcast_shapecast
+// CHECK-SAME: (%[[V:.+]]: vector<4xf32>)
+// CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[V]] : vector<4xf32> to vector<2x2xf32>
+// CHECK: return %[[SHAPECAST]] : vector<2x2xf32>
+func.func @fold_count_preserving_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<2x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32>
+ %1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
// CHECK: vector.broadcast
// CHECK-NOT: vector.shape_cast
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3a8320971bac4..fa4837126accb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1131,34 +1131,21 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// -----
+
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error@+1 {{op source/result vectors must have same element type}}
+ // expected-error@+1 {{'vector.shape_cast' op has different source and result element types}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
}
// -----
func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error@+1 {{op source/result number of elements must match}}
+ // expected-error@+1 {{'vector.shape_cast' op has different number of elements at source (30) and result (20)}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
}
// -----
-func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error@+1 {{invalid shape cast}}
- %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
-}
-
-// -----
-
-func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
- // expected-error@+1 {{invalid shape cast}}
- %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
-}
-
-// -----
-
func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
// expected-error@+1 {{different number of scalable dims at source (1) and result (0)}}
%0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..36f7db8c39d4d 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -543,6 +543,20 @@ func.func @vector_print_on_scalar(%arg0: i64) {
return
}
+// CHECK-LABEL: @shape_cast_valid_rank_reduction
+func.func @shape_cast_valid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
+ // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<2x15xf32>
+ %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
+ return
+}
+
+// CHECK-LABEL: @shape_cast_valid_rank_expansion
+func.func @shape_cast_valid_rank_expansion(%arg0 : vector<15x2xf32>) {
+ // CHECK: vector.shape_cast %{{.*}} : vector<15x2xf32> to vector<5x2x3x1xf32>
+ %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
+ return
+}
+
// CHECK-LABEL: @shape_cast
func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xf32>,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
I suggested adding a bit of clarification in the tests. Also, given the significance of this change, please leave ~48 hours for other reviewers to take a look, in case they'd like to weigh in.
Thanks again for the clean-up!
vector.shape_cast
was initially designed to be the union of collapse_shape and expand_shape. There was a inconsistency in the verifier that allowed any shape casts if the rank did not change, which led to a strange middle ground where you could cast from shape (4,3) to (3,4) but not from (4,3) to (2,3,2). That issue was fixed in #135855, but further feedback there and polling suggests that vector.shape_cast should rather just allow all shape casts (so more like tensor.reshape than tensor.collapse_shape/tensor.expand_shape). This PR makes it that by relaxing the verifier.