Skip to content

Commit 8e87a7f

Browse files
[MLIR][ArmSVE] Add initial lowering of vector.contract to SVE *MMLA instructions
1 parent 5e91c2e commit 8e87a7f

File tree

16 files changed

+1322
-1
lines changed

16 files changed

+1322
-1
lines changed

mlir/include/mlir/Conversion/Passes.td

+4
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14061406
"bool", /*default=*/"false",
14071407
"Enables the use of ArmSVE dialect while lowering the vector "
14081408
"dialect.">,
1409+
Option<"armI8MM", "enable-arm-i8mm",
1410+
"bool", /*default=*/"false",
1411+
"Enables the use of Arm FEAT_I8MM instructions while lowering "
1412+
"the vector dialect.">,
14091413
Option<"x86Vector", "enable-x86vector",
14101414
"bool", /*default=*/"false",
14111415
"Enables the use of X86Vector dialect while lowering the vector "

mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class RewritePatternSet;
2020
void populateArmSVELegalizeForLLVMExportPatterns(
2121
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
2222

23+
void populateLowerContractionToSVEI8MMPatternPatterns(
24+
RewritePatternSet &patterns);
25+
2326
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
2427
/// intrinsics.
2528
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);

mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
3535
MLIRVectorToLLVM
3636

3737
MLIRArmNeonDialect
38+
MLIRArmNeonTransforms
3839
MLIRArmSVEDialect
3940
MLIRArmSVETransforms
4041
MLIRAMXDialect

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/AMX/Transforms.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
17+
#include "mlir/Dialect/ArmNeon/Transforms.h"
1718
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
1819
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
1920
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8283
populateVectorStepLoweringPatterns(patterns);
8384
populateVectorRankReducingFMAPattern(patterns);
8485
populateVectorGatherLoweringPatterns(patterns);
86+
if (armI8MM) {
87+
if (armNeon)
88+
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
89+
if (armSVE)
90+
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
91+
}
8592
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
8693
}
8794

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
5656
// Avoid 0-D vectors and 1-D rhs:
5757
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
5858
return failure();
59+
// Avoid scalable vectors.
60+
if (lhsType.isScalable() || rhsType.isScalable())
61+
return failure();
5962
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
6063
auto dimN = rhsType.getDimSize(0);
6164
auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
238241
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
239242
RewritePatternSet &patterns) {
240243
MLIRContext *context = patterns.getContext();
241-
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
244+
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
242245
}

mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRArmSVETransforms
22
LegalizeForLLVMExport.cpp
33
LegalizeVectorStorage.cpp
4+
LowerContractionToSVEI8MMPattern.cpp
45

56
DEPENDS
67
MLIRArmSVEConversionsIncGen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements lowering patterns from vector.contract to
10+
// SVE I8MM operations.
11+
//
12+
//===---
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
16+
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19+
#include "mlir/Dialect/Utils/IndexingUtils.h"
20+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
21+
#include "mlir/IR/AffineMap.h"
22+
#include "mlir/IR/PatternMatch.h"
23+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
25+
#include "mlir/Dialect/UB/IR/UBOps.h"
26+
27+
#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
28+
29+
using namespace mlir;
30+
using namespace mlir::arm_sve;
31+
32+
namespace {
33+
// Check if the given value is a result of the operation `T` (which must be
34+
// sign- or zero- extend) from i8 to i32. Return the value before the extension.
35+
template <typename T>
36+
inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
37+
std::is_base_of_v<arith::ExtUIOp, T>),
38+
std::optional<Value>>
39+
extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
40+
auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
41+
if (!extOp)
42+
return {};
43+
44+
auto inOp = extOp.getIn();
45+
auto inTy = dyn_cast<VectorType>(inOp.getType());
46+
if (!inTy || inTy.getElementType() != i8Ty)
47+
return {};
48+
49+
auto outTy = dyn_cast<VectorType>(extOp.getType());
50+
if (!outTy || outTy.getElementType() != i32Ty)
51+
return {};
52+
53+
return inOp;
54+
}
55+
56+
// Designate the operation (resp. instruction) used to do sub-tile matrix
57+
// multiplications.
58+
enum class MMLA {
59+
Signed, // smmla
60+
Unsigned, // ummla
61+
Mixed, // usmmla
62+
MixedSwapped // usmmla with LHS and RHS swapped
63+
};
64+
65+
// Create the matrix multply and accumulate operation according to `op`.
66+
Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
67+
mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
68+
switch (op) {
69+
case MMLA::Signed:
70+
return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
71+
case MMLA::Unsigned:
72+
return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
73+
case MMLA::Mixed:
74+
return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
75+
case MMLA::MixedSwapped:
76+
// The accumulator comes transposed and the result will be transposed
77+
// later, so all we have to do here is swap the operands.
78+
return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
79+
}
80+
}
81+
82+
class LowerContractionToSVEI8MMPattern
83+
: public OpRewritePattern<vector::ContractionOp> {
84+
public:
85+
using OpRewritePattern::OpRewritePattern;
86+
LogicalResult matchAndRewrite(vector::ContractionOp op,
87+
PatternRewriter &rewriter) const override {
88+
89+
Location loc = op.getLoc();
90+
mlir::VectorType lhsType = op.getLhsType();
91+
mlir::VectorType rhsType = op.getRhsType();
92+
93+
// For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
94+
// eventually expect from MMT4D. M and N dimensions must be even and at
95+
// least 2.
96+
if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
97+
rhsType.getRank() != 2)
98+
return failure();
99+
100+
if (lhsType.isScalable() || !rhsType.isScalable())
101+
return failure();
102+
103+
// M, N, and K are the conventional names for matrix dimensions in the
104+
// context of matrix multiplication.
105+
auto M = lhsType.getDimSize(0);
106+
auto N = rhsType.getDimSize(0);
107+
auto K = rhsType.getDimSize(1);
108+
109+
if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
110+
N % 2 != 0 || !rhsType.getScalableDims()[0])
111+
return failure();
112+
113+
// Check permutation maps. For now only accept
114+
// lhs: (d0, d1, d2) -> (d0, d2)
115+
// rhs: (d0, d1, d2) -> (d1, d2)
116+
// acc: (d0, d1, d2) -> (d0, d1)
117+
// Note: RHS is transposed.
118+
if (op.getIndexingMapsArray()[0] !=
119+
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
120+
op.getContext()) ||
121+
op.getIndexingMapsArray()[1] !=
122+
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
123+
op.getContext()) ||
124+
op.getIndexingMapsArray()[2] !=
125+
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
126+
op.getContext()))
127+
return failure();
128+
129+
// Check iterator types for matrix multiplication.
130+
auto itTypes = op.getIteratorTypesArray();
131+
if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
132+
itTypes[1] != vector::IteratorType::parallel ||
133+
itTypes[2] != vector::IteratorType::reduction)
134+
return failure();
135+
136+
// Check the combining kind is addition.
137+
if (op.getKind() != vector::CombiningKind::ADD)
138+
return failure();
139+
140+
// Check the output is a vector of i32 elements.
141+
auto outTy = dyn_cast<VectorType>(op.getType());
142+
if (!outTy || outTy.getElementType() != rewriter.getI32Type())
143+
return failure();
144+
145+
// Check inputs are sign-/zero- extensions from i8 to i32. Get the values
146+
// before the extension. All four signed/unsigned combinations for input
147+
// operands are supported, but they are lowered to different operations.
148+
// Determina which is the appropriate operation to lower to.
149+
MMLA mmlaOp = MMLA::Signed;
150+
auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
151+
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
152+
if (!maybeLhs) {
153+
mmlaOp = MMLA::Unsigned;
154+
maybeLhs = extractExtOperand<arith::ExtUIOp>(
155+
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
156+
}
157+
if (!maybeLhs)
158+
return failure();
159+
160+
auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
161+
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
162+
if (maybeRhs) {
163+
if (mmlaOp == MMLA::Unsigned)
164+
mmlaOp = MMLA::Mixed;
165+
} else {
166+
if (mmlaOp == MMLA::Signed)
167+
mmlaOp = MMLA::MixedSwapped;
168+
maybeRhs = extractExtOperand<arith::ExtUIOp>(
169+
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
170+
}
171+
if (!maybeRhs)
172+
return failure();
173+
174+
// One-dimensional vector types for arm_sve.*mmla
175+
auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
176+
auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
177+
178+
// Extract LHS sub-tiles.
179+
SmallVector<Value> lhsTile;
180+
for (int64_t i = 0; i < M; i += 2) {
181+
// Exract two consective rows of the LHS tile.
182+
auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
183+
ArrayRef<int64_t>{i});
184+
auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
185+
ArrayRef<int64_t>{i + 1});
186+
// Concatenate to obtain a 16 x i8 flattened sub-tile.
187+
auto t = rewriter.create<vector::ShuffleOp>(
188+
loc, r0, r1,
189+
llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
190+
14, 15});
191+
// Turn it into a scalable vector.
192+
auto s = rewriter.create<vector::ScalableInsertOp>(
193+
loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
194+
// Replicate the sub-tile VSCALE times to fill the entire vector.
195+
auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
196+
lhsTile.push_back(r);
197+
}
198+
199+
// "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
200+
auto RHS = rewriter.create<vector::ShapeCastOp>(
201+
maybeRhs->getLoc(),
202+
VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);
203+
204+
// Extract the RHS sub-tiles.
205+
SmallVector<Value> rhsTile;
206+
for (int64_t j = 0; j < N; j += 2)
207+
rhsTile.push_back(
208+
rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
209+
210+
// Handy types for packing/unpacking of the accumulator tile.
211+
auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
212+
auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
213+
auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
214+
auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});
215+
216+
// Extract and pack the ACC sub-tiles.
217+
SmallVector<Value> accTile;
218+
for (int64_t i = 0; i < M; i += 2) {
219+
// Extract two consecutive rows of the accumulator tile.
220+
auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
221+
ArrayRef<int64_t>{i});
222+
auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
223+
ArrayRef<int64_t>{i + 1});
224+
Value accTileVec;
225+
if (mmlaOp == MMLA::MixedSwapped) {
226+
// We need to swap the positions of the LHS and RHS (since we don't have
227+
// a signed * unsigned operation), but then each individual 2x2 tile of
228+
// the acumulator and (later) the result need to be transposed.
229+
accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
230+
} else {
231+
// Bitcast them to 64-bit elements, so subsequent
232+
// interleave/deinterleave work on pairs of 32-bit numbers.
233+
auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
234+
auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
235+
236+
// Interleave the rows, effectively flattening each 2x2 tile into 4
237+
// consecutive elements.
238+
auto intr_i64 =
239+
rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);
240+
241+
// Bitcast back to 32-bit elements.
242+
accTileVec =
243+
rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
244+
}
245+
// Extract ACC sub-tiles.
246+
for (int64_t j = 0; j < N; j += 2)
247+
accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
248+
loc, nxv4i32, accTileVec, j * 2));
249+
}
250+
251+
// Emit sub-tile matrix multiplications.
252+
SmallVector<Value> outTile;
253+
for (int64_t i = 0; i < M / 2; ++i)
254+
for (int64_t j = 0; j < N / 2; ++j) {
255+
Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
256+
accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
257+
outTile.push_back(mmla);
258+
}
259+
260+
// Unpack the OUT sub-tiles and insert into the result.
261+
Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
262+
for (int64_t i = 0; i < M / 2; ++i) {
263+
// Collect a number of sub-tiles in a row.
264+
Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
265+
for (int64_t j = 0; j < N / 2; ++j)
266+
row = rewriter.create<vector::ScalableInsertOp>(
267+
loc, outTile[i * N / 2 + j], row, j * 4);
268+
269+
// Unpack the row to obtain two rows of the output. If we have the out
270+
// sub-tiles transposed we obtain two consecutive output rows by
271+
// separating even and odd elements, i.e. a simple deinterleave.
272+
// Otherwise, the interleave is by pairs.
273+
Value out0, out1;
274+
if (mmlaOp == MMLA::MixedSwapped) {
275+
auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
276+
out0 = tmp.getRes1();
277+
out1 = tmp.getRes2();
278+
} else {
279+
// Deinterleave by pairs.
280+
auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
281+
auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
282+
283+
// Bitcast back into 32-bit elements and insert into the result.
284+
out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
285+
deintr64.getRes1());
286+
out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
287+
deintr64.getRes2());
288+
}
289+
result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
290+
result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
291+
}
292+
293+
rewriter.replaceOp(op, result);
294+
return success();
295+
}
296+
};
297+
298+
} // namespace
299+
300+
void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
301+
RewritePatternSet &patterns) {
302+
MLIRContext *context = patterns.getContext();
303+
patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
304+
}

0 commit comments

Comments
 (0)