Skip to content

Commit 5812516

Browse files
alaa-aliAlaa Alijoker-eph
authored
[MLIR] Fix canonicalization pattern for 'shape.shape_of' (#134234)
This PR will fix a bug in a canonicalization pattern (operation shape.shape_of: shape of reshape) ``` // Before func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> { %reshape = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32> %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex> return %0 : tensor<3xindex> } //This is will error out as follows: error: 'tensor.cast' op operand type 'tensor<3xi32>' and result type 'tensor<3xindex>' are cast incompatible %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex> ^ note: see current operation: %0 = "tensor.cast"(%arg1) : (tensor<3xi32>) -> tensor<3xindex> ``` ``` // After func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> { %0 = arith.index_cast %arg1 : tensor<3xi32> to tensor<3xindex> return %0 : tensor<3xindex> } ``` See file canonicalize.mlir in the change list for an example. For the context, this bug was found while running a test on Keras 3, the canonicalizer errors out due to an invalid tensor.cast operation when the batch size is dynamic. The operands of the op are tensor<3xi32> cast to tensor<3xindex>. This change is related to a previous PR: #98531 --------- Co-authored-by: Alaa Ali <alaaali@ah-alaaali-l.dhcp.mathworks.com> Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
1 parent 4da5e9d commit 5812516

File tree

2 files changed

+60
-5
lines changed

2 files changed

+60
-5
lines changed

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,10 +1734,23 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
17341734
// Operand 'shape' of 'tensor.reshape' may now be used as the result of
17351735
// 'shape.shape_of'. While its type is guaranteed to be compatible in well-
17361736
// formed IR, it may not be identical (dynamically vs statically shaped),
1737-
// in which case it needs to be cast first.
1737+
// in which case it needs to be cast first using 'tensor.cast'.
1738+
// Additionally, it may not have identical element type (i32 vs index)
1739+
// while it has identical shaped type (dynamic vs static), in which case it
1740+
// needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1741+
// op result must be shape or extent tensor.
17381742
Value shape = tensorReshapeOp.getShape();
1739-
if (op.getType() != shape.getType())
1740-
shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
1743+
1744+
auto opTensorTy = cast<RankedTensorType>(op.getType());
1745+
auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
1746+
1747+
if (opTensorTy != shapeTensorTy) {
1748+
if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1749+
shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
1750+
else if (!isExtentTensorType(shapeTensorTy))
1751+
shape =
1752+
rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
1753+
}
17411754

17421755
rewriter.replaceOp(op, shape);
17431756
return success();

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,10 +1389,25 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -
13891389

13901390
// -----
13911391

1392-
// CHECK-LABEL: func @shape_of_from_reshape_compatible_types
1392+
// Check statically shaped types, with element types i32 to index.
1393+
// CHECK-LABEL: func @shape_of_from_reshape_int_to_index
1394+
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x1xf32>
1395+
// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
1396+
func.func @shape_of_from_reshape_int_to_index(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
1397+
// CHECK: %[[CAST_SHAPE:.*]] = arith.index_cast %[[SHAPE]] : tensor<3xi32> to tensor<3xindex>
1398+
// CHECK: return %[[CAST_SHAPE]] : tensor<3xindex>
1399+
%0 = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
1400+
%1 = shape.shape_of %0 : tensor<?x1x1xf32> -> tensor<3xindex>
1401+
return %1 : tensor<3xindex>
1402+
}
1403+
1404+
// -----
1405+
1406+
// Check similar element types, with statically shaped to dynamically shaped.
1407+
// CHECK-LABEL: func @shape_of_from_reshape_static_to_dynamic
13931408
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
13941409
// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
1395-
func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
1410+
func.func @shape_of_from_reshape_static_to_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
13961411
// CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor<?xindex>
13971412
// CHECK: return %[[CAST_SHAPE]] : tensor<?xindex>
13981413
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
@@ -1402,6 +1417,33 @@ func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: t
14021417

14031418
// -----
14041419

1420+
// Check similar element types, with dynamically shaped to statically shaped.
1421+
// CHECK-LABEL: func @shape_of_from_reshape_dynamic_to_static
1422+
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
1423+
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
1424+
func.func @shape_of_from_reshape_dynamic_to_static(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<5xindex> {
1425+
// CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<?xindex> to tensor<5xindex>
1426+
// CHECK: return %[[CAST_SHAPE]] : tensor<5xindex>
1427+
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1428+
%1 = shape.shape_of %0 : tensor<*xf32> -> tensor<5xindex>
1429+
return %1 : tensor<5xindex>
1430+
}
1431+
1432+
// -----
1433+
1434+
// Check similar element types and similar static shape.
1435+
// CHECK-LABEL: func @shape_of_from_reshape_identical_types
1436+
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
1437+
// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
1438+
func.func @shape_of_from_reshape_identical_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<5xindex> {
1439+
// CHECK: return %[[SHAPE]] : tensor<5xindex>
1440+
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
1441+
%1 = shape.shape_of %0 : tensor<*xf32> -> tensor<5xindex>
1442+
return %1 : tensor<5xindex>
1443+
}
1444+
1445+
// -----
1446+
14051447
// CHECK-LABEL: func @shape_of_from_reshape_nofold
14061448
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
14071449
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>

0 commit comments

Comments
 (0)