Skip to content

[NVPTX] Add fma mix precision intrinsics #136661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rajatbajpai
Copy link
Contributor

This change adds "fma" mix precision operations.

This change adds "fma" mix precision operations.
@rajatbajpai
Copy link
Contributor Author

@durga4github could you please review this change. Thanks!

@llvmbot
Copy link
Member

llvmbot commented Apr 22, 2025

@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-llvm-ir

Author: Rajat Bajpai (rajatbajpai)

Changes

This change adds "fma" mix precision operations.


Full diff: https://github.com/llvm/llvm-project/pull/136661.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+20)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+21)
  • (added) llvm/test/CodeGen/NVPTX/fma-mix-precision.ll (+278)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index d09e1da457249..5d717bf11e3da 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1180,6 +1180,26 @@ let TargetPrefix = "nvvm" in {
         [IntrNoMem, IntrSpeculatable]>;
   }
 
+  // Mixed-precision fma intrinsics for half and bfloat16 to float
+  foreach rnd = ["rn", "rz", "rm", "rp"] in {
+    foreach sat = ["", "_sat"] in {
+      // Half-precision to float
+      def int_nvvm_fma_#rnd#sat#_h_f
+          : ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_h_f">,
+            DefaultAttrsIntrinsic<[llvm_float_ty],
+                                  [llvm_half_ty, llvm_half_ty, llvm_float_ty],
+                                  [IntrNoMem, IntrSpeculatable]>;
+
+      // BFloat16 to float
+      def int_nvvm_fma_#rnd#sat#_bf_f
+          : ClangBuiltin<"__nvvm_fma_"#rnd#sat#"_bf_f">,
+            DefaultAttrsIntrinsic<[llvm_float_ty],
+                                  [llvm_bfloat_ty, llvm_bfloat_ty,
+                                   llvm_float_ty],
+                                  [IntrNoMem, IntrSpeculatable]>;
+    }
+  }
+
 //
 // Rcp
 //
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 4ba3e6f06bb5f..4b0693ac04671 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1510,6 +1510,27 @@ multiclass FMA_INST {
 
 defm INT_NVVM_FMA : FMA_INST;
 
+// Define mixed-precision fma instructions for half and bfloat16 to float
+foreach rnd = ["rn", "rz", "rm", "rp"] in {
+  foreach sat = ["", "_sat"] in {
+    // Half-precision to float
+    def INT_NVVM_FMA_#!toupper(rnd#sat)#_H_F
+        : F_MATH_3<"fma."#rnd#!subst(
+                       "_", ".", sat)#".f32.f16 \t$dst, $src0, $src1, $src2;",
+                   Float32Regs, Int16Regs, Int16Regs, Float32Regs,
+                   !cast<Intrinsic>("int_nvvm_fma_"#rnd#sat#"_h_f"),
+                   [hasPTX<86>, hasSM<100>]>;
+
+    // BFloat16 to float
+    def INT_NVVM_FMA_#!toupper(rnd#sat)#_BF_F
+        : F_MATH_3<"fma."#rnd#!subst(
+                       "_", ".", sat)#".f32.bf16 \t$dst, $src0, $src1, $src2;",
+                   Float32Regs, Int16Regs, Int16Regs, Float32Regs,
+                   !cast<Intrinsic>("int_nvvm_fma_"#rnd#sat#"_bf_f"),
+                   [hasPTX<86>, hasSM<100>]>;
+  }
+}
+
 //
 // Rcp
 //
diff --git a/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll b/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll
new file mode 100644
index 0000000000000..6c9341488fde1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-mix-precision.ll
@@ -0,0 +1,278 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | FileCheck %s
+
+; Basic f32.f16 variants with different rounding modes
+define float @test_fma_rn_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rn_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rn_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rn_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rn_h_f_param_2];
+; CHECK-NEXT:    fma.rn.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rn.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rz_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rz_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rz_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rz_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rz_h_f_param_2];
+; CHECK-NEXT:    fma.rz.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rz.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rm_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rm_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rm_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rm_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rm_h_f_param_2];
+; CHECK-NEXT:    fma.rm.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rm.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rp_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rp_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rp_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rp_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rp_h_f_param_2];
+; CHECK-NEXT:    fma.rp.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rp.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+; Basic f32.bf16 variants with different rounding modes
+define float @test_fma_rn_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rn_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rn_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rn_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rn_bf_f_param_2];
+; CHECK-NEXT:    fma.rn.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rn.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rz_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rz_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rz_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rz_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rz_bf_f_param_2];
+; CHECK-NEXT:    fma.rz.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rz.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rm_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rm_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rm_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rm_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rm_bf_f_param_2];
+; CHECK-NEXT:    fma.rm.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rm.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rp_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rp_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rp_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rp_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rp_bf_f_param_2];
+; CHECK-NEXT:    fma.rp.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rp.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+; f32.f16 variants with sat flag
+define float @test_fma_rn_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rn_sat_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rn_sat_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rn_sat_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rn_sat_h_f_param_2];
+; CHECK-NEXT:    fma.rn.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rn.sat.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rz_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rz_sat_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rz_sat_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rz_sat_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rz_sat_h_f_param_2];
+; CHECK-NEXT:    fma.rz.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rz.sat.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rm_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rm_sat_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rm_sat_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rm_sat_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rm_sat_h_f_param_2];
+; CHECK-NEXT:    fma.rm.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rm.sat.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rp_sat_h_f(half %a, half %b, float %c) {
+; CHECK-LABEL: test_fma_rp_sat_h_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rp_sat_h_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rp_sat_h_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rp_sat_h_f_param_2];
+; CHECK-NEXT:    fma.rp.sat.f32.f16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rp.sat.h.f(half %a, half %b, float %c)
+  ret float %res
+}
+
+; f32.bf16 variants with sat flag
+define float @test_fma_rn_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rn_sat_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rn_sat_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rn_sat_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rn_sat_bf_f_param_2];
+; CHECK-NEXT:    fma.rn.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rn.sat.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rz_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rz_sat_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rz_sat_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rz_sat_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rz_sat_bf_f_param_2];
+; CHECK-NEXT:    fma.rz.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rz.sat.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rm_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rm_sat_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rm_sat_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rm_sat_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rm_sat_bf_f_param_2];
+; CHECK-NEXT:    fma.rm.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rm.sat.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}
+
+define float @test_fma_rp_sat_bf_f(bfloat %a, bfloat %b, float %c) {
+; CHECK-LABEL: test_fma_rp_sat_bf_f(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .f32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_fma_rp_sat_bf_f_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_fma_rp_sat_bf_f_param_1];
+; CHECK-NEXT:    ld.param.f32 %f1, [test_fma_rp_sat_bf_f_param_2];
+; CHECK-NEXT:    fma.rp.sat.f32.bf16 %f2, %rs1, %rs2, %f1;
+; CHECK-NEXT:    st.param.f32 [func_retval0], %f2;
+; CHECK-NEXT:    ret;
+  %res = call float @llvm.nvvm.fma.rp.sat.bf.f(bfloat %a, bfloat %b, float %c)
+  ret float %res
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants