|
| 1 | +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This file implements lowering patterns from vector.contract to |
| 10 | +// SVE I8MM operations. |
| 11 | +// |
| 12 | +//===--- |
| 13 | + |
| 14 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" |
| 16 | +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" |
| 17 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 19 | +#include "mlir/Dialect/Utils/IndexingUtils.h" |
| 20 | +#include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 21 | +#include "mlir/IR/AffineMap.h" |
| 22 | +#include "mlir/IR/PatternMatch.h" |
| 23 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 24 | + |
| 25 | +#include "mlir/Dialect/UB/IR/UBOps.h" |
| 26 | + |
| 27 | +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" |
| 28 | + |
| 29 | +using namespace mlir; |
| 30 | +using namespace mlir::arm_sve; |
| 31 | + |
| 32 | +namespace { |
| 33 | +// Check if the given value is a result of the operation `T` (which must be |
| 34 | +// sign- or zero- extend) from i8 to i32. Return the value before the extension. |
| 35 | +template <typename T> |
| 36 | +inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> || |
| 37 | + std::is_base_of_v<arith::ExtUIOp, T>), |
| 38 | + std::optional<Value>> |
| 39 | +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { |
| 40 | + auto extOp = dyn_cast_or_null<T>(v.getDefiningOp()); |
| 41 | + if (!extOp) |
| 42 | + return {}; |
| 43 | + |
| 44 | + auto inOp = extOp.getIn(); |
| 45 | + auto inTy = dyn_cast<VectorType>(inOp.getType()); |
| 46 | + if (!inTy || inTy.getElementType() != i8Ty) |
| 47 | + return {}; |
| 48 | + |
| 49 | + auto outTy = dyn_cast<VectorType>(extOp.getType()); |
| 50 | + if (!outTy || outTy.getElementType() != i32Ty) |
| 51 | + return {}; |
| 52 | + |
| 53 | + return inOp; |
| 54 | +} |
| 55 | + |
| 56 | +// Designate the operation (resp. instruction) used to do sub-tile matrix |
| 57 | +// multiplications. |
| 58 | +enum class MMLA { |
| 59 | + Signed, // smmla |
| 60 | + Unsigned, // ummla |
| 61 | + Mixed, // usmmla |
| 62 | + MixedSwapped // usmmla with LHS and RHS swapped |
| 63 | +}; |
| 64 | + |
| 65 | +// Create the matrix multply and accumulate operation according to `op`. |
| 66 | +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, |
| 67 | + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { |
| 68 | + switch (op) { |
| 69 | + case MMLA::Signed: |
| 70 | + return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs); |
| 71 | + case MMLA::Unsigned: |
| 72 | + return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs); |
| 73 | + case MMLA::Mixed: |
| 74 | + return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs); |
| 75 | + case MMLA::MixedSwapped: |
| 76 | + // The accumulator comes transposed and the result will be transposed |
| 77 | + // later, so all we have to do here is swap the operands. |
| 78 | + return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs); |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +class LowerContractionToSVEI8MMPattern |
| 83 | + : public OpRewritePattern<vector::ContractionOp> { |
| 84 | +public: |
| 85 | + using OpRewritePattern::OpRewritePattern; |
| 86 | + LogicalResult matchAndRewrite(vector::ContractionOp op, |
| 87 | + PatternRewriter &rewriter) const override { |
| 88 | + |
| 89 | + Location loc = op.getLoc(); |
| 90 | + mlir::VectorType lhsType = op.getLhsType(); |
| 91 | + mlir::VectorType rhsType = op.getRhsType(); |
| 92 | + |
| 93 | + // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we |
| 94 | + // eventually expect from MMT4D. M and N dimensions must be even and at |
| 95 | + // least 2. |
| 96 | + if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || |
| 97 | + rhsType.getRank() != 2) |
| 98 | + return failure(); |
| 99 | + |
| 100 | + if (lhsType.isScalable() || !rhsType.isScalable()) |
| 101 | + return failure(); |
| 102 | + |
| 103 | + // M, N, and K are the conventional names for matrix dimensions in the |
| 104 | + // context of matrix multiplication. |
| 105 | + auto M = lhsType.getDimSize(0); |
| 106 | + auto N = rhsType.getDimSize(0); |
| 107 | + auto K = rhsType.getDimSize(1); |
| 108 | + |
| 109 | + if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || |
| 110 | + N % 2 != 0 || !rhsType.getScalableDims()[0]) |
| 111 | + return failure(); |
| 112 | + |
| 113 | + // Check permutation maps. For now only accept |
| 114 | + // lhs: (d0, d1, d2) -> (d0, d2) |
| 115 | + // rhs: (d0, d1, d2) -> (d1, d2) |
| 116 | + // acc: (d0, d1, d2) -> (d0, d1) |
| 117 | + // Note: RHS is transposed. |
| 118 | + if (op.getIndexingMapsArray()[0] != |
| 119 | + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, |
| 120 | + op.getContext()) || |
| 121 | + op.getIndexingMapsArray()[1] != |
| 122 | + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, |
| 123 | + op.getContext()) || |
| 124 | + op.getIndexingMapsArray()[2] != |
| 125 | + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, |
| 126 | + op.getContext())) |
| 127 | + return failure(); |
| 128 | + |
| 129 | + // Check iterator types for matrix multiplication. |
| 130 | + auto itTypes = op.getIteratorTypesArray(); |
| 131 | + if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || |
| 132 | + itTypes[1] != vector::IteratorType::parallel || |
| 133 | + itTypes[2] != vector::IteratorType::reduction) |
| 134 | + return failure(); |
| 135 | + |
| 136 | + // Check the combining kind is addition. |
| 137 | + if (op.getKind() != vector::CombiningKind::ADD) |
| 138 | + return failure(); |
| 139 | + |
| 140 | + // Check the output is a vector of i32 elements. |
| 141 | + auto outTy = dyn_cast<VectorType>(op.getType()); |
| 142 | + if (!outTy || outTy.getElementType() != rewriter.getI32Type()) |
| 143 | + return failure(); |
| 144 | + |
| 145 | + // Check inputs are sign-/zero- extensions from i8 to i32. Get the values |
| 146 | + // before the extension. All four signed/unsigned combinations for input |
| 147 | + // operands are supported, but they are lowered to different operations. |
| 148 | + // Determina which is the appropriate operation to lower to. |
| 149 | + MMLA mmlaOp = MMLA::Signed; |
| 150 | + auto maybeLhs = extractExtOperand<arith::ExtSIOp>( |
| 151 | + op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); |
| 152 | + if (!maybeLhs) { |
| 153 | + mmlaOp = MMLA::Unsigned; |
| 154 | + maybeLhs = extractExtOperand<arith::ExtUIOp>( |
| 155 | + op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); |
| 156 | + } |
| 157 | + if (!maybeLhs) |
| 158 | + return failure(); |
| 159 | + |
| 160 | + auto maybeRhs = extractExtOperand<arith::ExtSIOp>( |
| 161 | + op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); |
| 162 | + if (maybeRhs) { |
| 163 | + if (mmlaOp == MMLA::Unsigned) |
| 164 | + mmlaOp = MMLA::Mixed; |
| 165 | + } else { |
| 166 | + if (mmlaOp == MMLA::Signed) |
| 167 | + mmlaOp = MMLA::MixedSwapped; |
| 168 | + maybeRhs = extractExtOperand<arith::ExtUIOp>( |
| 169 | + op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); |
| 170 | + } |
| 171 | + if (!maybeRhs) |
| 172 | + return failure(); |
| 173 | + |
| 174 | + // One-dimensional vector types for arm_sve.*mmla |
| 175 | + auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true}); |
| 176 | + auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true}); |
| 177 | + |
| 178 | + // Extract LHS sub-tiles. |
| 179 | + SmallVector<Value> lhsTile; |
| 180 | + for (int64_t i = 0; i < M; i += 2) { |
| 181 | + // Exract two consective rows of the LHS tile. |
| 182 | + auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs, |
| 183 | + ArrayRef<int64_t>{i}); |
| 184 | + auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs, |
| 185 | + ArrayRef<int64_t>{i + 1}); |
| 186 | + // Concatenate to obtain a 16 x i8 flattened sub-tile. |
| 187 | + auto t = rewriter.create<vector::ShuffleOp>( |
| 188 | + loc, r0, r1, |
| 189 | + llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, |
| 190 | + 14, 15}); |
| 191 | + // Turn it into a scalable vector. |
| 192 | + auto s = rewriter.create<vector::ScalableInsertOp>( |
| 193 | + loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0); |
| 194 | + // Replicate the sub-tile VSCALE times to fill the entire vector. |
| 195 | + auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0); |
| 196 | + lhsTile.push_back(r); |
| 197 | + } |
| 198 | + |
| 199 | + // "Flatten" the RHS tile from <[N]x8> to <[8*N]>. |
| 200 | + auto RHS = rewriter.create<vector::ShapeCastOp>( |
| 201 | + maybeRhs->getLoc(), |
| 202 | + VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs); |
| 203 | + |
| 204 | + // Extract the RHS sub-tiles. |
| 205 | + SmallVector<Value> rhsTile; |
| 206 | + for (int64_t j = 0; j < N; j += 2) |
| 207 | + rhsTile.push_back( |
| 208 | + rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8)); |
| 209 | + |
| 210 | + // Handy types for packing/unpacking of the accumulator tile. |
| 211 | + auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true}); |
| 212 | + auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true}); |
| 213 | + auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true}); |
| 214 | + auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true}); |
| 215 | + |
| 216 | + // Extract and pack the ACC sub-tiles. |
| 217 | + SmallVector<Value> accTile; |
| 218 | + for (int64_t i = 0; i < M; i += 2) { |
| 219 | + // Extract two consecutive rows of the accumulator tile. |
| 220 | + auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(), |
| 221 | + ArrayRef<int64_t>{i}); |
| 222 | + auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(), |
| 223 | + ArrayRef<int64_t>{i + 1}); |
| 224 | + Value accTileVec; |
| 225 | + if (mmlaOp == MMLA::MixedSwapped) { |
| 226 | + // We need to swap the positions of the LHS and RHS (since we don't have |
| 227 | + // a signed * unsigned operation), but then each individual 2x2 tile of |
| 228 | + // the acumulator and (later) the result need to be transposed. |
| 229 | + accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1); |
| 230 | + } else { |
| 231 | + // Bitcast them to 64-bit elements, so subsequent |
| 232 | + // interleave/deinterleave work on pairs of 32-bit numbers. |
| 233 | + auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0); |
| 234 | + auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1); |
| 235 | + |
| 236 | + // Interleave the rows, effectively flattening each 2x2 tile into 4 |
| 237 | + // consecutive elements. |
| 238 | + auto intr_i64 = |
| 239 | + rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64); |
| 240 | + |
| 241 | + // Bitcast back to 32-bit elements. |
| 242 | + accTileVec = |
| 243 | + rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64); |
| 244 | + } |
| 245 | + // Extract ACC sub-tiles. |
| 246 | + for (int64_t j = 0; j < N; j += 2) |
| 247 | + accTile.push_back(rewriter.create<vector::ScalableExtractOp>( |
| 248 | + loc, nxv4i32, accTileVec, j * 2)); |
| 249 | + } |
| 250 | + |
| 251 | + // Emit sub-tile matrix multiplications. |
| 252 | + SmallVector<Value> outTile; |
| 253 | + for (int64_t i = 0; i < M / 2; ++i) |
| 254 | + for (int64_t j = 0; j < N / 2; ++j) { |
| 255 | + Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32, |
| 256 | + accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]); |
| 257 | + outTile.push_back(mmla); |
| 258 | + } |
| 259 | + |
| 260 | + // Unpack the OUT sub-tiles and insert into the result. |
| 261 | + Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType()); |
| 262 | + for (int64_t i = 0; i < M / 2; ++i) { |
| 263 | + // Collect a number of sub-tiles in a row. |
| 264 | + Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty); |
| 265 | + for (int64_t j = 0; j < N / 2; ++j) |
| 266 | + row = rewriter.create<vector::ScalableInsertOp>( |
| 267 | + loc, outTile[i * N / 2 + j], row, j * 4); |
| 268 | + |
| 269 | + // Unpack the row to obtain two rows of the output. If we have the out |
| 270 | + // sub-tiles transposed we obtain two consecutive output rows by |
| 271 | + // separating even and odd elements, i.e. a simple deinterleave. |
| 272 | + // Otherwise, the interleave is by pairs. |
| 273 | + Value out0, out1; |
| 274 | + if (mmlaOp == MMLA::MixedSwapped) { |
| 275 | + auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row); |
| 276 | + out0 = tmp.getRes1(); |
| 277 | + out1 = tmp.getRes2(); |
| 278 | + } else { |
| 279 | + // Deinterleave by pairs. |
| 280 | + auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row); |
| 281 | + auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64); |
| 282 | + |
| 283 | + // Bitcast back into 32-bit elements and insert into the result. |
| 284 | + out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy, |
| 285 | + deintr64.getRes1()); |
| 286 | + out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy, |
| 287 | + deintr64.getRes2()); |
| 288 | + } |
| 289 | + result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2); |
| 290 | + result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1); |
| 291 | + } |
| 292 | + |
| 293 | + rewriter.replaceOp(op, result); |
| 294 | + return success(); |
| 295 | + } |
| 296 | +}; |
| 297 | + |
| 298 | +} // namespace |
| 299 | + |
| 300 | +void mlir::populateLowerContractionToSVEI8MMPatternPatterns( |
| 301 | + RewritePatternSet &patterns) { |
| 302 | + MLIRContext *context = patterns.getContext(); |
| 303 | + patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2); |
| 304 | +} |
0 commit comments