Skip to content

[HLSL] Overloads for lerp with a scalar weight #137877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025

Conversation

bogner
Copy link
Contributor

@bogner bogner commented Apr 29, 2025

This adds overloads for the lerp function that accept a scalar for the weight parameter by splatting it into the appropriate vector.

Fixes #137827

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics HLSL HLSL Language Support labels Apr 29, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Justin Bogner (bogner)

Changes

This adds overloads for the lerp function that accept a scalar for the weight parameter by splatting it into the appropriate vector.

Fixes #137827


Full diff: https://github.com/llvm/llvm-project/pull/137877.diff

4 Files Affected:

  • (modified) clang/lib/Headers/hlsl/hlsl_compat_overloads.h (+6)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+2-1)
  • (modified) clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl (+18-8)
  • (modified) clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl (+11-11)
diff --git a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
index 47ae34adfe541..4874206d349c0 100644
--- a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
+++ b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
@@ -277,6 +277,12 @@ constexpr bool4 isinf(double4 V) { return isinf((float4)V); }
 // lerp builtins overloads
 //===----------------------------------------------------------------------===//
 
+template <typename T, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
+lerp(vector<T, N> x, vector<T, N> y, T s) {
+  return lerp(x, y, (vector<T, N>)s);
+}
+
 _DXC_COMPAT_TERNARY_DOUBLE_OVERLOADS(lerp)
 _DXC_COMPAT_TERNARY_INTEGER_OVERLOADS(lerp)
 
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 38322e6ba063b..0df27d9495109 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2555,7 +2555,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_lerp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
+        CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
index 6e452481e2fa2..e2935e5ffe593 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
@@ -1,11 +1,7 @@
-// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple \
-// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
-// RUN:   -o - | FileCheck %s --check-prefixes=CHECK \
-// RUN:   -DFNATTRS="noundef nofpclass(nan inf)" -DTARGET=dx
-// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple \
-// RUN:   spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
-// RUN:   -o - | FileCheck %s --check-prefixes=CHECK \
-// RUN:   -DFNATTRS="spir_func noundef nofpclass(nan inf)" -DTARGET=spv
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple  dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,NATIVE_HALF -DFNATTRS="noundef nofpclass(nan inf)" -DTARGET=dx
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple  dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF -DFNATTRS="noundef nofpclass(nan inf)" -DTARGET=dx
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple spirv-unknown-vulkan-compute %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,NATIVE_HALF -DFNATTRS="spir_func noundef nofpclass(nan inf)" -DTARGET=spv
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF -DFNATTRS="spir_func noundef nofpclass(nan inf)" -DTARGET=spv
 
 // CHECK-LABEL: test_lerp_double
 // CHECK: %hlsl.lerp = call reassoc nnan ninf nsz arcp afn float @llvm.[[TARGET]].lerp.f32(float %{{.*}}, float %{{.*}}, float %{{.*}})
@@ -106,3 +102,17 @@ float3 test_lerp_uint64_t3(uint64_t3 p0) { return lerp(p0, p0, p0); }
 // CHECK: %hlsl.lerp = call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.[[TARGET]].lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
 // CHECK: ret <4 x float> %hlsl.lerp
 float4 test_lerp_uint64_t4(uint64_t4 p0) { return lerp(p0, p0, p0); }
+
+// CHECK-LABEL: test_lerp_half_scalar
+// NATIVE_HALF: %hlsl.lerp = call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.[[TARGET]].lerp.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, <3 x half> %{{.*}})
+// NATIVE_HALF: ret <3 x half> %hlsl.lerp
+// NO_HALF: %hlsl.lerp = call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.[[TARGET]].lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
+// NO_HALF: ret <3 x float> %hlsl.lerp
+half3 test_lerp_half_scalar(half3 x, half3 y, half s) { return lerp(x, y, s); }
+
+// CHECK-LABEL: test_lerp_float_scalar
+// CHECK: %hlsl.lerp = call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.[[TARGET]].lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
+// CHECK: ret <3 x float> %hlsl.lerp
+float3 test_lerp_float_scalar(float3 x, float3 y, float s) {
+  return lerp(x, y, s);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
index 398d3c7f938c1..b4734a985f31c 100644
--- a/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
@@ -62,42 +62,42 @@ float2 test_lerp_element_type_mismatch(half2 p0, float2 p1) {
 
 float2 test_builtin_lerp_float2_splat(float p0, float2 p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float2 test_builtin_lerp_float2_splat2(double p0, double2 p1) {
   return __builtin_hlsl_lerp(p1, p0, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float2 test_builtin_lerp_float2_splat3(double p0, double2 p1) {
   return __builtin_hlsl_lerp(p1, p1, p0);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float3 test_builtin_lerp_float3_splat(float p0, float3 p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float4 test_builtin_lerp_float4_splat(float p0, float4 p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float2 test_lerp_float2_int_splat(float2 p0, int p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float3 test_lerp_float3_int_splat(float3 p0, int p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float2 test_builtin_lerp_int_vect_to_float_vec_promotion(int2 p0, float p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float test_builtin_lerp_bool_type_promotion(bool p0) {
@@ -107,17 +107,17 @@ float test_builtin_lerp_bool_type_promotion(bool p0) {
 
 float builtin_bool_to_float_type_promotion(float p0, bool p1) {
   return __builtin_hlsl_lerp(p0, p0, p1);
-  // expected-error@-1 {{3rd argument must be a scalar or vector of floating-point types (was 'bool')}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float builtin_bool_to_float_type_promotion2(bool p0, float p1) {
   return __builtin_hlsl_lerp(p1, p0, p1);
-  // expected-error@-1 {{2nd argument must be a scalar or vector of floating-point types (was 'bool')}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float builtin_lerp_int_to_float_promotion(float p0, int p1) {
   return __builtin_hlsl_lerp(p0, p0, p1);
-  // expected-error@-1 {{3rd argument must be a scalar or vector of floating-point types (was 'int')}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the same type}}
 }
 
 float4 test_lerp_int4(int4 p0, int4 p1, int4 p2) {

Comment on lines 109 to 172
// NO_HALF: %hlsl.lerp = call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.[[TARGET]].lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
// NO_HALF: ret <3 x float> %hlsl.lerp
half3 test_lerp_half_scalar(half3 x, half3 y, half s) { return lerp(x, y, s); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we not want something to see how s gets turned into a 3 x float, or is that exercised enough elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's not a bad idea. However, when I started to do that I realized this whole test file is a bit messed up - it's using -disable-llvm-passes, so none of the logic in the hlsl headers is inlined and we're actually splitting our checks across multiple functions in most of these cases o_O.

I'll put up an NFC PR to improve these tests and then follow up here shortly.

Copy link
Contributor Author

@bogner bogner Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After rebasing on top of #137898, the tests check for the conversion sequence

@@ -277,6 +277,12 @@ constexpr bool4 isinf(double4 V) { return isinf((float4)V); }
// lerp builtins overloads
//===----------------------------------------------------------------------===//

template <typename T, uint N>
constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think compat overload is the right place for this. Seems like we might want to keep this around past 202x.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think it feels a bit odd, but I was following the examples from clamp and min/max. I think if we're going to move them, we should do all of these together.

@farzonl
Copy link
Member

farzonl commented Apr 29, 2025

Since we are open to this for lerp I think we should file a similar issue for pow(vector<*>, scalar); Its much more user friendlly if we had an overload that would let us use a scalar for the exponent.

@bogner bogner changed the base branch from main to users/bogner/pr137898 April 29, 2025 23:49
@bogner bogner force-pushed the 2025-04-29-lerp-overloads branch from 64cc675 to 20a1723 Compare April 29, 2025 23:50
@bogner bogner changed the base branch from users/bogner/pr137898 to main April 30, 2025 19:28
This adds overloads for the `lerp` function that accept a scalar for the weight
parameter by splatting it into the appropriate vector.

Fixes llvm#137827
@bogner bogner force-pushed the 2025-04-29-lerp-overloads branch from b333286 to a6c359a Compare April 30, 2025 19:28
@bogner bogner merged commit ae6b4b2 into llvm:main Apr 30, 2025
7 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[HLSL] lerp overloads with mixed scalar and vector operands
5 participants