diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index c29cba6f675c5..f069685a9a2ce 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -244,11 +244,10 @@ class PointerReplacer { void replacePointer(Value *V); private: - bool collectUsersRecursive(Instruction &I); void replace(Instruction *I); - Value *getReplacement(Value *I); + Value *getReplacement(Value *V) const { return WorkMap.lookup(V); } bool isAvailable(Instruction *I) const { - return I == &Root || Worklist.contains(I); + return I == &Root || UsersToReplace.contains(I); } bool isEqualOrValidAddrSpaceCast(const Instruction *I, @@ -260,8 +259,7 @@ class PointerReplacer { return (FromAS == ToAS) || IC.isValidAddrSpaceCast(FromAS, ToAS); } - SmallPtrSet ValuesToRevisit; - SmallSetVector Worklist; + SmallSetVector UsersToReplace; MapVector WorkMap; InstCombinerImpl &IC; Instruction &Root; @@ -270,80 +268,131 @@ class PointerReplacer { } // end anonymous namespace bool PointerReplacer::collectUsers() { - if (!collectUsersRecursive(Root)) - return false; - - // Ensure that all outstanding (indirect) users of I - // are inserted into the Worklist. Return false - // otherwise. - return llvm::set_is_subset(ValuesToRevisit, Worklist); -} + SmallVector Worklist; + SmallSetVector ValuesToRevisit; + + auto PushUsersToWorklist = [&](Instruction *Inst) { + for (auto *U : Inst->users()) { + if (auto *I = dyn_cast(U)) { + if (!isAvailable(I) && !ValuesToRevisit.contains(I)) + Worklist.emplace_back(I); + } + } + }; -bool PointerReplacer::collectUsersRecursive(Instruction &I) { - for (auto *U : I.users()) { - auto *Inst = cast(&*U); + PushUsersToWorklist(&Root); + while (!Worklist.empty()) { + auto *Inst = Worklist.pop_back_val(); + if (!Inst) + return false; + if (isAvailable(Inst)) + continue; if (auto *Load = dyn_cast(Inst)) { if (Load->isVolatile()) return false; - Worklist.insert(Load); + UsersToReplace.insert(Load); } else if (auto *PHI = dyn_cast(Inst)) { // All incoming values must be instructions for replacability if (any_of(PHI->incoming_values(), [](Value *V) { return !isa(V); })) return false; - // If at least one incoming value of the PHI is not in Worklist, - // store the PHI for revisiting and skip this iteration of the - // loop. - if (any_of(PHI->incoming_values(), [this](Value *V) { - return !isAvailable(cast(V)); - })) { - ValuesToRevisit.insert(Inst); + // If all incoming values are available, mark this PHI as + // replacable and push it's users into the worklist. + if (all_of(PHI->incoming_values(), + [&](Value *V) { return isAvailable(cast(V)); })) { + UsersToReplace.insert(PHI); + PushUsersToWorklist(PHI); continue; } - Worklist.insert(PHI); - if (!collectUsersRecursive(*PHI)) - return false; - } else if (auto *SI = dyn_cast(Inst)) { - if (!isa(SI->getTrueValue()) || - !isa(SI->getFalseValue())) + // Not all incoming values are available. If this PHI was already + // visited prior to this iteration, return false. + if (!ValuesToRevisit.insert(PHI)) return false; - if (!isAvailable(cast(SI->getTrueValue())) || - !isAvailable(cast(SI->getFalseValue()))) { - ValuesToRevisit.insert(Inst); - continue; + // Push PHI back into the stack, followed by unavailable + // incoming values. + Worklist.emplace_back(PHI); + for (unsigned Idx = 0; Idx < PHI->getNumIncomingValues(); ++Idx) { + auto *IncomingValue = cast(PHI->getIncomingValue(Idx)); + if (UsersToReplace.contains(IncomingValue)) + continue; + if (!ValuesToRevisit.insert(IncomingValue)) + return false; + Worklist.emplace_back(IncomingValue); } - Worklist.insert(SI); - if (!collectUsersRecursive(*SI)) - return false; - } else if (isa(Inst)) { - Worklist.insert(Inst); - if (!collectUsersRecursive(*Inst)) + } else if (auto *SI = dyn_cast(Inst)) { + auto *TrueInst = dyn_cast(SI->getTrueValue()); + auto *FalseInst = dyn_cast(SI->getFalseValue()); + if (!TrueInst || !FalseInst) return false; + + UsersToReplace.insert(SI); + PushUsersToWorklist(SI); + } else if (auto *GEP = dyn_cast(Inst)) { + UsersToReplace.insert(GEP); + PushUsersToWorklist(GEP); } else if (auto *MI = dyn_cast(Inst)) { if (MI->isVolatile()) return false; - Worklist.insert(Inst); + UsersToReplace.insert(Inst); } else if (isEqualOrValidAddrSpaceCast(Inst, FromAS)) { - Worklist.insert(Inst); - if (!collectUsersRecursive(*Inst)) - return false; + UsersToReplace.insert(Inst); + PushUsersToWorklist(Inst); } else if (Inst->isLifetimeStartOrEnd()) { continue; } else { // TODO: For arbitrary uses with address space mismatches, should we check // if we can introduce a valid addrspacecast? - LLVM_DEBUG(dbgs() << "Cannot handle pointer user: " << *U << '\n'); + LLVM_DEBUG(dbgs() << "Cannot handle pointer user: " << *Inst << '\n'); return false; } } - return true; + return llvm::set_is_subset(ValuesToRevisit, UsersToReplace); } -Value *PointerReplacer::getReplacement(Value *V) { return WorkMap.lookup(V); } +void PointerReplacer::replacePointer(Value *V) { +#ifndef NDEBUG + auto *PT = cast(Root.getType()); + auto *NT = cast(V->getType()); + assert(PT != NT && "Invalid usage"); +#endif + + WorkMap[&Root] = V; + SmallVector Worklist; + SetVector PostOrderWorklist; + SmallPtrSet Visited; + + // Perform a postorder traversal of the users of Root. + Worklist.push_back(&Root); + while (!Worklist.empty()) { + Instruction *I = Worklist.back(); + + // If I has not been processed before, push each of its + // replacable users into the worklist. + if (Visited.insert(I).second) { + for (auto *U : I->users()) { + assert(isa(U) && + "User should not have been" + " collected as it is not an instruction."); + auto *UserInst = cast(U); + if (UsersToReplace.contains(UserInst)) + Worklist.push_back(UserInst); + } + // Otherwise, users of I have already been pushed into + // the PostOrderWorklist. Push I as well. + } else { + PostOrderWorklist.insert(I); + Worklist.pop_back(); + } + } + + // Replace pointers in reverse-postorder. + for (Instruction *I : llvm::reverse(PostOrderWorklist)) + replace(I); +} void PointerReplacer::replace(Instruction *I) { if (getReplacement(I)) @@ -365,12 +414,20 @@ void PointerReplacer::replace(Instruction *I) { // replacement (new value). WorkMap[NewI] = NewI; } else if (auto *PHI = dyn_cast(I)) { - Type *NewTy = getReplacement(PHI->getIncomingValue(0))->getType(); - auto *NewPHI = PHINode::Create(NewTy, PHI->getNumIncomingValues(), - PHI->getName(), PHI->getIterator()); + // Create a new PHI by replacing any incoming value that is a user of the + // root pointer and has a replacement. + SmallVector IncomingValues; + for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I) { + Value *V = getReplacement(PHI->getIncomingValue(I)); + if (!V) + V = PHI->getIncomingValue(I); + IncomingValues.push_back(V); + } + auto *NewPHI = PHINode::Create(IncomingValues[0]->getType(), + PHI->getNumIncomingValues(), + PHI->getName() + ".r", PHI->getIterator()); for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I) - NewPHI->addIncoming(getReplacement(PHI->getIncomingValue(I)), - PHI->getIncomingBlock(I)); + NewPHI->addIncoming(IncomingValues[I], PHI->getIncomingBlock(I)); WorkMap[PHI] = NewPHI; } else if (auto *GEP = dyn_cast(I)) { auto *V = getReplacement(GEP->getPointerOperand()); @@ -431,22 +488,12 @@ void PointerReplacer::replace(Instruction *I) { } } else { + dbgs() << "Instruction " << *I + << " is not supported in PointerReplacer::replace\n"; llvm_unreachable("should never reach here"); } } -void PointerReplacer::replacePointer(Value *V) { -#ifndef NDEBUG - auto *PT = cast(Root.getType()); - auto *NT = cast(V->getType()); - assert(PT != NT && "Invalid usage"); -#endif - WorkMap[&Root] = V; - - for (Instruction *Workitem : Worklist) - replace(Workitem); -} - Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { if (auto *I = simplifyAllocaArraySize(*this, AI, DT)) return I; diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/ptr-replace-alloca.ll b/llvm/test/Transforms/InstCombine/AMDGPU/ptr-replace-alloca.ll new file mode 100644 index 0000000000000..1a7a961a9b93a --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AMDGPU/ptr-replace-alloca.ll @@ -0,0 +1,48 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=instcombine -S < %s | FileCheck %s + +%struct.type = type { [256 x <2 x i64>] } +@g1 = external hidden addrspace(3) global %struct.type, align 16 + +; This test requires the PtrReplacer to replace users in an RPO traversal. +; Furthermore, %ptr.else need not to be replaced so it must be retained in +; %ptr.sink. +define <2 x i64> @func(ptr addrspace(4) byref(%struct.type) align 16 %0, i1 %cmp.0) { +; CHECK-LABEL: define <2 x i64> @func( +; CHECK-SAME: ptr addrspace(4) byref([[STRUCT_TYPE:%.*]]) align 16 [[TMP0:%.*]], i1 [[CMP_0:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: br i1 [[CMP_0]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] +; CHECK: [[IF_THEN]]: +; CHECK-NEXT: [[VAL_THEN:%.*]] = addrspacecast ptr addrspace(4) [[TMP0]] to ptr +; CHECK-NEXT: br label %[[SINK:.*]] +; CHECK: [[IF_ELSE]]: +; CHECK-NEXT: [[PTR_ELSE:%.*]] = load ptr, ptr addrspace(3) getelementptr inbounds nuw (i8, ptr addrspace(3) @g1, i32 32), align 16 +; CHECK-NEXT: br label %[[SINK]] +; CHECK: [[SINK]]: +; CHECK-NEXT: [[PTR_SINK:%.*]] = phi ptr [ [[PTR_ELSE]], %[[IF_ELSE]] ], [ [[VAL_THEN]], %[[IF_THEN]] ] +; CHECK-NEXT: [[VAL_SINK:%.*]] = load <2 x i64>, ptr [[PTR_SINK]], align 16 +; CHECK-NEXT: ret <2 x i64> [[VAL_SINK]] +; +entry: + %coerce = alloca %struct.type, align 16, addrspace(5) + call void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) align 16 %coerce, ptr addrspace(4) align 16 %0, i64 4096, i1 false) + br i1 %cmp.0, label %if.then, label %if.else + +if.then: ; preds = %entry + %ptr.then = getelementptr inbounds i8, ptr addrspace(5) %coerce, i64 0 + %val.then = addrspacecast ptr addrspace(5) %ptr.then to ptr + br label %sink + +if.else: ; preds = %entry + %ptr.else = load ptr, ptr addrspace(3) getelementptr inbounds nuw (i8, ptr addrspace(3) @g1, i32 32), align 16 + %val.else = getelementptr inbounds nuw i8, ptr %ptr.else, i64 0 + br label %sink + +sink: + %ptr.sink = phi ptr [ %val.else, %if.else ], [ %val.then, %if.then ] + %val.sink = load <2 x i64>, ptr %ptr.sink, align 16 + ret <2 x i64> %val.sink +} + +; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: readwrite) +declare void @llvm.memcpy.p5.p4.i64(ptr addrspace(5) noalias writeonly captures(none), ptr addrspace(4) noalias readonly captures(none), i64, i1 immarg) #0