Skip to content

Commit 395ba04

Browse files
committed
[DirectX] legalize memset
fixes #136243 This change converts memset into a series of geps and stores It is intentionally limited to memsets of fixed size It also converts the byte stores to type stores. DXIL does not support i8 plus this reduces the total number of gep and store instructions. This change also moves DXILFinalizeLinkage to run after Legalization to clean up any dead intrinsic definitions.
1 parent 1dbc8ef commit 395ba04

File tree

4 files changed

+161
-2
lines changed

4 files changed

+161
-2
lines changed

llvm/lib/Target/DirectX/DXILLegalizePass.cpp

+78
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "llvm/IR/IRBuilder.h"
1313
#include "llvm/IR/InstIterator.h"
1414
#include "llvm/IR/Instruction.h"
15+
#include "llvm/IR/Module.h"
1516
#include "llvm/Pass.h"
1617
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
1718
#include <functional>
@@ -151,6 +152,82 @@ downcastI64toI32InsertExtractElements(Instruction &I,
151152
}
152153
}
153154

155+
void emitMemset(IRBuilder<> &Builder, Value *Dst, Value *Val,
156+
ConstantInt *SizeCI) {
157+
LLVMContext &Ctx = Builder.getContext();
158+
[[maybe_unused]] DataLayout DL =
159+
Builder.GetInsertBlock()->getModule()->getDataLayout();
160+
[[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue();
161+
162+
AllocaInst *Alloca = dyn_cast<AllocaInst>(Dst);
163+
164+
assert(Alloca && "Expected memset on an Alloca");
165+
assert(OrigSize == Alloca->getAllocationSize(DL)->getFixedValue() &&
166+
"Expected for memset size to match DataLayout size");
167+
168+
Type *AllocatedTy = Alloca->getAllocatedType();
169+
ArrayType *ArrTy = dyn_cast<ArrayType>(AllocatedTy);
170+
assert(ArrTy && "Expected Alloca for an Array Type");
171+
172+
Type *ElemTy = ArrTy->getElementType();
173+
uint64_t Size = ArrTy->getArrayNumElements();
174+
175+
[[maybe_unused]] uint64_t ElemSize = DL.getTypeStoreSize(ElemTy);
176+
177+
assert(ElemSize > 0 && "Size must be set");
178+
assert(OrigSize == ElemSize * Size && "Size in bytes must match");
179+
180+
Value *TypedVal = Val;
181+
if (Val->getType() != ElemTy)
182+
TypedVal = Builder.CreateIntCast(Val, ElemTy,
183+
false); // Or use CreateBitCast for float
184+
185+
for (uint64_t I = 0; I < Size; ++I) {
186+
Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
187+
Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep");
188+
Builder.CreateStore(TypedVal, Ptr);
189+
}
190+
}
191+
192+
void removeLifetimesForMemset(CallInst *Memset,
193+
SmallVectorImpl<Instruction *> &ToRemove) {
194+
assert(Memset->getCalledFunction()->getIntrinsicID() == Intrinsic::memset &&
195+
"Expected a memset intrinsic");
196+
197+
Value *DstPtr = Memset->getArgOperand(0);
198+
DstPtr = DstPtr->stripPointerCasts();
199+
200+
for (User *U : DstPtr->users()) {
201+
if (auto *CI = dyn_cast<CallInst>(U)) {
202+
switch (CI->getIntrinsicID()) {
203+
case Intrinsic::lifetime_start:
204+
case Intrinsic::lifetime_end:
205+
ToRemove.push_back(CI);
206+
break;
207+
}
208+
}
209+
}
210+
}
211+
212+
static void removeMemSet(Instruction &I,
213+
SmallVectorImpl<Instruction *> &ToRemove,
214+
DenseMap<Value *, Value *>) {
215+
if (CallInst *CI = dyn_cast<CallInst>(&I)) {
216+
Intrinsic::ID ID = CI->getIntrinsicID();
217+
if (ID == Intrinsic::memset) {
218+
IRBuilder<> Builder(&I);
219+
Value *Dst = CI->getArgOperand(0);
220+
Value *Val = CI->getArgOperand(1);
221+
[[maybe_unused]] ConstantInt *Size =
222+
dyn_cast<ConstantInt>(CI->getArgOperand(2));
223+
assert(Size && "Expected Size to be a ConstantInt");
224+
emitMemset(Builder, Dst, Val, Size);
225+
removeLifetimesForMemset(CI, ToRemove);
226+
ToRemove.push_back(CI);
227+
}
228+
}
229+
}
230+
154231
namespace {
155232
class DXILLegalizationPipeline {
156233

@@ -181,6 +258,7 @@ class DXILLegalizationPipeline {
181258
LegalizationPipeline.push_back(fixI8TruncUseChain);
182259
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
183260
LegalizationPipeline.push_back(legalizeFreeze);
261+
LegalizationPipeline.push_back(removeMemSet);
184262
}
185263
};
186264

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ class DirectXPassConfig : public TargetPassConfig {
9696

9797
FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
9898
void addCodeGenPrepare() override {
99-
addPass(createDXILFinalizeLinkageLegacyPass());
10099
addPass(createDXILIntrinsicExpansionLegacyPass());
101100
addPass(createDXILCBufferAccessLegacyPass());
102101
addPass(createDXILDataScalarizationLegacyPass());
@@ -106,6 +105,7 @@ class DirectXPassConfig : public TargetPassConfig {
106105
DxilScalarOptions.ScalarizeLoadStore = true;
107106
addPass(createScalarizerPass(DxilScalarOptions));
108107
addPass(createDXILLegalizeLegacyPass());
108+
addPass(createDXILFinalizeLinkageLegacyPass());
109109
addPass(createDXILTranslateMetadataLegacyPass());
110110
addPass(createDXILOpLoweringLegacyPass());
111111
addPass(createDXILPrepareModulePass());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
4+
5+
define void @replace_float_memset_test() {
6+
; CHECK-LABEL: define void @replace_float_memset_test() {
7+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x float], align 4
8+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 0
9+
; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP]], align 4
10+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 1
11+
; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP1]], align 4
12+
; CHECK-NEXT: ret void
13+
;
14+
%accum.i.flat = alloca [2 x float], align 4
15+
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %accum.i.flat)
16+
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 8, i1 false)
17+
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %accum.i.flat)
18+
ret void
19+
}
20+
21+
define void @replace_half_memset_test() {
22+
; CHECK-LABEL: define void @replace_half_memset_test() {
23+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x half], align 4
24+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 0
25+
; CHECK-NEXT: store half 0xH0000, ptr [[GEP]], align 2
26+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 1
27+
; CHECK-NEXT: store half 0xH0000, ptr [[GEP1]], align 2
28+
; CHECK-NEXT: ret void
29+
;
30+
%accum.i.flat = alloca [2 x half], align 4
31+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
32+
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false)
33+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
34+
ret void
35+
}
36+
37+
define void @replace_double_memset_test() {
38+
; CHECK-LABEL: define void @replace_double_memset_test() {
39+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x double], align 4
40+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 0
41+
; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP]], align 8
42+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 1
43+
; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP1]], align 8
44+
; CHECK-NEXT: ret void
45+
;
46+
%accum.i.flat = alloca [2 x double], align 4
47+
call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %accum.i.flat)
48+
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 16, i1 false)
49+
call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %accum.i.flat)
50+
ret void
51+
}
52+
53+
define void @replace_int16_memset_test() {
54+
; CHECK-LABEL: define void @replace_int16_memset_test() {
55+
; CHECK-NEXT: [[CACHE_I:%.*]] = alloca [2 x i16], align 2
56+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 0
57+
; CHECK-NEXT: store i16 0, ptr [[GEP]], align 2
58+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 1
59+
; CHECK-NEXT: store i16 0, ptr [[GEP1]], align 2
60+
; CHECK-NEXT: ret void
61+
;
62+
%cache.i = alloca [2 x i16], align 2
63+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %cache.i)
64+
call void @llvm.memset.p0.i32(ptr nonnull align 2 dereferenceable(4) %cache.i, i8 0, i32 4, i1 false)
65+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %cache.i)
66+
ret void
67+
}
68+
69+
define void @replace_int_memset_test() {
70+
; CHECK-LABEL: define void @replace_int_memset_test() {
71+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
72+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
73+
; CHECK-NEXT: store i32 0, ptr [[GEP]], align 4
74+
; CHECK-NEXT: ret void
75+
;
76+
%accum.i.flat = alloca [1 x i32], align 4
77+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
78+
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false)
79+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
80+
ret void
81+
}

llvm/test/CodeGen/DirectX/llc-pipeline.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
; CHECK-OBJ-NEXT: Create Garbage Collector Module Metadata
1414

1515
; CHECK-NEXT: ModulePass Manager
16-
; CHECK-NEXT: DXIL Finalize Linkage
1716
; CHECK-NEXT: DXIL Intrinsic Expansion
1817
; CHECK-NEXT: DXIL CBuffer Access
1918
; CHECK-NEXT: DXIL Data Scalarization
@@ -23,6 +22,7 @@
2322
; CHECK-NEXT: Dominator Tree Construction
2423
; CHECK-NEXT: Scalarize vector operations
2524
; CHECK-NEXT: DXIL Legalizer
25+
; CHECK-NEXT: DXIL Finalize Linkage
2626
; CHECK-NEXT: DXIL Resource Binding Analysis
2727
; CHECK-NEXT: DXIL Module Metadata analysis
2828
; CHECK-NEXT: DXIL Shader Flag Analysis

0 commit comments

Comments
 (0)