Skip to content

[mlir][vector] Better handle rank-preserving shape_cast #135855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

newling
Copy link
Contributor

@newling newling commented Apr 15, 2025

Before this PR, the following operation

%1 = vector.shape_cast %0 : vector<3x2xf32> to vector<2x3xf32>

was valid. There were checks that n-d to k-d shape_casts are either strictly expanding (n < k) or collapsing (n > k) but the case of n = k was considered always legal w.r.t. shape. This was inconsistent -- why should rank-preserving shape_casts be allowed to be arbitrary reshapes?

With this PR, rank-preserving shape_casts are only legal if they insert/remove dimensions of size 1. For example <1x4xf32> -> <4x1xf32> is valid, but the example above is not. This makes it consistent with the n < k and n > k cases.

This PR also improves the error messages generated with emitOpError.

One alternative to this PR: make all shape_casts valid (as long as element type and number of elements are unchanged). This would be nice and simple, and wouldn't cause any problems lowering to LLVM afaict?

@newling newling changed the title [mlir][vector] Consistently handle rank-preserving shape_cast [mlir][vector] Better handle rank-preserving shape_cast Apr 15, 2025
Copy link

github-actions bot commented Apr 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Before this PR, the following operation

%1 = vector.shape_cast %0 : vector&lt;3x2xf32&gt; to vector&lt;2x3xf32&gt;

was valid. There were checks that n-d to k-d shape_casts are either strictly expanding (n < k) or collapsing (n > k) but the case of n = k was considered always legal w.r.t. shape. This was inconsistent -- why should rank-preserving shape_casts be allowed to be arbitrary reshapes?

With this PR, rank-preserving shape_casts are only legal if they insert/remove dimensions of size 1. For example &lt;1x4xf32&gt; -&gt; &lt;4x1xf32&gt; is valid, but the example above is not. This makes it consistent with the n &lt; k and n &gt; k cases.

This PR also improves the error messages generated with emitOpError.

One alternative to this PR: make all shape_casts valid (as long as element type and number of elements are unchanged). This would be nice and simple, and wouldn't cause any problems lowering to LLVM afaict?


Patch is 24.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135855.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+10-8)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+50-57)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14-12)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+11-4)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+28)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir (+29-29)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (-21)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7fc56b1aa4e7e..7d5b5048131d8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,20 @@ def Vector_ShapeCastOp :
     Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "shape_cast casts between vector shapes";
   let description = [{
-    The shape_cast operation casts between an n-D source vector shape and
-    a k-D result vector shape (the element type remains the same).
+    The shape_cast operation casts from an n-D source vector to a k-D result
+    vector. The element type remains the same, as does the number of elements
+    (product of dimensions).
+
+    A shape_cast must be either collapsing or expanding. Collapsing means all
+    result dimension sizes are products of contiguous source dimension sizes.
+    Expanding means source dimensions all factor into contiguous sequences of
+    destination dimension sizes. Size 1 dimensions in source and destination
+    are ignored.
 
-    If reducing rank (n > k), result dimension sizes must be a product
-    of contiguous source dimension sizes.
-    If expanding rank (n < k), source dimensions must factor into a
-    contiguous sequence of destination dimension sizes.
     Each source dim is expanded (or contiguous sequence of source dims combined)
     in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
     sequence of result dims (or a single result dim), in result dimension list
-    order (i.e. 0 <= j < k). The product of all source dimension sizes and all
-    result dimension sizes must match.
+    order (i.e. 0 <= j < k).
 
     It is currently assumed that this operation does not require moving data,
     and that it will be folded away before lowering vector operations.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..20162e93c88e8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5532,40 +5532,38 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
-  unsigned rankA = a.size();
-  unsigned rankB = b.size();
-  assert(rankA < rankB);
-
-  auto isOne = [](int64_t v) { return v == 1; };
-
-  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
-  // casted to a 0-d vector.
-  if (rankA == 0 && llvm::all_of(b, isOne))
-    return true;
+/// Returns true if each element of 'a' is either 1 or equal to the product of a
+/// contiguous sequence of the elements of 'b'. Returns false otherwise.
+///
+/// This function assumes that the product of elements in a and b are the same.
+static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
 
+  unsigned rankA = a.size();
   unsigned i = 0;
   unsigned j = 0;
-  while (i < rankA && j < rankB) {
+  while (i < rankA) {
+    if (a[i] == 1) {
+      ++i;
+      continue;
+    }
+
     int64_t dimA = a[i];
     int64_t dimB = 1;
-    while (dimB < dimA && j < rankB)
+
+    while (dimB < dimA) {
       dimB *= b[j++];
-    if (dimA != dimB)
-      break;
-    ++i;
+    }
 
-    // Handle the case when trailing dimensions are of size 1.
-    // Include them into the contiguous sequence.
-    if (i < rankA && llvm::all_of(a.slice(i), isOne))
-      i = rankA;
-    if (j < rankB && llvm::all_of(b.slice(j), isOne))
-      j = rankB;
+    if (dimA != dimB) {
+      return false;
+    }
+    ++i;
   }
+  return true;
+}
 
-  return i == rankA && j == rankB;
+static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+  return isExpandingShapeCast(a, b) || isExpandingShapeCast(b, a);
 }
 
 static LogicalResult verifyVectorShapeCast(Operation *op,
@@ -5573,34 +5571,33 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
                                            VectorType resultVectorType) {
   // Check that element type is the same.
   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
-    return op->emitOpError("source/result vectors must have same element type");
-  auto sourceShape = sourceVectorType.getShape();
-  auto resultShape = resultVectorType.getShape();
-
-  // Check that product of source dim sizes matches product of result dim sizes.
-  int64_t sourceDimProduct = std::accumulate(
-      sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
-  int64_t resultDimProduct = std::accumulate(
-      resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
-  if (sourceDimProduct != resultDimProduct)
-    return op->emitOpError("source/result number of elements must match");
-
-  // Check that expanding/contracting rank cases.
-  unsigned sourceRank = sourceVectorType.getRank();
-  unsigned resultRank = resultVectorType.getRank();
-  if (sourceRank < resultRank) {
-    if (!isValidShapeCast(sourceShape, resultShape))
-      return op->emitOpError("invalid shape cast");
-  } else if (sourceRank > resultRank) {
-    if (!isValidShapeCast(resultShape, sourceShape))
-      return op->emitOpError("invalid shape cast");
+    return op->emitOpError("has different source and result element types");
+  ArrayRef<int64_t> inShape = sourceVectorType.getShape();
+  ArrayRef<int64_t> outShape = resultVectorType.getShape();
+
+  // Check that product of source dim sizes matches product of result dim
+  // sizes.
+  int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
+                                    std::multiplies<int64_t>{});
+  int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
+                                     std::multiplies<int64_t>{});
+
+  if (nInElms != nOutElms) {
+    return op->emitOpError(
+        "has a different number of source and result elements");
+  }
+
+  if (!isValidShapeCast(inShape, outShape)) {
+    return op->emitOpError(
+        "is invalid (does not uniformly collapse or expand)");
   }
 
   // Check that (non-)scalability is preserved
   int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
   int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
   if (sourceNScalableDims != resultNScalableDims)
-    return op->emitOpError("different number of scalable dims at source (")
+    return op->emitOpError(
+               "has a different number of scalable dims at source (")
            << sourceNScalableDims << ") and result (" << resultNScalableDims
            << ")";
   sourceVectorType.getNumDynamicDims();
@@ -5634,17 +5631,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
     // Only allows valid transitive folding (expand/collapse dimensions).
     VectorType srcType = otherOp.getSource().getType();
+
     if (resultType == srcType)
       return otherOp.getSource();
-    if (srcType.getRank() < resultType.getRank()) {
-      if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
-        return {};
-    } else if (srcType.getRank() > resultType.getRank()) {
-      if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
-        return {};
-    } else {
+
+    if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
       return {};
-    }
+
     setOperand(otherOp.getSource());
     return getResult();
   }
@@ -6459,8 +6452,8 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
   OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
       MaskOp>::ensureTerminator(region, builder, loc);
-  // Keep the default yield terminator if the number of masked operations is not
-  // the expected. This case will trigger a verification failure.
+  // Keep the default yield terminator if the number of masked operations is
+  // not as expected. This case will trigger a verification failure.
   Block &block = region.front();
   if (block.getOperations().size() != 2)
     return;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..986e11d948052 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -950,14 +950,16 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
 
 // -----
 
+// The definition of shape_cast stipulates that it must be either expanding or collapsing,
+// it cannot be a mixture of both.
 // CHECK-LABEL: dont_fold_expand_collapse
-//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-//       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-//       CHECK:   return %[[B]] : vector<8x8xf32>
-func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
-    %0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
-    return %1 : vector<8x8xf32>
+//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+//       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+//       CHECK:   return %[[B]] : vector<4x3x3xf32>
+func.func @dont_fold_expand_collapse(%arg0: vector<2x2x9xf32>) -> vector<4x3x3xf32> {
+    %0 = vector.shape_cast %arg0 : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+    %1 = vector.shape_cast %0 : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+    return %1 : vector<4x3x3xf32>
 }
 
 // -----
@@ -1290,12 +1292,12 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
 // -----
 
 // CHECK-LABEL: consecutive_shape_cast
-//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
-//  CHECK-NEXT:   return %[[C]] : vector<4x4xf16>
-func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
+//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<2x2x4xf16>
+//  CHECK-NEXT:   return %[[C]] : vector<2x2x4xf16>
+func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<2x2x4xf16> {
   %0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
-  %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
-  return %1 : vector<4x4xf16>
+  %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<2x2x4xf16>
+  return %1 : vector<2x2x4xf16>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index dbf829e014b8d..45b7b44d47039 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
 // -----
 
 func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error@+1 {{op source/result vectors must have same element type}}
+  // expected-error@+1 {{op has different source and result element types}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
 }
 
 // -----
 
 func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error@+1 {{op source/result number of elements must match}}
+  // expected-error@+1 {{op has a different number of source and result elements}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
 }
 
 // -----
 
+func.func @shape_cast_invalid_rank_preserving(%arg0 : vector<3x2xf32>) {
+  // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
+  %0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32>
+}
+
+// -----
+
 func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error@+1 {{invalid shape cast}}
+  // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
 }
 
 // -----
 
 func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
-  // expected-error@+1 {{invalid shape cast}}
+  // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
   %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
 }
 
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..504c6c300e9f0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -576,6 +576,34 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
   return %1 : vector<1x1x1x1xf32>
 }
 
+// CHECK-LABEL: @shape_cast_rank_preserving
+func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+  %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
+
+  return %0 : vector<4x1xf32>
+}
+
+
+// CHECK-LABEL: @collapse_but_increase_rank
+func.func @collapse_but_increase_rank(%arg0 : vector<2x3x5x7xf32>) -> vector<1x6x1x35x1xf32> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+  %0 = vector.shape_cast %arg0 : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+
+  return %0 : vector<1x6x1x35x1xf32>
+}
+
+// CHECK-LABEL: @expand_but_decrease_rank
+func.func @expand_but_decrease_rank(%arg0 : vector<1x1x6xi8>) -> vector<2x3xi8> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<1x1x6xi8> to vector<2x3xi8>
+  %0 = vector.shape_cast %arg0 : vector<1x1x6xi8> to vector<2x3xi8>
+
+  return %0 : vector<2x3xi8>
+}
+
 // CHECK-LABEL: @bitcast
 func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xi32>,
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
index f4becad3c79c1..2faa47c1b08a8 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -74,23 +74,23 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
 
 // CHECK-LABEL: f32_permute_leading_non_scalable_dims
 // CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
-func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
-  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[4]xf32>
+func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<1x6x[4]xf32> {
+  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [0, 2] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [0, 3] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [0, 4] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
-  %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32>
-  // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32>
-  return %res : vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [0, 5] : vector<[4]xf32> into vector<1x6x[4]xf32>
+  %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<1x6x[4]xf32>
+  // CHECK-NEXT: return %[[res5]] : vector<1x6x[4]xf32>
+  return %res : vector<1x6x[4]xf32>
 }
 
 // -----
@@ -117,48 +117,48 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
 
 // CHECK-LABEL: f32_reduce_trailing_scalable_dim
 // CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
-func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
+func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<3x2x[2]xf32>
 {
-  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<6x[2]xf32>
+  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32>
-  %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32>
-  // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32>
-  return %res: vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[2]xf32> into ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Before this PR, the following operation

%1 = vector.shape_cast %0 : vector&lt;3x2xf32&gt; to vector&lt;2x3xf32&gt;

was valid. There were checks that n-d to k-d shape_casts are either strictly expanding (n < k) or collapsing (n > k) but the case of n = k was considered always legal w.r.t. shape. This was inconsistent -- why should rank-preserving shape_casts be allowed to be arbitrary reshapes?

With this PR, rank-preserving shape_casts are only legal if they insert/remove dimensions of size 1. For example &lt;1x4xf32&gt; -&gt; &lt;4x1xf32&gt; is valid, but the example above is not. This makes it consistent with the n &lt; k and n &gt; k cases.

This PR also improves the error messages generated with emitOpError.

One alternative to this PR: make all shape_casts valid (as long as element type and number of elements are unchanged). This would be nice and simple, and wouldn't cause any problems lowering to LLVM afaict?


Patch is 24.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135855.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+10-8)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+50-57)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14-12)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+11-4)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+28)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir (+29-29)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (-21)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7fc56b1aa4e7e..7d5b5048131d8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,20 @@ def Vector_ShapeCastOp :
     Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "shape_cast casts between vector shapes";
   let description = [{
-    The shape_cast operation casts between an n-D source vector shape and
-    a k-D result vector shape (the element type remains the same).
+    The shape_cast operation casts from an n-D source vector to a k-D result
+    vector. The element type remains the same, as does the number of elements
+    (product of dimensions).
+
+    A shape_cast must be either collapsing or expanding. Collapsing means all
+    result dimension sizes are products of contiguous source dimension sizes.
+    Expanding means source dimensions all factor into contiguous sequences of
+    destination dimension sizes. Size 1 dimensions in source and destination
+    are ignored.
 
-    If reducing rank (n > k), result dimension sizes must be a product
-    of contiguous source dimension sizes.
-    If expanding rank (n < k), source dimensions must factor into a
-    contiguous sequence of destination dimension sizes.
     Each source dim is expanded (or contiguous sequence of source dims combined)
     in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
     sequence of result dims (or a single result dim), in result dimension list
-    order (i.e. 0 <= j < k). The product of all source dimension sizes and all
-    result dimension sizes must match.
+    order (i.e. 0 <= j < k).
 
     It is currently assumed that this operation does not require moving data,
     and that it will be folded away before lowering vector operations.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..20162e93c88e8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5532,40 +5532,38 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
-  unsigned rankA = a.size();
-  unsigned rankB = b.size();
-  assert(rankA < rankB);
-
-  auto isOne = [](int64_t v) { return v == 1; };
-
-  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
-  // casted to a 0-d vector.
-  if (rankA == 0 && llvm::all_of(b, isOne))
-    return true;
+/// Returns true if each element of 'a' is either 1 or equal to the product of a
+/// contiguous sequence of the elements of 'b'. Returns false otherwise.
+///
+/// This function assumes that the product of elements in a and b are the same.
+static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
 
+  unsigned rankA = a.size();
   unsigned i = 0;
   unsigned j = 0;
-  while (i < rankA && j < rankB) {
+  while (i < rankA) {
+    if (a[i] == 1) {
+      ++i;
+      continue;
+    }
+
     int64_t dimA = a[i];
     int64_t dimB = 1;
-    while (dimB < dimA && j < rankB)
+
+    while (dimB < dimA) {
       dimB *= b[j++];
-    if (dimA != dimB)
-      break;
-    ++i;
+    }
 
-    // Handle the case when trailing dimensions are of size 1.
-    // Include them into the contiguous sequence.
-    if (i < rankA && llvm::all_of(a.slice(i), isOne))
-      i = rankA;
-    if (j < rankB && llvm::all_of(b.slice(j), isOne))
-      j = rankB;
+    if (dimA != dimB) {
+      return false;
+    }
+    ++i;
   }
+  return true;
+}
 
-  return i == rankA && j == rankB;
+static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+  return isExpandingShapeCast(a, b) || isExpandingShapeCast(b, a);
 }
 
 static LogicalResult verifyVectorShapeCast(Operation *op,
@@ -5573,34 +5571,33 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
                                            VectorType resultVectorType) {
   // Check that element type is the same.
   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
-    return op->emitOpError("source/result vectors must have same element type");
-  auto sourceShape = sourceVectorType.getShape();
-  auto resultShape = resultVectorType.getShape();
-
-  // Check that product of source dim sizes matches product of result dim sizes.
-  int64_t sourceDimProduct = std::accumulate(
-      sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
-  int64_t resultDimProduct = std::accumulate(
-      resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
-  if (sourceDimProduct != resultDimProduct)
-    return op->emitOpError("source/result number of elements must match");
-
-  // Check that expanding/contracting rank cases.
-  unsigned sourceRank = sourceVectorType.getRank();
-  unsigned resultRank = resultVectorType.getRank();
-  if (sourceRank < resultRank) {
-    if (!isValidShapeCast(sourceShape, resultShape))
-      return op->emitOpError("invalid shape cast");
-  } else if (sourceRank > resultRank) {
-    if (!isValidShapeCast(resultShape, sourceShape))
-      return op->emitOpError("invalid shape cast");
+    return op->emitOpError("has different source and result element types");
+  ArrayRef<int64_t> inShape = sourceVectorType.getShape();
+  ArrayRef<int64_t> outShape = resultVectorType.getShape();
+
+  // Check that product of source dim sizes matches product of result dim
+  // sizes.
+  int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
+                                    std::multiplies<int64_t>{});
+  int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
+                                     std::multiplies<int64_t>{});
+
+  if (nInElms != nOutElms) {
+    return op->emitOpError(
+        "has a different number of source and result elements");
+  }
+
+  if (!isValidShapeCast(inShape, outShape)) {
+    return op->emitOpError(
+        "is invalid (does not uniformly collapse or expand)");
   }
 
   // Check that (non-)scalability is preserved
   int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
   int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
   if (sourceNScalableDims != resultNScalableDims)
-    return op->emitOpError("different number of scalable dims at source (")
+    return op->emitOpError(
+               "has a different number of scalable dims at source (")
            << sourceNScalableDims << ") and result (" << resultNScalableDims
            << ")";
   sourceVectorType.getNumDynamicDims();
@@ -5634,17 +5631,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
     // Only allows valid transitive folding (expand/collapse dimensions).
     VectorType srcType = otherOp.getSource().getType();
+
     if (resultType == srcType)
       return otherOp.getSource();
-    if (srcType.getRank() < resultType.getRank()) {
-      if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
-        return {};
-    } else if (srcType.getRank() > resultType.getRank()) {
-      if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
-        return {};
-    } else {
+
+    if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
       return {};
-    }
+
     setOperand(otherOp.getSource());
     return getResult();
   }
@@ -6459,8 +6452,8 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
   OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
       MaskOp>::ensureTerminator(region, builder, loc);
-  // Keep the default yield terminator if the number of masked operations is not
-  // the expected. This case will trigger a verification failure.
+  // Keep the default yield terminator if the number of masked operations is
+  // not as expected. This case will trigger a verification failure.
   Block &block = region.front();
   if (block.getOperations().size() != 2)
     return;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..986e11d948052 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -950,14 +950,16 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
 
 // -----
 
+// The definition of shape_cast stipulates that it must be either expanding or collapsing,
+// it cannot be a mixture of both.
 // CHECK-LABEL: dont_fold_expand_collapse
-//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-//       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-//       CHECK:   return %[[B]] : vector<8x8xf32>
-func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
-    %0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
-    return %1 : vector<8x8xf32>
+//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+//       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+//       CHECK:   return %[[B]] : vector<4x3x3xf32>
+func.func @dont_fold_expand_collapse(%arg0: vector<2x2x9xf32>) -> vector<4x3x3xf32> {
+    %0 = vector.shape_cast %arg0 : vector<2x2x9xf32> to vector<2x2x3x3xf32>
+    %1 = vector.shape_cast %0 : vector<2x2x3x3xf32> to vector<4x3x3xf32>
+    return %1 : vector<4x3x3xf32>
 }
 
 // -----
@@ -1290,12 +1292,12 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
 // -----
 
 // CHECK-LABEL: consecutive_shape_cast
-//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
-//  CHECK-NEXT:   return %[[C]] : vector<4x4xf16>
-func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
+//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<2x2x4xf16>
+//  CHECK-NEXT:   return %[[C]] : vector<2x2x4xf16>
+func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<2x2x4xf16> {
   %0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
-  %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
-  return %1 : vector<4x4xf16>
+  %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<2x2x4xf16>
+  return %1 : vector<2x2x4xf16>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index dbf829e014b8d..45b7b44d47039 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
 // -----
 
 func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error@+1 {{op source/result vectors must have same element type}}
+  // expected-error@+1 {{op has different source and result element types}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
 }
 
 // -----
 
 func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error@+1 {{op source/result number of elements must match}}
+  // expected-error@+1 {{op has a different number of source and result elements}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
 }
 
 // -----
 
+func.func @shape_cast_invalid_rank_preserving(%arg0 : vector<3x2xf32>) {
+  // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
+  %0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32>
+}
+
+// -----
+
 func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error@+1 {{invalid shape cast}}
+  // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
 }
 
 // -----
 
 func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
-  // expected-error@+1 {{invalid shape cast}}
+  // expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
   %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
 }
 
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..504c6c300e9f0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -576,6 +576,34 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
   return %1 : vector<1x1x1x1xf32>
 }
 
+// CHECK-LABEL: @shape_cast_rank_preserving
+func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+  %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
+
+  return %0 : vector<4x1xf32>
+}
+
+
+// CHECK-LABEL: @collapse_but_increase_rank
+func.func @collapse_but_increase_rank(%arg0 : vector<2x3x5x7xf32>) -> vector<1x6x1x35x1xf32> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+  %0 = vector.shape_cast %arg0 : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
+
+  return %0 : vector<1x6x1x35x1xf32>
+}
+
+// CHECK-LABEL: @expand_but_decrease_rank
+func.func @expand_but_decrease_rank(%arg0 : vector<1x1x6xi8>) -> vector<2x3xi8> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<1x1x6xi8> to vector<2x3xi8>
+  %0 = vector.shape_cast %arg0 : vector<1x1x6xi8> to vector<2x3xi8>
+
+  return %0 : vector<2x3xi8>
+}
+
 // CHECK-LABEL: @bitcast
 func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xi32>,
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
index f4becad3c79c1..2faa47c1b08a8 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -74,23 +74,23 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
 
 // CHECK-LABEL: f32_permute_leading_non_scalable_dims
 // CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
-func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
-  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[4]xf32>
+func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<1x6x[4]xf32> {
+  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [0, 2] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [0, 3] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [0, 4] : vector<[4]xf32> into vector<1x6x[4]xf32>
   // CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
-  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
-  %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32>
-  // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32>
-  return %res : vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [0, 5] : vector<[4]xf32> into vector<1x6x[4]xf32>
+  %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<1x6x[4]xf32>
+  // CHECK-NEXT: return %[[res5]] : vector<1x6x[4]xf32>
+  return %res : vector<1x6x[4]xf32>
 }
 
 // -----
@@ -117,48 +117,48 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
 
 // CHECK-LABEL: f32_reduce_trailing_scalable_dim
 // CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
-func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
+func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<3x2x[2]xf32>
 {
-  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<6x[2]xf32>
+  // CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
   // CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32>
-  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32>
-  %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32>
-  // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32>
-  return %res: vector<6x[2]xf32>
+  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[2]xf32> into ...
[truncated]

@dcaballe
Copy link
Contributor

Hey, thanks for looking into this! What is the issue you are trying to address? The shape cast op in your example seems valid to me. Shape cast is a generic reshape op that allows rank increasing/decreasing/preserving transformations.

@newling
Copy link
Contributor Author

newling commented Apr 16, 2025

Shape cast is a generic reshape op that allows rank increasing/decreasing/preserving transformations.

Unfortunately it isn't. The description here says it cannot mix dimensions, so it is like the union of the tensor dialect's expand_shape and collapse_shape

Perhaps it should be like tensor dialects reshape, I would be happy to go down that path instead. But this PR is to make it consistent with how it's currently defined.

Example currently:

// error: 'vector.shape_cast' op invalid shape cast
func.func @invalid_0(%arg : vector<4x3xi32>) -> vector<2x3x2xi32> {
  %1 = vector.shape_cast %arg : vector<4x3xi32> to vector<2x3x2xi32>
  return %1 : vector<2x3x2xi32>
}

func.func @valid_0(%arg : vector<4x3x1xi32>) -> vector<2x3x2xi32> {
  %1 = vector.shape_cast %arg : vector<4x3x1xi32> to vector<2x3x2xi32>
  return %1 : vector<2x3x2xi32>
}

IMO it doesn't make sense that the second one above verifies, but the first one doesn't -- both mix dimensions.

The reason valid_0 doesn't fail to verify is because the current logic fails to consider the case where the rank is unchanged (see here. All I can think is that the initial commit of this op just forgot to consider this case.

@newling
Copy link
Contributor Author

newling commented Apr 21, 2025

Shape cast is a generic reshape op that allows rank increasing/decreasing/preserving transformations.

The changes needed to make this true: #136587

@newling
Copy link
Contributor Author

newling commented Apr 22, 2025

As suggested in a comment in #136587, that PR is preferred so I'm closing this one

thanks for the initial feedback @dcaballe ... you're not alone thinking that shape_cast can do any reshape, seems like a unanimous view!

@newling newling closed this Apr 22, 2025
newling added a commit that referenced this pull request May 1, 2025
`vector.shape_cast` was initially designed to be the union of
collapse_shape and expand_shape. There was an inconsistency in the
verifier that allowed any shape casts when the rank did not change, which
led to a strange middle ground where you could cast from shape (4,3) to
(3,4) but not from (4,3) to (2,3,2). That issue was fixed (verifier made stricter) 
in #135855, but further feedback
there (and polling) suggests that vector.shape_cast should rather allow all 
shape casts (so more like tensor.reshape than 
tensor.collapse_shape/tensor.expand_shape). This PR makes this simplification
by relaxing the verifier.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants