Skip to content

[DirectX] legalize memset #136244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions llvm/lib/Target/DirectX/DXILLegalizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <functional>
Expand Down Expand Up @@ -174,16 +175,22 @@ static void upcastI8AllocasAndUses(Instruction &I,

Type *SmallestType = nullptr;

// Gather all cast targets
for (User *U : AI->users()) {
auto *Load = dyn_cast<LoadInst>(U);
if (!Load)
continue;
for (User *LU : Load->users()) {
auto *Cast = dyn_cast<CastInst>(LU);
if (!Cast)
Type *Ty = nullptr;
if (auto *Cast = dyn_cast<CastInst>(LU))
Ty = Cast->getType();
if (CallInst *CI = dyn_cast<CallInst>(LU)) {
if (CI->getIntrinsicID() == Intrinsic::memset)
Ty = Type::getInt32Ty(CI->getContext());
}

if (!Ty)
continue;
Type *Ty = Cast->getType();

if (!SmallestType ||
Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
SmallestType = Ty;
Expand Down Expand Up @@ -239,6 +246,77 @@ downcastI64toI32InsertExtractElements(Instruction &I,
}
}

static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
ConstantInt *SizeCI,
DenseMap<Value *, Value *> &ReplacedValues) {
LLVMContext &Ctx = Builder.getContext();
[[maybe_unused]] const DataLayout &DL =
Builder.GetInsertBlock()->getModule()->getDataLayout();
[[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue();

AllocaInst *Alloca = dyn_cast<AllocaInst>(Dst);

assert(Alloca && "Expected memset on an Alloca");
assert(OrigSize == Alloca->getAllocationSize(DL)->getFixedValue() &&
"Expected for memset size to match DataLayout size");

Type *AllocatedTy = Alloca->getAllocatedType();
ArrayType *ArrTy = dyn_cast<ArrayType>(AllocatedTy);
assert(ArrTy && "Expected Alloca for an Array Type");

Type *ElemTy = ArrTy->getElementType();
uint64_t Size = ArrTy->getArrayNumElements();

[[maybe_unused]] uint64_t ElemSize = DL.getTypeStoreSize(ElemTy);

assert(ElemSize > 0 && "Size must be set");
assert(OrigSize == ElemSize * Size && "Size in bytes must match");

Value *TypedVal = Val;

if (Val->getType() != ElemTy) {
if (ReplacedValues[Val]) {
// Note for i8 replacements if we know them we should use them.
// Further if this is a constant ReplacedValues will return null
// so we will stick to TypedVal = Val
TypedVal = ReplacedValues[Val];

} else {
// This case Val is a ConstantInt so the cast folds away.
// However if we don't do the cast the store below ends up being
// an i8.
TypedVal = Builder.CreateIntCast(Val, ElemTy, false);
}
}

for (uint64_t I = 0; I < Size; ++I) {
Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep");
Builder.CreateStore(TypedVal, Ptr);
}
}

static void removeMemSet(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
DenseMap<Value *, Value *> &ReplacedValues) {

CallInst *CI = dyn_cast<CallInst>(&I);
if (!CI)
return;

Intrinsic::ID ID = CI->getIntrinsicID();
if (ID != Intrinsic::memset)
return;

IRBuilder<> Builder(&I);
Value *Dst = CI->getArgOperand(0);
Value *Val = CI->getArgOperand(1);
ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(2));
assert(Size && "Expected Size to be a ConstantInt");
emitMemsetExpansion(Builder, Dst, Val, Size, ReplacedValues);
ToRemove.push_back(CI);
}

namespace {
class DXILLegalizationPipeline {

Expand Down Expand Up @@ -270,6 +348,7 @@ class DXILLegalizationPipeline {
LegalizationPipeline.push_back(fixI8UseChain);
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
LegalizationPipeline.push_back(legalizeFreeze);
LegalizationPipeline.push_back(removeMemSet);
}
};

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ class DirectXPassConfig : public TargetPassConfig {

FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
void addCodeGenPrepare() override {
addPass(createDXILFinalizeLinkageLegacyPass());
addPass(createDXILIntrinsicExpansionLegacyPass());
addPass(createDXILCBufferAccessLegacyPass());
addPass(createDXILDataScalarizationLegacyPass());
Expand All @@ -109,6 +108,7 @@ class DirectXPassConfig : public TargetPassConfig {
addPass(createScalarizerPass(DxilScalarOptions));
addPass(createDXILForwardHandleAccessesLegacyPass());
addPass(createDXILLegalizeLegacyPass());
addPass(createDXILFinalizeLinkageLegacyPass());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the legalize legacy pass is a function level pass so it leaves behind the intrinsic decleration in the module. finalize linkage has a cleanup that will remove these dead intrinsics declarations so doing this after legalization works better for cleanup purposes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine, but do note that if we do #134260 then we can't rely on this ordering.

Copy link
Member Author

@farzonl farzonl Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do that we can make legalization a module level pass and then we can do the cleanup of declares without FinalizeLinkage.

addPass(createDXILTranslateMetadataLegacyPass());
addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILPrepareModulePass());
Expand Down
125 changes: 125 additions & 0 deletions llvm/test/CodeGen/DirectX/legalize-memset.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -dxil-legalize -dxil-finalize-linkage -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s

define void @replace_float_memset_test() #0 {
; CHECK-LABEL: define void @replace_float_memset_test(
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x float], align 4
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: [[GEP:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 0
; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP]], align 4
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 1
; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP1]], align 4
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: ret void
;
%accum.i.flat = alloca [2 x float], align 4
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %accum.i.flat)
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 8, i1 false)
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %accum.i.flat)
ret void
}

define void @replace_half_memset_test() #0 {
; CHECK-LABEL: define void @replace_half_memset_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x half], align 4
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: [[GEP:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 0
; CHECK-NEXT: store half 0xH0000, ptr [[GEP]], align 2
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 1
; CHECK-NEXT: store half 0xH0000, ptr [[GEP1]], align 2
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: ret void
;
%accum.i.flat = alloca [2 x half], align 4
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false)
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
ret void
}

define void @replace_double_memset_test() #0 {
; CHECK-LABEL: define void @replace_double_memset_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x double], align 4
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 16, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: [[GEP:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 0
; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP]], align 8
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 1
; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP1]], align 8
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 16, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: ret void
;
%accum.i.flat = alloca [2 x double], align 4
call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %accum.i.flat)
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 16, i1 false)
call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %accum.i.flat)
ret void
}

define void @replace_int16_memset_test() #0 {
; CHECK-LABEL: define void @replace_int16_memset_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[CACHE_I:%.*]] = alloca [2 x i16], align 2
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[CACHE_I]])
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 0
; CHECK-NEXT: store i16 0, ptr [[GEP]], align 2
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 1
; CHECK-NEXT: store i16 0, ptr [[GEP1]], align 2
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[CACHE_I]])
; CHECK-NEXT: ret void
;
%cache.i = alloca [2 x i16], align 2
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %cache.i)
call void @llvm.memset.p0.i32(ptr nonnull align 2 dereferenceable(4) %cache.i, i8 0, i32 4, i1 false)
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %cache.i)
ret void
}

define void @replace_int_memset_test() #0 {
; CHECK-LABEL: define void @replace_int_memset_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
; CHECK-NEXT: store i32 0, ptr [[GEP]], align 4
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: ret void
;
%accum.i.flat = alloca [1 x i32], align 4
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false)
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
ret void
}

define void @replace_int_memset_to_var_test() #0 {
; CHECK-LABEL: define void @replace_int_memset_to_var_test(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
; CHECK-NEXT: [[I:%.*]] = alloca i32, align 4
; CHECK-NEXT: store i32 1, ptr [[I]], align 4
; CHECK-NEXT: [[I8_LOAD:%.*]] = load i32, ptr [[I]], align 4
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
; CHECK-NEXT: store i32 [[I8_LOAD]], ptr [[GEP]], align 4
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
; CHECK-NEXT: ret void
;
%accum.i.flat = alloca [1 x i32], align 4
%i = alloca i8, align 4
store i8 1, ptr %i
%i8.load = load i8, ptr %i
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 %i8.load, i32 4, i1 false)
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
ret void
}

attributes #0 = {"hlsl.export"}


declare void @llvm.lifetime.end.p0(i64 immarg, ptr captures(none))
declare void @llvm.lifetime.start.p0(i64 immarg, ptr captures(none))
declare void @llvm.memset.p0.i32(ptr writeonly captures(none), i8, i32, i1 immarg)
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/DirectX/llc-pipeline.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
; CHECK-OBJ-NEXT: Create Garbage Collector Module Metadata

; CHECK-NEXT: ModulePass Manager
; CHECK-NEXT: DXIL Finalize Linkage
; CHECK-NEXT: DXIL Intrinsic Expansion
; CHECK-NEXT: DXIL CBuffer Access
; CHECK-NEXT: DXIL Data Scalarization
Expand All @@ -24,6 +23,7 @@
; CHECK-NEXT: Scalarize vector operations
; CHECK-NEXT: DXIL Forward Handle Accesses
; CHECK-NEXT: DXIL Legalizer
; CHECK-NEXT: DXIL Finalize Linkage
; CHECK-NEXT: DXIL Resources Analysis
; CHECK-NEXT: DXIL Module Metadata analysis
; CHECK-NEXT: DXIL Shader Flag Analysis
Expand Down
Loading