Skip to content

Commit b6f65f0

Browse files
[SelectionDAG] Improve type legalisation for PARTIAL_REDUCE_MLA (#130935)
Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes. This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used. --------- Co-authored-by: James Chesterman <james.chesterman@arm.com>
1 parent 001cc34 commit b6f65f0

File tree

5 files changed

+606
-679
lines changed

5 files changed

+606
-679
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

+38-3
Original file line numberDiff line numberDiff line change
@@ -3220,8 +3220,30 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
32203220
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
32213221
SDValue &Hi) {
32223222
SDLoc DL(N);
3223-
SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
3224-
std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
3223+
SDValue Acc = N->getOperand(0);
3224+
SDValue Input1 = N->getOperand(1);
3225+
SDValue Input2 = N->getOperand(2);
3226+
3227+
SDValue AccLo, AccHi;
3228+
std::tie(AccLo, AccHi) = DAG.SplitVector(Acc, DL);
3229+
unsigned Opcode = N->getOpcode();
3230+
3231+
// If the input types don't need splitting, just accumulate into the
3232+
// low part of the accumulator.
3233+
if (getTypeAction(Input1.getValueType()) != TargetLowering::TypeSplitVector) {
3234+
Lo = DAG.getNode(Opcode, DL, AccLo.getValueType(), AccLo, Input1, Input2);
3235+
Hi = AccHi;
3236+
return;
3237+
}
3238+
3239+
SDValue Input1Lo, Input1Hi;
3240+
SDValue Input2Lo, Input2Hi;
3241+
std::tie(Input1Lo, Input1Hi) = DAG.SplitVector(Input1, DL);
3242+
std::tie(Input2Lo, Input2Hi) = DAG.SplitVector(Input2, DL);
3243+
EVT ResultVT = AccLo.getValueType();
3244+
3245+
Lo = DAG.getNode(Opcode, DL, ResultVT, AccLo, Input1Lo, Input2Lo);
3246+
Hi = DAG.getNode(Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
32253247
}
32263248

32273249
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
@@ -4501,7 +4523,20 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
45014523
}
45024524

45034525
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
4504-
return TLI.expandPartialReduceMLA(N, DAG);
4526+
SDValue Acc = N->getOperand(0);
4527+
assert(getTypeAction(Acc.getValueType()) != TargetLowering::TypeSplitVector &&
4528+
"Accumulator should already be a legal type, and shouldn't need "
4529+
"further splitting");
4530+
4531+
SDLoc DL(N);
4532+
SDValue Input1Lo, Input1Hi, Input2Lo, Input2Hi;
4533+
std::tie(Input1Lo, Input1Hi) = DAG.SplitVector(N->getOperand(1), DL);
4534+
std::tie(Input2Lo, Input2Hi) = DAG.SplitVector(N->getOperand(2), DL);
4535+
unsigned Opcode = N->getOpcode();
4536+
EVT ResultVT = Acc.getValueType();
4537+
4538+
SDValue Lo = DAG.getNode(Opcode, DL, ResultVT, Acc, Input1Lo, Input2Lo);
4539+
return DAG.getNode(Opcode, DL, ResultVT, Lo, Input1Hi, Input2Hi);
45054540
}
45064541

45074542
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)