@@ -3220,8 +3220,30 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
3220
3220
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA (SDNode *N, SDValue &Lo,
3221
3221
SDValue &Hi) {
3222
3222
SDLoc DL (N);
3223
- SDValue Expanded = TLI.expandPartialReduceMLA (N, DAG);
3224
- std::tie (Lo, Hi) = DAG.SplitVector (Expanded, DL);
3223
+ SDValue Acc = N->getOperand (0 );
3224
+ SDValue Input1 = N->getOperand (1 );
3225
+ SDValue Input2 = N->getOperand (2 );
3226
+
3227
+ SDValue AccLo, AccHi;
3228
+ std::tie (AccLo, AccHi) = DAG.SplitVector (Acc, DL);
3229
+ unsigned Opcode = N->getOpcode ();
3230
+
3231
+ // If the input types don't need splitting, just accumulate into the
3232
+ // low part of the accumulator.
3233
+ if (getTypeAction (Input1.getValueType ()) != TargetLowering::TypeSplitVector) {
3234
+ Lo = DAG.getNode (Opcode, DL, AccLo.getValueType (), AccLo, Input1, Input2);
3235
+ Hi = AccHi;
3236
+ return ;
3237
+ }
3238
+
3239
+ SDValue Input1Lo, Input1Hi;
3240
+ SDValue Input2Lo, Input2Hi;
3241
+ std::tie (Input1Lo, Input1Hi) = DAG.SplitVector (Input1, DL);
3242
+ std::tie (Input2Lo, Input2Hi) = DAG.SplitVector (Input2, DL);
3243
+ EVT ResultVT = AccLo.getValueType ();
3244
+
3245
+ Lo = DAG.getNode (Opcode, DL, ResultVT, AccLo, Input1Lo, Input2Lo);
3246
+ Hi = DAG.getNode (Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
3225
3247
}
3226
3248
3227
3249
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE (SDNode *N) {
@@ -4501,7 +4523,20 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
4501
4523
}
4502
4524
4503
4525
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA (SDNode *N) {
4504
- return TLI.expandPartialReduceMLA (N, DAG);
4526
+ SDValue Acc = N->getOperand (0 );
4527
+ assert (getTypeAction (Acc.getValueType ()) != TargetLowering::TypeSplitVector &&
4528
+ " Accumulator should already be a legal type, and shouldn't need "
4529
+ " further splitting" );
4530
+
4531
+ SDLoc DL (N);
4532
+ SDValue Input1Lo, Input1Hi, Input2Lo, Input2Hi;
4533
+ std::tie (Input1Lo, Input1Hi) = DAG.SplitVector (N->getOperand (1 ), DL);
4534
+ std::tie (Input2Lo, Input2Hi) = DAG.SplitVector (N->getOperand (2 ), DL);
4535
+ unsigned Opcode = N->getOpcode ();
4536
+ EVT ResultVT = Acc.getValueType ();
4537
+
4538
+ SDValue Lo = DAG.getNode (Opcode, DL, ResultVT, Acc, Input1Lo, Input2Lo);
4539
+ return DAG.getNode (Opcode, DL, ResultVT, Lo, Input1Hi, Input2Hi);
4505
4540
}
4506
4541
4507
4542
// ===----------------------------------------------------------------------===//
0 commit comments