Skip to content

Commit dd87127

Browse files
[DAGCombiner] Eliminate fp casts if we have the right fast math flags (#131345)
When floating-point operations are legalized to operations of a higher precision (e.g. f16 fadd being legalized to f32 fadd) then we get narrowing then widening operations between each operation. With the appropriate fast math flags (nnan ninf contract) we can eliminate these casts.
1 parent ec1016f commit dd87127

17 files changed

+1122
-462
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -18591,7 +18591,45 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
1859118591
return SDValue();
1859218592
}
1859318593

18594+
// Eliminate a floating-point widening of a narrowed value if the fast math
18595+
// flags allow it.
18596+
static SDValue eliminateFPCastPair(SDNode *N) {
18597+
SDValue N0 = N->getOperand(0);
18598+
EVT VT = N->getValueType(0);
18599+
18600+
unsigned NarrowingOp;
18601+
switch (N->getOpcode()) {
18602+
case ISD::FP16_TO_FP:
18603+
NarrowingOp = ISD::FP_TO_FP16;
18604+
break;
18605+
case ISD::BF16_TO_FP:
18606+
NarrowingOp = ISD::FP_TO_BF16;
18607+
break;
18608+
case ISD::FP_EXTEND:
18609+
NarrowingOp = ISD::FP_ROUND;
18610+
break;
18611+
default:
18612+
llvm_unreachable("Expected widening FP cast");
18613+
}
18614+
18615+
if (N0.getOpcode() == NarrowingOp && N0.getOperand(0).getValueType() == VT) {
18616+
const SDNodeFlags NarrowFlags = N0->getFlags();
18617+
const SDNodeFlags WidenFlags = N->getFlags();
18618+
// Narrowing can introduce inf and change the encoding of a nan, so the
18619+
// widen must have the nnan and ninf flags to indicate that we don't need to
18620+
// care about that. We are also removing a rounding step, and that requires
18621+
// both the narrow and widen to allow contraction.
18622+
if (WidenFlags.hasNoNaNs() && WidenFlags.hasNoInfs() &&
18623+
NarrowFlags.hasAllowContract() && WidenFlags.hasAllowContract()) {
18624+
return N0.getOperand(0);
18625+
}
18626+
}
18627+
18628+
return SDValue();
18629+
}
18630+
1859418631
SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
18632+
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
1859518633
SDValue N0 = N->getOperand(0);
1859618634
EVT VT = N->getValueType(0);
1859718635
SDLoc DL(N);
@@ -18643,6 +18681,9 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
1864318681
if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
1864418682
return NewVSel;
1864518683

18684+
if (SDValue CastEliminated = eliminateFPCastPair(N))
18685+
return CastEliminated;
18686+
1864618687
return SDValue();
1864718688
}
1864818689

@@ -27407,6 +27448,7 @@ SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
2740727448
}
2740827449

2740927450
SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
27451+
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
2741027452
auto Op = N->getOpcode();
2741127453
assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
2741227454
"opcode should be FP16_TO_FP or BF16_TO_FP.");
@@ -27421,6 +27463,9 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
2742127463
}
2742227464
}
2742327465

27466+
if (SDValue CastEliminated = eliminateFPCastPair(N))
27467+
return CastEliminated;
27468+
2742427469
// Sometimes constants manage to survive very late in the pipeline, e.g.,
2742527470
// because they are wrapped inside the <1 x f16> type. Try one last time to
2742627471
// get rid of them.

llvm/test/CodeGen/AArch64/bf16_fast_math.ll

+400
Large diffs are not rendered by default.

llvm/test/CodeGen/AArch64/f16-instructions.ll

+11-5
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,8 @@ define half @test_fmadd(half %a, half %b, half %c) #0 {
8484
; CHECK-CVT-SD: // %bb.0:
8585
; CHECK-CVT-SD-NEXT: fcvt s1, h1
8686
; CHECK-CVT-SD-NEXT: fcvt s0, h0
87-
; CHECK-CVT-SD-NEXT: fmul s0, s0, s1
88-
; CHECK-CVT-SD-NEXT: fcvt s1, h2
89-
; CHECK-CVT-SD-NEXT: fcvt h0, s0
90-
; CHECK-CVT-SD-NEXT: fcvt s0, h0
91-
; CHECK-CVT-SD-NEXT: fadd s0, s0, s1
87+
; CHECK-CVT-SD-NEXT: fcvt s2, h2
88+
; CHECK-CVT-SD-NEXT: fmadd s0, s0, s1, s2
9289
; CHECK-CVT-SD-NEXT: fcvt h0, s0
9390
; CHECK-CVT-SD-NEXT: ret
9491
;
@@ -1248,6 +1245,15 @@ define half @test_atan(half %a) #0 {
12481245
}
12491246

12501247
define half @test_atan2(half %a, half %b) #0 {
1248+
; CHECK-LABEL: test_atan2:
1249+
; CHECK: // %bb.0:
1250+
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
1251+
; CHECK-NEXT: fcvt s0, h0
1252+
; CHECK-NEXT: fcvt s1, h1
1253+
; CHECK-NEXT: bl atan2f
1254+
; CHECK-NEXT: fcvt h0, s0
1255+
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
1256+
; CHECK-NEXT: ret
12511257
%r = call half @llvm.atan2.f16(half %a, half %b)
12521258
ret half %r
12531259
}

llvm/test/CodeGen/AArch64/fmla.ll

+2-5
Original file line numberDiff line numberDiff line change
@@ -1114,11 +1114,8 @@ define half @fmul_f16(half %a, half %b, half %c) {
11141114
; CHECK-SD-NOFP16: // %bb.0: // %entry
11151115
; CHECK-SD-NOFP16-NEXT: fcvt s1, h1
11161116
; CHECK-SD-NOFP16-NEXT: fcvt s0, h0
1117-
; CHECK-SD-NOFP16-NEXT: fmul s0, s0, s1
1118-
; CHECK-SD-NOFP16-NEXT: fcvt s1, h2
1119-
; CHECK-SD-NOFP16-NEXT: fcvt h0, s0
1120-
; CHECK-SD-NOFP16-NEXT: fcvt s0, h0
1121-
; CHECK-SD-NOFP16-NEXT: fadd s0, s0, s1
1117+
; CHECK-SD-NOFP16-NEXT: fcvt s2, h2
1118+
; CHECK-SD-NOFP16-NEXT: fmadd s0, s0, s1, s2
11221119
; CHECK-SD-NOFP16-NEXT: fcvt h0, s0
11231120
; CHECK-SD-NOFP16-NEXT: ret
11241121
;

llvm/test/CodeGen/AArch64/fp16_fast_math.ll

+109
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,112 @@ entry:
8888
%add = fadd ninf half %x, %y
8989
ret half %add
9090
}
91+
92+
; Check that when we have the right fast math flags the converts in between the
93+
; two fadds are removed.
94+
95+
define half @normal_fadd_sequence(half %x, half %y, half %z) {
96+
; CHECK-CVT-LABEL: name: normal_fadd_sequence
97+
; CHECK-CVT: bb.0.entry:
98+
; CHECK-CVT-NEXT: liveins: $h0, $h1, $h2
99+
; CHECK-CVT-NEXT: {{ $}}
100+
; CHECK-CVT-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
101+
; CHECK-CVT-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
102+
; CHECK-CVT-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
103+
; CHECK-CVT-NEXT: [[FCVTSHr:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr [[COPY1]], implicit $fpcr
104+
; CHECK-CVT-NEXT: [[FCVTSHr1:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr [[COPY2]], implicit $fpcr
105+
; CHECK-CVT-NEXT: [[FADDSrr:%[0-9]+]]:fpr32 = nofpexcept FADDSrr killed [[FCVTSHr1]], killed [[FCVTSHr]], implicit $fpcr
106+
; CHECK-CVT-NEXT: [[FCVTHSr:%[0-9]+]]:fpr16 = nofpexcept FCVTHSr killed [[FADDSrr]], implicit $fpcr
107+
; CHECK-CVT-NEXT: [[FCVTSHr2:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr killed [[FCVTHSr]], implicit $fpcr
108+
; CHECK-CVT-NEXT: [[FCVTSHr3:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr [[COPY]], implicit $fpcr
109+
; CHECK-CVT-NEXT: [[FADDSrr1:%[0-9]+]]:fpr32 = nofpexcept FADDSrr killed [[FCVTSHr2]], killed [[FCVTSHr3]], implicit $fpcr
110+
; CHECK-CVT-NEXT: [[FCVTHSr1:%[0-9]+]]:fpr16 = nofpexcept FCVTHSr killed [[FADDSrr1]], implicit $fpcr
111+
; CHECK-CVT-NEXT: $h0 = COPY [[FCVTHSr1]]
112+
; CHECK-CVT-NEXT: RET_ReallyLR implicit $h0
113+
;
114+
; CHECK-FP16-LABEL: name: normal_fadd_sequence
115+
; CHECK-FP16: bb.0.entry:
116+
; CHECK-FP16-NEXT: liveins: $h0, $h1, $h2
117+
; CHECK-FP16-NEXT: {{ $}}
118+
; CHECK-FP16-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
119+
; CHECK-FP16-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
120+
; CHECK-FP16-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
121+
; CHECK-FP16-NEXT: [[FADDHrr:%[0-9]+]]:fpr16 = nofpexcept FADDHrr [[COPY2]], [[COPY1]], implicit $fpcr
122+
; CHECK-FP16-NEXT: [[FADDHrr1:%[0-9]+]]:fpr16 = nofpexcept FADDHrr killed [[FADDHrr]], [[COPY]], implicit $fpcr
123+
; CHECK-FP16-NEXT: $h0 = COPY [[FADDHrr1]]
124+
; CHECK-FP16-NEXT: RET_ReallyLR implicit $h0
125+
entry:
126+
%add1 = fadd half %x, %y
127+
%add2 = fadd half %add1, %z
128+
ret half %add2
129+
}
130+
131+
define half @nnan_ninf_contract_fadd_sequence(half %x, half %y, half %z) {
132+
; CHECK-CVT-LABEL: name: nnan_ninf_contract_fadd_sequence
133+
; CHECK-CVT: bb.0.entry:
134+
; CHECK-CVT-NEXT: liveins: $h0, $h1, $h2
135+
; CHECK-CVT-NEXT: {{ $}}
136+
; CHECK-CVT-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
137+
; CHECK-CVT-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
138+
; CHECK-CVT-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
139+
; CHECK-CVT-NEXT: [[FCVTSHr:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FCVTSHr [[COPY1]], implicit $fpcr
140+
; CHECK-CVT-NEXT: [[FCVTSHr1:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FCVTSHr [[COPY2]], implicit $fpcr
141+
; CHECK-CVT-NEXT: [[FADDSrr:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FADDSrr killed [[FCVTSHr1]], killed [[FCVTSHr]], implicit $fpcr
142+
; CHECK-CVT-NEXT: [[FCVTSHr2:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FCVTSHr [[COPY]], implicit $fpcr
143+
; CHECK-CVT-NEXT: [[FADDSrr1:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FADDSrr killed [[FADDSrr]], killed [[FCVTSHr2]], implicit $fpcr
144+
; CHECK-CVT-NEXT: [[FCVTHSr:%[0-9]+]]:fpr16 = nnan ninf contract nofpexcept FCVTHSr killed [[FADDSrr1]], implicit $fpcr
145+
; CHECK-CVT-NEXT: $h0 = COPY [[FCVTHSr]]
146+
; CHECK-CVT-NEXT: RET_ReallyLR implicit $h0
147+
;
148+
; CHECK-FP16-LABEL: name: nnan_ninf_contract_fadd_sequence
149+
; CHECK-FP16: bb.0.entry:
150+
; CHECK-FP16-NEXT: liveins: $h0, $h1, $h2
151+
; CHECK-FP16-NEXT: {{ $}}
152+
; CHECK-FP16-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
153+
; CHECK-FP16-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
154+
; CHECK-FP16-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
155+
; CHECK-FP16-NEXT: [[FADDHrr:%[0-9]+]]:fpr16 = nnan ninf contract nofpexcept FADDHrr [[COPY2]], [[COPY1]], implicit $fpcr
156+
; CHECK-FP16-NEXT: [[FADDHrr1:%[0-9]+]]:fpr16 = nnan ninf contract nofpexcept FADDHrr killed [[FADDHrr]], [[COPY]], implicit $fpcr
157+
; CHECK-FP16-NEXT: $h0 = COPY [[FADDHrr1]]
158+
; CHECK-FP16-NEXT: RET_ReallyLR implicit $h0
159+
entry:
160+
%add1 = fadd nnan ninf contract half %x, %y
161+
%add2 = fadd nnan ninf contract half %add1, %z
162+
ret half %add2
163+
}
164+
165+
define half @ninf_fadd_sequence(half %x, half %y, half %z) {
166+
; CHECK-CVT-LABEL: name: ninf_fadd_sequence
167+
; CHECK-CVT: bb.0.entry:
168+
; CHECK-CVT-NEXT: liveins: $h0, $h1, $h2
169+
; CHECK-CVT-NEXT: {{ $}}
170+
; CHECK-CVT-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
171+
; CHECK-CVT-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
172+
; CHECK-CVT-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
173+
; CHECK-CVT-NEXT: [[FCVTSHr:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr [[COPY1]], implicit $fpcr
174+
; CHECK-CVT-NEXT: [[FCVTSHr1:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr [[COPY2]], implicit $fpcr
175+
; CHECK-CVT-NEXT: [[FADDSrr:%[0-9]+]]:fpr32 = ninf nofpexcept FADDSrr killed [[FCVTSHr1]], killed [[FCVTSHr]], implicit $fpcr
176+
; CHECK-CVT-NEXT: [[FCVTHSr:%[0-9]+]]:fpr16 = ninf nofpexcept FCVTHSr killed [[FADDSrr]], implicit $fpcr
177+
; CHECK-CVT-NEXT: [[FCVTSHr2:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr killed [[FCVTHSr]], implicit $fpcr
178+
; CHECK-CVT-NEXT: [[FCVTSHr3:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr [[COPY]], implicit $fpcr
179+
; CHECK-CVT-NEXT: [[FADDSrr1:%[0-9]+]]:fpr32 = ninf nofpexcept FADDSrr killed [[FCVTSHr2]], killed [[FCVTSHr3]], implicit $fpcr
180+
; CHECK-CVT-NEXT: [[FCVTHSr1:%[0-9]+]]:fpr16 = ninf nofpexcept FCVTHSr killed [[FADDSrr1]], implicit $fpcr
181+
; CHECK-CVT-NEXT: $h0 = COPY [[FCVTHSr1]]
182+
; CHECK-CVT-NEXT: RET_ReallyLR implicit $h0
183+
;
184+
; CHECK-FP16-LABEL: name: ninf_fadd_sequence
185+
; CHECK-FP16: bb.0.entry:
186+
; CHECK-FP16-NEXT: liveins: $h0, $h1, $h2
187+
; CHECK-FP16-NEXT: {{ $}}
188+
; CHECK-FP16-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
189+
; CHECK-FP16-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
190+
; CHECK-FP16-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
191+
; CHECK-FP16-NEXT: [[FADDHrr:%[0-9]+]]:fpr16 = ninf nofpexcept FADDHrr [[COPY2]], [[COPY1]], implicit $fpcr
192+
; CHECK-FP16-NEXT: [[FADDHrr1:%[0-9]+]]:fpr16 = ninf nofpexcept FADDHrr killed [[FADDHrr]], [[COPY]], implicit $fpcr
193+
; CHECK-FP16-NEXT: $h0 = COPY [[FADDHrr1]]
194+
; CHECK-FP16-NEXT: RET_ReallyLR implicit $h0
195+
entry:
196+
%add1 = fadd ninf half %x, %y
197+
%add2 = fadd ninf half %add1, %z
198+
ret half %add2
199+
}

0 commit comments

Comments
 (0)