Skip to content

[mlir][vector] Add use64bitIndex option for VectorToSPIRVPass #97061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

akroviakov
Copy link
Contributor

This PR adds support for use64bitIndex option when lowering vector to SPIRV (e.g, vector<...xindex> to vector<...xi64>), instead of the current default lowering to i32.

@llvmbot
Copy link
Member

llvmbot commented Jun 28, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Artem Kroviakov (akroviakov)

Changes

This PR adds support for use64bitIndex option when lowering vector to SPIRV (e.g, vector&lt;...xindex&gt; to vector&lt;...xi64&gt;), instead of the current default lowering to i32.


Full diff: https://github.com/llvm/llvm-project/pull/97061.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+5)
  • (modified) mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h (+3)
  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp (+3-1)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+43-13)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..44bc9b2e0a064 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1380,6 +1380,11 @@ def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
   let summary = "Convert Vector dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertVectorToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
+  let options = [
+    Option<"use64bitIndex", "use-64bit-index",
+           "bool", /*default=*/"false",
+           "Use 64-bit integers to convert index types">,
+  ];
 }
 
 #endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index f8c02c54066b8..0df1afe196010 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -18,6 +18,9 @@
 namespace mlir {
 class SPIRVTypeConverter;
 
+#define GEN_PASS_DECL_CONVERTVECTORTOSPIRV
+#include "mlir/Conversion/Passes.h.inc"
+
 /// Appends to a pattern list additional patterns for translating Vector Ops to
 /// SPIR-V ops.
 void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 1932de1be603b..c9f8db36b4efd 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -40,7 +40,9 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
-  SPIRVTypeConverter typeConverter(targetAttr);
+  SPIRVConversionOptions options;
+  options.use64bitIndex = this->use64bitIndex;
+  SPIRVTypeConverter typeConverter(targetAttr, options);
 
   // Use UnrealizedConversionCast as the bridge so that we don't need to pull in
   // patterns for other dialects.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..dd34aa5ae0b33 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1,4 +1,6 @@
 // RUN: mlir-opt -split-input-file -convert-vector-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-vector-to-spirv=use-64bit-index=false -verify-diagnostics %s -o - | FileCheck %s --check-prefix=INDEX32
+// RUN: mlir-opt -split-input-file -convert-vector-to-spirv=use-64bit-index=true -verify-diagnostics %s -o - | FileCheck %s --check-prefix=INDEX64
 
 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float16], []>, #spirv.resource_limits<>> } {
 
@@ -182,12 +184,26 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 }
 
 // -----
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+  // CHECK-LABEL: @insert_index_vector
+  // INDEX32-LABEL: @insert_index_vector
+  // INDEX64-LABEL: @insert_index_vector
+  // CHECK-SAME: %[[IN_VEC:.*]]: vector<4xindex>
+  // INDEX32-SAME: %[[IN_VEC:.*]]: vector<4xindex>
+  // INDEX64-SAME: %[[IN_VEC:.*]]: vector<4xindex>
+  func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
+    // CHECK: builtin.unrealized_conversion_cast %[[IN_VEC]] : vector<4xindex> to vector<4xi32> 
+    // INDEX32: builtin.unrealized_conversion_cast %[[IN_VEC]] : vector<4xindex> to vector<4xi32> 
+    // INDEX64: builtin.unrealized_conversion_cast %[[IN_VEC]] : vector<4xindex> to vector<4xi64> 
 
-// CHECK-LABEL: @insert_index_vector
-//       CHECK:   spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
-func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
-  %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
-  return %1: vector<4xindex>
+    // CHECK:   spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+    // INDEX32: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+    // INDEX64: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i64 into vector<4xi64>
+    %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
+    return %1: vector<4xindex>
+  }
 }
 
 // -----
@@ -411,14 +427,28 @@ func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
 
 // -----
 
-// CHECK-LABEL:  func @shuffle_index_vector
-//  CHECK-SAME:  %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
-//   CHECK-DAG:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
-//   CHECK-DAG:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
-//       CHECK:    spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
-func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> {
-  %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex>
-  return %shuffle : vector<4xindex>
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+  // CHECK-LABEL:  func @shuffle_index_vector
+  // INDEX32-LABEL:  func @shuffle_index_vector
+  // INDEX64-LABEL:  func @shuffle_index_vector
+  //  CHECK-SAME:  %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+  //  INDEX32-SAME:  %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+  //  INDEX64-SAME:  %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+  //   CHECK-DAG:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  //   CHECK-DAG:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+  //   INDEX32-DAG:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  //   INDEX32-DAG:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+  //   INDEX64-DAG:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  //   INDEX64-DAG:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+  func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> {    
+    //  CHECK: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
+    //  INDEX32: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
+    //  INDEX64: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i64, i64, i64, i64) -> vector<4xi64>
+    %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex>
+    return %shuffle : vector<4xindex>
+  }
 }
 
 // -----

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants