Skip to content

Commit 5fc7342

Browse files
committed
[mlir][linalg][nfc] Fix linalg.matmul_transpose_a def.
The `matmul_transpose_a` input data format should be `KxM * KxN` instead of current `KxN * KxM` format. It's a NFC fix.
1 parent 173514d commit 5fc7342

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1336,7 +1336,7 @@ structured_op: !LinalgStructuredOpConfig
13361336
name: C
13371337
kind: output_tensor
13381338
type_var: U
1339-
shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
1339+
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
13401340
- !LinalgOperandDefConfig
13411341
name: cast
13421342
kind: type_fn_attr

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,8 @@ def quantized_matmul(
429429

430430
@linalg_structured_op
431431
def matmul_transpose_a(
432-
A=TensorDef(T1, S.K, S.N),
433-
B=TensorDef(T2, S.K, S.M),
432+
A=TensorDef(T1, S.K, S.M),
433+
B=TensorDef(T2, S.K, S.N),
434434
C=TensorDef(U, S.M, S.N, output=True),
435435
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
436436
):

0 commit comments

Comments
 (0)