Skip to content

[HLSL][RootSignature] Add mandatory parameters for RootConstants #138002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/inbelic/pr-137999
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions clang/include/clang/Parse/ParseHLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ class RootSignatureParser {
parseDescriptorTableClause();

/// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
/// order and only exactly once. `ParsedClauseParams` denotes the current
/// state of parsed params
/// order and only exactly once. The following methods define a
/// `Parsed.*Params` struct to denote the current state of parsed params
struct ParsedConstantParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
std::optional<uint32_t> Num32BitConstants;
};
std::optional<ParsedConstantParams> parseRootConstantParams();

struct ParsedClauseParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
std::optional<uint32_t> NumDescriptors;
Expand Down
68 changes: 65 additions & 3 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {

RootConstants Constants;

auto Params = parseRootConstantParams();
if (!Params.has_value())
return std::nullopt;

// Check mandatory parameters were provided
if (!Params->Num32BitConstants.has_value()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
<< TokenKind::kw_num32BitConstants;
return std::nullopt;
}

Constants.Num32BitConstants = Params->Num32BitConstants.value();

if (!Params->Reg.has_value()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
<< TokenKind::bReg;
return std::nullopt;
}

Constants.Reg = Params->Reg.value();

if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/TokenKind::kw_RootConstants))
Expand Down Expand Up @@ -187,14 +208,55 @@ RootSignatureParser::parseDescriptorTableClause() {
return Clause;
}

// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
// order and only exactly once. The following methods will parse through as
// many arguments as possible reporting an error if a duplicate is seen.
std::optional<RootSignatureParser::ParsedConstantParams>
RootSignatureParser::parseRootConstantParams() {
assert(CurToken.TokKind == TokenKind::pu_l_paren &&
"Expects to only be invoked starting at given token");

ParsedConstantParams Params;
do {
// `num32BitConstants` `=` POS_INT
if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
if (Params.Num32BitConstants.has_value()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
<< CurToken.TokKind;
return std::nullopt;
}

if (consumeExpectedToken(TokenKind::pu_equal))
return std::nullopt;

auto Num32BitConstants = parseUIntParam();
if (!Num32BitConstants.has_value())
return std::nullopt;
Params.Num32BitConstants = Num32BitConstants;
}

// `b` POS_INT
if (tryConsumeExpectedToken(TokenKind::bReg)) {
if (Params.Reg.has_value()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
<< CurToken.TokKind;
return std::nullopt;
}
auto Reg = parseRegister();
if (!Reg.has_value())
return std::nullopt;
Params.Reg = Reg;
}
} while (tryConsumeExpectedToken(TokenKind::pu_comma));

return Params;
}

std::optional<RootSignatureParser::ParsedClauseParams>
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
assert(CurToken.TokKind == TokenKind::pu_l_paren &&
"Expects to only be invoked starting at given token");

// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
// order and only exactly once. Parse through as many arguments as possible
// reporting an error if a duplicate is seen.
ParsedClauseParams Params;
do {
// ( `b` | `t` | `u` | `s`) POS_INT
Expand Down
14 changes: 12 additions & 2 deletions clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {

TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
const llvm::StringLiteral Source = R"cc(
RootConstants()
RootConstants(num32BitConstants = 1, b0),
RootConstants(b42, num32BitConstants = 4294967295)
)cc";

TrivialModuleLoader ModLoader;
Expand All @@ -270,10 +271,19 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {

ASSERT_FALSE(Parser.parse());

ASSERT_EQ(Elements.size(), 1u);
ASSERT_EQ(Elements.size(), 2u);

RootElement Elem = Elements[0];
ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 1u);
ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 0u);

Elem = Elements[1];
ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 4294967295u);
ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 42u);

ASSERT_TRUE(Consumer->isSatisfied());
}
Expand Down
5 changes: 4 additions & 1 deletion llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ struct Register {
};

// Models the parameter values of root constants
struct RootConstants {};
struct RootConstants {
uint32_t Num32BitConstants;
Register Reg;
};

// Models the end of a descriptor table and stores its visibility
struct DescriptorTable {
Expand Down
Loading