Skip to content

Commit 3a492ab

Browse files
authored
[mlir][vector] Add linearization pattern for vector.splat (#137651)
This PR is a breakdown [2 / 4] of the PR #136193 The PR adds linearization patterns for vector.splat.
1 parent c37b254 commit 3a492ab

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

+32-1
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ struct LinearizeVectorExtract final
293293
LogicalResult
294294
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
295295
ConversionPatternRewriter &rewriter) const override {
296+
// Skip if result is not a vector type
297+
if (!isa<VectorType>(extractOp.getType()))
298+
return rewriter.notifyMatchFailure(extractOp,
299+
"scalar extract is not supported.");
296300
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
297301
assert(dstTy && "expected 1-D vector type");
298302

@@ -415,6 +419,32 @@ struct LinearizeVectorBitCast final
415419
}
416420
};
417421

422+
/// This pattern converts the SplatOp to work on a linearized vector.
423+
/// Following,
424+
/// vector.splat %value : vector<4x4xf32>
425+
/// is converted to:
426+
/// %out_1d = vector.splat %value : vector<16xf32>
427+
/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
428+
struct LinearizeVectorSplat final
429+
: public OpConversionPattern<vector::SplatOp> {
430+
using OpConversionPattern::OpConversionPattern;
431+
432+
LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
433+
PatternBenefit benefit = 1)
434+
: OpConversionPattern(typeConverter, context, benefit) {}
435+
436+
LogicalResult
437+
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
438+
ConversionPatternRewriter &rewriter) const override {
439+
auto dstTy = getTypeConverter()->convertType(splatOp.getType());
440+
if (!dstTy)
441+
return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
442+
rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
443+
dstTy);
444+
return success();
445+
}
446+
};
447+
418448
} // namespace
419449

420450
/// Return true if the operation `op` does not support scalable vectors and
@@ -501,7 +531,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
501531
const TypeConverter &typeConverter, const ConversionTarget &target,
502532
RewritePatternSet &patterns) {
503533
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
504-
LinearizeVectorBitCast>(typeConverter, patterns.getContext());
534+
LinearizeVectorBitCast, LinearizeVectorSplat>(
535+
typeConverter, patterns.getContext());
505536
}
506537

507538
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

+34
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,37 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
413413
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
414414
return %1 : vector<[4]x4xf16>
415415
}
416+
417+
// -----
418+
// ALL-LABEL: linearize_vector_splat
419+
// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
420+
func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
421+
// DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
422+
// DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
423+
// DEFAULT: return %[[CAST]] : vector<4x2xi32>
424+
// BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
425+
// BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
426+
// BW-128: return %[[CAST]] : vector<4x2xi32>
427+
428+
// BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32>
429+
// BW-0: return %[[SPLAT]] : vector<4x2xi32>
430+
%0 = vector.splat %arg0 : vector<4x2xi32>
431+
return %0 : vector<4x2xi32>
432+
}
433+
434+
// -----
435+
// ALL-LABEL: linearize_scalable_vector_splat
436+
// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
437+
func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
438+
// DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
439+
// DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
440+
// DEFAULT: return %[[CAST]] : vector<4x[2]xi32>
441+
// BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
442+
// BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
443+
// BW-128: return %[[CAST]] : vector<4x[2]xi32>
444+
445+
// BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x[2]xi32>
446+
// BW-0: return %[[SPLAT]] : vector<4x[2]xi32>
447+
%0 = vector.splat %arg0 : vector<4x[2]xi32>
448+
return %0 : vector<4x[2]xi32>
449+
}

0 commit comments

Comments
 (0)