Skip to content

[mlir][vector] Support index type in ND to 1D vector linearization #118404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ayzhuang
Copy link
Contributor

@ayzhuang ayzhuang commented Dec 2, 2024

Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns.

Currently index type is not supported because getElementTypeBitWidth
aborts for index type. This patch adds indexBitWidth input to
the vector linearization patterns.
@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2024

@llvm/pr-subscribers-mlir

Author: Amy Zhuang (ayzhuang)

Changes

Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns.


Full diff: https://github.com/llvm/llvm-project/pull/118404.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+2-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+59-26)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+12-3)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6-2)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..e3c19a078c18b0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
 /// the ops to get converted properly.
 void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+    ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
 
 /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
 /// vector shuffle operations.
 void populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+    ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..f0bf6276f0e659 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -25,34 +25,44 @@
 
 using namespace mlir;
 
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth,
+                                     unsigned targetBitWidth) {
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
     VectorType vecType = dyn_cast<VectorType>(resType);
-    // Reject index since getElementTypeBitWidth will abort for Index types.
-    if (!vecType || vecType.getElementType().isIndex())
+    if (!vecType)
+      return false;
+    bool isIndexTy = vecType.getElementType().isIndex();
+    // Reject index if `indexBitWidth` is not supplied.
+    if (isIndexTy && indexBitWidth == 0)
       return false;
     // There are no dimension to fold if it is a 0-D vector.
     if (vecType.getRank() == 0)
       return false;
     unsigned trailingVecDimBitWidth =
-        vecType.getShape().back() * vecType.getElementTypeBitWidth();
+        vecType.getShape().back() *
+        (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
     if (trailingVecDimBitWidth >= targetBitWidth)
       return false;
   }
   return true;
 }
 
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
+static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth,
+                                            unsigned targetBitWidth) {
   VectorType vecType = dyn_cast<VectorType>(t);
-  // Reject index since getElementTypeBitWidth will abort for Index types.
-  if (!vecType || vecType.getElementType().isIndex())
+  if (!vecType)
+    return false;
+  bool isIndexTy = vecType.getElementType().isIndex();
+  // Reject index if `indexBitWidth` is not supplied.
+  if (isIndexTy && indexBitWidth == 0)
     return false;
   // There are no dimension to fold if it is a 0-D vector.
   if (vecType.getRank() == 0)
     return false;
   unsigned trailingVecDimBitWidth =
-      vecType.getShape().back() * vecType.getElementTypeBitWidth();
+      vecType.getShape().back() *
+      (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
   return trailingVecDimBitWidth <= targetBitWidth;
 }
 
@@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern::OpConversionPattern;
   LinearizeConstant(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
 
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
-    if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           loc, "Can't flatten since targetBitWidth <= OpSize");
     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
@@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -103,14 +116,16 @@ struct LinearizeVectorizable final
 public:
   LinearizeVectorizable(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
@@ -123,6 +138,7 @@ struct LinearizeVectorizable final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtractStridedSlice(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
 
   LogicalResult
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+                                  targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
@@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorShuffle(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
 
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
              shuffleOp.getV2VectorType().isScalable() ||
              dstType.isScalable()) &&
            "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+                                  targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
 
@@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtract(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
         cast<VectorType>(dstTy).isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+                                  targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
@@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorInsert(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
                                          "scalable vectors are not supported.");
 
     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
-                                         targetVectorBitWidth))
+                                         indexBitWidth, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           insertOp, "Can't flatten since targetBitWidth < OpSize");
 
@@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth) {
+    ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {
 
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
@@ -488,7 +519,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       [=](Operation *op) -> std::optional<bool> {
         if ((isa<arith::ConstantOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
-          return (isLessThanTargetBitWidth(op, targetBitWidth)
+          return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth)
                       ? typeConverter.isLegal(op)
                       : true);
         }
@@ -496,15 +527,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       });
 
   patterns.add<LinearizeConstant, LinearizeVectorizable>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+      typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned int targetBitWidth) {
+    ConversionTarget &target, unsigned indexBitWidth,
+    unsigned int targetBitWidth) {
   target.addDynamicallyLegalOp<vector::ShuffleOp>(
       [=](vector::ShuffleOp shuffleOp) -> bool {
-        return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+        return isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+                                        targetBitWidth)
                    ? (typeConverter.isLegal(shuffleOp) &&
                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
                               .getRank() == 1)
@@ -512,5 +545,5 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
       });
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+      typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..fe169d3e16d683 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128  -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64
 
 // ALL-LABEL: test_linearize
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
 
   // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+  // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
   %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
 
   // DEFAULT: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
@@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
 
   // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+  // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
   %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
 
   // DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32>
@@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
 
 // -----
 
-// ALL-LABEL: test_index_no_linearize
-func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
-    // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+// ALL-LABEL: test_index_linearize
+func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
+    // DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex>
     %0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
     return %0 : vector<2x2xindex>
 }
@@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
 
   // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+  // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32>
   // ALL: return %[[RES]] : vector<2x[2]xf32>
   return %2 : vector<2x[2]xf32>
 }
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f67a24755ac09a..2589782aee1449 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -853,6 +853,10 @@ struct TestVectorLinearize final
     registry.insert<vector::VectorDialect>();
   }
 
+  Option<unsigned> indexBitwidth{*this, "index-bitwidth",
+                                 llvm::cl::desc("Bitwidth of the index type"),
+                                 llvm::cl::init(0)};
+
   Option<unsigned> targetVectorBitwidth{
       *this, "target-vector-bitwidth",
       llvm::cl::desc(
@@ -866,9 +870,9 @@ struct TestVectorLinearize final
     ConversionTarget target(*context);
 
     vector::populateVectorLinearizeTypeConversionsAndLegality(
-        typeConverter, patterns, target, targetVectorBitwidth);
+        typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
     vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-        typeConverter, patterns, target, targetVectorBitwidth);
+        typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();

@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2024

@llvm/pr-subscribers-mlir-vector

Author: Amy Zhuang (ayzhuang)

Changes

Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns.


Full diff: https://github.com/llvm/llvm-project/pull/118404.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+2-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+59-26)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+12-3)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6-2)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..e3c19a078c18b0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
 /// the ops to get converted properly.
 void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+    ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
 
 /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
 /// vector shuffle operations.
 void populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+    ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..f0bf6276f0e659 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -25,34 +25,44 @@
 
 using namespace mlir;
 
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth,
+                                     unsigned targetBitWidth) {
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
     VectorType vecType = dyn_cast<VectorType>(resType);
-    // Reject index since getElementTypeBitWidth will abort for Index types.
-    if (!vecType || vecType.getElementType().isIndex())
+    if (!vecType)
+      return false;
+    bool isIndexTy = vecType.getElementType().isIndex();
+    // Reject index if `indexBitWidth` is not supplied.
+    if (isIndexTy && indexBitWidth == 0)
       return false;
     // There are no dimension to fold if it is a 0-D vector.
     if (vecType.getRank() == 0)
       return false;
     unsigned trailingVecDimBitWidth =
-        vecType.getShape().back() * vecType.getElementTypeBitWidth();
+        vecType.getShape().back() *
+        (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
     if (trailingVecDimBitWidth >= targetBitWidth)
       return false;
   }
   return true;
 }
 
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
+static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth,
+                                            unsigned targetBitWidth) {
   VectorType vecType = dyn_cast<VectorType>(t);
-  // Reject index since getElementTypeBitWidth will abort for Index types.
-  if (!vecType || vecType.getElementType().isIndex())
+  if (!vecType)
+    return false;
+  bool isIndexTy = vecType.getElementType().isIndex();
+  // Reject index if `indexBitWidth` is not supplied.
+  if (isIndexTy && indexBitWidth == 0)
     return false;
   // There are no dimension to fold if it is a 0-D vector.
   if (vecType.getRank() == 0)
     return false;
   unsigned trailingVecDimBitWidth =
-      vecType.getShape().back() * vecType.getElementTypeBitWidth();
+      vecType.getShape().back() *
+      (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
   return trailingVecDimBitWidth <= targetBitWidth;
 }
 
@@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern::OpConversionPattern;
   LinearizeConstant(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
 
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
-    if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           loc, "Can't flatten since targetBitWidth <= OpSize");
     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
@@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -103,14 +116,16 @@ struct LinearizeVectorizable final
 public:
   LinearizeVectorizable(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
@@ -123,6 +138,7 @@ struct LinearizeVectorizable final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtractStridedSlice(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
 
   LogicalResult
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+                                  targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
@@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorShuffle(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
 
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
              shuffleOp.getV2VectorType().isScalable() ||
              dstType.isScalable()) &&
            "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+                                  targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
 
@@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtract(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
         cast<VectorType>(dstTy).isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+                                  targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
@@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorInsert(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
                                          "scalable vectors are not supported.");
 
     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
-                                         targetVectorBitWidth))
+                                         indexBitWidth, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           insertOp, "Can't flatten since targetBitWidth < OpSize");
 
@@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth) {
+    ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {
 
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
@@ -488,7 +519,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       [=](Operation *op) -> std::optional<bool> {
         if ((isa<arith::ConstantOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
-          return (isLessThanTargetBitWidth(op, targetBitWidth)
+          return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth)
                       ? typeConverter.isLegal(op)
                       : true);
         }
@@ -496,15 +527,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       });
 
   patterns.add<LinearizeConstant, LinearizeVectorizable>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+      typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned int targetBitWidth) {
+    ConversionTarget &target, unsigned indexBitWidth,
+    unsigned int targetBitWidth) {
   target.addDynamicallyLegalOp<vector::ShuffleOp>(
       [=](vector::ShuffleOp shuffleOp) -> bool {
-        return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+        return isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+                                        targetBitWidth)
                    ? (typeConverter.isLegal(shuffleOp) &&
                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
                               .getRank() == 1)
@@ -512,5 +545,5 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
       });
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+      typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..fe169d3e16d683 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128  -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64
 
 // ALL-LABEL: test_linearize
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
 
   // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+  // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
   %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
 
   // DEFAULT: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
@@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
 
   // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+  // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
   %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
 
   // DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32>
@@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
 
 // -----
 
-// ALL-LABEL: test_index_no_linearize
-func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
-    // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+// ALL-LABEL: test_index_linearize
+func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
+    // DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex>
     %0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
     return %0 : vector<2x2xindex>
 }
@@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
 
   // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+  // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32>
   // ALL: return %[[RES]] : vector<2x[2]xf32>
   return %2 : vector<2x[2]xf32>
 }
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f67a24755ac09a..2589782aee1449 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -853,6 +853,10 @@ struct TestVectorLinearize final
     registry.insert<vector::VectorDialect>();
   }
 
+  Option<unsigned> indexBitwidth{*this, "index-bitwidth",
+                                 llvm::cl::desc("Bitwidth of the index type"),
+                                 llvm::cl::init(0)};
+
   Option<unsigned> targetVectorBitwidth{
       *this, "target-vector-bitwidth",
       llvm::cl::desc(
@@ -866,9 +870,9 @@ struct TestVectorLinearize final
     ConversionTarget target(*context);
 
     vector::populateVectorLinearizeTypeConversionsAndLegality(
-        typeConverter, patterns, target, targetVectorBitwidth);
+        typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
     vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-        typeConverter, patterns, target, targetVectorBitwidth);
+        typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();

@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth);
ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is indexBitWidth and targetBitWidth different here? Aren't they representing the same thing?

Copy link
Contributor Author

@ayzhuang ayzhuang Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcaballe Currently index type is not supported because we can't use getElementTypeBitWidth to get the bit width of index type. I add indexBitWidth argument to supply the bit width of index type. When it has non zero value and targetBitWidth is big enough, we can linearize vector of indices. Example: %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> to %0 = arith.addi %arg0, %arg1 : vector<4xindex>.

@ayzhuang
Copy link
Contributor Author

ayzhuang commented Jan 6, 2025

@dcaballe Could you please review this PR? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants