Skip to content

Commit 02e316c

Browse files
authored
[DirectX] legalize memset (#136244)
fixes #136243 This change converts memset into a series of geps and stores It is intentionally limited to memsets of fixed size It also converts the byte stores to type stores. DXIL does not support i8 plus this reduces the total number of gep and store instructions. This change also moves DXILFinalizeLinkage to run after Legalization to clean up any dead intrinsic definitions.
1 parent 60c9f3f commit 02e316c

File tree

4 files changed

+210
-6
lines changed

4 files changed

+210
-6
lines changed

llvm/lib/Target/DirectX/DXILLegalizePass.cpp

+83-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "llvm/IR/InstIterator.h"
1414
#include "llvm/IR/Instruction.h"
1515
#include "llvm/IR/Instructions.h"
16+
#include "llvm/IR/Module.h"
1617
#include "llvm/Pass.h"
1718
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
1819
#include <functional>
@@ -174,16 +175,22 @@ static void upcastI8AllocasAndUses(Instruction &I,
174175

175176
Type *SmallestType = nullptr;
176177

177-
// Gather all cast targets
178178
for (User *U : AI->users()) {
179179
auto *Load = dyn_cast<LoadInst>(U);
180180
if (!Load)
181181
continue;
182182
for (User *LU : Load->users()) {
183-
auto *Cast = dyn_cast<CastInst>(LU);
184-
if (!Cast)
183+
Type *Ty = nullptr;
184+
if (auto *Cast = dyn_cast<CastInst>(LU))
185+
Ty = Cast->getType();
186+
if (CallInst *CI = dyn_cast<CallInst>(LU)) {
187+
if (CI->getIntrinsicID() == Intrinsic::memset)
188+
Ty = Type::getInt32Ty(CI->getContext());
189+
}
190+
191+
if (!Ty)
185192
continue;
186-
Type *Ty = Cast->getType();
193+
187194
if (!SmallestType ||
188195
Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
189196
SmallestType = Ty;
@@ -239,6 +246,77 @@ downcastI64toI32InsertExtractElements(Instruction &I,
239246
}
240247
}
241248

249+
static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
250+
ConstantInt *SizeCI,
251+
DenseMap<Value *, Value *> &ReplacedValues) {
252+
LLVMContext &Ctx = Builder.getContext();
253+
[[maybe_unused]] const DataLayout &DL =
254+
Builder.GetInsertBlock()->getModule()->getDataLayout();
255+
[[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue();
256+
257+
AllocaInst *Alloca = dyn_cast<AllocaInst>(Dst);
258+
259+
assert(Alloca && "Expected memset on an Alloca");
260+
assert(OrigSize == Alloca->getAllocationSize(DL)->getFixedValue() &&
261+
"Expected for memset size to match DataLayout size");
262+
263+
Type *AllocatedTy = Alloca->getAllocatedType();
264+
ArrayType *ArrTy = dyn_cast<ArrayType>(AllocatedTy);
265+
assert(ArrTy && "Expected Alloca for an Array Type");
266+
267+
Type *ElemTy = ArrTy->getElementType();
268+
uint64_t Size = ArrTy->getArrayNumElements();
269+
270+
[[maybe_unused]] uint64_t ElemSize = DL.getTypeStoreSize(ElemTy);
271+
272+
assert(ElemSize > 0 && "Size must be set");
273+
assert(OrigSize == ElemSize * Size && "Size in bytes must match");
274+
275+
Value *TypedVal = Val;
276+
277+
if (Val->getType() != ElemTy) {
278+
if (ReplacedValues[Val]) {
279+
// Note for i8 replacements if we know them we should use them.
280+
// Further if this is a constant ReplacedValues will return null
281+
// so we will stick to TypedVal = Val
282+
TypedVal = ReplacedValues[Val];
283+
284+
} else {
285+
// This case Val is a ConstantInt so the cast folds away.
286+
// However if we don't do the cast the store below ends up being
287+
// an i8.
288+
TypedVal = Builder.CreateIntCast(Val, ElemTy, false);
289+
}
290+
}
291+
292+
for (uint64_t I = 0; I < Size; ++I) {
293+
Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
294+
Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep");
295+
Builder.CreateStore(TypedVal, Ptr);
296+
}
297+
}
298+
299+
static void removeMemSet(Instruction &I,
300+
SmallVectorImpl<Instruction *> &ToRemove,
301+
DenseMap<Value *, Value *> &ReplacedValues) {
302+
303+
CallInst *CI = dyn_cast<CallInst>(&I);
304+
if (!CI)
305+
return;
306+
307+
Intrinsic::ID ID = CI->getIntrinsicID();
308+
if (ID != Intrinsic::memset)
309+
return;
310+
311+
IRBuilder<> Builder(&I);
312+
Value *Dst = CI->getArgOperand(0);
313+
Value *Val = CI->getArgOperand(1);
314+
ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(2));
315+
assert(Size && "Expected Size to be a ConstantInt");
316+
emitMemsetExpansion(Builder, Dst, Val, Size, ReplacedValues);
317+
ToRemove.push_back(CI);
318+
}
319+
242320
namespace {
243321
class DXILLegalizationPipeline {
244322

@@ -270,6 +348,7 @@ class DXILLegalizationPipeline {
270348
LegalizationPipeline.push_back(fixI8UseChain);
271349
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
272350
LegalizationPipeline.push_back(legalizeFreeze);
351+
LegalizationPipeline.push_back(removeMemSet);
273352
}
274353
};
275354

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ class DirectXPassConfig : public TargetPassConfig {
9898

9999
FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
100100
void addCodeGenPrepare() override {
101-
addPass(createDXILFinalizeLinkageLegacyPass());
102101
addPass(createDXILIntrinsicExpansionLegacyPass());
103102
addPass(createDXILCBufferAccessLegacyPass());
104103
addPass(createDXILDataScalarizationLegacyPass());
@@ -109,6 +108,7 @@ class DirectXPassConfig : public TargetPassConfig {
109108
addPass(createScalarizerPass(DxilScalarOptions));
110109
addPass(createDXILForwardHandleAccessesLegacyPass());
111110
addPass(createDXILLegalizeLegacyPass());
111+
addPass(createDXILFinalizeLinkageLegacyPass());
112112
addPass(createDXILTranslateMetadataLegacyPass());
113113
addPass(createDXILOpLoweringLegacyPass());
114114
addPass(createDXILPrepareModulePass());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -dxil-legalize -dxil-finalize-linkage -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
4+
define void @replace_float_memset_test() #0 {
5+
; CHECK-LABEL: define void @replace_float_memset_test(
6+
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
7+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x float], align 4
8+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr nonnull [[ACCUM_I_FLAT]])
9+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 0
10+
; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP]], align 4
11+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 1
12+
; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP1]], align 4
13+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull [[ACCUM_I_FLAT]])
14+
; CHECK-NEXT: ret void
15+
;
16+
%accum.i.flat = alloca [2 x float], align 4
17+
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %accum.i.flat)
18+
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 8, i1 false)
19+
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %accum.i.flat)
20+
ret void
21+
}
22+
23+
define void @replace_half_memset_test() #0 {
24+
; CHECK-LABEL: define void @replace_half_memset_test(
25+
; CHECK-SAME: ) #[[ATTR0]] {
26+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x half], align 4
27+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
28+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 0
29+
; CHECK-NEXT: store half 0xH0000, ptr [[GEP]], align 2
30+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 1
31+
; CHECK-NEXT: store half 0xH0000, ptr [[GEP1]], align 2
32+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
33+
; CHECK-NEXT: ret void
34+
;
35+
%accum.i.flat = alloca [2 x half], align 4
36+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
37+
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false)
38+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
39+
ret void
40+
}
41+
42+
define void @replace_double_memset_test() #0 {
43+
; CHECK-LABEL: define void @replace_double_memset_test(
44+
; CHECK-SAME: ) #[[ATTR0]] {
45+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x double], align 4
46+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 16, ptr nonnull [[ACCUM_I_FLAT]])
47+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 0
48+
; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP]], align 8
49+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 1
50+
; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP1]], align 8
51+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 16, ptr nonnull [[ACCUM_I_FLAT]])
52+
; CHECK-NEXT: ret void
53+
;
54+
%accum.i.flat = alloca [2 x double], align 4
55+
call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %accum.i.flat)
56+
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 16, i1 false)
57+
call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %accum.i.flat)
58+
ret void
59+
}
60+
61+
define void @replace_int16_memset_test() #0 {
62+
; CHECK-LABEL: define void @replace_int16_memset_test(
63+
; CHECK-SAME: ) #[[ATTR0]] {
64+
; CHECK-NEXT: [[CACHE_I:%.*]] = alloca [2 x i16], align 2
65+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[CACHE_I]])
66+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 0
67+
; CHECK-NEXT: store i16 0, ptr [[GEP]], align 2
68+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 1
69+
; CHECK-NEXT: store i16 0, ptr [[GEP1]], align 2
70+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[CACHE_I]])
71+
; CHECK-NEXT: ret void
72+
;
73+
%cache.i = alloca [2 x i16], align 2
74+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %cache.i)
75+
call void @llvm.memset.p0.i32(ptr nonnull align 2 dereferenceable(4) %cache.i, i8 0, i32 4, i1 false)
76+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %cache.i)
77+
ret void
78+
}
79+
80+
define void @replace_int_memset_test() #0 {
81+
; CHECK-LABEL: define void @replace_int_memset_test(
82+
; CHECK-SAME: ) #[[ATTR0]] {
83+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
84+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
85+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
86+
; CHECK-NEXT: store i32 0, ptr [[GEP]], align 4
87+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
88+
; CHECK-NEXT: ret void
89+
;
90+
%accum.i.flat = alloca [1 x i32], align 4
91+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
92+
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false)
93+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
94+
ret void
95+
}
96+
97+
define void @replace_int_memset_to_var_test() #0 {
98+
; CHECK-LABEL: define void @replace_int_memset_to_var_test(
99+
; CHECK-SAME: ) #[[ATTR0]] {
100+
; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4
101+
; CHECK-NEXT: [[I:%.*]] = alloca i32, align 4
102+
; CHECK-NEXT: store i32 1, ptr [[I]], align 4
103+
; CHECK-NEXT: [[I8_LOAD:%.*]] = load i32, ptr [[I]], align 4
104+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
105+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0
106+
; CHECK-NEXT: store i32 [[I8_LOAD]], ptr [[GEP]], align 4
107+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]])
108+
; CHECK-NEXT: ret void
109+
;
110+
%accum.i.flat = alloca [1 x i32], align 4
111+
%i = alloca i8, align 4
112+
store i8 1, ptr %i
113+
%i8.load = load i8, ptr %i
114+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat)
115+
call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 %i8.load, i32 4, i1 false)
116+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat)
117+
ret void
118+
}
119+
120+
attributes #0 = {"hlsl.export"}
121+
122+
123+
declare void @llvm.lifetime.end.p0(i64 immarg, ptr captures(none))
124+
declare void @llvm.lifetime.start.p0(i64 immarg, ptr captures(none))
125+
declare void @llvm.memset.p0.i32(ptr writeonly captures(none), i8, i32, i1 immarg)

llvm/test/CodeGen/DirectX/llc-pipeline.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
; CHECK-OBJ-NEXT: Create Garbage Collector Module Metadata
1414

1515
; CHECK-NEXT: ModulePass Manager
16-
; CHECK-NEXT: DXIL Finalize Linkage
1716
; CHECK-NEXT: DXIL Intrinsic Expansion
1817
; CHECK-NEXT: DXIL CBuffer Access
1918
; CHECK-NEXT: DXIL Data Scalarization
@@ -24,6 +23,7 @@
2423
; CHECK-NEXT: Scalarize vector operations
2524
; CHECK-NEXT: DXIL Forward Handle Accesses
2625
; CHECK-NEXT: DXIL Legalizer
26+
; CHECK-NEXT: DXIL Finalize Linkage
2727
; CHECK-NEXT: DXIL Resources Analysis
2828
; CHECK-NEXT: DXIL Module Metadata analysis
2929
; CHECK-NEXT: DXIL Shader Flag Analysis

0 commit comments

Comments
 (0)