@@ -293,6 +293,10 @@ struct LinearizeVectorExtract final
293
293
LogicalResult
294
294
matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
295
295
ConversionPatternRewriter &rewriter) const override {
296
+ // Skip if result is not a vector type
297
+ if (!isa<VectorType>(extractOp.getType ()))
298
+ return rewriter.notifyMatchFailure (extractOp,
299
+ " scalar extract is not supported." );
296
300
Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
297
301
assert (dstTy && " expected 1-D vector type" );
298
302
@@ -415,6 +419,32 @@ struct LinearizeVectorBitCast final
415
419
}
416
420
};
417
421
422
+ // / This pattern converts the SplatOp to work on a linearized vector.
423
+ // / Following,
424
+ // / vector.splat %value : vector<4x4xf32>
425
+ // / is converted to:
426
+ // / %out_1d = vector.splat %value : vector<16xf32>
427
+ // / %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
428
+ struct LinearizeVectorSplat final
429
+ : public OpConversionPattern<vector::SplatOp> {
430
+ using OpConversionPattern::OpConversionPattern;
431
+
432
+ LinearizeVectorSplat (const TypeConverter &typeConverter, MLIRContext *context,
433
+ PatternBenefit benefit = 1 )
434
+ : OpConversionPattern(typeConverter, context, benefit) {}
435
+
436
+ LogicalResult
437
+ matchAndRewrite (vector::SplatOp splatOp, OpAdaptor adaptor,
438
+ ConversionPatternRewriter &rewriter) const override {
439
+ auto dstTy = getTypeConverter ()->convertType (splatOp.getType ());
440
+ if (!dstTy)
441
+ return rewriter.notifyMatchFailure (splatOp, " cannot convert type." );
442
+ rewriter.replaceOpWithNewOp <vector::SplatOp>(splatOp, adaptor.getInput (),
443
+ dstTy);
444
+ return success ();
445
+ }
446
+ };
447
+
418
448
} // namespace
419
449
420
450
// / Return true if the operation `op` does not support scalable vectors and
@@ -501,7 +531,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
501
531
const TypeConverter &typeConverter, const ConversionTarget &target,
502
532
RewritePatternSet &patterns) {
503
533
patterns.add <LinearizeConstantLike, LinearizeVectorizable,
504
- LinearizeVectorBitCast>(typeConverter, patterns.getContext ());
534
+ LinearizeVectorBitCast, LinearizeVectorSplat>(
535
+ typeConverter, patterns.getContext ());
505
536
}
506
537
507
538
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments