-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][vector] Better handle rank-preserving shape_cast #135855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
@llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesBefore this PR, the following operation
was valid. There were checks that n-d to k-d shape_casts are either strictly expanding (n < k) or collapsing (n > k) but the case of n = k was considered always legal w.r.t. shape. This was inconsistent -- why should rank-preserving shape_casts be allowed to be arbitrary reshapes? With this PR, rank-preserving shape_casts are only legal if they insert/remove dimensions of size 1. For example This PR also improves the error messages generated with One alternative to this PR: make all shape_casts valid (as long as element type and number of elements are unchanged). This would be nice and simple, and wouldn't cause any problems lowering to LLVM afaict? Patch is 24.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135855.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7fc56b1aa4e7e..7d5b5048131d8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,20 @@ def Vector_ShapeCastOp :
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
- The shape_cast operation casts between an n-D source vector shape and
- a k-D result vector shape (the element type remains the same).
+ The shape_cast operation casts from an n-D source vector to a k-D result
+ vector. The element type remains the same, as does the number of elements
+ (product of dimensions).
+
+ A shape_cast must be either collapsing or expanding. Collapsing means all
+ result dimension sizes are products of contiguous source dimension sizes.
+ Expanding means source dimensions all factor into contiguous sequences of
+ destination dimension sizes. Size 1 dimensions in source and destination
+ are ignored.
- If reducing rank (n > k), result dimension sizes must be a product
- of contiguous source dimension sizes.
- If expanding rank (n < k), source dimensions must factor into a
- contiguous sequence of destination dimension sizes.
Each source dim is expanded (or contiguous sequence of source dims combined)
in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
sequence of result dims (or a single result dim), in result dimension list
- order (i.e. 0 <= j < k). The product of all source dimension sizes and all
- result dimension sizes must match.
+ order (i.e. 0 <= j < k).
It is currently assumed that this operation does not require moving data,
and that it will be folded away before lowering vector operations.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..20162e93c88e8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5532,40 +5532,38 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
- unsigned rankA = a.size();
- unsigned rankB = b.size();
- assert(rankA < rankB);
-
- auto isOne = [](int64_t v) { return v == 1; };
-
- // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
- // casted to a 0-d vector.
- if (rankA == 0 && llvm::all_of(b, isOne))
- return true;
+/// Returns true if each element of 'a' is either 1 or equal to the product of a
+/// contiguous sequence of the elements of 'b'. Returns false otherwise.
+///
+/// This function assumes that the product of elements in a and b are the same.
+static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+ unsigned rankA = a.size();
unsigned i = 0;
unsigned j = 0;
- while (i < rankA && j < rankB) {
+ while (i < rankA) {
+ if (a[i] == 1) {
+ ++i;
+ continue;
+ }
+
int64_t dimA = a[i];
int64_t dimB = 1;
- while (dimB < dimA && j < rankB)
+
+ while (dimB < dimA) {
dimB *= b[j++];
- if (dimA != dimB)
- break;
- ++i;
+ }
- // Handle the case when trailing dimensions are of size 1.
- // Include them into the contiguous sequence.
- if (i < rankA && llvm::all_of(a.slice(i), isOne))
- i = rankA;
- if (j < rankB && llvm::all_of(b.slice(j), isOne))
- j = rankB;
+ if (dimA != dimB) {
+ return false;
+ }
+ ++i;
}
+ return true;
+}
- return i == rankA && j == rankB;
+static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+ return isExpandingShapeCast(a, b) || isExpandingShapeCast(b, a);
}
static LogicalResult verifyVectorShapeCast(Operation *op,
@@ -5573,34 +5571,33 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
VectorType resultVectorType) {
// Check that element type is the same.
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
- return op->emitOpError("source/result vectors must have same element type");
- auto sourceShape = sourceVectorType.getShape();
- auto resultShape = resultVectorType.getShape();
-
- // Check that product of source dim sizes matches product of result dim sizes.
- int64_t sourceDimProduct = std::accumulate(
- sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
- int64_t resultDimProduct = std::accumulate(
- resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
- if (sourceDimProduct != resultDimProduct)
- return op->emitOpError("source/result number of elements must match");
-
- // Check that expanding/contracting rank cases.
- unsigned sourceRank = sourceVectorType.getRank();
- unsigned resultRank = resultVectorType.getRank();
- if (sourceRank < resultRank) {
- if (!isValidShapeCast(sourceShape, resultShape))
- return op->emitOpError("invalid shape cast");
- } else if (sourceRank > resultRank) {
- if (!isValidShapeCast(resultShape, sourceShape))
- return op->emitOpError("invalid shape cast");
+ return op->emitOpError("has different source and result element types");
+ ArrayRef<int64_t> inShape = sourceVectorType.getShape();
+ ArrayRef<int64_t> outShape = resultVectorType.getShape();
+
+ // Check that product of source dim sizes matches product of result dim
+ // sizes.
+ int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
+ std::multiplies<int64_t>{});
+ int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
+ std::multiplies<int64_t>{});
+
+ if (nInElms != nOutElms) {
+ return op->emitOpError(
+ "has a different number of source and result elements");
+ }
+
+ if (!isValidShapeCast(inShape, outShape)) {
+ return op->emitOpError(
+ "is invalid (does not uniformly collapse or expand)");
}
// Check that (non-)scalability is preserved
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
if (sourceNScalableDims != resultNScalableDims)
- return op->emitOpError("different number of scalable dims at source (")
+ return op->emitOpError(
+ "has a different number of scalable dims at source (")
<< sourceNScalableDims << ") and result (" << resultNScalableDims
<< ")";
sourceVectorType.getNumDynamicDims();
@@ -5634,17 +5631,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();
+
if (resultType == srcType)
return otherOp.getSource();
- if (srcType.getRank() < resultType.getRank()) {
- if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
- return {};
- } else if (srcType.getRank() > resultType.getRank()) {
- if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
- return {};
- } else {
+
+ if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
- }
+
setOperand(otherOp.getSource());
return getResult();
}
@@ -6459,8 +6452,8 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
- // Keep the default yield terminator if the number of masked operations is not
- // the expected. This case will trigger a verification failure.
+ // Keep the default yield terminator if the number of masked operations is
+ // not as expected. This case will trigger a verification failure.
Block &block = region.front();
if (block.getOperations().size() != 2)
return;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..986e11d948052 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -950,14 +950,16 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
// -----
+// The definition of shape_cast stipulates that it must be either expanding or collapsing,
+// it cannot be a mixture of both.
// CHECK-LABEL: dont_fold_expand_collapse
-// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-// CHECK: return %[[B]] : vector<8x8xf32>
-func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
- %0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
- %1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
- return %1 : vector<8x8xf32>
+// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+// CHECK: return %[[B]] : vector<4x3x3xf32>
+func.func @dont_fold_expand_collapse(%arg0: vector<2x2x9xf32>) -> vector<4x3x3xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+ %1 = vector.shape_cast %0 : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+ return %1 : vector<4x3x3xf32>
}
// -----
@@ -1290,12 +1292,12 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
// -----
// CHECK-LABEL: consecutive_shape_cast
-// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
-// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
-func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
+// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<2x2x4xf16>
+// CHECK-NEXT: return %[[C]] : vector<2x2x4xf16>
+func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<2x2x4xf16> {
%0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
- %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
- return %1 : vector<4x4xf16>
+ %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<2x2x4xf16>
+ return %1 : vector<2x2x4xf16>
}
// -----
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index dbf829e014b8d..45b7b44d47039 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// -----
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error@+1 {{op source/result vectors must have same element type}}
+ // expected-error@+1 {{op has different source and result element types}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
}
// -----
func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error@+1 {{op source/result number of elements must match}}
+ // expected-error@+1 {{op has a different number of source and result elements}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
}
// -----
+func.func @shape_cast_invalid_rank_preserving(%arg0 : vector<3x2xf32>) {
+ // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
+ %0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32>
+}
+
+// -----
+
func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error@+1 {{invalid shape cast}}
+ // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
}
// -----
func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
- // expected-error@+1 {{invalid shape cast}}
+ // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..504c6c300e9f0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -576,6 +576,34 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
return %1 : vector<1x1x1x1xf32>
}
+// CHECK-LABEL: @shape_cast_rank_preserving
+func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {
+
+ // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+ %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
+
+ return %0 : vector<4x1xf32>
+}
+
+
+// CHECK-LABEL: @collapse_but_increase_rank
+func.func @collapse_but_increase_rank(%arg0 : vector<2x3x5x7xf32>) -> vector<1x6x1x35x1xf32> {
+
+ // CHECK: vector.shape_cast %{{.*}} : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+ %0 = vector.shape_cast %arg0 : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+
+ return %0 : vector<1x6x1x35x1xf32>
+}
+
+// CHECK-LABEL: @expand_but_decrease_rank
+func.func @expand_but_decrease_rank(%arg0 : vector<1x1x6xi8>) -> vector<2x3xi8> {
+
+ // CHECK: vector.shape_cast %{{.*}} : vector<1x1x6xi8> to vector<2x3xi8>
+ %0 = vector.shape_cast %arg0 : vector<1x1x6xi8> to vector<2x3xi8>
+
+ return %0 : vector<2x3xi8>
+}
+
// CHECK-LABEL: @bitcast
func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xi32>,
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
index f4becad3c79c1..2faa47c1b08a8 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -74,23 +74,23 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
// CHECK-LABEL: f32_permute_leading_non_scalable_dims
// CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
-func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
- // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[4]xf32>
+func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<1x6x[4]xf32> {
+ // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [0, 2] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [0, 3] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [0, 4] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
- %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32>
- // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32>
- return %res : vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [0, 5] : vector<[4]xf32> into vector<1x6x[4]xf32>
+ %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<1x6x[4]xf32>
+ // CHECK-NEXT: return %[[res5]] : vector<1x6x[4]xf32>
+ return %res : vector<1x6x[4]xf32>
}
// -----
@@ -117,48 +117,48 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
// CHECK-LABEL: f32_reduce_trailing_scalable_dim
// CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
-func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
+func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<3x2x[2]xf32>
{
- // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<6x[2]xf32>
+ // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[2]xf32>
// CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32>
- %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32>
- // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32>
- return %res: vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[2]xf32> into ...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesBefore this PR, the following operation
was valid. There were checks that n-d to k-d shape_casts are either strictly expanding (n < k) or collapsing (n > k) but the case of n = k was considered always legal w.r.t. shape. This was inconsistent -- why should rank-preserving shape_casts be allowed to be arbitrary reshapes? With this PR, rank-preserving shape_casts are only legal if they insert/remove dimensions of size 1. For example This PR also improves the error messages generated with One alternative to this PR: make all shape_casts valid (as long as element type and number of elements are unchanged). This would be nice and simple, and wouldn't cause any problems lowering to LLVM afaict? Patch is 24.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135855.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7fc56b1aa4e7e..7d5b5048131d8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,20 @@ def Vector_ShapeCastOp :
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
- The shape_cast operation casts between an n-D source vector shape and
- a k-D result vector shape (the element type remains the same).
+ The shape_cast operation casts from an n-D source vector to a k-D result
+ vector. The element type remains the same, as does the number of elements
+ (product of dimensions).
+
+ A shape_cast must be either collapsing or expanding. Collapsing means all
+ result dimension sizes are products of contiguous source dimension sizes.
+ Expanding means source dimensions all factor into contiguous sequences of
+ destination dimension sizes. Size 1 dimensions in source and destination
+ are ignored.
- If reducing rank (n > k), result dimension sizes must be a product
- of contiguous source dimension sizes.
- If expanding rank (n < k), source dimensions must factor into a
- contiguous sequence of destination dimension sizes.
Each source dim is expanded (or contiguous sequence of source dims combined)
in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
sequence of result dims (or a single result dim), in result dimension list
- order (i.e. 0 <= j < k). The product of all source dimension sizes and all
- result dimension sizes must match.
+ order (i.e. 0 <= j < k).
It is currently assumed that this operation does not require moving data,
and that it will be folded away before lowering vector operations.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..20162e93c88e8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5532,40 +5532,38 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
- unsigned rankA = a.size();
- unsigned rankB = b.size();
- assert(rankA < rankB);
-
- auto isOne = [](int64_t v) { return v == 1; };
-
- // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
- // casted to a 0-d vector.
- if (rankA == 0 && llvm::all_of(b, isOne))
- return true;
+/// Returns true if each element of 'a' is either 1 or equal to the product of a
+/// contiguous sequence of the elements of 'b'. Returns false otherwise.
+///
+/// This function assumes that the product of elements in a and b are the same.
+static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+ unsigned rankA = a.size();
unsigned i = 0;
unsigned j = 0;
- while (i < rankA && j < rankB) {
+ while (i < rankA) {
+ if (a[i] == 1) {
+ ++i;
+ continue;
+ }
+
int64_t dimA = a[i];
int64_t dimB = 1;
- while (dimB < dimA && j < rankB)
+
+ while (dimB < dimA) {
dimB *= b[j++];
- if (dimA != dimB)
- break;
- ++i;
+ }
- // Handle the case when trailing dimensions are of size 1.
- // Include them into the contiguous sequence.
- if (i < rankA && llvm::all_of(a.slice(i), isOne))
- i = rankA;
- if (j < rankB && llvm::all_of(b.slice(j), isOne))
- j = rankB;
+ if (dimA != dimB) {
+ return false;
+ }
+ ++i;
}
+ return true;
+}
- return i == rankA && j == rankB;
+static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+ return isExpandingShapeCast(a, b) || isExpandingShapeCast(b, a);
}
static LogicalResult verifyVectorShapeCast(Operation *op,
@@ -5573,34 +5571,33 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
VectorType resultVectorType) {
// Check that element type is the same.
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
- return op->emitOpError("source/result vectors must have same element type");
- auto sourceShape = sourceVectorType.getShape();
- auto resultShape = resultVectorType.getShape();
-
- // Check that product of source dim sizes matches product of result dim sizes.
- int64_t sourceDimProduct = std::accumulate(
- sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
- int64_t resultDimProduct = std::accumulate(
- resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
- if (sourceDimProduct != resultDimProduct)
- return op->emitOpError("source/result number of elements must match");
-
- // Check that expanding/contracting rank cases.
- unsigned sourceRank = sourceVectorType.getRank();
- unsigned resultRank = resultVectorType.getRank();
- if (sourceRank < resultRank) {
- if (!isValidShapeCast(sourceShape, resultShape))
- return op->emitOpError("invalid shape cast");
- } else if (sourceRank > resultRank) {
- if (!isValidShapeCast(resultShape, sourceShape))
- return op->emitOpError("invalid shape cast");
+ return op->emitOpError("has different source and result element types");
+ ArrayRef<int64_t> inShape = sourceVectorType.getShape();
+ ArrayRef<int64_t> outShape = resultVectorType.getShape();
+
+ // Check that product of source dim sizes matches product of result dim
+ // sizes.
+ int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
+ std::multiplies<int64_t>{});
+ int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
+ std::multiplies<int64_t>{});
+
+ if (nInElms != nOutElms) {
+ return op->emitOpError(
+ "has a different number of source and result elements");
+ }
+
+ if (!isValidShapeCast(inShape, outShape)) {
+ return op->emitOpError(
+ "is invalid (does not uniformly collapse or expand)");
}
// Check that (non-)scalability is preserved
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
if (sourceNScalableDims != resultNScalableDims)
- return op->emitOpError("different number of scalable dims at source (")
+ return op->emitOpError(
+ "has a different number of scalable dims at source (")
<< sourceNScalableDims << ") and result (" << resultNScalableDims
<< ")";
sourceVectorType.getNumDynamicDims();
@@ -5634,17 +5631,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();
+
if (resultType == srcType)
return otherOp.getSource();
- if (srcType.getRank() < resultType.getRank()) {
- if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
- return {};
- } else if (srcType.getRank() > resultType.getRank()) {
- if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
- return {};
- } else {
+
+ if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
- }
+
setOperand(otherOp.getSource());
return getResult();
}
@@ -6459,8 +6452,8 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
- // Keep the default yield terminator if the number of masked operations is not
- // the expected. This case will trigger a verification failure.
+ // Keep the default yield terminator if the number of masked operations is
+ // not as expected. This case will trigger a verification failure.
Block &block = region.front();
if (block.getOperations().size() != 2)
return;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..986e11d948052 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -950,14 +950,16 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
// -----
+// The definition of shape_cast stipulates that it must be either expanding or collapsing,
+// it cannot be a mixture of both.
// CHECK-LABEL: dont_fold_expand_collapse
-// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-// CHECK: return %[[B]] : vector<8x8xf32>
-func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
- %0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
- %1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
- return %1 : vector<8x8xf32>
+// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+// CHECK: return %[[B]] : vector<4x3x3xf32>
+func.func @dont_fold_expand_collapse(%arg0: vector<2x2x9xf32>) -> vector<4x3x3xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+ %1 = vector.shape_cast %0 : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+ return %1 : vector<4x3x3xf32>
}
// -----
@@ -1290,12 +1292,12 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
// -----
// CHECK-LABEL: consecutive_shape_cast
-// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
-// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
-func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
+// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<2x2x4xf16>
+// CHECK-NEXT: return %[[C]] : vector<2x2x4xf16>
+func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<2x2x4xf16> {
%0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
- %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
- return %1 : vector<4x4xf16>
+ %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<2x2x4xf16>
+ return %1 : vector<2x2x4xf16>
}
// -----
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index dbf829e014b8d..45b7b44d47039 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// -----
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error@+1 {{op source/result vectors must have same element type}}
+ // expected-error@+1 {{op has different source and result element types}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
}
// -----
func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error@+1 {{op source/result number of elements must match}}
+ // expected-error@+1 {{op has a different number of source and result elements}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
}
// -----
+func.func @shape_cast_invalid_rank_preserving(%arg0 : vector<3x2xf32>) {
+ // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
+ %0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32>
+}
+
+// -----
+
func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error@+1 {{invalid shape cast}}
+ // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
}
// -----
func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
- // expected-error@+1 {{invalid shape cast}}
+ // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..504c6c300e9f0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -576,6 +576,34 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
return %1 : vector<1x1x1x1xf32>
}
+// CHECK-LABEL: @shape_cast_rank_preserving
+func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {
+
+ // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+ %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
+
+ return %0 : vector<4x1xf32>
+}
+
+
+// CHECK-LABEL: @collapse_but_increase_rank
+func.func @collapse_but_increase_rank(%arg0 : vector<2x3x5x7xf32>) -> vector<1x6x1x35x1xf32> {
+
+ // CHECK: vector.shape_cast %{{.*}} : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+ %0 = vector.shape_cast %arg0 : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+
+ return %0 : vector<1x6x1x35x1xf32>
+}
+
+// CHECK-LABEL: @expand_but_decrease_rank
+func.func @expand_but_decrease_rank(%arg0 : vector<1x1x6xi8>) -> vector<2x3xi8> {
+
+ // CHECK: vector.shape_cast %{{.*}} : vector<1x1x6xi8> to vector<2x3xi8>
+ %0 = vector.shape_cast %arg0 : vector<1x1x6xi8> to vector<2x3xi8>
+
+ return %0 : vector<2x3xi8>
+}
+
// CHECK-LABEL: @bitcast
func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xi32>,
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
index f4becad3c79c1..2faa47c1b08a8 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -74,23 +74,23 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
// CHECK-LABEL: f32_permute_leading_non_scalable_dims
// CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
-func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
- // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[4]xf32>
+func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<1x6x[4]xf32> {
+ // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [0, 2] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [0, 3] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [0, 4] : vector<[4]xf32> into vector<1x6x[4]xf32>
// CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
- // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
- %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32>
- // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32>
- return %res : vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [0, 5] : vector<[4]xf32> into vector<1x6x[4]xf32>
+ %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<1x6x[4]xf32>
+ // CHECK-NEXT: return %[[res5]] : vector<1x6x[4]xf32>
+ return %res : vector<1x6x[4]xf32>
}
// -----
@@ -117,48 +117,48 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
// CHECK-LABEL: f32_reduce_trailing_scalable_dim
// CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
-func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
+func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<3x2x[2]xf32>
{
- // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<6x[2]xf32>
+ // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[2]xf32>
// CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
// CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32>
- // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32>
- %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32>
- // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32>
- return %res: vector<6x[2]xf32>
+ // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[2]xf32> into ...
[truncated]
|
Hey, thanks for looking into this! What is the issue you are trying to address? The shape cast op in your example seems valid to me. Shape cast is a generic reshape op that allows rank increasing/decreasing/preserving transformations. |
Unfortunately it isn't. The description here says it cannot mix dimensions, so it is like the union of the tensor dialect's expand_shape and collapse_shape Perhaps it should be like tensor dialects reshape, I would be happy to go down that path instead. But this PR is to make it consistent with how it's currently defined. Example currently:
IMO it doesn't make sense that the second one above verifies, but the first one doesn't -- both mix dimensions. The reason |
The changes needed to make this true: #136587 |
`vector.shape_cast` was initially designed to be the union of collapse_shape and expand_shape. There was an inconsistency in the verifier that allowed any shape casts when the rank did not change, which led to a strange middle ground where you could cast from shape (4,3) to (3,4) but not from (4,3) to (2,3,2). That issue was fixed (verifier made stricter) in #135855, but further feedback there (and polling) suggests that vector.shape_cast should rather allow all shape casts (so more like tensor.reshape than tensor.collapse_shape/tensor.expand_shape). This PR makes this simplification by relaxing the verifier.
Before this PR, the following operation
was valid. There were checks that n-d to k-d shape_casts are either strictly expanding (n < k) or collapsing (n > k) but the case of n = k was considered always legal w.r.t. shape. This was inconsistent -- why should rank-preserving shape_casts be allowed to be arbitrary reshapes?
With this PR, rank-preserving shape_casts are only legal if they insert/remove dimensions of size 1. For example
<1x4xf32> -> <4x1xf32>
is valid, but the example above is not. This makes it consistent with then < k
andn > k
cases.This PR also improves the error messages generated with
emitOpError
.One alternative to this PR: make all shape_casts valid (as long as element type and number of elements are unchanged). This would be nice and simple, and wouldn't cause any problems lowering to LLVM afaict?