-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][vector] Address linearization comments (post commit) #138075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] Address linearization comments (post commit) #138075
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-scf Author: James Newling (newling) ChangesThis PR adds some documentation to address comments in #136581 ( FYI @banach-space ) This PR adds a test for linearization across scf.for. This new test might be considered redundant by more experienced MLIRers, but might help newer users understand how to linearize scf/cf/func operations easily (thanks for pointing this out @Hardcode84). The documentation added in this PR also tightens our definition of linearization, to now exclude unrolling (which creates multiple ops from 1 op). We hadn't really specified what linearization meant before. ( FYI @nbpatel ) Full diff: https://github.com/llvm/llvm-project/pull/138075.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..34a94e6ea7051 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -407,13 +407,22 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-/// This registers (1) which operations are legal and hence should not be
-/// linearized, (2) what converted types are (rank-1 vectors) and how to
+///
+/// Definition: here 'linearization' means converting a single operation with
+/// 1+ vector operand/result of rank>1, into a new single operation whose
+/// vector operands and results are all of rank<=1.
+///
+/// This function registers (1) which operations are legal, and hence should not
+/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
/// materialze the conversion (with shape_cast)
///
/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, but adding additional
+/// certain rank>1 vectors are considered valid, by adding additional
/// dynamically legal ops to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to make it easy to reuse generic structural type
+/// conversions for linearizing scf/cf/func operations
void populateForVectorLinearize(TypeConverter &typeConverter,
ConversionTarget &conversionTarget);
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 40d2e254fb7dd..09326242eec2a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -99,7 +99,7 @@ class ConvertForOpTypes
// PR47938 tracks this issue, but it seems hard to fix. Instead, we need
// to clone the op.
//
- // 2. We need to resue the original region instead of cloning it, otherwise
+ // 2. We need to reuse the original region instead of cloning it, otherwise
// the dialect conversion framework thinks that we just inserted all the
// cloned child ops. But what we want is to "take" the child regions and let
// the dialect conversion framework continue recursively into ops inside
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 67e15852dc5ea..62aeb2473f651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -135,9 +135,6 @@ struct LinearizeVectorExtractStridedSlice final
VectorType dstType =
getTypeConverter()->convertType<VectorType>(extractOp.getType());
assert(dstType && "vector type destination expected.");
- if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
- return rewriter.notifyMatchFailure(extractOp,
- "scalable vectors are not supported.");
ArrayAttr offsets = extractOp.getOffsets();
ArrayAttr sizes = extractOp.getSizes();
@@ -426,18 +423,22 @@ struct LinearizeVectorBitCast final
} // namespace
-/// Return true if the operation `op` does not support scalable vectors and
-/// has at least 1 scalable vector result. These ops should all eventually
-/// support scalable vectors, and this function should be removed.
+/// Some operations currently will not be linearized if they have scalable
+/// vector results, although support should be added in the future. This
+/// function returns true if `op` is such an operation.
static bool isNotLinearizableBecauseScalable(Operation *op) {
bool unsupported =
isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
op);
+
+ // Case where linearization is possible even when there are scalable vector
+ // results.
if (!unsupported)
return false;
- // Check if any of the results is a scalable vector type.
+ // Check if any of the results is a scalable vector type, and if there are
+ // return true (not linearizable).
auto types = op->getResultTypes();
bool containsScalableResult =
std::any_of(types.begin(), types.end(), [](Type type) {
@@ -448,10 +449,16 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
return containsScalableResult;
}
+/// This method defines a set of operations that are not linearizable, and hence
+/// they are considered legal for the conversion target. These ops are
+/// currently,
+///
+/// 1) ones that are not in the vector dialect, are not ConstantLike, and are
+/// not Vectorizable, or
+///
+/// 2) have scalable vector results, for which support has not yet been added.
static bool isNotLinearizable(Operation *op) {
- // Only ops that are in the vector dialect, are ConstantLike, or
- // are Vectorizable might be linearized currently.
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
StringRef opDialect = op->getDialect()->getNamespace();
bool unsupported = (opDialect != vectorDialect) &&
@@ -460,6 +467,10 @@ static bool isNotLinearizable(Operation *op) {
if (unsupported)
return true;
+ // vector.shape_cast cannot be linearized.
+ if (isa<vector::ShapeCastOp>(op))
+ return true;
+
// Some ops currently don't support scalable vectors.
if (isNotLinearizableBecauseScalable(op))
return true;
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 06eaf58b225ae..1c7427cccddc8 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -413,3 +413,25 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
+
+// -----
+
+// DEFAULT-LABEL: test_linearize_across_for
+func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
+ %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ // DEFAULT: scf.for {{.*}} -> (vector<4xi8>)
+ %1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {
+
+ // DEFAULT: arith.addi {{.*}} : vector<4xi8>
+ %2 = arith.addi %arg1, %0 : vector<2x2xi8>
+
+ // DEFAULT: scf.yield {{.*}} : vector<4xi8>
+ scf.yield %2 : vector<2x2xi8>
+ }
+ %3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
+ return %3 : vector<4xi8>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..03f8a04a0ba7a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -836,8 +837,7 @@ struct TestVectorEmulateMaskedLoadStore final
}
};
-// TODO: move this code into the user project.
-namespace vendor {
+namespace bit_width_constrained_linearization {
/// Get the set of operand/result types to check for sufficiently
/// small inner-most dimension size.
@@ -960,7 +960,7 @@ struct TestVectorBitWidthLinearize final
}
};
-} // namespace vendor
+} // namespace bit_width_constrained_linearization
struct TestVectorLinearize final
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
@@ -987,6 +987,8 @@ struct TestVectorLinearize final
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
patterns);
+ mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+ converter, patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -1067,7 +1069,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorLinearize>();
- PassRegistration<vendor::TestVectorBitWidthLinearize>();
+ PassRegistration<
+ bit_width_constrained_linearization::TestVectorBitWidthLinearize>();
PassRegistration<TestEliminateVectorMasks>();
}
|
@llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesThis PR adds some documentation to address comments in #136581 ( FYI @banach-space ) This PR adds a test for linearization across scf.for. This new test might be considered redundant by more experienced MLIRers, but might help newer users understand how to linearize scf/cf/func operations easily (thanks for pointing this out @Hardcode84). The documentation added in this PR also tightens our definition of linearization, to now exclude unrolling (which creates multiple ops from 1 op). We hadn't really specified what linearization meant before. ( FYI @nbpatel ) Full diff: https://github.com/llvm/llvm-project/pull/138075.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..34a94e6ea7051 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -407,13 +407,22 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-/// This registers (1) which operations are legal and hence should not be
-/// linearized, (2) what converted types are (rank-1 vectors) and how to
+///
+/// Definition: here 'linearization' means converting a single operation with
+/// 1+ vector operand/result of rank>1, into a new single operation whose
+/// vector operands and results are all of rank<=1.
+///
+/// This function registers (1) which operations are legal, and hence should not
+/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
/// materialze the conversion (with shape_cast)
///
/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, but adding additional
+/// certain rank>1 vectors are considered valid, by adding additional
/// dynamically legal ops to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to make it easy to reuse generic structural type
+/// conversions for linearizing scf/cf/func operations
void populateForVectorLinearize(TypeConverter &typeConverter,
ConversionTarget &conversionTarget);
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 40d2e254fb7dd..09326242eec2a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -99,7 +99,7 @@ class ConvertForOpTypes
// PR47938 tracks this issue, but it seems hard to fix. Instead, we need
// to clone the op.
//
- // 2. We need to resue the original region instead of cloning it, otherwise
+ // 2. We need to reuse the original region instead of cloning it, otherwise
// the dialect conversion framework thinks that we just inserted all the
// cloned child ops. But what we want is to "take" the child regions and let
// the dialect conversion framework continue recursively into ops inside
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 67e15852dc5ea..62aeb2473f651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -135,9 +135,6 @@ struct LinearizeVectorExtractStridedSlice final
VectorType dstType =
getTypeConverter()->convertType<VectorType>(extractOp.getType());
assert(dstType && "vector type destination expected.");
- if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
- return rewriter.notifyMatchFailure(extractOp,
- "scalable vectors are not supported.");
ArrayAttr offsets = extractOp.getOffsets();
ArrayAttr sizes = extractOp.getSizes();
@@ -426,18 +423,22 @@ struct LinearizeVectorBitCast final
} // namespace
-/// Return true if the operation `op` does not support scalable vectors and
-/// has at least 1 scalable vector result. These ops should all eventually
-/// support scalable vectors, and this function should be removed.
+/// Some operations currently will not be linearized if they have scalable
+/// vector results, although support should be added in the future. This
+/// function returns true if `op` is such an operation.
static bool isNotLinearizableBecauseScalable(Operation *op) {
bool unsupported =
isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
op);
+
+ // Case where linearization is possible even when there are scalable vector
+ // results.
if (!unsupported)
return false;
- // Check if any of the results is a scalable vector type.
+ // Check if any of the results is a scalable vector type, and if there are
+ // return true (not linearizable).
auto types = op->getResultTypes();
bool containsScalableResult =
std::any_of(types.begin(), types.end(), [](Type type) {
@@ -448,10 +449,16 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
return containsScalableResult;
}
+/// This method defines a set of operations that are not linearizable, and hence
+/// they are considered legal for the conversion target. These ops are
+/// currently,
+///
+/// 1) ones that are not in the vector dialect, are not ConstantLike, and are
+/// not Vectorizable, or
+///
+/// 2) have scalable vector results, for which support has not yet been added.
static bool isNotLinearizable(Operation *op) {
- // Only ops that are in the vector dialect, are ConstantLike, or
- // are Vectorizable might be linearized currently.
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
StringRef opDialect = op->getDialect()->getNamespace();
bool unsupported = (opDialect != vectorDialect) &&
@@ -460,6 +467,10 @@ static bool isNotLinearizable(Operation *op) {
if (unsupported)
return true;
+ // vector.shape_cast cannot be linearized.
+ if (isa<vector::ShapeCastOp>(op))
+ return true;
+
// Some ops currently don't support scalable vectors.
if (isNotLinearizableBecauseScalable(op))
return true;
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 06eaf58b225ae..1c7427cccddc8 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -413,3 +413,25 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
+
+// -----
+
+// DEFAULT-LABEL: test_linearize_across_for
+func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
+ %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ // DEFAULT: scf.for {{.*}} -> (vector<4xi8>)
+ %1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {
+
+ // DEFAULT: arith.addi {{.*}} : vector<4xi8>
+ %2 = arith.addi %arg1, %0 : vector<2x2xi8>
+
+ // DEFAULT: scf.yield {{.*}} : vector<4xi8>
+ scf.yield %2 : vector<2x2xi8>
+ }
+ %3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
+ return %3 : vector<4xi8>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..03f8a04a0ba7a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -836,8 +837,7 @@ struct TestVectorEmulateMaskedLoadStore final
}
};
-// TODO: move this code into the user project.
-namespace vendor {
+namespace bit_width_constrained_linearization {
/// Get the set of operand/result types to check for sufficiently
/// small inner-most dimension size.
@@ -960,7 +960,7 @@ struct TestVectorBitWidthLinearize final
}
};
-} // namespace vendor
+} // namespace bit_width_constrained_linearization
struct TestVectorLinearize final
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
@@ -987,6 +987,8 @@ struct TestVectorLinearize final
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
patterns);
+ mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+ converter, patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -1067,7 +1069,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorLinearize>();
- PassRegistration<vendor::TestVectorBitWidthLinearize>();
+ PassRegistration<
+ bit_width_constrained_linearization::TestVectorBitWidthLinearize>();
PassRegistration<TestEliminateVectorMasks>();
}
|
This PR adds some documentation to address comments in #136581 ( FYI @banach-space )
This PR adds a test for linearization across scf.for. This new test might be considered redundant by more experienced MLIRers, but might help newer users understand how to linearize scf/cf/func operations easily (thanks for pointing this out @Hardcode84).
The documentation added in this PR also tightens our definition of linearization, to now exclude unrolling (which creates multiple ops from 1 op). We hadn't really specified what linearization meant before. ( FYI @nbpatel )