-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][vector] Support index type in ND to 1D vector linearization #118404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns.
@llvm/pr-subscribers-mlir Author: Amy Zhuang (ayzhuang) ChangesCurrently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns. Full diff: https://github.com/llvm/llvm-project/pull/118404.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..e3c19a078c18b0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
/// vector shuffle operations.
void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..f0bf6276f0e659 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -25,34 +25,44 @@
using namespace mlir;
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth,
+ unsigned targetBitWidth) {
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
+ if (!vecType)
+ return false;
+ bool isIndexTy = vecType.getElementType().isIndex();
+ // Reject index if `indexBitWidth` is not supplied.
+ if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ vecType.getShape().back() *
+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
if (trailingVecDimBitWidth >= targetBitWidth)
return false;
}
return true;
}
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
+static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth,
+ unsigned targetBitWidth) {
VectorType vecType = dyn_cast<VectorType>(t);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
+ if (!vecType)
+ return false;
+ bool isIndexTy = vecType.getElementType().isIndex();
+ // Reject index if `indexBitWidth` is not supplied.
+ if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ vecType.getShape().back() *
+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
return trailingVecDimBitWidth <= targetBitWidth;
}
@@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeConstant(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
- if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
@@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -103,14 +116,16 @@ struct LinearizeVectorizable final
public:
LinearizeVectorizable(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
@@ -123,6 +138,7 @@ struct LinearizeVectorizable final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
shuffleOp.getV2VectorType().isScalable() ||
dstType.isScalable()) &&
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtract(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
cast<VectorType>(dstTy).isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorInsert(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
"scalable vectors are not supported.");
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
- targetVectorBitWidth))
+ indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
insertOp, "Can't flatten since targetBitWidth < OpSize");
@@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth) {
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
@@ -488,7 +519,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
[=](Operation *op) -> std::optional<bool> {
if ((isa<arith::ConstantOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
- return (isLessThanTargetBitWidth(op, targetBitWidth)
+ return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth)
? typeConverter.isLegal(op)
: true);
}
@@ -496,15 +527,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
});
patterns.add<LinearizeConstant, LinearizeVectorizable>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned int targetBitWidth) {
+ ConversionTarget &target, unsigned indexBitWidth,
+ unsigned int targetBitWidth) {
target.addDynamicallyLegalOp<vector::ShuffleOp>(
[=](vector::ShuffleOp shuffleOp) -> bool {
- return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+ return isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+ targetBitWidth)
? (typeConverter.isLegal(shuffleOp) &&
cast<mlir::VectorType>(shuffleOp.getResult().getType())
.getRank() == 1)
@@ -512,5 +545,5 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
});
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..fe169d3e16d683 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64
// ALL-LABEL: test_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
@@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32>
@@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// -----
-// ALL-LABEL: test_index_no_linearize
-func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
- // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+// ALL-LABEL: test_index_linearize
+func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
+ // DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex>
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
@@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32>
// ALL: return %[[RES]] : vector<2x[2]xf32>
return %2 : vector<2x[2]xf32>
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f67a24755ac09a..2589782aee1449 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -853,6 +853,10 @@ struct TestVectorLinearize final
registry.insert<vector::VectorDialect>();
}
+ Option<unsigned> indexBitwidth{*this, "index-bitwidth",
+ llvm::cl::desc("Bitwidth of the index type"),
+ llvm::cl::init(0)};
+
Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
@@ -866,9 +870,9 @@ struct TestVectorLinearize final
ConversionTarget target(*context);
vector::populateVectorLinearizeTypeConversionsAndLegality(
- typeConverter, patterns, target, targetVectorBitwidth);
+ typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- typeConverter, patterns, target, targetVectorBitwidth);
+ typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
|
@llvm/pr-subscribers-mlir-vector Author: Amy Zhuang (ayzhuang) ChangesCurrently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns. Full diff: https://github.com/llvm/llvm-project/pull/118404.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..e3c19a078c18b0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
/// vector shuffle operations.
void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..f0bf6276f0e659 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -25,34 +25,44 @@
using namespace mlir;
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth,
+ unsigned targetBitWidth) {
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
+ if (!vecType)
+ return false;
+ bool isIndexTy = vecType.getElementType().isIndex();
+ // Reject index if `indexBitWidth` is not supplied.
+ if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ vecType.getShape().back() *
+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
if (trailingVecDimBitWidth >= targetBitWidth)
return false;
}
return true;
}
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
+static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth,
+ unsigned targetBitWidth) {
VectorType vecType = dyn_cast<VectorType>(t);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
+ if (!vecType)
+ return false;
+ bool isIndexTy = vecType.getElementType().isIndex();
+ // Reject index if `indexBitWidth` is not supplied.
+ if (isIndexTy && indexBitWidth == 0)
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ vecType.getShape().back() *
+ (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
return trailingVecDimBitWidth <= targetBitWidth;
}
@@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeConstant(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
- if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
@@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -103,14 +116,16 @@ struct LinearizeVectorizable final
public:
LinearizeVectorizable(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
@@ -123,6 +138,7 @@ struct LinearizeVectorizable final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
shuffleOp.getV2VectorType().isScalable() ||
dstType.isScalable()) &&
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtract(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
cast<VectorType>(dstTy).isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+ if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+ targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
@@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorInsert(
const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned indexBitWidth = 0,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+ }
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
"scalable vectors are not supported.");
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
- targetVectorBitWidth))
+ indexBitWidth, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
insertOp, "Can't flatten since targetBitWidth < OpSize");
@@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
}
private:
+ unsigned indexBitWidth;
unsigned targetVectorBitWidth;
};
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth) {
+ ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
@@ -488,7 +519,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
[=](Operation *op) -> std::optional<bool> {
if ((isa<arith::ConstantOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
- return (isLessThanTargetBitWidth(op, targetBitWidth)
+ return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth)
? typeConverter.isLegal(op)
: true);
}
@@ -496,15 +527,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
});
patterns.add<LinearizeConstant, LinearizeVectorizable>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned int targetBitWidth) {
+ ConversionTarget &target, unsigned indexBitWidth,
+ unsigned int targetBitWidth) {
target.addDynamicallyLegalOp<vector::ShuffleOp>(
[=](vector::ShuffleOp shuffleOp) -> bool {
- return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+ return isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+ targetBitWidth)
? (typeConverter.isLegal(shuffleOp) &&
cast<mlir::VectorType>(shuffleOp.getResult().getType())
.getRank() == 1)
@@ -512,5 +545,5 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
});
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..fe169d3e16d683 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64
// ALL-LABEL: test_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
@@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32>
@@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// -----
-// ALL-LABEL: test_index_no_linearize
-func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
- // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+// ALL-LABEL: test_index_linearize
+func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
+ // DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex>
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
@@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+ // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32>
// ALL: return %[[RES]] : vector<2x[2]xf32>
return %2 : vector<2x[2]xf32>
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f67a24755ac09a..2589782aee1449 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -853,6 +853,10 @@ struct TestVectorLinearize final
registry.insert<vector::VectorDialect>();
}
+ Option<unsigned> indexBitwidth{*this, "index-bitwidth",
+ llvm::cl::desc("Bitwidth of the index type"),
+ llvm::cl::init(0)};
+
Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
@@ -866,9 +870,9 @@ struct TestVectorLinearize final
ConversionTarget target(*context);
vector::populateVectorLinearizeTypeConversionsAndLegality(
- typeConverter, patterns, target, targetVectorBitwidth);
+ typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- typeConverter, patterns, target, targetVectorBitwidth);
+ typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
|
@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns( | |||
/// the ops to get converted properly. | |||
void populateVectorLinearizeTypeConversionsAndLegality( | |||
TypeConverter &typeConverter, RewritePatternSet &patterns, | |||
ConversionTarget &target, unsigned targetBitWidth); | |||
ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is indexBitWidth
and targetBitWidth
different here? Aren't they representing the same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dcaballe Currently index type is not supported because we can't use getElementTypeBitWidth to get the bit width of index type. I add indexBitWidth argument to supply the bit width of index type. When it has non zero value and targetBitWidth is big enough, we can linearize vector of indices. Example: %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> to %0 = arith.addi %arg0, %arg1 : vector<4xindex>.
@dcaballe Could you please review this PR? Thanks! |
Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns.