Skip to content

[MLIR][ArmSVE] Add initial lowering of vector.contract to SVE *MMLA instructions #135636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/momchil-velikov/svusmmla
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
Option<"armI8MM", "enable-arm-i8mm",
"bool", /*default=*/"false",
"Enables the use of Arm FEAT_I8MM instructions while lowering "
"the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class RewritePatternSet;
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);

void populateLowerContractionToSVEI8MMPatternPatterns(
RewritePatternSet &patterns);

/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRVectorToLLVM

MLIRArmNeonDialect
MLIRArmNeonTransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
if (armSVE)
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
// Avoid 0-D vectors and 1-D rhs:
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
return failure();
// Avoid scalable vectors.
if (lhsType.isScalable() || rhsType.isScalable())
return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
Expand Down Expand Up @@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
LowerContractionToSVEI8MMPattern.cpp

DEPENDS
MLIRArmSVEConversionsIncGen
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering patterns from vector.contract to
// SVE I8MM operations.
Comment on lines +9 to +10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a note that vector.contract needs to be accompanied by arith.extsi (or arith.extui) Ops? Also, is I8MM the official name? Shouldn't that be FEAT_I8MM?

Basically, could we document a bit more?

//
//===---
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
//===---
//===----------------------------------------------------------------------===//```


#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "mlir/Dialect/UB/IR/UBOps.h"

#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"

using namespace mlir;
using namespace mlir::arm_sve;

namespace {
// Check if the given value is a result of the operation `T` (which must be
// sign- or zero- extend) from i8 to i32. Return the value before the extension.
template <typename T>
inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
std::is_base_of_v<arith::ExtUIOp, T>),
std::optional<Value>>
Comment on lines +36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simple isa<arith::ExtSIOp, arith::ExtUIOp>(v.getDefinitionOp()) inside the function instead of this? That's more common from what I've seen (there's very little SFINAE in the Dialect code).

extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
if (!extOp)
return {};

auto inOp = extOp.getIn();
auto inTy = dyn_cast<VectorType>(inOp.getType());
if (!inTy || inTy.getElementType() != i8Ty)
return {};

auto outTy = dyn_cast<VectorType>(extOp.getType());
if (!outTy || outTy.getElementType() != i32Ty)
return {};

return inOp;
}

// Designate the operation (resp. instruction) used to do sub-tile matrix
// multiplications.
enum class MMLA {
Signed, // smmla
Unsigned, // ummla
Mixed, // usmmla
MixedSwapped // usmmla with LHS and RHS swapped
};

// Create the matrix multply and accumulate operation according to `op`.
Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
switch (op) {
case MMLA::Signed:
return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
case MMLA::Unsigned:
return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
case MMLA::Mixed:
return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
case MMLA::MixedSwapped:
// The accumulator comes transposed and the result will be transposed
// later, so all we have to do here is swap the operands.
return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
}
}

class LowerContractionToSVEI8MMPattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a very long pattern. Could you document the high-level logic?

: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();

// For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
// eventually expect from MMT4D. M and N dimensions must be even and at
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] We shouldn't be concerned with MMT4D in this dialect - it's a much higher-level abstraction and this logic should be valid irrespective of how the input is generated.

// least 2.
if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, inputs to vector.contract are required to be vectors, hence lhsType.hasRank() should always be true, no?

rhsType.getRank() != 2)
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use notifyMatchFailure with some descriptive error message instead? Thanks! Some comment for other instances of failure.


if (lhsType.isScalable() || !rhsType.isScalable())
return failure();

// M, N, and K are the conventional names for matrix dimensions in the
// context of matrix multiplication.
auto M = lhsType.getDimSize(0);
auto N = rhsType.getDimSize(0);
auto K = rhsType.getDimSize(1);

if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
N % 2 != 0 || !rhsType.getScalableDims()[0])
return failure();

// Check permutation maps. For now only accept
// lhs: (d0, d1, d2) -> (d0, d2)
// rhs: (d0, d1, d2) -> (d1, d2)
// acc: (d0, d1, d2) -> (d0, d1)
// Note: RHS is transposed.
if (op.getIndexingMapsArray()[0] !=
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
op.getContext()) ||
op.getIndexingMapsArray()[1] !=
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
op.getContext()) ||
op.getIndexingMapsArray()[2] !=
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
op.getContext()))
return failure();

// Check iterator types for matrix multiplication.
auto itTypes = op.getIteratorTypesArray();
if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
itTypes[1] != vector::IteratorType::parallel ||
itTypes[2] != vector::IteratorType::reduction)
return failure();

// Check the combining kind is addition.
if (op.getKind() != vector::CombiningKind::ADD)
return failure();

// Check the output is a vector of i32 elements.
auto outTy = dyn_cast<VectorType>(op.getType());
if (!outTy || outTy.getElementType() != rewriter.getI32Type())
return failure();

// Check inputs are sign-/zero- extensions from i8 to i32. Get the values
// before the extension. All four signed/unsigned combinations for input
// operands are supported, but they are lowered to different operations.
// Determina which is the appropriate operation to lower to.
MMLA mmlaOp = MMLA::Signed;
auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
if (!maybeLhs) {
mmlaOp = MMLA::Unsigned;
maybeLhs = extractExtOperand<arith::ExtUIOp>(
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
}
if (!maybeLhs)
return failure();

auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
if (maybeRhs) {
if (mmlaOp == MMLA::Unsigned)
mmlaOp = MMLA::Mixed;
} else {
if (mmlaOp == MMLA::Signed)
mmlaOp = MMLA::MixedSwapped;
maybeRhs = extractExtOperand<arith::ExtUIOp>(
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
}
if (!maybeRhs)
return failure();

// One-dimensional vector types for arm_sve.*mmla
auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar suggestion for other instances of VectorType::get

Suggested change
auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
auto nxv16i8 = VectorType::get(/*shape=*/16, rewriter.getI8Type(), /*scalableDims=*/{true});

auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});

// Extract LHS sub-tiles.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you specify the dims? That would be helpful.

SmallVector<Value> lhsTile;
for (int64_t i = 0; i < M; i += 2) {
// Exract two consective rows of the LHS tile.
auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
ArrayRef<int64_t>{i});
auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
ArrayRef<int64_t>{i + 1});
// Concatenate to obtain a 16 x i8 flattened sub-tile.
auto t = rewriter.create<vector::ShuffleOp>(
loc, r0, r1,
llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15});
// Turn it into a scalable vector.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] This is "broadcasting", so a bit more involved/nuanced than just turning into a scalable vector. Perhaps expand?

auto s = rewriter.create<vector::ScalableInsertOp>(
loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
// Replicate the sub-tile VSCALE times to fill the entire vector.
auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
lhsTile.push_back(r);
}

// "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
auto RHS = rewriter.create<vector::ShapeCastOp>(
maybeRhs->getLoc(),
VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);

// Extract the RHS sub-tiles.
SmallVector<Value> rhsTile;
for (int64_t j = 0; j < N; j += 2)
rhsTile.push_back(
rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));

// Handy types for packing/unpacking of the accumulator tile.
auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});

// Extract and pack the ACC sub-tiles.
SmallVector<Value> accTile;
for (int64_t i = 0; i < M; i += 2) {
// Extract two consecutive rows of the accumulator tile.
auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
ArrayRef<int64_t>{i});
auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
ArrayRef<int64_t>{i + 1});
Value accTileVec;
if (mmlaOp == MMLA::MixedSwapped) {
// We need to swap the positions of the LHS and RHS (since we don't have
// a signed * unsigned operation), but then each individual 2x2 tile of
// the acumulator and (later) the result need to be transposed.
accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
} else {
// Bitcast them to 64-bit elements, so subsequent
// interleave/deinterleave work on pairs of 32-bit numbers.
auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);

// Interleave the rows, effectively flattening each 2x2 tile into 4
// consecutive elements.
auto intr_i64 =
rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);

// Bitcast back to 32-bit elements.
accTileVec =
rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
}
// Extract ACC sub-tiles.
for (int64_t j = 0; j < N; j += 2)
accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
loc, nxv4i32, accTileVec, j * 2));
}

// Emit sub-tile matrix multiplications.
SmallVector<Value> outTile;
for (int64_t i = 0; i < M / 2; ++i)
for (int64_t j = 0; j < N / 2; ++j) {
Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
outTile.push_back(mmla);
}

// Unpack the OUT sub-tiles and insert into the result.
Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
for (int64_t i = 0; i < M / 2; ++i) {
// Collect a number of sub-tiles in a row.
Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
for (int64_t j = 0; j < N / 2; ++j)
row = rewriter.create<vector::ScalableInsertOp>(
loc, outTile[i * N / 2 + j], row, j * 4);

// Unpack the row to obtain two rows of the output. If we have the out
// sub-tiles transposed we obtain two consecutive output rows by
// separating even and odd elements, i.e. a simple deinterleave.
// Otherwise, the interleave is by pairs.
Value out0, out1;
if (mmlaOp == MMLA::MixedSwapped) {
auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
out0 = tmp.getRes1();
out1 = tmp.getRes2();
} else {
// Deinterleave by pairs.
auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);

// Bitcast back into 32-bit elements and insert into the result.
out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
deintr64.getRes1());
out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
deintr64.getRes2());
}
result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
}

rewriter.replaceOp(op, result);
return success();
}
};

} // namespace

void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
}
Loading
Loading