Skip to content

Commit edace61

Browse files
committed
DAG: Move scalarizeExtractedVectorLoad to TargetLowering
SimplifyDemandedVectorElts should be able to use this on loads
1 parent 077e0c1 commit edace61

File tree

3 files changed

+93
-2
lines changed

3 files changed

+93
-2
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

+12
Original file line numberDiff line numberDiff line change
@@ -5622,6 +5622,18 @@ class TargetLowering : public TargetLoweringBase {
56225622
// joining their results. SDValue() is returned when expansion did not happen.
56235623
SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;
56245624

5625+
/// Replace an extraction of a load with a narrowed load.
5626+
///
5627+
/// \param ResultVT type of the result extraction.
5628+
/// \param InVecVT type of the input vector to with bitcasts resolved.
5629+
/// \param EltNo index of the vector element to load.
5630+
/// \param OriginalLoad vector load that to be replaced.
5631+
/// \returns \p ResultVT Load on success SDValue() on failure.
5632+
SDValue scalarizeExtractedVectorLoad(EVT ResultVT, const SDLoc &DL,
5633+
EVT InVecVT, SDValue EltNo,
5634+
LoadSDNode *OriginalLoad,
5635+
SelectionDAG &DAG) const;
5636+
56255637
private:
56265638
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
56275639
const SDLoc &DL, DAGCombinerInfo &DCI) const;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -23272,8 +23272,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
2327223272
ISD::isNormalLoad(VecOp.getNode()) &&
2327323273
!Index->hasPredecessor(VecOp.getNode())) {
2327423274
auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
23275-
if (VecLoad && VecLoad->isSimple())
23276-
return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
23275+
if (VecLoad && VecLoad->isSimple()) {
23276+
if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
23277+
ExtVT, SDLoc(N), VecVT, Index, VecLoad, DAG)) {
23278+
++OpsNarrowed;
23279+
return Scalarized;
23280+
}
23281+
}
2327723282
}
2327823283

2327923284
// Perform only after legalization to ensure build_vector / vector_shuffle

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

+74
Original file line numberDiff line numberDiff line change
@@ -12114,3 +12114,77 @@ SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
1211412114
SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
1211512115
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
1211612116
}
12117+
12118+
SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
12119+
const SDLoc &DL,
12120+
EVT InVecVT, SDValue EltNo,
12121+
LoadSDNode *OriginalLoad,
12122+
SelectionDAG &DAG) const {
12123+
assert(OriginalLoad->isSimple());
12124+
12125+
EVT VecEltVT = InVecVT.getVectorElementType();
12126+
12127+
// If the vector element type is not a multiple of a byte then we are unable
12128+
// to correctly compute an address to load only the extracted element as a
12129+
// scalar.
12130+
if (!VecEltVT.isByteSized())
12131+
return SDValue();
12132+
12133+
ISD::LoadExtType ExtTy =
12134+
ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
12135+
if (!isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
12136+
!shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
12137+
return SDValue();
12138+
12139+
Align Alignment = OriginalLoad->getAlign();
12140+
MachinePointerInfo MPI;
12141+
if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
12142+
int Elt = ConstEltNo->getZExtValue();
12143+
unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
12144+
MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
12145+
Alignment = commonAlignment(Alignment, PtrOff);
12146+
} else {
12147+
// Discard the pointer info except the address space because the memory
12148+
// operand can't represent this new access since the offset is variable.
12149+
MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
12150+
Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
12151+
}
12152+
12153+
unsigned IsFast = 0;
12154+
if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
12155+
OriginalLoad->getAddressSpace(), Alignment,
12156+
OriginalLoad->getMemOperand()->getFlags(), &IsFast) ||
12157+
!IsFast)
12158+
return SDValue();
12159+
12160+
SDValue NewPtr =
12161+
getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo);
12162+
12163+
// We are replacing a vector load with a scalar load. The new load must have
12164+
// identical memory op ordering to the original.
12165+
SDValue Load;
12166+
if (ResultVT.bitsGT(VecEltVT)) {
12167+
// If the result type of vextract is wider than the load, then issue an
12168+
// extending load instead.
12169+
ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT)
12170+
? ISD::ZEXTLOAD
12171+
: ISD::EXTLOAD;
12172+
Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
12173+
NewPtr, MPI, VecEltVT, Alignment,
12174+
OriginalLoad->getMemOperand()->getFlags(),
12175+
OriginalLoad->getAAInfo());
12176+
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
12177+
} else {
12178+
// The result type is narrower or the same width as the vector element
12179+
Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
12180+
Alignment, OriginalLoad->getMemOperand()->getFlags(),
12181+
OriginalLoad->getAAInfo());
12182+
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
12183+
if (ResultVT.bitsLT(VecEltVT))
12184+
Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
12185+
else
12186+
Load = DAG.getBitcast(ResultVT, Load);
12187+
}
12188+
12189+
return Load;
12190+
}

0 commit comments

Comments
 (0)