From 8e87a7f3b1438d9542d28c90eb9593ebe8cf6500 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Apr 2025 14:43:54 +0000 Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to SVE `*MMLA` instructions --- mlir/include/mlir/Conversion/Passes.td | 4 + .../Dialect/ArmSVE/Transforms/Transforms.h | 3 + .../Conversion/VectorToLLVM/CMakeLists.txt | 1 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 + .../LowerContractionToSMMLAPattern.cpp | 5 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 + .../LowerContractionToSVEI8MMPattern.cpp | 304 ++++++++++++++++++ .../Vector/CPU/ArmSVE/vector-smmla.mlir | 94 ++++++ .../Vector/CPU/ArmSVE/vector-summla.mlir | 85 +++++ .../Vector/CPU/ArmSVE/vector-ummla.mlir | 94 ++++++ .../Vector/CPU/ArmSVE/vector-usmmla.mlir | 95 ++++++ .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir | 117 +++++++ .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir | 159 +++++++++ .../CPU/ArmSVE/contraction-summla-4x8x4.mlir | 118 +++++++ .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir | 119 +++++++ .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir | 117 +++++++ 16 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..930d8b44abca0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, + Option<"armI8MM", "enable-arm-i8mm", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_I8MM instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 8665c8224cc45..232e2be29e574 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -20,6 +20,9 @@ class RewritePatternSet; void populateArmSVELegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateLowerContractionToSVEI8MMPatternPatterns( + RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 330474a718e30..8e2620029c354 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRVectorToLLVM MLIRArmNeonDialect + MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7082b92c95d1d..1e6c8122b1d0e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); + if (armI8MM) { + if (armNeon) + arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns); + if (armSVE) + populateLowerContractionToSVEI8MMPatternPatterns(patterns); + } (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 2a1271dfd6bdf..e807b233aa7aa 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp @@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern // Avoid 0-D vectors and 1-D rhs: if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) return failure(); + // Avoid scalable vectors. + if (lhsType.isScalable() || rhsType.isScalable()) + return failure(); auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); auto dimN = rhsType.getDimSize(0); auto dimK = rhsType.getDimSize(1); @@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/2); } diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt index a70c489a51fea..65f98b44b1b69 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms LegalizeForLLVMExport.cpp LegalizeVectorStorage.cpp + LowerContractionToSVEI8MMPattern.cpp DEPENDS MLIRArmSVEConversionsIncGen diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp new file mode 100644 index 0000000000000..c0620c71440bc --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp @@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template +inline std::enable_if_t<(std::is_base_of_v || + std::is_base_of_v), + std::optional> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) + return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) + return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) + return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned, // ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: + return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: + return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: + return rewriter.create(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: + // The accumulator comes transposed and the result will be transposed + // later, so all we have to do here is swap the operands. + return rewriter.create(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + mlir::VectorType lhsType = op.getLhsType(); + mlir::VectorType rhsType = op.getRhsType(); + + // For now handle LHS and RHS<8x[N]> - these are the types we + // eventually expect from MMT4D. M and N dimensions must be even and at + // least 2. + if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || + rhsType.getRank() != 2) + return failure(); + + if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + + // M, N, and K are the conventional names for matrix dimensions in the + // context of matrix multiplication. + auto M = lhsType.getDimSize(0); + auto N = rhsType.getDimSize(0); + auto K = rhsType.getDimSize(1); + + if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || + N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + + // Check permutation maps. For now only accept + // lhs: (d0, d1, d2) -> (d0, d2) + // rhs: (d0, d1, d2) -> (d1, d2) + // acc: (d0, d1, d2) -> (d0, d1) + // Note: RHS is transposed. + if (op.getIndexingMapsArray()[0] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || + op.getIndexingMapsArray()[1] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || + op.getIndexingMapsArray()[2] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + + // Check iterator types for matrix multiplication. + auto itTypes = op.getIteratorTypesArray(); + if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || + itTypes[1] != vector::IteratorType::parallel || + itTypes[2] != vector::IteratorType::reduction) + return failure(); + + // Check the combining kind is addition. + if (op.getKind() != vector::CombiningKind::ADD) + return failure(); + + // Check the output is a vector of i32 elements. + auto outTy = dyn_cast(op.getType()); + if (!outTy || outTy.getElementType() != rewriter.getI32Type()) + return failure(); + + // Check inputs are sign-/zero- extensions from i8 to i32. Get the values + // before the extension. All four signed/unsigned combinations for input + // operands are supported, but they are lowered to different operations. + // Determina which is the appropriate operation to lower to. + MMLA mmlaOp = MMLA::Signed; + auto maybeLhs = extractExtOperand( + op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + if (!maybeLhs) { + mmlaOp = MMLA::Unsigned; + maybeLhs = extractExtOperand( + op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + } + if (!maybeLhs) + return failure(); + + auto maybeRhs = extractExtOperand( + op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + if (maybeRhs) { + if (mmlaOp == MMLA::Unsigned) + mmlaOp = MMLA::Mixed; + } else { + if (mmlaOp == MMLA::Signed) + mmlaOp = MMLA::MixedSwapped; + maybeRhs = extractExtOperand( + op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + } + if (!maybeRhs) + return failure(); + + // One-dimensional vector types for arm_sve.*mmla + auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true}); + auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true}); + + // Extract LHS sub-tiles. + SmallVector lhsTile; + for (int64_t i = 0; i < M; i += 2) { + // Exract two consective rows of the LHS tile. + auto r0 = rewriter.create(loc, *maybeLhs, + ArrayRef{i}); + auto r1 = rewriter.create(loc, *maybeLhs, + ArrayRef{i + 1}); + // Concatenate to obtain a 16 x i8 flattened sub-tile. + auto t = rewriter.create( + loc, r0, r1, + llvm::ArrayRef{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15}); + // Turn it into a scalable vector. + auto s = rewriter.create( + loc, t, rewriter.create(loc, nxv16i8), 0); + // Replicate the sub-tile VSCALE times to fill the entire vector. + auto r = rewriter.create(loc, s, 0); + lhsTile.push_back(r); + } + + // "Flatten" the RHS tile from <[N]x8> to <[8*N]>. + auto RHS = rewriter.create( + maybeRhs->getLoc(), + VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs); + + // Extract the RHS sub-tiles. + SmallVector rhsTile; + for (int64_t j = 0; j < N; j += 2) + rhsTile.push_back( + rewriter.create(loc, nxv16i8, RHS, j * 8)); + + // Handy types for packing/unpacking of the accumulator tile. + auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true}); + auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true}); + auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true}); + auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true}); + + // Extract and pack the ACC sub-tiles. + SmallVector accTile; + for (int64_t i = 0; i < M; i += 2) { + // Extract two consecutive rows of the accumulator tile. + auto r0 = rewriter.create(loc, op.getAcc(), + ArrayRef{i}); + auto r1 = rewriter.create(loc, op.getAcc(), + ArrayRef{i + 1}); + Value accTileVec; + if (mmlaOp == MMLA::MixedSwapped) { + // We need to swap the positions of the LHS and RHS (since we don't have + // a signed * unsigned operation), but then each individual 2x2 tile of + // the acumulator and (later) the result need to be transposed. + accTileVec = rewriter.create(loc, r0, r1); + } else { + // Bitcast them to 64-bit elements, so subsequent + // interleave/deinterleave work on pairs of 32-bit numbers. + auto r0_i64 = rewriter.create(loc, accRow64Ty, r0); + auto r1_i64 = rewriter.create(loc, accRow64Ty, r1); + + // Interleave the rows, effectively flattening each 2x2 tile into 4 + // consecutive elements. + auto intr_i64 = + rewriter.create(loc, r0_i64, r1_i64); + + // Bitcast back to 32-bit elements. + accTileVec = + rewriter.create(loc, accRowX2Ty, intr_i64); + } + // Extract ACC sub-tiles. + for (int64_t j = 0; j < N; j += 2) + accTile.push_back(rewriter.create( + loc, nxv4i32, accTileVec, j * 2)); + } + + // Emit sub-tile matrix multiplications. + SmallVector outTile; + for (int64_t i = 0; i < M / 2; ++i) + for (int64_t j = 0; j < N / 2; ++j) { + Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32, + accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]); + outTile.push_back(mmla); + } + + // Unpack the OUT sub-tiles and insert into the result. + Value result = rewriter.create(loc, op.getResultType()); + for (int64_t i = 0; i < M / 2; ++i) { + // Collect a number of sub-tiles in a row. + Value row = rewriter.create(loc, accRowX2Ty); + for (int64_t j = 0; j < N / 2; ++j) + row = rewriter.create( + loc, outTile[i * N / 2 + j], row, j * 4); + + // Unpack the row to obtain two rows of the output. If we have the out + // sub-tiles transposed we obtain two consecutive output rows by + // separating even and odd elements, i.e. a simple deinterleave. + // Otherwise, the interleave is by pairs. + Value out0, out1; + if (mmlaOp == MMLA::MixedSwapped) { + auto tmp = rewriter.create(loc, row); + out0 = tmp.getRes1(); + out1 = tmp.getRes2(); + } else { + // Deinterleave by pairs. + auto row64 = rewriter.create(loc, accRowX264Ty, row); + auto deintr64 = rewriter.create(loc, row64); + + // Bitcast back into 32-bit elements and insert into the result. + out0 = rewriter.create(loc, accRowTy, + deintr64.getRes1()); + out1 = rewriter.create(loc, accRowTy, + deintr64.getRes2()); + } + result = rewriter.create(loc, out0, result, i * 2); + result = rewriter.create(loc, out1, result, i * 2 + 1); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +void mlir::populateLowerContractionToSVEI8MMPatternPatterns( + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(context, /*benefit=*/2); +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir new file mode 100644 index 0000000000000..2535ee9181c13 --- /dev/null +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir @@ -0,0 +1,94 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +// CHECK-LABEL: @test_vector_contract_to_smmla + +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> + +// Replicate across the entire length of the scalabale vector +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Same for LHS rows 2 and 4 +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3. +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> + +func.func @test_vector_contract_to_smmla(%lhs: vector<4x8xi8>, + %rhs: vector<[4]x8xi8>, + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { + + %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + return %2 : vector<4x[4]xi32> +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir new file mode 100644 index 0000000000000..b6285d068b0f8 --- /dev/null +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir @@ -0,0 +1,85 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +// CHECK-LABEL: @test_vector_contract_to_usmmla_rev + +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T1:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T1]][1] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T5:[0-9]+]][0] : vector<16xi8> into vector<[16]xi8> + +// Replicate across the entire length of the scalabale vector +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Same for LHS rows 2 and 4 +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T1]][3] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T5]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T0:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T0]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T21:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T19]], %[[T20]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32> +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.intr.vector.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T23:[0-9]+]] = llvm.intr.vector.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3. +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.extractvalue %[[T0]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.extractvalue %[[T0]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T26:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T24]], %[[T25]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32> +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.intr.vector.extract %[[T26]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.intr.vector.extract %[[T26]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[T29:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T22]], %[[T17]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T30:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T23]], %[[T18]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T31:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T27]], %[[T17]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T32:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T28]], %[[T18]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.insert %[[T29]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.insert %[[T30]], %[[T33]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T35:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T34]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T36:[0-9]+]] = llvm.extractvalue %[[T35]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T37:[0-9]+]] = llvm.extractvalue %[[T35]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T38:[0-9]+]] = llvm.insertvalue %[[T36]], %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.insertvalue %[[T37]], %[[T38]][1] : !llvm.array<4 x vector<[4]xi32>> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T31]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.intr.vector.insert %[[T32]], %[[T40]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.insertvalue %[[T43]], %[[T39]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.insertvalue %[[T44]], %[[T45]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T47:[0-9]+]] = builtin.unrealized_conversion_cast %[[T46]] : !llvm.array<4 x vector<[4]xi32>> to vector<4x[4]xi32> + +func.func @test_vector_contract_to_usmmla_rev( + %lhs: vector<4x8xi8>, + %rhs: vector<[4]x8xi8>, + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { + + %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + return %2 : vector<4x[4]xi32> +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir new file mode 100644 index 0000000000000..cde57842295f7 --- /dev/null +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir @@ -0,0 +1,94 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +// CHECK-LABEL: @test_vector_contract_to_ummla + +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> + +// Replicate across the entire length of the scalabale vector +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Same for LHS rows 2 and 4 +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3. +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.ummla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.ummla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.ummla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.ummla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> + +func.func @test_vector_contract_to_ummla(%lhs: vector<4x8xi8>, + %rhs: vector<[4]x8xi8>, + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { + + %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + return %2 : vector<4x[4]xi32> +} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir new file mode 100644 index 0000000000000..d0eef9fb9769c --- /dev/null +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir @@ -0,0 +1,95 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +// CHECK-LABEL: @test_vector_contract_to_usmmla + +// Extract LHS rows 0 and 1, concatenate, turn into scalable vector +// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8> + +// Replicate across the entire length of the scalabale vector +// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Same for LHS rows 2 and 4 +// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>> +// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> +// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8> +// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8> + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8> +// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8> +// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8> + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Same for accumulator rows 2 and 3. +// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64> +// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64> +// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32> +// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32> +// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> +// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32> +// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64> +// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)> +// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32> +// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>> +// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>> + +func.func @test_vector_contract_to_usmmla( + %lhs: vector<4x8xi8>, + %rhs: vector<[4]x8xi8>, + %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> { + + %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + return %2 : vector<4x[4]xi32> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir new file mode 100644 index 0000000000000..88534dd2aab1e --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir @@ -0,0 +1,117 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + %c128 = arith.constant 128 : i32 + func.call @setArmVLBits(%c128) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + +// Accumulator test data + %acc_cst = arith.constant dense<[[-44, 20, 44, -46], + [ -8, 25, -34, 26], + [-20, -36, -3, 39], + [-48, -31, -25, -21]]> : vector<4x4xi32> + %acc_m = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32> + %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33], + [-20, 17, -32, -47, 37, 22, -7, -21], + [ -7, -35, 20, -4, 39, 46, -23, 40], + [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8> + + %lhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33], + [-35, -24, 37, -32, 33, 30, -11, -17], + [-28, 31, 3, -44, -15, -27, 22, 35], + [-23, 39, 48, 26, -23, 32, -39, -38]]> : vector<4x8xi8> + + %rhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + + %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8> + %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8> + + vector.print str "RHS:\n" + %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8> + %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8> + vector.print %rhs0 : vector<[16]xi8> + vector.print %rhs1 : vector<[16]xi8> + + %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> + + // Matrix multiplication + %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + // Display the result of the multiplication + vector.print str "Result:\n" + %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32> + %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32> + %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32> + %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %u0 : vector<[4]xi32> + vector.print %u1 : vector<[4]xi32> + vector.print %u2 : vector<[4]xi32> + vector.print %u3 : vector<[4]xi32> + +// CHECK: ( -1999, 1941, 685, -2879 ) +// CHECK: ( -3705, 2952, 987, -685 ) +// CHECK: ( 2565, 4157, -1589, -357 ) +// CHECK: ( 2383, -2252, 32, -1365 ) + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir new file mode 100644 index 0000000000000..ce57be91fa540 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir @@ -0,0 +1,159 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + %c256 = arith.constant 256 : i32 + func.call @setArmVLBits(%c256) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + + + // Accumulator test data + %acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26], + [-20, -36, -3, 39, -48, -31, -25, -21], + [-35, -27, -36, -31, 23, -34, -8, -33], + [-20, 17, -32, -47, 37, 22, -7, -21], + [ -7, -35, 20, -4, 39, 46, -23, 40], + [ 40, 27, 37, 43, 38, -6, 37, 49], + [-17, -50, -1, 48, -13, 22, 39, 33], + [-35, -24, 37, -32, 33, 30, -11, -17]]> : vector<8x8xi32> + %acc_m = memref.alloca() : memref<8x8xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32> + %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32> + %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32> + %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32> + %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32> + %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + vector.print %acc4 : vector<[4]xi32> + vector.print %acc5 : vector<[4]xi32> + vector.print %acc6 : vector<[4]xi32> + vector.print %acc7 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[-28, 31, 3, -44, -15, -27, 22, 35], + [-23, 39, 48, 26, -23, 32, -39, -38], + [ -3, 9, 43, -30, -32, 39, 41, -39], + [-13, -21, -25, 27, 47, -36, -11, -11], + [ -4, -20, 36, 11, 13, -23, 24, -13], + [-20, 30, -5, 1, 42, -37, -22, 35], + [-22, 38, -4, 44, 25, -31, 23, -39], + [-45, -4, -31, -24, 14, -41, -47, 22]]> : vector<8x8xi8> + + %lhs_m = memref.alloca() : memref<8x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8> + %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8> + %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8> + %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8> + %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + vector.print %lhs4 : vector<8xi8> + vector.print %lhs5 : vector<8xi8> + vector.print %lhs6 : vector<8xi8> + vector.print %lhs7 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[-40, -11, -36, 36, -1, 20, 14, -32], + [ 46, -45, -48, -46, -24, 31, -36, 22], + [ 2, 36, 45, -29, -37, -49, -20, -35], + [ -6, 23, 23, 15, 20, 4, -8, -2], + [-35, -6, 16, 49, -50, 9, -44, 13], + [ 24, 1, -4, -44, 41, 15, -43, 44], + [ 44, 0, -10, 41, 22, 44, -40, 0], + [-33, 19, 27, 22, 38, -17, 23, -9]]> : vector<8x8xi8> + + %rhs_m = memref.alloca() : memref<8x8xi8> + vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8> + + %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8> + %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8> + + vector.print str "RHS:\n" + %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8> + %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8> + vector.print %rhs0 : vector<[16]xi8> + vector.print %rhs1 : vector<[16]xi8> + + %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> + + // Matrix multiplication + %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32> + %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32> + + // Display the result of the multilication + vector.print str "Result:\n" + %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32> + %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32> + %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32> + %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32> + %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32> + %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32> + %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32> + %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32> + vector.print %u0 : vector<[4]xi32> + vector.print %u1 : vector<[4]xi32> + vector.print %u2 : vector<[4]xi32> + vector.print %u3 : vector<[4]xi32> + vector.print %u4 : vector<[4]xi32> + vector.print %u5 : vector<[4]xi32> + vector.print %u6 : vector<[4]xi32> + vector.print %u7 : vector<[4]xi32> + + +// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 ) +// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 ) +// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 ) +// CHECK: ( 2034, -1768, -2092, 284, -792, -23, 668, 2172 ) +// CHECK: ( -248, -3728, 1214, 555, -668, -2114, -1794, 2560 ) +// CHECK: ( -1484, -2642, 297, 1551, -483, 3173, -576, 2570 ) +// CHECK: ( 3098, -7851, 1366, 1892, -427, -4533, -819, 4698 ) +// CHECK: ( -135, 1247, 765, -479, 1245, 3074, -2281, -23 ) + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir new file mode 100644 index 0000000000000..f1f311ddb0c18 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir @@ -0,0 +1,118 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + %c128 = arith.constant 128 : i32 + func.call @setArmVLBits(%c128) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + +// Accumulator test data + %acc_cst = arith.constant dense<[[-44, 20, 44, -46], + [ -8, 25, -34, 26], + [-20, -36, -3, 39], + [-48, -31, -25, -21]]> : vector<4x4xi32> + %acc_m = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32> + %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33], + [-20, 17, -32, -47, 37, 22, -7, -21], + [ -7, -35, 20, -4, 39, 46, -23, 40], + [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8> + + %lhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175, 82, 99], + [221, 25, 164, 97, 156, 221, 218, 177], + [171, 160, 219, 191, 144, 45, 161, 210], + [223, 165, 123, 99, 108, 86, 37, 92]]> : vector<4x8xi8> + + %rhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + + %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8> + %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8> + + vector.print str "RHS:\n" + %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8> + %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8> + vector.print %rhs0 : vector<[16]xi8> + vector.print %rhs1 : vector<[16]xi8> + + %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> + + // Matrix multiplication + %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + // Display the result of the multiplication + vector.print str "Result:\n" + %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32> + %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32> + %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32> + %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %u0 : vector<[4]xi32> + vector.print %u1 : vector<[4]xi32> + vector.print %u2 : vector<[4]xi32> + vector.print %u3 : vector<[4]xi32> + +// CHECK: ( -27190, -28812, -30502, -23575 ) +// CHECK: ( -7613, -8386, -15938, -6521 ) +// CHECK: ( 9468, 18750, 9199, 5764 ) +// CHECK: ( 33655, 41064, 48900, 31627 ) + return +} + diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir new file mode 100644 index 0000000000000..b5a7675f59881 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir @@ -0,0 +1,119 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + + %c128 = arith.constant 128 : i32 + func.call @setArmVLBits(%c128) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + + +// Accumulator test data + %acc_cst = arith.constant dense<[[16, 16, 48, 40], + [40, 24, 35, 12], + [33, 24, 29, 19], + [28, 13, 33, 18]]> : vector<4x4xi32> + %acc_m = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32> + %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[35, 42, 37, 49, 36, 36, 23, 33], + [39, 34, 33, 45, 43, 10, 44, 47], + [18, 35, 29, 25, 36, 33, 28, 29], + [26, 49, 43, 32, 27, 16, 45, 33]]> : vector<4x8xi8> + + %lhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[18, 31, 37, 35, 44, 22, 37, 28], + [21, 22, 49, 39, 30, 28, 35, 37], + [21, 47, 39, 35, 23, 43, 24, 49], + [49, 49, 40, 32, 37, 20, 47, 40]]> : vector<4x8xi8> + + %rhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + + %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8> + %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8> + + vector.print str "RHS:\n" + %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8> + %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8> + vector.print %rhs0 : vector<[16]xi8> + vector.print %rhs1 : vector<[16]xi8> + + %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> + + // Matrix multiplication + %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + // Display the result of the multiplication + vector.print str "Result:\n" + %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32> + %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32> + %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32> + %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %u0 : vector<[4]xi32> + vector.print %u1 : vector<[4]xi32> + vector.print %u2 : vector<[4]xi32> + vector.print %u3 : vector<[4]xi32> + +// CHECK: ( 9183, 9513, 10460, 11314 ) +// CHECK: ( 9648, 9812, 10092, 12088 ) +// CHECK: ( 7548, 7625, 8398, 9044 ) +// CHECK: ( 8855, 9046, 9685, 11191 ) + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir new file mode 100644 index 0000000000000..a25a51dd7018c --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir @@ -0,0 +1,117 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> +] + +func.func private @setArmVLBits(%bits : i32) + +func.func @main() { + %c128 = arith.constant 128 : i32 + func.call @setArmVLBits(%c128) : (i32) -> () + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c0_i8 = arith.constant 0 : i8 + +// Accumulator test data + %acc_cst = arith.constant dense<[[-44, 20, 44, -46], + [ -8, 25, -34, 26], + [-20, -36, -3, 39], + [-48, -31, -25, -21]]> : vector<4x4xi32> + %acc_m = memref.alloca() : memref<4x4xi32> + vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32> + + %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32> + %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32> + %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32> + + vector.print str "ACC:\n" + %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32> + %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32> + %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32> + %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %acc0 : vector<[4]xi32> + vector.print %acc1 : vector<[4]xi32> + vector.print %acc2 : vector<[4]xi32> + vector.print %acc3 : vector<[4]xi32> + + // LHS test data + %lhs_cst = arith.constant dense<[[153, 161, 24, 157, 211, 154, 52, 27], + [168, 77, 136, 124, 249, 28, 13, 122], + [ 97, 82, 181, 39, 53, 25, 80, 240], + [184, 227, 106, 165, 126, 113, 121, 228]]> : vector<4x8xi8> + + %lhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8> + + vector.print str "LHS:\n" + %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8> + %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8> + %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8> + %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8> + vector.print %lhs0 : vector<8xi8> + vector.print %lhs1 : vector<8xi8> + vector.print %lhs2 : vector<8xi8> + vector.print %lhs3 : vector<8xi8> + + // RHS test data + %rhs_cst = arith.constant dense<[[ 40, 27, 37, 43, 38, -6, 37, 49], + [-17, -50, -1, 48, -13, 22, 39, 33], + [-35, -24, 37, -32, 33, 30, -11, -17], + [-28, 31, 3, -44, -15, -27, 22, 35]]> : vector<4x8xi8> + + %rhs_m = memref.alloca() : memref<4x8xi8> + vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8> + + %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8> + %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8> + + vector.print str "RHS:\n" + %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8> + %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8> + vector.print %rhs0 : vector<[16]xi8> + vector.print %rhs1 : vector<[16]xi8> + + %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> + + // Matrix multiplication + %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32> + %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> + %2 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %1, %acc + : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32> + + // Display the result of the multiplication + vector.print str "Result:\n" + %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32> + %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32> + %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32> + %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32> + vector.print %u0 : vector<[4]xi32> + vector.print %u1 : vector<[4]xi32> + vector.print %u2 : vector<[4]xi32> + vector.print %u3 : vector<[4]xi32> + + // CHECK: ( 28403, 445, -2759, -11409 ) + // CHECK: ( 34908, 1047, 142, -7274 ) + // CHECK: ( 31032, 6807, -2378, 7382 ) + // CHECK: ( 44217, 6396, -10930, 623 ) + return +}