You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[MLIR] Fix canonicalization pattern for 'shape.shape_of' (#134234)
This PR will fix a bug in a canonicalization pattern (operation
shape.shape_of: shape of reshape)
```
// Before
func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
%reshape = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
%0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex>
return %0 : tensor<3xindex>
}
//This is will error out as follows:
error: 'tensor.cast' op operand type 'tensor<3xi32>' and result type 'tensor<3xindex>' are cast incompatible
%0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex>
^
note: see current operation: %0 = "tensor.cast"(%arg1) : (tensor<3xi32>) -> tensor<3xindex>
```
```
// After
func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
%0 = arith.index_cast %arg1 : tensor<3xi32> to tensor<3xindex>
return %0 : tensor<3xindex>
}
```
See file canonicalize.mlir in the change list for an example.
For the context, this bug was found while running a test on Keras 3, the
canonicalizer errors out due to an invalid tensor.cast operation when
the batch size is dynamic.
The operands of the op are tensor<3xi32> cast to tensor<3xindex>.
This change is related to a previous PR:
#98531
---------
Co-authored-by: Alaa Ali <alaaali@ah-alaaali-l.dhcp.mathworks.com>
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
0 commit comments