Skip to content

[HLSL][DXIL] Implement refract intrinsic #136026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/BuiltinsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
let Prototype = "void(...)";
}

def SPIRVRefract : Builtin {
let Spellings = ["__builtin_spirv_refract"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def SPIRVSmoothStep : Builtin {
let Spellings = ["__builtin_spirv_smoothstep"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
Expand Down
15 changes: 15 additions & 0 deletions clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
case SPIRV::BI__builtin_spirv_refract: {
Value *I = EmitScalarExpr(E->getArg(0));
Value *N = EmitScalarExpr(E->getArg(1));
Value *eta = EmitScalarExpr(E->getArg(2));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
E->getArg(2)->getType()->hasFloatingRepresentation() &&
Copy link
Member

@farzonl farzonl Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should just check E->getArg(2)->getType() is a float or half scalar might be worth putting this in its own assert so you can have a different message.

"refract operands must have a float representation");
assert(E->getArg(0)->getType()->isVectorType() &&
E->getArg(1)->getType()->isVectorType() &&
"refract I and N operands must be a vector");
return Builder.CreateIntrinsic(
/*ReturnType=*/I->getType(), Intrinsic::spv_refract,
ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
}
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
Expand Down
25 changes: 23 additions & 2 deletions clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#define _HLSL_HLSL_INTRINSIC_HELPERS_H_

namespace hlsl {
namespace __detail {
namespace __dETAil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like you did find replace for eta. Please revert this change. In the future you can use clangd' vscode plugin to apply llvm c++'s stye guide.


constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
// Use the same scaling factor used by FXC, and DXC for DXIL
Expand Down Expand Up @@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}

template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
T K = 1 - Eta * Eta * (1 - (N * I * N * I));
if (K < 0)
return 0;
else
return (Eta * I - (Eta * N * I + sqrt(K)) * N);
}

template <typename T, int L>
constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
#if (__has_builtin(__builtin_spirv_refract))
return __builtin_spirv_refract(I, N, Eta);
#else
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
if (K < 0)
return 0;
else
return (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
#endif
}

template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
Expand Down Expand Up @@ -126,7 +147,7 @@ template <typename T> constexpr vector<T, 4> lit_impl(T NDotL, T NDotH, T M) {
return Result;
}

} // namespace __detail
} // namespace __dETAil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix this too

} // namespace hlsl

#endif // _HLSL_HLSL_INTRINSIC_HELPERS_H_
59 changes: 59 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}

//===----------------------------------------------------------------------===//
// refract builtin
//===----------------------------------------------------------------------===//

/// \fn T refract(T I, T N, T eta)
/// \brief Returns a refraction using an entering ray, \a I, a surface
/// normal, \a N and refraction index \a eta
/// \param I The entering ray.
/// \param N The surface normal.
/// \param eta The refraction index.
///
/// The return value is a floating-point vector that represents the refraction
/// using the refraction index, \a eta, for the direction of the entering ray,
/// \a I, off a surface with the normal \a N.
///
/// This function calculates the refraction vector using the following formulas:
/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
/// if k < 0.0 the result is 0.0
/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
///
/// I and N must already be normalized in order to achieve the desired result.
///
/// I and N must be a scalar or vector whose component type is
/// floating-point.
///
/// eta must be a 16-bit or 32-bit floating-point scalar.
///
/// Result type, the type of I, and the type of N must all be the same type.

template <typename T>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
__detail::is_same<half, T>::value,
T> refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <typename T>
const inline __detail::enable_if_t<
__detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <int L>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
__detail::HLSL_FIXED_VECTOR<half, L> I,
__detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
return __detail::refract_vec_impl(I, N, eta);
}

template <int L>
const inline __detail::HLSL_FIXED_VECTOR<float, L>
refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
__detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
return __detail::refract_vec_impl(I, N, eta);
}

//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
Expand Down
45 changes: 43 additions & 2 deletions clang/lib/Sema/SemaSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
Expand Down Expand Up @@ -69,6 +69,47 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_refract: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;

ExprResult A = TheCall->getArg(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could do for (unsigned i = 0; i < TheCall->getNumArgs()-1; ++i) { so you don't have to repeat the code for ExprResult A and B.

QualType ArgTyA = A.get()->getType();
auto *VTyA = ArgTyA->getAs<VectorType>();
if (VTyA == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}

ExprResult B = TheCall->getArg(1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a helper function since reflect is doing the same thing.

QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}

ExprResult C = TheCall->getArg(2);
QualType ArgTyC = C.get()->getType();
if (!ArgTyC->hasFloatingRepresentation()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasFloatingRepresentation by itself is wrong here because its going to include float\half vectors.

SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
Copy link
Member

@farzonl farzonl Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this wouldn't be scalar or vector needs to just be scalar float\half.

<< ArgTyC;
return true;
}

QualType RetTy = ArgTyA;
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_reflect: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
Expand All @@ -89,7 +130,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
Expand Down
Loading