From 9f9ad7f6b6548269294e304bb53c2b4cabb2cb3e Mon Sep 17 00:00:00 2001 From: rbajpai Date: Tue, 22 Apr 2025 12:28:20 +0530 Subject: [PATCH] [NVPTX] Add fma mix precision intrinsics This change adds "fma" mix precision operations. --- llvm/include/llvm/IR/IntrinsicsNVVM.td | 20 ++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 21 ++ llvm/test/CodeGen/NVPTX/fma-mix-precision.ll | 278 +++++++++++++++++++ 3 files changed, 319 insertions(+) create mode 100644 llvm/test/CodeGen/NVPTX/fma-mix-precision.ll diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index d09e1da457249..5d717bf11e3da 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1180,6 +1180,26 @@ let TargetPrefix = "nvvm" in { [IntrNoMem, IntrSpeculatable]>; } + // Mixed-precision fma intrinsics for half and bfloat16 to float + foreach rnd = ["rn", "rz", "rm", "rp"] in { + foreach sat = ["", "_sat"] in { + // Half-precision to float + def int_nvvm_fma_#rnd#sat#_h_f + : ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_h_f">, + DefaultAttrsIntrinsic<[llvm_float_ty], + [llvm_half_ty, llvm_half_ty, llvm_float_ty], + [IntrNoMem, IntrSpeculatable]>; + + // BFloat16 to float + def int_nvvm_fma_#rnd#sat#_bf_f + : ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_bf_f">, + DefaultAttrsIntrinsic<[llvm_float_ty], + [llvm_bfloat_ty, llvm_bfloat_ty, + llvm_float_ty], + [IntrNoMem, IntrSpeculatable]>; + } + } + // // Rcp // diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 4ba3e6f06bb5f..4b0693ac04671 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1510,6 +1510,27 @@ multiclass FMA_INST { defm INT_NVVM_FMA : FMA_INST; +// Define mixed-precision fma instructions for half and bfloat16 to float +foreach rnd = ["rn", "rz", "rm", "rp"] in { + foreach sat = ["", "_sat"] in { + // Half-precision to float + def INT_NVVM_FMA_#!toupper(rnd#sat)#_H_F + : F_MATH_3<"fma."#rnd#!subst( + "_", ".", sat)#".f32.f16 \t$dst, $src0, $src1, $src2;", + Float32Regs, Int16Regs, Int16Regs, Float32Regs, + !cast("int_nvvm_fma_"#rnd#sat#"_h_f"), + [hasPTX<86>, hasSM<100>]>; + + // BFloat16 to float + def INT_NVVM_FMA_#!toupper(rnd#sat)#_BF_F + : F_MATH_3<"fma."#rnd#!subst( + "_", ".", sat)#".f32.bf16 \t$dst, $src0, $src1, $src2;", + Float32Regs, Int16Regs, Int16Regs, Float32Regs, + !cast("int_nvvm_fma_"#rnd#sat#"_bf_f"), + [hasPTX<86>, hasSM<100>]>; + } +} + // // Rcp // diff --git a/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll b/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll new file mode 100644 index 0000000000000..6c9341488fde1 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll @@ -0,0 +1,278 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | FileCheck %s + +; Basic f32.f16 variants with different rounding modes +define float @test_fma_rn_h_f(half %a, half %b, float %c) { +; CHECK-LABEL: test_fma_rn_h_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_h_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_h_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_h_f_param_2]; +; CHECK-NEXT: fma.rn.f32.f16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rn.h.f(half %a, half %b, float %c) + ret float %res +} + +define float @test_fma_rz_h_f(half %a, half %b, float %c) { +; CHECK-LABEL: test_fma_rz_h_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_h_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_h_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_h_f_param_2]; +; CHECK-NEXT: fma.rz.f32.f16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rz.h.f(half %a, half %b, float %c) + ret float %res +} + +define float @test_fma_rm_h_f(half %a, half %b, float %c) { +; CHECK-LABEL: test_fma_rm_h_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_h_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_h_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_h_f_param_2]; +; CHECK-NEXT: fma.rm.f32.f16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rm.h.f(half %a, half %b, float %c) + ret float %res +} + +define float @test_fma_rp_h_f(half %a, half %b, float %c) { +; CHECK-LABEL: test_fma_rp_h_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_h_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_h_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_h_f_param_2]; +; CHECK-NEXT: fma.rp.f32.f16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rp.h.f(half %a, half %b, float %c) + ret float %res +} + +; Basic f32.bf16 variants with different rounding modes +define float @test_fma_rn_bf_f(bfloat %a, bfloat %b, float %c) { +; CHECK-LABEL: test_fma_rn_bf_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_bf_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_bf_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_bf_f_param_2]; +; CHECK-NEXT: fma.rn.f32.bf16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rn.bf.f(bfloat %a, bfloat %b, float %c) + ret float %res +} + +define float @test_fma_rz_bf_f(bfloat %a, bfloat %b, float %c) { +; CHECK-LABEL: test_fma_rz_bf_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_bf_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_bf_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_bf_f_param_2]; +; CHECK-NEXT: fma.rz.f32.bf16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rz.bf.f(bfloat %a, bfloat %b, float %c) + ret float %res +} + +define float @test_fma_rm_bf_f(bfloat %a, bfloat %b, float %c) { +; CHECK-LABEL: test_fma_rm_bf_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_bf_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_bf_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_bf_f_param_2]; +; CHECK-NEXT: fma.rm.f32.bf16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rm.bf.f(bfloat %a, bfloat %b, float %c) + ret float %res +} + +define float @test_fma_rp_bf_f(bfloat %a, bfloat %b, float %c) { +; CHECK-LABEL: test_fma_rp_bf_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_bf_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_bf_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_bf_f_param_2]; +; CHECK-NEXT: fma.rp.f32.bf16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rp.bf.f(bfloat %a, bfloat %b, float %c) + ret float %res +} + +; f32.f16 variants with sat flag +define float @test_fma_rn_sat_h_f(half %a, half %b, float %c) { +; CHECK-LABEL: test_fma_rn_sat_h_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_sat_h_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_sat_h_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_sat_h_f_param_2]; +; CHECK-NEXT: fma.rn.sat.f32.f16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rn.sat.h.f(half %a, half %b, float %c) + ret float %res +} + +define float @test_fma_rz_sat_h_f(half %a, half %b, float %c) { +; CHECK-LABEL: test_fma_rz_sat_h_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_sat_h_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_sat_h_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_sat_h_f_param_2]; +; CHECK-NEXT: fma.rz.sat.f32.f16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rz.sat.h.f(half %a, half %b, float %c) + ret float %res +} + +define float @test_fma_rm_sat_h_f(half %a, half %b, float %c) { +; CHECK-LABEL: test_fma_rm_sat_h_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_sat_h_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_sat_h_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_sat_h_f_param_2]; +; CHECK-NEXT: fma.rm.sat.f32.f16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rm.sat.h.f(half %a, half %b, float %c) + ret float %res +} + +define float @test_fma_rp_sat_h_f(half %a, half %b, float %c) { +; CHECK-LABEL: test_fma_rp_sat_h_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_sat_h_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_sat_h_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_sat_h_f_param_2]; +; CHECK-NEXT: fma.rp.sat.f32.f16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rp.sat.h.f(half %a, half %b, float %c) + ret float %res +} + +; f32.bf16 variants with sat flag +define float @test_fma_rn_sat_bf_f(bfloat %a, bfloat %b, float %c) { +; CHECK-LABEL: test_fma_rn_sat_bf_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rn_sat_bf_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rn_sat_bf_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rn_sat_bf_f_param_2]; +; CHECK-NEXT: fma.rn.sat.f32.bf16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rn.sat.bf.f(bfloat %a, bfloat %b, float %c) + ret float %res +} + +define float @test_fma_rz_sat_bf_f(bfloat %a, bfloat %b, float %c) { +; CHECK-LABEL: test_fma_rz_sat_bf_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rz_sat_bf_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rz_sat_bf_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rz_sat_bf_f_param_2]; +; CHECK-NEXT: fma.rz.sat.f32.bf16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rz.sat.bf.f(bfloat %a, bfloat %b, float %c) + ret float %res +} + +define float @test_fma_rm_sat_bf_f(bfloat %a, bfloat %b, float %c) { +; CHECK-LABEL: test_fma_rm_sat_bf_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rm_sat_bf_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rm_sat_bf_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rm_sat_bf_f_param_2]; +; CHECK-NEXT: fma.rm.sat.f32.bf16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rm.sat.bf.f(bfloat %a, bfloat %b, float %c) + ret float %res +} + +define float @test_fma_rp_sat_bf_f(bfloat %a, bfloat %b, float %c) { +; CHECK-LABEL: test_fma_rp_sat_bf_f( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .f32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_rp_sat_bf_f_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_rp_sat_bf_f_param_1]; +; CHECK-NEXT: ld.param.f32 %f1, [test_fma_rp_sat_bf_f_param_2]; +; CHECK-NEXT: fma.rp.sat.f32.bf16 %f2, %rs1, %rs2, %f1; +; CHECK-NEXT: st.param.f32 [func_retval0], %f2; +; CHECK-NEXT: ret; + %res = call float @llvm.nvvm.fma.rp.sat.bf.f(bfloat %a, bfloat %b, float %c) + ret float %res +}