From 97a6c57217bab6a815ffcd9bd12905af4d5fca1a Mon Sep 17 00:00:00 2001 From: nbpatel <nishant.b.patel@intel.com> Date: Fri, 25 Apr 2025 17:32:50 +0000 Subject: [PATCH 1/4] Add linearization pattern for vector.splat --- .../Vector/Transforms/VectorLinearize.cpp | 63 ++++++++++++++++--- mlir/test/Dialect/Vector/linearize.mlir | 17 +++++ 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index a009aa03aaf64..45c7e37738898 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -26,6 +26,9 @@ using namespace mlir; +constexpr unsigned defaultTargetVectorBitWidth = + std::numeric_limits<unsigned>::max(); + static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { auto resultTypes = op->getResultTypes(); for (auto resType : resultTypes) { @@ -82,7 +85,7 @@ struct LinearizeConstantLike final LinearizeConstantLike( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -136,7 +139,7 @@ struct LinearizeVectorizable final public: LinearizeVectorizable( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -175,7 +178,7 @@ struct LinearizeVectorExtractStridedSlice final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtractStridedSlice( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -289,7 +292,7 @@ struct LinearizeVectorShuffle final using OpConversionPattern::OpConversionPattern; LinearizeVectorShuffle( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -362,13 +365,17 @@ struct LinearizeVectorExtract final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtract( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Skip if result is not a vector type + if (!isa<VectorType>(extractOp.getType())) + return rewriter.notifyMatchFailure(extractOp, + "scalar extract is not supported."); Type dstTy = getTypeConverter()->convertType(extractOp.getType()); if (!dstTy) return rewriter.notifyMatchFailure(extractOp, @@ -425,7 +432,7 @@ struct LinearizeVectorInsert final using OpConversionPattern::OpConversionPattern; LinearizeVectorInsert( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -506,7 +513,7 @@ struct LinearizeVectorBitCast final using OpConversionPattern::OpConversionPattern; LinearizeVectorBitCast( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -531,12 +538,48 @@ struct LinearizeVectorBitCast final unsigned targetVectorBitWidth; }; +/// This pattern converts the SplatOp to work on a linearized vector. +/// Following, +/// vector.splat %value : vector<4x4xf32> +/// is converted to: +/// %out_1d = vector.splat %value : vector<16xf32> +/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> +/// It ensures that the operation is compatible with the target vector +/// bit width and replaces the original operation with a new SplatOp +/// that operates on the converted type. +struct LinearizeVectorSplat final + : public OpConversionPattern<vector::SplatOp> { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorSplat( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstTy = getTypeConverter()->convertType(splatOp.getType()); + if (!dstTy) + return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); + rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(), + dstTy); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth) { + typeConverter.addConversion([](Type type) -> Type { return type; }); typeConverter.addConversion([](VectorType type) -> std::optional<Type> { if (!isLinearizableVector(type)) return type; @@ -557,7 +600,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( typeConverter.addTargetMaterialization(materializeCast); target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional<bool> { - if ((isa<vector::BitCastOp>(op) || + if ((isa<vector::BitCastOp, vector::SplatOp>(op) || op->hasTrait<OpTrait::ConstantLike>() || op->hasTrait<OpTrait::Vectorizable>())) { return (isLessThanTargetBitWidth(op, targetBitWidth) @@ -568,8 +611,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( }); patterns.add<LinearizeConstantLike, LinearizeVectorizable, - LinearizeVectorBitCast>(typeConverter, patterns.getContext(), - targetBitWidth); + LinearizeVectorBitCast, LinearizeVectorSplat>( + typeConverter, patterns.getContext(), targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 9052c6440e6ac..89f01abb79a74 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -399,3 +399,20 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> return %1 : vector<[4]x4xf16> } + +// ----- +// ALL-LABEL: linearize_vector_splat +// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> +func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> { + // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> + // DEFAULT: return %[[CAST]] : vector<4x2xi32> + // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> + // BW-128: return %[[CAST]] : vector<4x2xi32> + + // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32> + // BW-0: return %[[SPLAT]] : vector<4x2xi32> + %0 = vector.splat %arg0 : vector<4x2xi32> + return %0 : vector<4x2xi32> +} From ff82c484ce1b2a7e3cc137c6c77b9253cd1b3f8a Mon Sep 17 00:00:00 2001 From: nbpatel <nishant.b.patel@intel.com> Date: Thu, 1 May 2025 18:08:39 +0000 Subject: [PATCH 2/4] add newline --- mlir/test/Dialect/Vector/linearize.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index e3af10be7fd61..20169c15eb2c1 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -446,4 +446,4 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { // BW-0: return %[[SPLAT]] : vector<4x[2]xi32> %0 = vector.splat %arg0 : vector<4x[2]xi32> return %0 : vector<4x[2]xi32> -} \ No newline at end of file +} From 9b21851cfdc503bde21b0c2e83f37d228e314b3b Mon Sep 17 00:00:00 2001 From: nbpatel <nishant.b.patel@intel.com> Date: Thu, 1 May 2025 18:10:36 +0000 Subject: [PATCH 3/4] Remove targetVectorBitWidth --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index dbed6d5a4cd75..c2c9c206dc2b2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -446,9 +446,6 @@ struct LinearizeVectorSplat final dstTy); return success(); } - -private: - unsigned targetVectorBitWidth; }; } // namespace From 8dbe3ccc3844f42607c8a75ff862f54938f8700a Mon Sep 17 00:00:00 2001 From: nbpatel <nishant.b.patel@intel.com> Date: Thu, 1 May 2025 20:37:39 +0000 Subject: [PATCH 4/4] Fix comments --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index c2c9c206dc2b2..b9cef003fa365 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -425,9 +425,6 @@ struct LinearizeVectorBitCast final /// is converted to: /// %out_1d = vector.splat %value : vector<16xf32> /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> -/// It ensures that the operation is compatible with the target vector -/// bit width and replaces the original operation with a new SplatOp -/// that operates on the converted type. struct LinearizeVectorSplat final : public OpConversionPattern<vector::SplatOp> { using OpConversionPattern::OpConversionPattern;