Skip to content

[DirectX] Implement the DXILCBufferAccess pass #134571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2025

Conversation

bogner
Copy link
Contributor

@bogner bogner commented Apr 7, 2025

This introduces a pass that walks accesses to globals in cbuffers and replaces them with accesses via the cbuffer handle itself. The logic to interpret the cbuffer metadata is kept in lib/Frontend/HLSL so that it can be reused by other consumers of that metadata.

Fixes #124630.

@llvmbot
Copy link
Member

llvmbot commented Apr 7, 2025

@llvm/pr-subscribers-backend-directx

Author: Justin Bogner (bogner)

Changes

This introduces a pass that walks accesses to globals in cbuffers and replaces them with accesses via the cbuffer handle itself. The logic to interpret the cbuffer metadata is kept in lib/Frontend/HLSL so that it can be reused by other consumers of that metadata.

Fixes #124630.


Patch is 41.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134571.diff

15 Files Affected:

  • (added) llvm/include/llvm/Frontend/HLSL/CBuffer.h (+64)
  • (added) llvm/lib/Frontend/HLSL/CBuffer.cpp (+71)
  • (modified) llvm/lib/Frontend/HLSL/CMakeLists.txt (+1)
  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILCBufferAccess.cpp (+209)
  • (added) llvm/lib/Target/DirectX/DXILCBufferAccess.h (+28)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXPassRegistry.def (+1)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+3)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/arrays.ll (+121)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/float.ll (+22)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/gep-ce-two-uses.ll (+32)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/scalars.ll (+105)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/vectors.ll (+116)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+1)
diff --git a/llvm/include/llvm/Frontend/HLSL/CBuffer.h b/llvm/include/llvm/Frontend/HLSL/CBuffer.h
new file mode 100644
index 0000000000000..cf45b5ff5a8e2
--- /dev/null
+++ b/llvm/include/llvm/Frontend/HLSL/CBuffer.h
@@ -0,0 +1,64 @@
+//===- CBuffer.h - HLSL constant buffer handling ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains utilities to work with constant buffers in HLSL.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_FRONTEND_HLSL_CBUFFER_H
+#define LLVM_FRONTEND_HLSL_CBUFFER_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include <optional>
+
+namespace llvm {
+class Module;
+class GlobalVariable;
+class NamedMDNode;
+
+namespace hlsl {
+
+struct CBufferMember {
+  CBufferMember(GlobalVariable *GV, size_t Offset) : GV(GV), Offset(Offset) {}
+
+  GlobalVariable *GV;
+  size_t Offset;
+};
+
+struct CBufferMapping {
+  CBufferMapping(GlobalVariable *Handle) : Handle(Handle) {}
+
+  GlobalVariable *Handle;
+  SmallVector<CBufferMember> Members;
+};
+
+class CBufferMetadata {
+  NamedMDNode *MD;
+  SmallVector<CBufferMapping> Mappings;
+
+  CBufferMetadata(NamedMDNode *MD) : MD(MD) {}
+
+public:
+  static std::optional<CBufferMetadata> get(Module &M);
+
+  using iterator = SmallVector<CBufferMapping>::iterator;
+  iterator begin() { return Mappings.begin(); }
+  iterator end() { return Mappings.end(); }
+
+  void eraseFromModule();
+};
+
+APInt translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
+                               ArrayType *Ty);
+
+} // namespace hlsl
+} // namespace llvm
+
+#endif // LLVM_FRONTEND_HLSL_CBUFFER_H
diff --git a/llvm/lib/Frontend/HLSL/CBuffer.cpp b/llvm/lib/Frontend/HLSL/CBuffer.cpp
new file mode 100644
index 0000000000000..b311f6aea9636
--- /dev/null
+++ b/llvm/lib/Frontend/HLSL/CBuffer.cpp
@@ -0,0 +1,71 @@
+//===- CBuffer.cpp - HLSL constant buffer handling ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Frontend/HLSL/CBuffer.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+
+using namespace llvm;
+using namespace llvm::hlsl;
+
+static size_t getMemberOffset(GlobalVariable *Handle, size_t Index) {
+  auto *HandleTy = cast<TargetExtType>(Handle->getValueType());
+  assert(HandleTy->getName().ends_with(".CBuffer") && "Not a cbuffer type");
+  assert(HandleTy->getNumTypeParameters() == 1 && "Expected layout type");
+
+  auto *LayoutTy = cast<TargetExtType>(HandleTy->getTypeParameter(0));
+  assert(LayoutTy->getName().ends_with(".Layout") && "Not a layout type");
+
+  // Skip the "size" parameter.
+  size_t ParamIndex = Index + 1;
+  assert(LayoutTy->getNumIntParameters() > ParamIndex &&
+         "Not enough parameters");
+
+  return LayoutTy->getIntParameter(ParamIndex);
+}
+
+std::optional<CBufferMetadata> CBufferMetadata::get(Module &M) {
+  NamedMDNode *CBufMD = M.getNamedMetadata("hlsl.cbs");
+  if (!CBufMD)
+    return std::nullopt;
+
+  std::optional<CBufferMetadata> Result({CBufMD});
+
+  for (const MDNode *MD : CBufMD->operands()) {
+    assert(MD->getNumOperands() && "Invalid cbuffer metadata");
+
+    auto *Handle = cast<GlobalVariable>(
+        cast<ValueAsMetadata>(MD->getOperand(0))->getValue());
+    CBufferMapping &Mapping = Result->Mappings.emplace_back(Handle);
+
+    for (int I = 1, E = MD->getNumOperands(); I < E; ++I) {
+      Metadata *OpMD = MD->getOperand(I);
+      // Some members may be null if they've been optimized out.
+      if (!OpMD)
+        continue;
+      auto *V = cast<GlobalVariable>(cast<ValueAsMetadata>(OpMD)->getValue());
+      Mapping.Members.emplace_back(V, getMemberOffset(Handle, I - 1));
+    }
+  }
+
+  return Result;
+}
+
+
+void CBufferMetadata::eraseFromModule() {
+  // Remove the cbs named metadata
+  MD->eraseFromParent();
+}
+
+APInt hlsl::translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
+                                     ArrayType *Ty) {
+  int64_t TypeSize = DL.getTypeSizeInBits(Ty->getElementType()) / 8;
+  int64_t RoundUp = alignTo(TypeSize, Align(16));
+  return Offset.udiv(TypeSize) * RoundUp;
+}
diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt
index eda6cb8e69a49..07a0c845ceef6 100644
--- a/llvm/lib/Frontend/HLSL/CMakeLists.txt
+++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_llvm_component_library(LLVMFrontendHLSL
+  CBuffer.cpp
   HLSLResource.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 13f8adbe4f132..c55028bc75dd6 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -20,6 +20,7 @@ add_llvm_target(DirectXCodeGen
   DirectXTargetMachine.cpp
   DirectXTargetTransformInfo.cpp
   DXContainerGlobals.cpp
+  DXILCBufferAccess.cpp
   DXILDataScalarization.cpp
   DXILFinalizeLinkage.cpp
   DXILFlattenArrays.cpp
diff --git a/llvm/lib/Target/DirectX/DXILCBufferAccess.cpp b/llvm/lib/Target/DirectX/DXILCBufferAccess.cpp
new file mode 100644
index 0000000000000..f8771efeac991
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILCBufferAccess.cpp
@@ -0,0 +1,209 @@
+//===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILCBufferAccess.h"
+#include "DirectX.h"
+#include "llvm/Frontend/HLSL/CBuffer.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/Local.h"
+
+#define DEBUG_TYPE "dxil-cbuffer-access"
+using namespace llvm;
+
+namespace {
+/// Helper for building a `load.cbufferrow` intrinsic given a simple type.
+struct CBufferRowIntrin {
+  Intrinsic::ID IID;
+  Type *RetTy;
+  unsigned int EltSize;
+  unsigned int NumElts;
+
+  CBufferRowIntrin(const DataLayout &DL, Type *Ty) {
+    assert(Ty == Ty->getScalarType() && "Expected scalar type");
+
+    switch (DL.getTypeSizeInBits(Ty)) {
+    case 16:
+      IID = Intrinsic::dx_resource_load_cbufferrow_8;
+      RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty);
+      EltSize = 2;
+      NumElts = 8;
+      break;
+    case 32:
+      IID = Intrinsic::dx_resource_load_cbufferrow_4;
+      RetTy = StructType::get(Ty, Ty, Ty, Ty);
+      EltSize = 4;
+      NumElts = 4;
+      break;
+    case 64:
+      IID = Intrinsic::dx_resource_load_cbufferrow_2;
+      RetTy = StructType::get(Ty, Ty);
+      EltSize = 8;
+      NumElts = 2;
+      break;
+    default:
+      llvm_unreachable("Only 16, 32, and 64 bit types supported");
+  }
+  }
+};
+} // namespace
+
+static size_t getOffsetForCBufferGEP(GEPOperator *GEP, GlobalVariable *Global,
+                                     const DataLayout &DL) {
+  // Since we should always have a constant offset, we should only ever have a
+  // single GEP of indirection from the Global.
+  assert(GEP->getPointerOperand() == Global &&
+         "Indirect access to resource handle");
+
+  APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
+  bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
+  (void)Success;
+  assert(Success && "Offsets into cbuffer globals must be constant");
+
+  if (auto *ATy = dyn_cast<ArrayType>(Global->getValueType()))
+    ConstantOffset = hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
+
+  return ConstantOffset.getZExtValue();
+}
+
+/// Replace access via cbuffer global with a load from the cbuffer handle
+/// itself.
+static void replaceAccess(LoadInst *LI, GlobalVariable *Global,
+                          GlobalVariable *HandleGV, size_t BaseOffset,
+                          SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
+  const DataLayout &DL = HandleGV->getDataLayout();
+
+  size_t Offset = BaseOffset;
+  if (auto *GEP = dyn_cast<GEPOperator>(LI->getPointerOperand()))
+    Offset += getOffsetForCBufferGEP(GEP, Global, DL);
+  else if (LI->getPointerOperand() != Global)
+    llvm_unreachable("Load instruction doesn't reference cbuffer global");
+
+  IRBuilder<> Builder(LI);
+  auto *Handle = Builder.CreateLoad(HandleGV->getValueType(), HandleGV,
+                                    HandleGV->getName());
+
+  Type *Ty = LI->getType();
+  CBufferRowIntrin Intrin(DL, Ty->getScalarType());
+  // The cbuffer consists of some number of 16-byte rows.
+  unsigned int CurrentRow = Offset / 16;
+  unsigned int CurrentIndex = (Offset % 16) / Intrin.EltSize;
+
+  auto *CBufLoad = Builder.CreateIntrinsic(
+      Intrin.RetTy, Intrin.IID,
+      {Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
+      LI->getName());
+  auto *Elt =
+      Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, LI->getName());
+
+  Value *Result = nullptr;
+  unsigned int Remaining =
+      ((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
+  if (Remaining == 0) {
+    // We only have a single element, so we're done.
+    Result = Elt;
+
+    // However, if we loaded a <1 x T>, then we need to adjust the type here.
+    if (auto *VT = dyn_cast<FixedVectorType>(LI->getType()))
+      if (VT->getNumElements() == 1)
+        Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
+                                             Builder.getInt32(0));
+  } else {
+    // Walk each element and extract it, wrapping to new rows as needed.
+    SmallVector<Value *> Extracts{Elt};
+    while (Remaining--) {
+      CurrentIndex %= Intrin.NumElts;
+
+      if (CurrentIndex == 0)
+        CBufLoad = Builder.CreateIntrinsic(
+            Intrin.RetTy, Intrin.IID,
+            {Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
+            nullptr, LI->getName());
+
+      Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
+                                                    LI->getName()));
+    }
+
+    // Finally, we build up the original loaded value.
+    Result = PoisonValue::get(Ty);
+    for (int I = 0, E = Extracts.size(); I < E; ++I)
+      Result =
+          Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I));
+  }
+
+  LI->replaceAllUsesWith(Result);
+  DeadInsts.push_back(LI);
+}
+
+static void replaceAccessesWithHandle(GlobalVariable *Global,
+                                      GlobalVariable *HandleGV,
+                                      size_t BaseOffset) {
+  SmallVector<WeakTrackingVH> DeadInsts;
+
+  SmallVector<User *> ToProcess{Global->users()};
+  while (!ToProcess.empty()) {
+    User *Cur = ToProcess.pop_back_val();
+
+    // If we have a load instruction, replace the access.
+    if (auto *LI = dyn_cast<LoadInst>(Cur)) {
+      replaceAccess(LI, Global, HandleGV, BaseOffset, DeadInsts);
+      continue;
+    }
+
+    // Otherwise, walk users looking for a load...
+    ToProcess.append(Cur->user_begin(), Cur->user_end());
+  }
+  RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
+}
+
+static bool replaceCBufferAccesses(Module &M) {
+  std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
+  if (!CBufMD)
+    return false;
+
+  for (const hlsl::CBufferMapping &Mapping : *CBufMD)
+    for (const hlsl::CBufferMember &Member : Mapping.Members) {
+      replaceAccessesWithHandle(Member.GV, Mapping.Handle, Member.Offset);
+      Member.GV->removeFromParent();
+    }
+
+  CBufMD->eraseFromModule();
+  return true;
+}
+
+PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) {
+  PreservedAnalyses PA;
+  bool Changed = replaceCBufferAccesses(M);
+
+  if (!Changed)
+    return PreservedAnalyses::all();
+  return PA;
+}
+
+namespace {
+class DXILCBufferAccessLegacy : public ModulePass {
+public:
+  bool runOnModule(Module &M) override {
+    return replaceCBufferAccesses(M);
+  }
+  StringRef getPassName() const override { return "DXIL CBuffer Access"; }
+  DXILCBufferAccessLegacy() : ModulePass(ID) {}
+
+  static char ID; // Pass identification.
+};
+char DXILCBufferAccessLegacy::ID = 0;
+} // end anonymous namespace
+
+INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access",
+                false, false)
+
+ModulePass *llvm::createDXILCBufferAccessLegacyPass() {
+  return new DXILCBufferAccessLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILCBufferAccess.h b/llvm/lib/Target/DirectX/DXILCBufferAccess.h
new file mode 100644
index 0000000000000..6c1cde164004e
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILCBufferAccess.h
@@ -0,0 +1,28 @@
+//===- DXILCBufferAccess.h - Translate CBuffer Loads ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file Pass for replacing loads from cbuffers in the cbuffer address space to
+// cbuffer load intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
+#define LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class DXILCBufferAccess : public PassInfoMixin<DXILCBufferAccess> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 96a8a08c875f8..c0eb221d12203 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -35,6 +35,12 @@ void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &);
 /// Pass to expand intrinsic operations that lack DXIL opCodes
 ModulePass *createDXILIntrinsicExpansionLegacyPass();
 
+/// Initializer for DXIL CBuffer Access Pass
+void initializeDXILCBufferAccessLegacyPass(PassRegistry &);
+
+/// Pass to translate loads in the cbuffer address space to intrinsics
+ModulePass *createDXILCBufferAccessLegacyPass();
+
 /// Initializer for DXIL Data Scalarization Pass
 void initializeDXILDataScalarizationLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index 87d91ead1896f..37093f16680a9 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -23,6 +23,7 @@ MODULE_ANALYSIS("dxil-root-signature-analysis", dxil::RootSignatureAnalysis())
 #ifndef MODULE_PASS
 #define MODULE_PASS(NAME, CREATE_PASS)
 #endif
+MODULE_PASS("dxil-cbuffer-access", DXILCBufferAccess())
 MODULE_PASS("dxil-data-scalarization", DXILDataScalarization())
 MODULE_PASS("dxil-flatten-arrays", DXILFlattenArrays())
 MODULE_PASS("dxil-intrinsic-expansion", DXILIntrinsicExpansion())
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index ce408b4034f83..40ae1a3062704 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DirectXTargetMachine.h"
+#include "DXILCBufferAccess.h"
 #include "DXILDataScalarization.h"
 #include "DXILFlattenArrays.h"
 #include "DXILIntrinsicExpansion.h"
@@ -64,6 +65,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeShaderFlagsAnalysisWrapperPass(*PR);
   initializeRootSignatureAnalysisWrapperPass(*PR);
   initializeDXILFinalizeLinkageLegacyPass(*PR);
+  initializeDXILCBufferAccessLegacyPass(*PR);
 }
 
 class DXILTargetObjectFile : public TargetLoweringObjectFile {
@@ -95,6 +97,7 @@ class DirectXPassConfig : public TargetPassConfig {
   void addCodeGenPrepare() override {
     addPass(createDXILFinalizeLinkageLegacyPass());
     addPass(createDXILIntrinsicExpansionLegacyPass());
+    addPass(createDXILCBufferAccessLegacyPass());
     addPass(createDXILDataScalarizationLegacyPass());
     addPass(createDXILFlattenArraysLegacyPass());
     addPass(createDXILResourceAccessLegacyPass());
diff --git a/llvm/test/CodeGen/DirectX/CBufferAccess/arrays.ll b/llvm/test/CodeGen/DirectX/CBufferAccess/arrays.ll
new file mode 100644
index 0000000000000..7478cc5f362dc
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/CBufferAccess/arrays.ll
@@ -0,0 +1,121 @@
+; RUN: opt -S -dxil-cbuffer-access -mtriple=dxil--shadermodel6.3-library %s | FileCheck %s
+
+; cbuffer CB : register(b0) {
+;   float a1[3];
+;   double3 a2[2];
+;   float16_t a3[2][2];
+;   uint64_t a4[3];
+;   int4 a5[2][3][4];
+;   uint16_t a6[1];
+;   int64_t a7[2];
+;   bool a8[4];
+; }
+%__cblayout_CB = type <{ [3 x float], [2 x <3 x double>], [2 x [2 x half]], [3 x i64], [2 x [3 x [4 x <4 x i32>]]], [1 x i16], [2 x i64], [4 x i32] }>
+%struct.S = type { float, <3 x double>, half, i64, <4 x i32>, i16, i64, i32, [12 x i8] }
+
+@CB.cb = local_unnamed_addr global target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 708, 0, 48, 112, 176, 224, 608, 624, 656)) poison
+@a1 = external local_unnamed_addr addrspace(2) global [3 x float], align 4
+@a2 = external local_unnamed_addr addrspace(2) global [2 x <3 x double>], align 32
+@a3 = external local_unnamed_addr addrspace(2) global [2 x [2 x half]], align 2
+@a4 = external local_unnamed_addr addrspace(2) global [3 x i64], align 8
+@a5 = external local_unnamed_addr addrspace(2) global [2 x [3 x [4 x <4 x i32>]]], align 16
+@a6 = external local_unnamed_addr addrspace(2) global [1 x i16], align 2
+@a7 = external local_unnamed_addr addrspace(2) global [2 x i64], align 8
+@a8 = external local_unnamed_addr addrspace(2) global [4 x i32], align 4
+
+define void @f(ptr %dst) {
+entry:
+  %CB.cb_h.i.i = tail call target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 708, 0, 48, 112, 176, 224, 608, 624, 656)) @llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 1, i32 0, i1 false)
+  store target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 708, 0, 48, 112, 176, 224, 608, 624, 656)) %CB.cb_h.i.i, ptr @CB.cb, align 4
+
+  ; CHECK: [[CB:%.*]] = load target("dx.CBuffer", {{.*}})), ptr @CB.cb
+  ; CHECK: [[LOAD:%.*]] = call { float, float, float, float } @llvm.dx.resource.load.cbufferrow.4.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 1)
+  ; CHECK: [[X:%.*]] = extractvalue { float, float, float, float } [[LOAD]], 0
+  ; CHECK: store float [[X]], ptr %dst
+  %a1 = load float, ptr addrspace(2) getelementptr inbounds nuw (i8, ptr addrspace(2) @a1, i32 4), align 4
+  store float %a1, ptr %dst, align 32
+
+  ; CHECK: [[CB:%.*]] = load target("dx.CBuffer", {{.*}})), ptr @CB.cb
+  ; CHECK: [[LOAD:%.*]] = call { double, double } @llvm.dx.resource.load.cbufferrow.2.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 5)
+  ; CHECK: [[X:%.*]] = extractvalue { double, double } [[LOAD]], 0
+  ; CHECK: [[Y:%.*]] = extractvalue { double, double } [[LOAD]], 1
+  ; CHECK: [[LOAD:%.*]] = call { double, double } @llvm.dx.resource.load.cbufferrow.2.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 6)
+  ; CHECK: [[Z:%.*]] = extractvalue { double, double } [[LOAD]], 0
+  ; CHECK: [[VEC0:%.*]] = insertelement <3 x double> poison, double [[X]], i32 0
+  ; CHECK: [[VEC1:%.*]] = insertelement <3 x double> [[VEC0]], double [[Y]], i32 1
+  ; CHECK: [[VEC2:%.*]] = insertelement <3 x double> [[VEC1]], double [[Z]], i32 2
+  ; CHECK: [[PTR:%.*]] = getelementptr inbounds nuw i8, ptr %dst, i32 8
+  ; CHECK: store <3 x double> [[VEC2]], ptr [[PTR]]
+  %...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 7, 2025

@llvm/pr-subscribers-hlsl

Author: Justin Bogner (bogner)

Changes

This introduces a pass that walks accesses to globals in cbuffers and replaces them with accesses via the cbuffer handle itself. The logic to interpret the cbuffer metadata is kept in lib/Frontend/HLSL so that it can be reused by other consumers of that metadata.

Fixes #124630.


Patch is 41.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134571.diff

15 Files Affected:

  • (added) llvm/include/llvm/Frontend/HLSL/CBuffer.h (+64)
  • (added) llvm/lib/Frontend/HLSL/CBuffer.cpp (+71)
  • (modified) llvm/lib/Frontend/HLSL/CMakeLists.txt (+1)
  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILCBufferAccess.cpp (+209)
  • (added) llvm/lib/Target/DirectX/DXILCBufferAccess.h (+28)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXPassRegistry.def (+1)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+3)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/arrays.ll (+121)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/float.ll (+22)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/gep-ce-two-uses.ll (+32)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/scalars.ll (+105)
  • (added) llvm/test/CodeGen/DirectX/CBufferAccess/vectors.ll (+116)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+1)
diff --git a/llvm/include/llvm/Frontend/HLSL/CBuffer.h b/llvm/include/llvm/Frontend/HLSL/CBuffer.h
new file mode 100644
index 0000000000000..cf45b5ff5a8e2
--- /dev/null
+++ b/llvm/include/llvm/Frontend/HLSL/CBuffer.h
@@ -0,0 +1,64 @@
+//===- CBuffer.h - HLSL constant buffer handling ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains utilities to work with constant buffers in HLSL.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_FRONTEND_HLSL_CBUFFER_H
+#define LLVM_FRONTEND_HLSL_CBUFFER_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include <optional>
+
+namespace llvm {
+class Module;
+class GlobalVariable;
+class NamedMDNode;
+
+namespace hlsl {
+
+struct CBufferMember {
+  CBufferMember(GlobalVariable *GV, size_t Offset) : GV(GV), Offset(Offset) {}
+
+  GlobalVariable *GV;
+  size_t Offset;
+};
+
+struct CBufferMapping {
+  CBufferMapping(GlobalVariable *Handle) : Handle(Handle) {}
+
+  GlobalVariable *Handle;
+  SmallVector<CBufferMember> Members;
+};
+
+class CBufferMetadata {
+  NamedMDNode *MD;
+  SmallVector<CBufferMapping> Mappings;
+
+  CBufferMetadata(NamedMDNode *MD) : MD(MD) {}
+
+public:
+  static std::optional<CBufferMetadata> get(Module &M);
+
+  using iterator = SmallVector<CBufferMapping>::iterator;
+  iterator begin() { return Mappings.begin(); }
+  iterator end() { return Mappings.end(); }
+
+  void eraseFromModule();
+};
+
+APInt translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
+                               ArrayType *Ty);
+
+} // namespace hlsl
+} // namespace llvm
+
+#endif // LLVM_FRONTEND_HLSL_CBUFFER_H
diff --git a/llvm/lib/Frontend/HLSL/CBuffer.cpp b/llvm/lib/Frontend/HLSL/CBuffer.cpp
new file mode 100644
index 0000000000000..b311f6aea9636
--- /dev/null
+++ b/llvm/lib/Frontend/HLSL/CBuffer.cpp
@@ -0,0 +1,71 @@
+//===- CBuffer.cpp - HLSL constant buffer handling ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Frontend/HLSL/CBuffer.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+
+using namespace llvm;
+using namespace llvm::hlsl;
+
+static size_t getMemberOffset(GlobalVariable *Handle, size_t Index) {
+  auto *HandleTy = cast<TargetExtType>(Handle->getValueType());
+  assert(HandleTy->getName().ends_with(".CBuffer") && "Not a cbuffer type");
+  assert(HandleTy->getNumTypeParameters() == 1 && "Expected layout type");
+
+  auto *LayoutTy = cast<TargetExtType>(HandleTy->getTypeParameter(0));
+  assert(LayoutTy->getName().ends_with(".Layout") && "Not a layout type");
+
+  // Skip the "size" parameter.
+  size_t ParamIndex = Index + 1;
+  assert(LayoutTy->getNumIntParameters() > ParamIndex &&
+         "Not enough parameters");
+
+  return LayoutTy->getIntParameter(ParamIndex);
+}
+
+std::optional<CBufferMetadata> CBufferMetadata::get(Module &M) {
+  NamedMDNode *CBufMD = M.getNamedMetadata("hlsl.cbs");
+  if (!CBufMD)
+    return std::nullopt;
+
+  std::optional<CBufferMetadata> Result({CBufMD});
+
+  for (const MDNode *MD : CBufMD->operands()) {
+    assert(MD->getNumOperands() && "Invalid cbuffer metadata");
+
+    auto *Handle = cast<GlobalVariable>(
+        cast<ValueAsMetadata>(MD->getOperand(0))->getValue());
+    CBufferMapping &Mapping = Result->Mappings.emplace_back(Handle);
+
+    for (int I = 1, E = MD->getNumOperands(); I < E; ++I) {
+      Metadata *OpMD = MD->getOperand(I);
+      // Some members may be null if they've been optimized out.
+      if (!OpMD)
+        continue;
+      auto *V = cast<GlobalVariable>(cast<ValueAsMetadata>(OpMD)->getValue());
+      Mapping.Members.emplace_back(V, getMemberOffset(Handle, I - 1));
+    }
+  }
+
+  return Result;
+}
+
+
+void CBufferMetadata::eraseFromModule() {
+  // Remove the cbs named metadata
+  MD->eraseFromParent();
+}
+
+APInt hlsl::translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
+                                     ArrayType *Ty) {
+  int64_t TypeSize = DL.getTypeSizeInBits(Ty->getElementType()) / 8;
+  int64_t RoundUp = alignTo(TypeSize, Align(16));
+  return Offset.udiv(TypeSize) * RoundUp;
+}
diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt
index eda6cb8e69a49..07a0c845ceef6 100644
--- a/llvm/lib/Frontend/HLSL/CMakeLists.txt
+++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_llvm_component_library(LLVMFrontendHLSL
+  CBuffer.cpp
   HLSLResource.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 13f8adbe4f132..c55028bc75dd6 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -20,6 +20,7 @@ add_llvm_target(DirectXCodeGen
   DirectXTargetMachine.cpp
   DirectXTargetTransformInfo.cpp
   DXContainerGlobals.cpp
+  DXILCBufferAccess.cpp
   DXILDataScalarization.cpp
   DXILFinalizeLinkage.cpp
   DXILFlattenArrays.cpp
diff --git a/llvm/lib/Target/DirectX/DXILCBufferAccess.cpp b/llvm/lib/Target/DirectX/DXILCBufferAccess.cpp
new file mode 100644
index 0000000000000..f8771efeac991
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILCBufferAccess.cpp
@@ -0,0 +1,209 @@
+//===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILCBufferAccess.h"
+#include "DirectX.h"
+#include "llvm/Frontend/HLSL/CBuffer.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/Local.h"
+
+#define DEBUG_TYPE "dxil-cbuffer-access"
+using namespace llvm;
+
+namespace {
+/// Helper for building a `load.cbufferrow` intrinsic given a simple type.
+struct CBufferRowIntrin {
+  Intrinsic::ID IID;
+  Type *RetTy;
+  unsigned int EltSize;
+  unsigned int NumElts;
+
+  CBufferRowIntrin(const DataLayout &DL, Type *Ty) {
+    assert(Ty == Ty->getScalarType() && "Expected scalar type");
+
+    switch (DL.getTypeSizeInBits(Ty)) {
+    case 16:
+      IID = Intrinsic::dx_resource_load_cbufferrow_8;
+      RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty);
+      EltSize = 2;
+      NumElts = 8;
+      break;
+    case 32:
+      IID = Intrinsic::dx_resource_load_cbufferrow_4;
+      RetTy = StructType::get(Ty, Ty, Ty, Ty);
+      EltSize = 4;
+      NumElts = 4;
+      break;
+    case 64:
+      IID = Intrinsic::dx_resource_load_cbufferrow_2;
+      RetTy = StructType::get(Ty, Ty);
+      EltSize = 8;
+      NumElts = 2;
+      break;
+    default:
+      llvm_unreachable("Only 16, 32, and 64 bit types supported");
+  }
+  }
+};
+} // namespace
+
+static size_t getOffsetForCBufferGEP(GEPOperator *GEP, GlobalVariable *Global,
+                                     const DataLayout &DL) {
+  // Since we should always have a constant offset, we should only ever have a
+  // single GEP of indirection from the Global.
+  assert(GEP->getPointerOperand() == Global &&
+         "Indirect access to resource handle");
+
+  APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
+  bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
+  (void)Success;
+  assert(Success && "Offsets into cbuffer globals must be constant");
+
+  if (auto *ATy = dyn_cast<ArrayType>(Global->getValueType()))
+    ConstantOffset = hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
+
+  return ConstantOffset.getZExtValue();
+}
+
+/// Replace access via cbuffer global with a load from the cbuffer handle
+/// itself.
+static void replaceAccess(LoadInst *LI, GlobalVariable *Global,
+                          GlobalVariable *HandleGV, size_t BaseOffset,
+                          SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
+  const DataLayout &DL = HandleGV->getDataLayout();
+
+  size_t Offset = BaseOffset;
+  if (auto *GEP = dyn_cast<GEPOperator>(LI->getPointerOperand()))
+    Offset += getOffsetForCBufferGEP(GEP, Global, DL);
+  else if (LI->getPointerOperand() != Global)
+    llvm_unreachable("Load instruction doesn't reference cbuffer global");
+
+  IRBuilder<> Builder(LI);
+  auto *Handle = Builder.CreateLoad(HandleGV->getValueType(), HandleGV,
+                                    HandleGV->getName());
+
+  Type *Ty = LI->getType();
+  CBufferRowIntrin Intrin(DL, Ty->getScalarType());
+  // The cbuffer consists of some number of 16-byte rows.
+  unsigned int CurrentRow = Offset / 16;
+  unsigned int CurrentIndex = (Offset % 16) / Intrin.EltSize;
+
+  auto *CBufLoad = Builder.CreateIntrinsic(
+      Intrin.RetTy, Intrin.IID,
+      {Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
+      LI->getName());
+  auto *Elt =
+      Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, LI->getName());
+
+  Value *Result = nullptr;
+  unsigned int Remaining =
+      ((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
+  if (Remaining == 0) {
+    // We only have a single element, so we're done.
+    Result = Elt;
+
+    // However, if we loaded a <1 x T>, then we need to adjust the type here.
+    if (auto *VT = dyn_cast<FixedVectorType>(LI->getType()))
+      if (VT->getNumElements() == 1)
+        Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
+                                             Builder.getInt32(0));
+  } else {
+    // Walk each element and extract it, wrapping to new rows as needed.
+    SmallVector<Value *> Extracts{Elt};
+    while (Remaining--) {
+      CurrentIndex %= Intrin.NumElts;
+
+      if (CurrentIndex == 0)
+        CBufLoad = Builder.CreateIntrinsic(
+            Intrin.RetTy, Intrin.IID,
+            {Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
+            nullptr, LI->getName());
+
+      Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
+                                                    LI->getName()));
+    }
+
+    // Finally, we build up the original loaded value.
+    Result = PoisonValue::get(Ty);
+    for (int I = 0, E = Extracts.size(); I < E; ++I)
+      Result =
+          Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I));
+  }
+
+  LI->replaceAllUsesWith(Result);
+  DeadInsts.push_back(LI);
+}
+
+static void replaceAccessesWithHandle(GlobalVariable *Global,
+                                      GlobalVariable *HandleGV,
+                                      size_t BaseOffset) {
+  SmallVector<WeakTrackingVH> DeadInsts;
+
+  SmallVector<User *> ToProcess{Global->users()};
+  while (!ToProcess.empty()) {
+    User *Cur = ToProcess.pop_back_val();
+
+    // If we have a load instruction, replace the access.
+    if (auto *LI = dyn_cast<LoadInst>(Cur)) {
+      replaceAccess(LI, Global, HandleGV, BaseOffset, DeadInsts);
+      continue;
+    }
+
+    // Otherwise, walk users looking for a load...
+    ToProcess.append(Cur->user_begin(), Cur->user_end());
+  }
+  RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
+}
+
+static bool replaceCBufferAccesses(Module &M) {
+  std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
+  if (!CBufMD)
+    return false;
+
+  for (const hlsl::CBufferMapping &Mapping : *CBufMD)
+    for (const hlsl::CBufferMember &Member : Mapping.Members) {
+      replaceAccessesWithHandle(Member.GV, Mapping.Handle, Member.Offset);
+      Member.GV->removeFromParent();
+    }
+
+  CBufMD->eraseFromModule();
+  return true;
+}
+
+PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) {
+  PreservedAnalyses PA;
+  bool Changed = replaceCBufferAccesses(M);
+
+  if (!Changed)
+    return PreservedAnalyses::all();
+  return PA;
+}
+
+namespace {
+class DXILCBufferAccessLegacy : public ModulePass {
+public:
+  bool runOnModule(Module &M) override {
+    return replaceCBufferAccesses(M);
+  }
+  StringRef getPassName() const override { return "DXIL CBuffer Access"; }
+  DXILCBufferAccessLegacy() : ModulePass(ID) {}
+
+  static char ID; // Pass identification.
+};
+char DXILCBufferAccessLegacy::ID = 0;
+} // end anonymous namespace
+
+INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access",
+                false, false)
+
+ModulePass *llvm::createDXILCBufferAccessLegacyPass() {
+  return new DXILCBufferAccessLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILCBufferAccess.h b/llvm/lib/Target/DirectX/DXILCBufferAccess.h
new file mode 100644
index 0000000000000..6c1cde164004e
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILCBufferAccess.h
@@ -0,0 +1,28 @@
+//===- DXILCBufferAccess.h - Translate CBuffer Loads ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file Pass for replacing loads from cbuffers in the cbuffer address space to
+// cbuffer load intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
+#define LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class DXILCBufferAccess : public PassInfoMixin<DXILCBufferAccess> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 96a8a08c875f8..c0eb221d12203 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -35,6 +35,12 @@ void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &);
 /// Pass to expand intrinsic operations that lack DXIL opCodes
 ModulePass *createDXILIntrinsicExpansionLegacyPass();
 
+/// Initializer for DXIL CBuffer Access Pass
+void initializeDXILCBufferAccessLegacyPass(PassRegistry &);
+
+/// Pass to translate loads in the cbuffer address space to intrinsics
+ModulePass *createDXILCBufferAccessLegacyPass();
+
 /// Initializer for DXIL Data Scalarization Pass
 void initializeDXILDataScalarizationLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index 87d91ead1896f..37093f16680a9 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -23,6 +23,7 @@ MODULE_ANALYSIS("dxil-root-signature-analysis", dxil::RootSignatureAnalysis())
 #ifndef MODULE_PASS
 #define MODULE_PASS(NAME, CREATE_PASS)
 #endif
+MODULE_PASS("dxil-cbuffer-access", DXILCBufferAccess())
 MODULE_PASS("dxil-data-scalarization", DXILDataScalarization())
 MODULE_PASS("dxil-flatten-arrays", DXILFlattenArrays())
 MODULE_PASS("dxil-intrinsic-expansion", DXILIntrinsicExpansion())
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index ce408b4034f83..40ae1a3062704 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DirectXTargetMachine.h"
+#include "DXILCBufferAccess.h"
 #include "DXILDataScalarization.h"
 #include "DXILFlattenArrays.h"
 #include "DXILIntrinsicExpansion.h"
@@ -64,6 +65,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeShaderFlagsAnalysisWrapperPass(*PR);
   initializeRootSignatureAnalysisWrapperPass(*PR);
   initializeDXILFinalizeLinkageLegacyPass(*PR);
+  initializeDXILCBufferAccessLegacyPass(*PR);
 }
 
 class DXILTargetObjectFile : public TargetLoweringObjectFile {
@@ -95,6 +97,7 @@ class DirectXPassConfig : public TargetPassConfig {
   void addCodeGenPrepare() override {
     addPass(createDXILFinalizeLinkageLegacyPass());
     addPass(createDXILIntrinsicExpansionLegacyPass());
+    addPass(createDXILCBufferAccessLegacyPass());
     addPass(createDXILDataScalarizationLegacyPass());
     addPass(createDXILFlattenArraysLegacyPass());
     addPass(createDXILResourceAccessLegacyPass());
diff --git a/llvm/test/CodeGen/DirectX/CBufferAccess/arrays.ll b/llvm/test/CodeGen/DirectX/CBufferAccess/arrays.ll
new file mode 100644
index 0000000000000..7478cc5f362dc
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/CBufferAccess/arrays.ll
@@ -0,0 +1,121 @@
+; RUN: opt -S -dxil-cbuffer-access -mtriple=dxil--shadermodel6.3-library %s | FileCheck %s
+
+; cbuffer CB : register(b0) {
+;   float a1[3];
+;   double3 a2[2];
+;   float16_t a3[2][2];
+;   uint64_t a4[3];
+;   int4 a5[2][3][4];
+;   uint16_t a6[1];
+;   int64_t a7[2];
+;   bool a8[4];
+; }
+%__cblayout_CB = type <{ [3 x float], [2 x <3 x double>], [2 x [2 x half]], [3 x i64], [2 x [3 x [4 x <4 x i32>]]], [1 x i16], [2 x i64], [4 x i32] }>
+%struct.S = type { float, <3 x double>, half, i64, <4 x i32>, i16, i64, i32, [12 x i8] }
+
+@CB.cb = local_unnamed_addr global target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 708, 0, 48, 112, 176, 224, 608, 624, 656)) poison
+@a1 = external local_unnamed_addr addrspace(2) global [3 x float], align 4
+@a2 = external local_unnamed_addr addrspace(2) global [2 x <3 x double>], align 32
+@a3 = external local_unnamed_addr addrspace(2) global [2 x [2 x half]], align 2
+@a4 = external local_unnamed_addr addrspace(2) global [3 x i64], align 8
+@a5 = external local_unnamed_addr addrspace(2) global [2 x [3 x [4 x <4 x i32>]]], align 16
+@a6 = external local_unnamed_addr addrspace(2) global [1 x i16], align 2
+@a7 = external local_unnamed_addr addrspace(2) global [2 x i64], align 8
+@a8 = external local_unnamed_addr addrspace(2) global [4 x i32], align 4
+
+define void @f(ptr %dst) {
+entry:
+  %CB.cb_h.i.i = tail call target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 708, 0, 48, 112, 176, 224, 608, 624, 656)) @llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 1, i32 0, i1 false)
+  store target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 708, 0, 48, 112, 176, 224, 608, 624, 656)) %CB.cb_h.i.i, ptr @CB.cb, align 4
+
+  ; CHECK: [[CB:%.*]] = load target("dx.CBuffer", {{.*}})), ptr @CB.cb
+  ; CHECK: [[LOAD:%.*]] = call { float, float, float, float } @llvm.dx.resource.load.cbufferrow.4.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 1)
+  ; CHECK: [[X:%.*]] = extractvalue { float, float, float, float } [[LOAD]], 0
+  ; CHECK: store float [[X]], ptr %dst
+  %a1 = load float, ptr addrspace(2) getelementptr inbounds nuw (i8, ptr addrspace(2) @a1, i32 4), align 4
+  store float %a1, ptr %dst, align 32
+
+  ; CHECK: [[CB:%.*]] = load target("dx.CBuffer", {{.*}})), ptr @CB.cb
+  ; CHECK: [[LOAD:%.*]] = call { double, double } @llvm.dx.resource.load.cbufferrow.2.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 5)
+  ; CHECK: [[X:%.*]] = extractvalue { double, double } [[LOAD]], 0
+  ; CHECK: [[Y:%.*]] = extractvalue { double, double } [[LOAD]], 1
+  ; CHECK: [[LOAD:%.*]] = call { double, double } @llvm.dx.resource.load.cbufferrow.2.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 6)
+  ; CHECK: [[Z:%.*]] = extractvalue { double, double } [[LOAD]], 0
+  ; CHECK: [[VEC0:%.*]] = insertelement <3 x double> poison, double [[X]], i32 0
+  ; CHECK: [[VEC1:%.*]] = insertelement <3 x double> [[VEC0]], double [[Y]], i32 1
+  ; CHECK: [[VEC2:%.*]] = insertelement <3 x double> [[VEC1]], double [[Z]], i32 2
+  ; CHECK: [[PTR:%.*]] = getelementptr inbounds nuw i8, ptr %dst, i32 8
+  ; CHECK: store <3 x double> [[VEC2]], ptr [[PTR]]
+  %...
[truncated]

Copy link

github-actions bot commented Apr 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right place for this transformation? I expect that we would want this pass to run for all backends. We would definitely want it for SPIR-V. Could we move it into an HLSL directory as in #134260?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking here is that the generic logic that's helpful for all backends should belong in Frontend/HLSL/CBuffer.h, but the pass itself is fairly DirectX specific. While all of the logic to figure out offsets and memory layout is necessary for all targets, the details of what this should transform into aren't necessarily compatible.

For example, the DirectX backend is constrained in that it needs to access the cbuffers via an operation that loads a single 16-byte row, so we need to lower to this series of dx.cbuffer.load.cbufferrow operations and then piece together the data we actually want to load. I can't imagine that this would be the best way to represent this for SPIR-V, where all we really care about is where the object ended up in memory and how it's padded but can use normal load operations from there.

So trying to put all of this in a generic pass that's aware of the various backends and their different target intrinsics feels like it would be wrong.

All that said, I'm not sure if the balance between what's in the pass itself and what's done in lib/Frontend is correct here - we may want to move more stuff to there when we implement a similar change for SPIR-V.

Copy link
Member

@hekota hekota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, I just have a question about array indexing. I don't think it will work for loads like these:

%1 = load float, ptr addrspace(2) getelementptr inbounds ([3 x float], ptr addrspace(2) @a1, i32 0, i32 1), align 4

Copy link
Member

@hekota hekota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! :)

store target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 708, 0, 48, 112, 176, 224, 608, 624, 656)) %CB.cb_h.i.i, ptr @CB.cb, align 4

; CHECK: [[CB:%.*]] = load target("dx.CBuffer", {{.*}})), ptr @CB.cb
; CHECK: [[LOAD:%.*]] = call { float, float, float, float } @llvm.dx.resource.load.cbufferrow.4.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 1)
Copy link
Contributor

@inbelic inbelic Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
; CHECK: [[LOAD:%.*]] = call { float, float, float, float } @llvm.dx.resource.load.cbufferrow.4.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 1)
; CHECK: [[LOAD:%.*]] = call { float, float, float, float } @llvm.dx.resource.load.cbufferrow.4.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 0)

When we lower this to CBufferLoadLegacy this corresponds to the "0-based row index". Why do we expect this to be 1 instead of 0? Presumably this is the first (0 indexed) row?

For context, arrays.ll and gep-ce-two-uses.ll seem to start at row index 1. But float.ll, scalars.ll, and vectors.ll all seem to start at 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we're loading a1[1], so the index of 1 is correct. You can see this from the input if you take a look at the GEP constant expression in the load:

%a1 = load float, ptr addrspace(2) getelementptr inbounds nuw (i8, ptr addrspace(2) @a1, i32 4), align 4

Here, we're accessing a value that's four bytes past the @a1 pointer, ie the float at a1[1]. So the cbuffer access pass needs to translate this into loading the element at index 1 of the cbuffer, which is in the second row.

@farzonl
Copy link
Member

farzonl commented Apr 14, 2025

I'm seeing a pretty large regression in the DML shaders that are able to compile in the form Load of {{.*}} is not a global resource handle.

@llvm-beanz
Copy link
Collaborator

I'm seeing a pretty large regression in the DML shaders that are able to compile in the form Load of {{.*}} is not a global resource handle.

The shaders failing to compile instead of compiling garbage is not a regression.

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable. I think we're going to have to think about how to handle cases where LLVM or Clang might do silly things (like generate memcpy) because there isn't a sane way to represent that memory layout is different in different address spaces.

PR #134174 does cause clang to generate memcpy for array->rvalue casts, but the same will occur for aggregate assignments (see this C++ example).

We will either need to teach clang not to generate memcpy or we'll need to have some way to legalize it.

@bogner
Copy link
Contributor Author

bogner commented Apr 16, 2025

I'm seeing a pretty large regression in the DML shaders that are able to compile in the form Load of {{.*}} is not a global resource handle.

I looked into this and every case that regresses errors in the frontend if we implement #135909. Unfortunately there are a lot of uses of implicit bindings in that set of shaders, so it isn't exercising much of this code at all.

bogner added 3 commits April 15, 2025 22:28
This introduces a pass that walks accesses to globals in cbuffers and
replaces them with accesses via the cbuffer handle itself. The logic to
interpret the cbuffer metadata is kept in `lib/Frontend/HLSL` so that it
can be reused by other consumers of that metadata.

Fixes llvm#124630.
@bogner bogner force-pushed the 2025-04-07-cbuffer-access branch from c9be1d4 to 0290631 Compare April 16, 2025 05:29
@bogner bogner merged commit 3de88fe into llvm:main Apr 16, 2025
7 of 11 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Apr 16, 2025

LLVM Buildbot has detected a new failure on builder lld-x86_64-win running on as-worker-93 while building llvm at step 7 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/146/builds/2719

Here is the relevant piece of the build log for the reference
Step 7 (test-build-unified-tree-check-all) failure: test (failure)
******************** TEST 'LLVM-Unit :: Support/./SupportTests.exe/86/95' FAILED ********************
Script(shard):
--
GTEST_OUTPUT=json:C:\a\lld-x86_64-win\build\unittests\Support\.\SupportTests.exe-LLVM-Unit-7624-86-95.json GTEST_SHUFFLE=0 GTEST_TOTAL_SHARDS=95 GTEST_SHARD_INDEX=86 C:\a\lld-x86_64-win\build\unittests\Support\.\SupportTests.exe
--

Script:
--
C:\a\lld-x86_64-win\build\unittests\Support\.\SupportTests.exe --gtest_filter=ProgramEnvTest.CreateProcessLongPath
--
C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp(160): error: Expected equality of these values:
  0
  RC
    Which is: -2

C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp(163): error: fs::remove(Twine(LongPath)): did not return errc::success.
error number: 13
error message: permission denied



C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp:160
Expected equality of these values:
  0
  RC
    Which is: -2

C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp:163
fs::remove(Twine(LongPath)): did not return errc::success.
error number: 13
error message: permission denied




********************


var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
This introduces a pass that walks accesses to globals in cbuffers and
replaces them with accesses via the cbuffer handle itself. The logic to
interpret the cbuffer metadata is kept in `lib/Frontend/HLSL` so that it
can be reused by other consumers of that metadata.

Fixes llvm#124630.
@damyanp damyanp moved this to Closed in HLSL Support Apr 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX HLSL HLSL Language Support
Projects
Status: Closed
Development

Successfully merging this pull request may close these issues.

[HLSL] Create HLSLConstantAccess pass to translate hlsl_constant loads
8 participants