Skip to content

[HLSL][RootSignature] Add mandatory parameters for RootConstants #138002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/inbelic/pr-137999
Choose a base branch
from

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Apr 30, 2025

  • defines the parseRootConstantParams function and adds handling for the mandatory arguments of num32BitConstants and bReg

  • adds corresponding unit tests

Part two of implementing #126576

- defines the `parseRootConstantParams` function and adds handling for
the mandatory arguments of `num32BitConstants` and `bReg`

- adds corresponding unit tests

Part two of implementing
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Apr 30, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2025

@llvm/pr-subscribers-clang

Author: Finn Plummer (inbelic)

Changes
  • defines the parseRootConstantParams function and adds handling for the mandatory arguments of num32BitConstants and bReg

  • adds corresponding unit tests

Part two of implementing #126576


Full diff: https://github.com/llvm/llvm-project/pull/138002.diff

4 Files Affected:

  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+8-2)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+65-3)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+12-2)
  • (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+4-1)
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index efa735ea03d94..0f05b05ed4df6 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -77,8 +77,14 @@ class RootSignatureParser {
   parseDescriptorTableClause();
 
   /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  /// order and only exactly once. `ParsedClauseParams` denotes the current
-  /// state of parsed params
+  /// order and only exactly once. The following methods define a
+  /// `Parsed.*Params` struct to denote the current state of parsed params
+  struct ParsedConstantParams {
+    std::optional<llvm::hlsl::rootsig::Register> Reg;
+    std::optional<uint32_t> Num32BitConstants;
+  };
+  std::optional<ParsedConstantParams> parseRootConstantParams();
+
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
     std::optional<uint32_t> NumDescriptors;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 48d3e38b0519d..2ce8e6e5cca98 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -57,6 +57,27 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
 
   RootConstants Constants;
 
+  auto Params = parseRootConstantParams();
+  if (!Params.has_value())
+    return std::nullopt;
+
+  // Check mandatory parameters were provided
+  if (!Params->Num32BitConstants.has_value()) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+        << TokenKind::kw_num32BitConstants;
+    return std::nullopt;
+  }
+
+  Constants.Num32BitConstants = Params->Num32BitConstants.value();
+
+  if (!Params->Reg.has_value()) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+        << TokenKind::bReg;
+    return std::nullopt;
+  }
+
+  Constants.Reg = Params->Reg.value();
+
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/TokenKind::kw_RootConstants))
@@ -187,14 +208,55 @@ RootSignatureParser::parseDescriptorTableClause() {
   return Clause;
 }
 
+// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
+// order and only exactly once. The following methods will parse through as
+// many arguments as possible reporting an error if a duplicate is seen.
+std::optional<RootSignatureParser::ParsedConstantParams>
+RootSignatureParser::parseRootConstantParams() {
+  assert(CurToken.TokKind == TokenKind::pu_l_paren &&
+         "Expects to only be invoked starting at given token");
+
+  ParsedConstantParams Params;
+  do {
+    // `num32BitConstants` `=` POS_INT
+    if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+      if (Params.Num32BitConstants.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+
+      if (consumeExpectedToken(TokenKind::pu_equal))
+        return std::nullopt;
+
+      auto Num32BitConstants = parseUIntParam();
+      if (!Num32BitConstants.has_value())
+        return std::nullopt;
+      Params.Num32BitConstants = Num32BitConstants;
+    }
+
+    // `b` POS_INT
+    if (tryConsumeExpectedToken(TokenKind::bReg)) {
+      if (Params.Reg.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+      auto Reg = parseRegister();
+      if (!Reg.has_value())
+        return std::nullopt;
+      Params.Reg = Reg;
+    }
+  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+  return Params;
+}
+
 std::optional<RootSignatureParser::ParsedClauseParams>
 RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
          "Expects to only be invoked starting at given token");
 
-  // Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  // order and only exactly once. Parse through as many arguments as possible
-  // reporting an error if a duplicate is seen.
   ParsedClauseParams Params;
   do {
     // ( `b` | `t` | `u` | `s`) POS_INT
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 0a7d8ac86cc5f..336868b579866 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -254,7 +254,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
 
 TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
   const llvm::StringLiteral Source = R"cc(
-    RootConstants()
+    RootConstants(num32BitConstants = 1, b0),
+    RootConstants(b42, num32BitConstants = 4294967295)
   )cc";
 
   TrivialModuleLoader ModLoader;
@@ -270,10 +271,19 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
 
   ASSERT_FALSE(Parser.parse());
 
-  ASSERT_EQ(Elements.size(), 1u);
+  ASSERT_EQ(Elements.size(), 2u);
 
   RootElement Elem = Elements[0];
   ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
+  ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 1u);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 0u);
+
+  Elem = Elements[1];
+  ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
+  ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 4294967295u);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 42u);
 
   ASSERT_TRUE(Consumer->isSatisfied());
 }
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 05735fa75b318..a3f98a9f1944f 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -55,7 +55,10 @@ struct Register {
 };
 
 // Models the parameter values of root constants
-struct RootConstants {};
+struct RootConstants {
+  uint32_t Num32BitConstants;
+  Register Reg;
+};
 
 // Models the end of a descriptor table and stores its visibility
 struct DescriptorTable {

@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2025

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

Changes
  • defines the parseRootConstantParams function and adds handling for the mandatory arguments of num32BitConstants and bReg

  • adds corresponding unit tests

Part two of implementing #126576


Full diff: https://github.com/llvm/llvm-project/pull/138002.diff

4 Files Affected:

  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+8-2)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+65-3)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+12-2)
  • (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+4-1)
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index efa735ea03d94..0f05b05ed4df6 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -77,8 +77,14 @@ class RootSignatureParser {
   parseDescriptorTableClause();
 
   /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  /// order and only exactly once. `ParsedClauseParams` denotes the current
-  /// state of parsed params
+  /// order and only exactly once. The following methods define a
+  /// `Parsed.*Params` struct to denote the current state of parsed params
+  struct ParsedConstantParams {
+    std::optional<llvm::hlsl::rootsig::Register> Reg;
+    std::optional<uint32_t> Num32BitConstants;
+  };
+  std::optional<ParsedConstantParams> parseRootConstantParams();
+
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
     std::optional<uint32_t> NumDescriptors;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 48d3e38b0519d..2ce8e6e5cca98 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -57,6 +57,27 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
 
   RootConstants Constants;
 
+  auto Params = parseRootConstantParams();
+  if (!Params.has_value())
+    return std::nullopt;
+
+  // Check mandatory parameters were provided
+  if (!Params->Num32BitConstants.has_value()) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+        << TokenKind::kw_num32BitConstants;
+    return std::nullopt;
+  }
+
+  Constants.Num32BitConstants = Params->Num32BitConstants.value();
+
+  if (!Params->Reg.has_value()) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+        << TokenKind::bReg;
+    return std::nullopt;
+  }
+
+  Constants.Reg = Params->Reg.value();
+
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/TokenKind::kw_RootConstants))
@@ -187,14 +208,55 @@ RootSignatureParser::parseDescriptorTableClause() {
   return Clause;
 }
 
+// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
+// order and only exactly once. The following methods will parse through as
+// many arguments as possible reporting an error if a duplicate is seen.
+std::optional<RootSignatureParser::ParsedConstantParams>
+RootSignatureParser::parseRootConstantParams() {
+  assert(CurToken.TokKind == TokenKind::pu_l_paren &&
+         "Expects to only be invoked starting at given token");
+
+  ParsedConstantParams Params;
+  do {
+    // `num32BitConstants` `=` POS_INT
+    if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+      if (Params.Num32BitConstants.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+
+      if (consumeExpectedToken(TokenKind::pu_equal))
+        return std::nullopt;
+
+      auto Num32BitConstants = parseUIntParam();
+      if (!Num32BitConstants.has_value())
+        return std::nullopt;
+      Params.Num32BitConstants = Num32BitConstants;
+    }
+
+    // `b` POS_INT
+    if (tryConsumeExpectedToken(TokenKind::bReg)) {
+      if (Params.Reg.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+      auto Reg = parseRegister();
+      if (!Reg.has_value())
+        return std::nullopt;
+      Params.Reg = Reg;
+    }
+  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+  return Params;
+}
+
 std::optional<RootSignatureParser::ParsedClauseParams>
 RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
          "Expects to only be invoked starting at given token");
 
-  // Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  // order and only exactly once. Parse through as many arguments as possible
-  // reporting an error if a duplicate is seen.
   ParsedClauseParams Params;
   do {
     // ( `b` | `t` | `u` | `s`) POS_INT
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 0a7d8ac86cc5f..336868b579866 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -254,7 +254,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
 
 TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
   const llvm::StringLiteral Source = R"cc(
-    RootConstants()
+    RootConstants(num32BitConstants = 1, b0),
+    RootConstants(b42, num32BitConstants = 4294967295)
   )cc";
 
   TrivialModuleLoader ModLoader;
@@ -270,10 +271,19 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
 
   ASSERT_FALSE(Parser.parse());
 
-  ASSERT_EQ(Elements.size(), 1u);
+  ASSERT_EQ(Elements.size(), 2u);
 
   RootElement Elem = Elements[0];
   ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
+  ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 1u);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 0u);
+
+  Elem = Elements[1];
+  ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
+  ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 4294967295u);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 42u);
 
   ASSERT_TRUE(Consumer->isSatisfied());
 }
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 05735fa75b318..a3f98a9f1944f 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -55,7 +55,10 @@ struct Register {
 };
 
 // Models the parameter values of root constants
-struct RootConstants {};
+struct RootConstants {
+  uint32_t Num32BitConstants;
+  Register Reg;
+};
 
 // Models the end of a descriptor table and stores its visibility
 struct DescriptorTable {

@inbelic inbelic force-pushed the inbelic/rs-mand-root-const branch from f801180 to 15857bf Compare April 30, 2025 18:10
Copy link
Contributor

@joaosaffran joaosaffran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add tests to verify error scenarios as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants