Skip to content

[MLIR][ArmSVE] Add initial lowering of vector.contract to SVE *MMLA instructions #135636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/momchil-velikov/svusmmla
Choose a base branch
from

Conversation

momchil-velikov
Copy link
Collaborator

Supersedes #135359

@llvmbot
Copy link
Member

llvmbot commented Apr 14, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-neon

Author: Momchil Velikov (momchil-velikov)

Changes

Supersedes #135359


Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135636.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+4)
  • (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3)
  • (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of ArmSVE dialect while lowering the vector "
        "dialect.">,
+    Option<"armI8MM", "enable-arm-i8mm",
+           "bool", /*default=*/"false",
+           "Enables the use of Arm FEAT_I8MM instructions while lowering "
+           "the vector dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
+    if (armI8MM) {
+      if (armNeon)
+        arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+      if (armSVE)
+        populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+    }
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
     // Avoid 0-D vectors and 1-D rhs:
     if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
       return failure();
+    // Avoid scalable vectors.
+    if (lhsType.isScalable() || rhsType.isScalable())
+      return failure();
     auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
 void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
-  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
 }
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSVETransforms
   LegalizeForLLVMExport.cpp
   LegalizeVectorStorage.cpp
+  LowerContractionToSVEI8MMPattern.cpp
 
   DEPENDS
   MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+                         std::is_base_of_v<arith::ExtUIOp, T>),
+                        std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+  if (!extOp)
+    return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+    return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+
+    // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+    // eventually expect from MMT4D. M and N dimensions must be even and at
+    // least 2.
+    if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+        rhsType.getRank() != 2)
+      return failure();
+
+    if (lhsType.isScalable() || !rhsType.isScalable())
+      return failure();
+
+    // M, N, and K are the conventional names for matrix dimensions in the
+    // context of matrix multiplication.
+    auto M = lhsType.getDimSize(0);
+    auto N = rhsType.getDimSize(0);
+    auto K = rhsType.getDimSize(1);
+
+    if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+        N % 2 != 0 || !rhsType.getScalableDims()[0])
+      return failure();
+
+    // Check permutation maps. For now only accept
+    //   lhs: (d0, d1, d2) -> (d0, d2)
+    //   rhs: (d0, d1, d2) -> (d1, d2)
+    //   acc: (d0, d1, d2) -> (d0, d1)
+    // Note: RHS is transposed.
+    if (op.getIndexingMapsArray()[0] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[1] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[2] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+                                                 op.getContext()))
+      return failure();
+
+    // Check iterator types for matrix multiplication.
+    auto itTypes = op.getIteratorTypesArray();
+    if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+        itTypes[1] != vector::IteratorType::parallel ||
+        itTypes[2] != vector::IteratorType::reduction)
+      return failure();
+
+    // Check the combining kind is addition.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return failure();
+
+    // Check the output is a vector of i32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getType());
+    if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+      return failure();
+
+    // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+    // before the extension. All four signed/unsigned combinations for input
+    // operands are supported, but they are lowered to different operations.
+    // Determina which is the appropriate operation to lower to.
+    MMLA mmlaOp = MMLA::Signed;
+    auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+        op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::Unsigned;
+      maybeLhs = extractExtOperand<arith::ExtUIOp>(
+          op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeLhs)
+      return failure();
+
+    auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+        op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::Unsigned)
+        mmlaOp = MMLA::Mixed;
+    } else {
+      if (mmlaOp == MMLA::Signed)
+        mmlaOp = MMLA::MixedSwapped;
+      maybeRhs = extractExtOperand<arith::ExtUIOp>(
+          op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeRhs)
+      return failure();
+
+    // One-dimensional vector types for arm_sve.*mmla
+    auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+    auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+    // Extract LHS sub-tiles.
+    SmallVector<Value> lhsTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Exract two consective rows of the LHS tile.
+      auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i});
+      auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i + 1});
+      // Concatenate to obtain a 16 x i8 flattened sub-tile.
+      auto t = rewriter.create<vector::ShuffleOp>(
+          loc, r0, r1,
+          llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+                                  14, 15});
+      // Turn it into a scalable vector.
+      auto s = rewriter.create<vector::ScalableInsertOp>(
+          loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
+      // Replicate the sub-tile VSCALE times to fill the entire vector.
+      auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+      lhsTile.push_back(r);
+    }
+
+    // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
+    auto RHS = rewriter.create<vector::ShapeCastOp>(
+        maybeRhs->getLoc(),
+        VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);
+
+    // Extract the RHS sub-tiles.
+    SmallVector<Value> rhsTile;
+    for (int64_t j = 0; j < N; j += 2)
+      rhsTile.push_back(
+          rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
+
+    // Handy types for packing/unpacking of the accumulator tile.
+    auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
+    auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
+    auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
+    auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});
+
+    // Extract and pack the ACC sub-tiles.
+    SmallVector<Value> accTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Extract two consecutive rows of the accumulator tile.
+      auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                   ArrayRef<int64_t>{i});
+      auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                   ArrayRef<int64_t>{i + 1});
+      Value accTileVec;
+      if (mmlaOp == MMLA::MixedSwapped) {
+        // We need to swap the positions of the LHS and RHS (since we don't have
+        // a signed * unsigned operation), but then each individual 2x2 tile of
+        // the acumulator and (later) the result need to be transposed.
+        accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+      } else {
+        // Bitcast them to 64-bit elements, so subsequent
+        // interleave/deinterleave work on pairs of 32-bit numbers.
+        auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+        auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+        // Interleave the rows, effectively flattening each 2x2 tile into 4
+        // consecutive elements.
+        auto intr_i64 =
+            rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);
+
+        // Bitcast back to 32-bit elements.
+        accTileVec =
+            rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
+      }
+      // Extract ACC sub-tiles.
+      for (int64_t j = 0; j < N; j += 2)
+        accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+            loc, nxv4i32, accTileVec, j * 2));
+    }
+
+    // Emit sub-tile matrix multiplications.
+    SmallVector<Value> outTile;
+    for (int64_t i = 0; i < M / 2; ++i)
+      for (int64_t j = 0; j < N / 2; ++j) {
+        Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
+                                accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
+        outTile.push_back(mmla);
+      }
+
+    // Unpack the OUT sub-tiles and insert into the result.
+    Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+    for (int64_t i = 0; i < M / 2; ++i) {
+      // Collect a number of sub-tiles in a row.
+      Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+      for (int64_t j = 0; j < N / 2; ++j)
+        row = rewriter.create<vector::ScalableInsertOp>(
+            loc, outTile[i * N / 2 + j], row, j * 4);
+
+      // Unpack the row to obtain two rows of the output. If we have the out
+      // sub-tiles transposed we obtain two consecutive output rows by
+      // separating even and odd elements, i.e. a simple deinterleave.
+      // Otherwise, the interleave is by pairs.
+      Value out0, out1;
+      if (mmlaOp == MMLA::MixedSwapped) {
+        auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+        out0 = tmp.getRes1();
+        out1 = tmp.getRes2();
+      } else {
+        // Deinterleave by pairs.
+        auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+        auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+        // Bitcast back into 32-bit elements and insert into the result.
+        out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+                                                  deintr64.getRes1());
+        out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+                                                  deintr64.getRes2());
+      }
+      result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+      result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
new file mode 100644
index 0000000000000..2535ee9181c13
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -0,0 +1,94 @@
+// RUN:  mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_smmla
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extract...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 14, 2025

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

Changes

Supersedes #135359


Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135636.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+4)
  • (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3)
  • (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of ArmSVE dialect while lowering the vector "
        "dialect.">,
+    Option<"armI8MM", "enable-arm-i8mm",
+           "bool", /*default=*/"false",
+           "Enables the use of Arm FEAT_I8MM instructions while lowering "
+           "the vector dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
+    if (armI8MM) {
+      if (armNeon)
+        arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+      if (armSVE)
+        populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+    }
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
     // Avoid 0-D vectors and 1-D rhs:
     if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
       return failure();
+    // Avoid scalable vectors.
+    if (lhsType.isScalable() || rhsType.isScalable())
+      return failure();
     auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
 void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
-  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
 }
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSVETransforms
   LegalizeForLLVMExport.cpp
   LegalizeVectorStorage.cpp
+  LowerContractionToSVEI8MMPattern.cpp
 
   DEPENDS
   MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+                         std::is_base_of_v<arith::ExtUIOp, T>),
+                        std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+  if (!extOp)
+    return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+    return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+
+    // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+    // eventually expect from MMT4D. M and N dimensions must be even and at
+    // least 2.
+    if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+        rhsType.getRank() != 2)
+      return failure();
+
+    if (lhsType.isScalable() || !rhsType.isScalable())
+      return failure();
+
+    // M, N, and K are the conventional names for matrix dimensions in the
+    // context of matrix multiplication.
+    auto M = lhsType.getDimSize(0);
+    auto N = rhsType.getDimSize(0);
+    auto K = rhsType.getDimSize(1);
+
+    if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+        N % 2 != 0 || !rhsType.getScalableDims()[0])
+      return failure();
+
+    // Check permutation maps. For now only accept
+    //   lhs: (d0, d1, d2) -> (d0, d2)
+    //   rhs: (d0, d1, d2) -> (d1, d2)
+    //   acc: (d0, d1, d2) -> (d0, d1)
+    // Note: RHS is transposed.
+    if (op.getIndexingMapsArray()[0] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[1] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[2] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+                                                 op.getContext()))
+      return failure();
+
+    // Check iterator types for matrix multiplication.
+    auto itTypes = op.getIteratorTypesArray();
+    if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+        itTypes[1] != vector::IteratorType::parallel ||
+        itTypes[2] != vector::IteratorType::reduction)
+      return failure();
+
+    // Check the combining kind is addition.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return failure();
+
+    // Check the output is a vector of i32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getType());
+    if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+      return failure();
+
+    // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+    // before the extension. All four signed/unsigned combinations for input
+    // operands are supported, but they are lowered to different operations.
+    // Determina which is the appropriate operation to lower to.
+    MMLA mmlaOp = MMLA::Signed;
+    auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+        op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::Unsigned;
+      maybeLhs = extractExtOperand<arith::ExtUIOp>(
+          op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeLhs)
+      return failure();
+
+    auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+        op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::Unsigned)
+        mmlaOp = MMLA::Mixed;
+    } else {
+      if (mmlaOp == MMLA::Signed)
+        mmlaOp = MMLA::MixedSwapped;
+      maybeRhs = extractExtOperand<arith::ExtUIOp>(
+          op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeRhs)
+      return failure();
+
+    // One-dimensional vector types for arm_sve.*mmla
+    auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+    auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+    // Extract LHS sub-tiles.
+    SmallVector<Value> lhsTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Exract two consective rows of the LHS tile.
+      auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i});
+      auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i + 1});
+      // Concatenate to obtain a 16 x i8 flattened sub-tile.
+      auto t = rewriter.create<vector::ShuffleOp>(
+          loc, r0, r1,
+          llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+                                  14, 15});
+      // Turn it into a scalable vector.
+      auto s = rewriter.create<vector::ScalableInsertOp>(
+          loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
+      // Replicate the sub-tile VSCALE times to fill the entire vector.
+      auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+      lhsTile.push_back(r);
+    }
+
+    // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
+    auto RHS = rewriter.create<vector::ShapeCastOp>(
+        maybeRhs->getLoc(),
+        VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);
+
+    // Extract the RHS sub-tiles.
+    SmallVector<Value> rhsTile;
+    for (int64_t j = 0; j < N; j += 2)
+      rhsTile.push_back(
+          rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
+
+    // Handy types for packing/unpacking of the accumulator tile.
+    auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
+    auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
+    auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
+    auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});
+
+    // Extract and pack the ACC sub-tiles.
+    SmallVector<Value> accTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Extract two consecutive rows of the accumulator tile.
+      auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                   ArrayRef<int64_t>{i});
+      auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                   ArrayRef<int64_t>{i + 1});
+      Value accTileVec;
+      if (mmlaOp == MMLA::MixedSwapped) {
+        // We need to swap the positions of the LHS and RHS (since we don't have
+        // a signed * unsigned operation), but then each individual 2x2 tile of
+        // the acumulator and (later) the result need to be transposed.
+        accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+      } else {
+        // Bitcast them to 64-bit elements, so subsequent
+        // interleave/deinterleave work on pairs of 32-bit numbers.
+        auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+        auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+        // Interleave the rows, effectively flattening each 2x2 tile into 4
+        // consecutive elements.
+        auto intr_i64 =
+            rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);
+
+        // Bitcast back to 32-bit elements.
+        accTileVec =
+            rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
+      }
+      // Extract ACC sub-tiles.
+      for (int64_t j = 0; j < N; j += 2)
+        accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+            loc, nxv4i32, accTileVec, j * 2));
+    }
+
+    // Emit sub-tile matrix multiplications.
+    SmallVector<Value> outTile;
+    for (int64_t i = 0; i < M / 2; ++i)
+      for (int64_t j = 0; j < N / 2; ++j) {
+        Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
+                                accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
+        outTile.push_back(mmla);
+      }
+
+    // Unpack the OUT sub-tiles and insert into the result.
+    Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+    for (int64_t i = 0; i < M / 2; ++i) {
+      // Collect a number of sub-tiles in a row.
+      Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+      for (int64_t j = 0; j < N / 2; ++j)
+        row = rewriter.create<vector::ScalableInsertOp>(
+            loc, outTile[i * N / 2 + j], row, j * 4);
+
+      // Unpack the row to obtain two rows of the output. If we have the out
+      // sub-tiles transposed we obtain two consecutive output rows by
+      // separating even and odd elements, i.e. a simple deinterleave.
+      // Otherwise, the interleave is by pairs.
+      Value out0, out1;
+      if (mmlaOp == MMLA::MixedSwapped) {
+        auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+        out0 = tmp.getRes1();
+        out1 = tmp.getRes2();
+      } else {
+        // Deinterleave by pairs.
+        auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+        auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+        // Bitcast back into 32-bit elements and insert into the result.
+        out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+                                                  deintr64.getRes1());
+        out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+                                                  deintr64.getRes2());
+      }
+      result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+      result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
new file mode 100644
index 0000000000000..2535ee9181c13
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -0,0 +1,94 @@
+// RUN:  mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_smmla
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extract...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 14, 2025

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

Changes

Supersedes #135359


Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135636.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+4)
  • (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3)
  • (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94)
  • (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of ArmSVE dialect while lowering the vector "
        "dialect.">,
+    Option<"armI8MM", "enable-arm-i8mm",
+           "bool", /*default=*/"false",
+           "Enables the use of Arm FEAT_I8MM instructions while lowering "
+           "the vector dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
+    if (armI8MM) {
+      if (armNeon)
+        arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+      if (armSVE)
+        populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+    }
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
     // Avoid 0-D vectors and 1-D rhs:
     if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
       return failure();
+    // Avoid scalable vectors.
+    if (lhsType.isScalable() || rhsType.isScalable())
+      return failure();
     auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
 void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
-  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
 }
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSVETransforms
   LegalizeForLLVMExport.cpp
   LegalizeVectorStorage.cpp
+  LowerContractionToSVEI8MMPattern.cpp
 
   DEPENDS
   MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+                         std::is_base_of_v<arith::ExtUIOp, T>),
+                        std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+  if (!extOp)
+    return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+    return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+
+    // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+    // eventually expect from MMT4D. M and N dimensions must be even and at
+    // least 2.
+    if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+        rhsType.getRank() != 2)
+      return failure();
+
+    if (lhsType.isScalable() || !rhsType.isScalable())
+      return failure();
+
+    // M, N, and K are the conventional names for matrix dimensions in the
+    // context of matrix multiplication.
+    auto M = lhsType.getDimSize(0);
+    auto N = rhsType.getDimSize(0);
+    auto K = rhsType.getDimSize(1);
+
+    if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+        N % 2 != 0 || !rhsType.getScalableDims()[0])
+      return failure();
+
+    // Check permutation maps. For now only accept
+    //   lhs: (d0, d1, d2) -> (d0, d2)
+    //   rhs: (d0, d1, d2) -> (d1, d2)
+    //   acc: (d0, d1, d2) -> (d0, d1)
+    // Note: RHS is transposed.
+    if (op.getIndexingMapsArray()[0] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[1] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[2] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+                                                 op.getContext()))
+      return failure();
+
+    // Check iterator types for matrix multiplication.
+    auto itTypes = op.getIteratorTypesArray();
+    if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+        itTypes[1] != vector::IteratorType::parallel ||
+        itTypes[2] != vector::IteratorType::reduction)
+      return failure();
+
+    // Check the combining kind is addition.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return failure();
+
+    // Check the output is a vector of i32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getType());
+    if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+      return failure();
+
+    // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+    // before the extension. All four signed/unsigned combinations for input
+    // operands are supported, but they are lowered to different operations.
+    // Determina which is the appropriate operation to lower to.
+    MMLA mmlaOp = MMLA::Signed;
+    auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+        op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::Unsigned;
+      maybeLhs = extractExtOperand<arith::ExtUIOp>(
+          op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeLhs)
+      return failure();
+
+    auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+        op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::Unsigned)
+        mmlaOp = MMLA::Mixed;
+    } else {
+      if (mmlaOp == MMLA::Signed)
+        mmlaOp = MMLA::MixedSwapped;
+      maybeRhs = extractExtOperand<arith::ExtUIOp>(
+          op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeRhs)
+      return failure();
+
+    // One-dimensional vector types for arm_sve.*mmla
+    auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+    auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+    // Extract LHS sub-tiles.
+    SmallVector<Value> lhsTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Exract two consective rows of the LHS tile.
+      auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i});
+      auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i + 1});
+      // Concatenate to obtain a 16 x i8 flattened sub-tile.
+      auto t = rewriter.create<vector::ShuffleOp>(
+          loc, r0, r1,
+          llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+                                  14, 15});
+      // Turn it into a scalable vector.
+      auto s = rewriter.create<vector::ScalableInsertOp>(
+          loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
+      // Replicate the sub-tile VSCALE times to fill the entire vector.
+      auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+      lhsTile.push_back(r);
+    }
+
+    // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
+    auto RHS = rewriter.create<vector::ShapeCastOp>(
+        maybeRhs->getLoc(),
+        VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);
+
+    // Extract the RHS sub-tiles.
+    SmallVector<Value> rhsTile;
+    for (int64_t j = 0; j < N; j += 2)
+      rhsTile.push_back(
+          rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
+
+    // Handy types for packing/unpacking of the accumulator tile.
+    auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
+    auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
+    auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
+    auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});
+
+    // Extract and pack the ACC sub-tiles.
+    SmallVector<Value> accTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Extract two consecutive rows of the accumulator tile.
+      auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                   ArrayRef<int64_t>{i});
+      auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                   ArrayRef<int64_t>{i + 1});
+      Value accTileVec;
+      if (mmlaOp == MMLA::MixedSwapped) {
+        // We need to swap the positions of the LHS and RHS (since we don't have
+        // a signed * unsigned operation), but then each individual 2x2 tile of
+        // the acumulator and (later) the result need to be transposed.
+        accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+      } else {
+        // Bitcast them to 64-bit elements, so subsequent
+        // interleave/deinterleave work on pairs of 32-bit numbers.
+        auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+        auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+        // Interleave the rows, effectively flattening each 2x2 tile into 4
+        // consecutive elements.
+        auto intr_i64 =
+            rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);
+
+        // Bitcast back to 32-bit elements.
+        accTileVec =
+            rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
+      }
+      // Extract ACC sub-tiles.
+      for (int64_t j = 0; j < N; j += 2)
+        accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+            loc, nxv4i32, accTileVec, j * 2));
+    }
+
+    // Emit sub-tile matrix multiplications.
+    SmallVector<Value> outTile;
+    for (int64_t i = 0; i < M / 2; ++i)
+      for (int64_t j = 0; j < N / 2; ++j) {
+        Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
+                                accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
+        outTile.push_back(mmla);
+      }
+
+    // Unpack the OUT sub-tiles and insert into the result.
+    Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+    for (int64_t i = 0; i < M / 2; ++i) {
+      // Collect a number of sub-tiles in a row.
+      Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+      for (int64_t j = 0; j < N / 2; ++j)
+        row = rewriter.create<vector::ScalableInsertOp>(
+            loc, outTile[i * N / 2 + j], row, j * 4);
+
+      // Unpack the row to obtain two rows of the output. If we have the out
+      // sub-tiles transposed we obtain two consecutive output rows by
+      // separating even and odd elements, i.e. a simple deinterleave.
+      // Otherwise, the interleave is by pairs.
+      Value out0, out1;
+      if (mmlaOp == MMLA::MixedSwapped) {
+        auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+        out0 = tmp.getRes1();
+        out1 = tmp.getRes2();
+      } else {
+        // Deinterleave by pairs.
+        auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+        auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+        // Bitcast back into 32-bit elements and insert into the result.
+        out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+                                                  deintr64.getRes1());
+        out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+                                                  deintr64.getRes2());
+      }
+      result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+      result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
new file mode 100644
index 0000000000000..2535ee9181c13
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -0,0 +1,94 @@
+// RUN:  mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_smmla
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extract...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This one is a bit longer, so I may need to wait till Thursday before I can review.

One high-level question - would sharing some code between NEON and SVE be possible?

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/vector-contract-i8mm branch from 2e61d3e to 8e87a7f Compare April 15, 2025 15:54
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/svusmmla branch from 71e2f13 to 5e91c2e Compare April 15, 2025 15:54
@momchil-velikov
Copy link
Collaborator Author

One high-level question - would sharing some code between NEON and SVE be possible?

No, I can't see it happening and resulting in less, or simpler, or easier to maintain code.
However, it might be possible to add Neon lowering to this patch and see if the result is any good.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Momchil - this is great!

I skimmed through the pattern logic, and it's very neatly written. It's actually quite easy to follow, despite the underlying logic being a bit convoluted - well done! I've left a few minor suggestions, but nothing major.

Also, it seems like we should be able to extend this fairly easily to support NEON as well. Worth thinking about 🙂

Now, overall this patch is quite large, and I’d suggest extracting the end-to-end / integration tests into a separate PR. Additionally, the remaining tests currently use --convert-vector-to-llvm=, which lowers all the way to LLVM (i.e., it exercises a lot of patterns). Instead, I’d recommend testing LowerContractionToSVEI8MMPattern in isolation and only verifying that the correct sequence of ArmSVE ops (plus some Vector ops) is generated - for example:

    (...)
    %33 = arm_sve.smmla %23, %7, %15 : vector<[16]xi8> to vector<[4]xi32>
    %34 = arm_sve.smmla %24, %7, %16 : vector<[16]xi8> to vector<[4]xi32>
    %35 = arm_sve.smmla %31, %13, %15 : vector<[16]xi8> to vector<[4]xi32>
    %36 = arm_sve.smmla %32, %13, %16 : vector<[16]xi8> to vector<[4]xi32>

That way, we will:

  • reduce noise in the test output (by focusing on a single pattern),
  • simplify expected output (fewer ops to match),
  • avoid re-testing functionality already covered elsewhere (e.g., arm_sve.smmlaarm_sve.intr.smmla lowering).

Btw, this is already looking great, and I know I’m asking for a bit of a rewrite (especially around the tests), but I really think it’ll help with long-term maintainability.

Comment on lines +36 to +38
inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
std::is_base_of_v<arith::ExtUIOp, T>),
std::optional<Value>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simple isa<arith::ExtSIOp, arith::ExtUIOp>(v.getDefinitionOp()) inside the function instead of this? That's more common from what I've seen (there's very little SFINAE in the Dialect code).

}
}

class LowerContractionToSVEI8MMPattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a very long pattern. Could you document the high-level logic?

mlir::VectorType rhsType = op.getRhsType();

// For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
// eventually expect from MMT4D. M and N dimensions must be even and at
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] We shouldn't be concerned with MMT4D in this dialect - it's a much higher-level abstraction and this logic should be valid irrespective of how the input is generated.

// least 2.
if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
rhsType.getRank() != 2)
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use notifyMatchFailure with some descriptive error message instead? Thanks! Some comment for other instances of failure.

// For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
// eventually expect from MMT4D. M and N dimensions must be even and at
// least 2.
if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, inputs to vector.contract are required to be vectors, hence lhsType.hasRank() should always be true, no?

return failure();

// One-dimensional vector types for arm_sve.*mmla
auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar suggestion for other instances of VectorType::get

Suggested change
auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
auto nxv16i8 = VectorType::get(/*shape=*/16, rewriter.getI8Type(), /*scalableDims=*/{true});

loc, r0, r1,
llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15});
// Turn it into a scalable vector.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] This is "broadcasting", so a bit more involved/nuanced than just turning into a scalable vector. Perhaps expand?

auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});

// Extract LHS sub-tiles.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you specify the dims? That would be helpful.

// This file implements lowering patterns from vector.contract to
// SVE I8MM operations.
//
//===---
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
//===---
//===----------------------------------------------------------------------===//```

Comment on lines +9 to +10
// This file implements lowering patterns from vector.contract to
// SVE I8MM operations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a note that vector.contract needs to be accompanied by arith.extsi (or arith.extui) Ops? Also, is I8MM the official name? Shouldn't that be FEAT_I8MM?

Basically, could we document a bit more?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants