Skip to content

[mlir][spirv] Add support for Constant Matrices #123334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

mishaobu
Copy link
Contributor

No description provided.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: None (mishaobu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/123334.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+43-12)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+3)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+16)
  • (modified) mlir/test/Dialect/SPIRV/IR/composite-ops.mlir (+29)
  • (modified) mlir/test/Dialect/SPIRV/IR/structure-ops.mlir (+27-1)
  • (modified) mlir/test/Target/SPIRV/composite-op.mlir (+5)
  • (modified) mlir/test/Target/SPIRV/constant.mlir (+11)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..ee7c7860b05c4e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -579,7 +579,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
 
 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
   printer << ' ' << getValue();
-  if (llvm::isa<spirv::ArrayType>(getType()))
+  if (llvm::isa<spirv::ArrayType, spirv::MatrixType>(getType()))
     printer << " : " << getType();
 }
 
@@ -626,18 +626,49 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
     }
     return success();
   }
-  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
-    auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
-    if (!arrayType)
-      return op.emitOpError(
-          "must have spirv.array result type for array value");
-    Type elemType = arrayType.getElementType();
-    for (Attribute element : arrayAttr.getValue()) {
-      // Verify array elements recursively.
-      if (failed(verifyConstantType(op, element, elemType)))
-        return failure();
+  if (auto arrayAttr = mlir::dyn_cast<ArrayAttr>(value)) {
+    // Case for Matrix result type
+    if (auto matrixType = mlir::dyn_cast<spirv::MatrixType>(opType)) {
+      unsigned numColumns = matrixType.getNumColumns();
+      unsigned numRows    = matrixType.getNumRows();
+      if (arrayAttr.size() != numColumns)
+        return op.emitOpError("expected ")
+              << numColumns << " columns in matrix constant, but got "
+              << arrayAttr.size();
+
+      Type elementTy = matrixType.getElementType();
+      for (auto [colIndex, colAttr] : llvm::enumerate(arrayAttr)) {
+        // Ensure each column is a dense array of the right shape/type
+        auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(colAttr);
+        if (!denseAttr)
+          return op.emitOpError("matrix column #")
+                << colIndex << " must be a DenseElementsAttr";
+
+        auto shapedTy = mlir::dyn_cast<ShapedType>(denseAttr.getType());
+        if (!shapedTy || shapedTy.getNumElements() != numRows)
+          return op.emitOpError("matrix column #")
+                << colIndex << " has incorrect size: expected "
+                << numRows << " elements";
+
+        if (shapedTy.getElementType() != elementTy)
+          return op.emitOpError("matrix column #")
+                << colIndex << " has incorrect element type: expected "
+                << elementTy << ", got " << shapedTy.getElementType();
+      }
+      return success();
     }
-    return success();
+    // Case for Array result type
+    if (auto arrayType = mlir::dyn_cast<spirv::ArrayType>(opType)) {
+      Type elemType = arrayType.getElementType();
+      for (Attribute element : arrayAttr.getValue()) {
+        // Verify array elements recursively.
+        if (failed(verifyConstantType(op, element, elemType)))
+          return failure();
+      }
+      return success();
+    }
+    return op.emitOpError(
+        "must have spirv.array or spirv.matrix result type for array value");
   }
   return op.emitOpError("cannot have attribute: ") << value;
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 04469f1933819b..ecc822e553aefc 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1442,6 +1442,9 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
     auto attr = opBuilder.getArrayAttr(elements);
     constantMap.try_emplace(resultID, attr, resultType);
+  } else if (auto matrixType = dyn_cast<spirv::MatrixType>(resultType)) {
+    auto attr = opBuilder.getArrayAttr(elements);
+    constantMap.try_emplace(resultID, attr, resultType);
   } else {
     return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
            << resultType;
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1f4f5d7f764db3..b5e3cd381ef822 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -782,6 +782,22 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
     SmallVector<uint64_t, 4> index(rank);
     resultID = prepareDenseElementsConstant(loc, constType, attr,
                                             /*dim=*/0, index);
+  } else if (isa<spirv::MatrixType>(constType)) {
+    if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+      resultID = getNextID();
+      SmallVector<uint32_t, 4> operands = {typeID, resultID};
+      operands.reserve(arrayAttr.size() + 2);
+      for (Attribute elementAttr : arrayAttr) {
+        if (auto elementID = prepareConstant(loc, 
+            cast<spirv::MatrixType>(constType).getColumnType(), elementAttr)) {
+          operands.push_back(elementID);
+        } else {
+          return 0;
+        }
+      }
+      spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
+      encodeInstructionInto(typesGlobalValues, opcode, operands);
+    }
   } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
     resultID = prepareArrayConstant(loc, constType, arrayAttr);
   }
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 3fc8dfb2767d1e..5c835d2e08de91 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve
   return %0: vector<3xf32>
 }
 
+// CHECK-LABEL: func @composite_construct_matrix
+func.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
+  // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+  %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+  return %0: !spirv.matrix<3 x vector<3xf32>>
+}
+
 // CHECK-LABEL: func @composite_construct_struct
 func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
   // CHECK: spirv.CompositeConstruct
@@ -89,9 +96,31 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
   %0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32>
   return %0: vector<4xf32>
 }
+// -----
+
+func.func @composite_construct_matrix_wrong_column_count(%v1: vector<3xf32>, %v2: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
+  // expected-error @+1 {{'spirv.CompositeConstruct' op expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
+  %0 = spirv.CompositeConstruct %v1, %v2 : (vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+  return %0: !spirv.matrix<3 x vector<3xf32>>
+}
+
+// -----
+
+func.func @composite_construct_matrix_wrong_row_count(%v1: vector<4xf32>, %v2: vector<4xf32>, %v3: vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
+  // expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<4xf32>'}}
+  %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+  return %0: !spirv.matrix<3 x vector<3xf32>>
+}
 
 // -----
 
+func.func @composite_construct_matrix_wrong_element_type(%v1: vector<3xi32>, %v2: vector<3xi32>, %v3: vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>> {
+  // expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<3xi32>'}}
+  %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xi32>, vector<3xi32>, vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>>
+  return %0: !spirv.matrix<3 x vector<3xf32>>
+}
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.CompositeExtractOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 5e98b9fdb3c546..6003d2a3576b12 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -62,6 +62,7 @@ func.func @const() -> () {
   // CHECK: spirv.Constant dense<1.000000e+00> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
   // CHECK: spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
   // CHECK: spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+  // CHECK: spirv.Constant [dense<1.000000e+00> : vector<3xf32>, dense<2.000000e+00> : vector<3xf32>, dense<3.000000e+00> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
 
   %0 = spirv.Constant true
   %1 = spirv.Constant 42 : i32
@@ -73,6 +74,7 @@ func.func @const() -> () {
   %7 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
   %8 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
   %9 = spirv.Constant [[dense<3.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1xvector<2xf32>>>
+  %10 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>, dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
   return
 }
 
@@ -95,7 +97,7 @@ func.func @array_constant() -> () {
 // -----
 
 func.func @array_constant() -> () {
-  // expected-error @+1 {{must have spirv.array result type for array value}}
+  // expected-error @+1 {{'spirv.Constant' op must have spirv.array or spirv.matrix result type for array value}}
   %0 = spirv.Constant [dense<3.0> : vector<2xf32>] : !spirv.rtarray<vector<2xf32>>
   return
 }
@@ -132,6 +134,30 @@ func.func @value_result_num_elements_mismatch() -> () {
 
 // -----
 
+func.func @matrix_constant() -> () {
+  // CHECK: spirv.Constant [dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<3xf32>, dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : vector<3xf32>, dense<[7.000000e+00, 8.000000e+00, 9.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+  %0 = spirv.Constant [dense<[1.0, 2.0, 3.0]> : vector<3xf32>, dense<[4.0, 5.0, 6.0]> : vector<3xf32>, dense<[7.0, 8.0, 9.0]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+  return
+}
+
+// -----
+
+func.func @matrix_constant_wrong_column_count() -> () {
+  // expected-error @+1 {{expected 3 columns in matrix constant, but got 2}}
+  %0 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+  return
+}
+
+// -----
+
+func.func @matrix_constant_non_dense_column() -> () {
+  // expected-error @+1 {{matrix column #1 must be a DenseElementsAttr}}
+  %0 = spirv.Constant [dense<1.0> : vector<3xf32>, "wrong", dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.EntryPoint
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir
index 5f302fd0d38f8b..bafdb3340d0e79 100644
--- a/mlir/test/Target/SPIRV/composite-op.mlir
+++ b/mlir/test/Target/SPIRV/composite-op.mlir
@@ -11,6 +11,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
     spirv.ReturnValue %0: vector<3xf32>
   }
+  spirv.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> "None" {
+    // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+    %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+    spirv.ReturnValue %0: !spirv.matrix<3 x vector<3xf32>>
+  }
   spirv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" {
     // CHECK: spirv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32
     %0 = spirv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index f3950214a7f055..0fa70c7e5cdbb3 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -198,6 +198,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.Return
   }
 
+  // CHECK-LABEL: @matrix_const
+  spirv.func @matrix_const() -> () "None" {
+    // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<!spirv.matrix<3 x vector<3xf32>>, Function>
+    %0 = spirv.Variable : !spirv.ptr<!spirv.matrix<3 x vector<3xf32>>, Function>
+    // CHECK: %[[CST:.*]] = spirv.Constant [dense<[1.000000e+00, 0.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 1.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 0.000000e+00, 1.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+    %1 = spirv.Constant [dense<[1., 0., 0.]> : vector<3xf32>, dense<[0., 1., 0.]> : vector<3xf32>, dense<[0., 0., 1.]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+    // CHECK: spirv.Store "Function" %[[VAR]], %[[CST]] : !spirv.matrix<3 x vector<3xf32>>
+    spirv.Store "Function" %0, %1 : !spirv.matrix<3 x vector<3xf32>>
+    spirv.Return
+  }
+
   // CHECK-LABEL: @ui64_array_const
   spirv.func @ui64_array_const() -> (!spirv.array<3xui64>) "None" {
     // CHECK: spirv.Constant [5, 6, 7] : !spirv.array<3 x i64>

@mishaobu mishaobu marked this pull request as draft January 18, 2025 20:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants