Skip to content

New workgroup reduce + scan using subgroup2 #870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
859c313
added OpSelect intrinsic for mix, fix mix behavior with bool
keptsecret Apr 4, 2025
c5a3223
use mix instead of ternary op
keptsecret Apr 11, 2025
87bca2b
fixes to subgroup2 funcs
keptsecret Apr 11, 2025
49fd605
changes to handle coalesced data loads
keptsecret Apr 21, 2025
4ae51a1
merge master, fix example conflicts
keptsecret Apr 21, 2025
609ad85
fixes to inclusive_scan for coalesced
keptsecret Apr 21, 2025
6b692f4
removed redundant code
keptsecret Apr 21, 2025
d0acb31
enabled handling vectors in spirv group ops with templates and enable_if
keptsecret Apr 23, 2025
fc92538
added impl component wise inclusive scan for inclusive scan
keptsecret Apr 23, 2025
a3d8509
workgroup reduce with subgroup2 stuff
keptsecret Apr 25, 2025
8ad4843
revert to scans using consecutive data loads
keptsecret Apr 25, 2025
cc47a7b
Merge branch 'improve_scan' into improve-workgroup-scan
keptsecret Apr 25, 2025
4c6495d
minor fixes, example
keptsecret Apr 28, 2025
549c6be
bug fixes and example
keptsecret Apr 28, 2025
5167919
fix to data accessor indexing
keptsecret Apr 29, 2025
7f4b238
added template spec for vector dim 1
keptsecret Apr 29, 2025
41cfb13
added inclusive scan
keptsecret Apr 29, 2025
8d7003b
exclusive scan working
keptsecret Apr 30, 2025
7a14716
removed outdated comment
keptsecret Apr 30, 2025
9cae616
minor changes to config usage
keptsecret May 1, 2025
68b4c60
add 1 level scans
keptsecret May 1, 2025
65c747b
fixes to 1 level scans
keptsecret May 2, 2025
9dfd319
added handling >1 vectors on level 1 scan (untested)
keptsecret May 2, 2025
cb842bb
move load/store smem into scan funcs, setup config for 3 levels
keptsecret May 5, 2025
74352c2
change to use coalesced indexing for 2-level scans
keptsecret May 6, 2025
dc09cf3
added 3-level scans
keptsecret May 6, 2025
d8f4ee3
minor bug fixes
keptsecret May 6, 2025
52aaff9
merge master, fix conflicts
keptsecret May 7, 2025
47bfc9e
latest examples
keptsecret May 7, 2025
61c7f5c
changes to data accessor usage
keptsecret May 7, 2025
de882fb
wg reduction uses reduce instead of scan
keptsecret May 8, 2025
3dfee91
fixes to calculating levels in config
keptsecret May 9, 2025
2375935
fixes to 3-level scan
keptsecret May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions include/nbl/builtin/hlsl/bxdf/fresnel.hlsl
Original file line number Diff line number Diff line change
@@ -33,6 +33,20 @@ struct orientedEtas<float>
rcpOrientedEta = backside ? eta : rcpEta;
return backside;
}

static T diffuseFresnelCorrectionFactor(T n, T n2)
{
// assert(n*n==n2);
vector<bool,vector_traits<T>::Dimension> TIR = n < (T)1.0;
T invdenum = nbl::hlsl::mix<T>(hlsl::promote<T>(1.0), hlsl::promote<T>(1.0) / (n2 * n2 * (hlsl::promote<T>(554.33) - 380.7 * n)), TIR);
T num = n * nbl::hlsl::mix<T>(hlsl::promote<T>(0.1921156102251088), n * 298.25 - 261.38 * n2 + 138.43, TIR);
num += nbl::hlsl::mix<T>(hlsl::promote<T>(0.8078843897748912), hlsl::promote<T>(-1.67), TIR);
return num * invdenum;
}

T value;
T rcp;
bool backside;
};

template<>
@@ -140,6 +154,184 @@ struct refract
scalar_type rcpOrientedEta2;
};

template<typename T NBL_PRIMARY_REQUIRES(concepts::Vectorial<T> && vector_traits<T>::Dimension == 3)
struct ReflectRefract
{
using this_t = ReflectRefract<T>;
using vector_type = T;
using scalar_type = typename vector_traits<T>::scalar_type;

static this_t create(bool refract, NBL_CONST_REF_ARG(vector_type) I, NBL_CONST_REF_ARG(vector_type) N, scalar_type NdotI, scalar_type NdotTorR, scalar_type rcpOrientedEta)
{
this_t retval;
retval.refract = refract;
retval.I = I;
retval.N = N;
retval.NdotI = NdotI;
retval.NdotTorR = NdotTorR;
retval.rcpOrientedEta = rcpOrientedEta;
return retval;
}

static this_t create(bool r, NBL_CONST_REF_ARG(Refract<vector_type>) refract)
{
this_t retval;
retval.refract = r;
retval.I = refract.I;
retval.N = refract.N;
retval.NdotI = refract.NdotI;
retval.NdotTorR = r ? Refract<vector_type>::computeNdotT(refract.backside, refract.NdotI2, refract.rcpOrientedEta2) : refract.NdotI;
retval.rcpOrientedEta = refract.rcpOrientedEta;
return retval;
}

vector_type operator()()
{
return N * (NdotI * (hlsl::mix<scalar_type>(1.0f, rcpOrientedEta, refract)) + NdotTorR) - I * (hlsl::mix<scalar_type>(1.0f, rcpOrientedEta, refract));
}

bool refract;
vector_type I;
vector_type N;
scalar_type NdotI;
scalar_type NdotTorR;
scalar_type rcpOrientedEta;
};


namespace fresnel
{

template<typename T NBL_PRIMARY_REQUIRES(is_scalar_v<T> || is_vector_v<T>)
struct Schlick
{
using scalar_type = typename vector_traits<T>::scalar_type;

static Schlick<T> create(NBL_CONST_REF_ARG(T) F0, scalar_type VdotH)
{
Schlick<T> retval;
retval.F0 = F0;
retval.VdotH = VdotH;
return retval;
}

T operator()()
{
T x = 1.0 - VdotH;
return F0 + (1.0 - F0) * x*x*x*x*x;
}

T F0;
scalar_type VdotH;
};

template<typename T NBL_PRIMARY_REQUIRES(is_scalar_v<T> || is_vector_v<T>)
struct Conductor
{
using scalar_type = typename vector_traits<T>::scalar_type;

static Conductor<T> create(NBL_CONST_REF_ARG(T) eta, NBL_CONST_REF_ARG(T) etak, scalar_type cosTheta)
{
Conductor<T> retval;
retval.eta = eta;
retval.etak = etak;
retval.cosTheta = cosTheta;
return retval;
}

T operator()()
{
const scalar_type cosTheta2 = cosTheta * cosTheta;
//const float sinTheta2 = 1.0 - cosTheta2;

const T etaLen2 = eta * eta + etak * etak;
const T etaCosTwice = eta * cosTheta * 2.0f;

const T rs_common = etaLen2 + (T)(cosTheta2);
const T rs2 = (rs_common - etaCosTwice) / (rs_common + etaCosTwice);

const T rp_common = etaLen2 * cosTheta2 + (T)(1.0);
const T rp2 = (rp_common - etaCosTwice) / (rp_common + etaCosTwice);

return (rs2 + rp2) * 0.5f;
}

T eta;
T etak;
scalar_type cosTheta;
};

template<typename T NBL_PRIMARY_REQUIRES(is_scalar_v<T> || is_vector_v<T>)
struct Dielectric
{
using scalar_type = typename vector_traits<T>::scalar_type;

static Dielectric<T> create(NBL_CONST_REF_ARG(T) eta, scalar_type cosTheta)
{
Dielectric<T> retval;
OrientedEtas<T> orientedEta = OrientedEtas<T>::create(cosTheta, eta);
retval.eta2 = orientedEta.value * orientedEta.value;
retval.cosTheta = cosTheta;
return retval;
}

static T __call(NBL_CONST_REF_ARG(T) orientedEta2, scalar_type absCosTheta)
{
const scalar_type sinTheta2 = 1.0 - absCosTheta * absCosTheta;

// the max() clamping can handle TIR when orientedEta2<1.0
const T t0 = hlsl::sqrt<T>(hlsl::max<T>(orientedEta2 - sinTheta2, hlsl::promote<T>(0.0)));
const T rs = (hlsl::promote<T>(absCosTheta) - t0) / (hlsl::promote<T>(absCosTheta) + t0);

const T t2 = orientedEta2 * absCosTheta;
const T rp = (t0 - t2) / (t0 + t2);

return (rs * rs + rp * rp) * 0.5f;
}

T operator()()
{
return __call(eta2, cosTheta);
}

T eta2;
scalar_type cosTheta;
};

template<typename T NBL_PRIMARY_REQUIRES(is_scalar_v<T> || is_vector_v<T>)
struct DielectricFrontFaceOnly
{
using scalar_type = typename vector_traits<T>::scalar_type;

static DielectricFrontFaceOnly<T> create(NBL_CONST_REF_ARG(T) orientedEta2, scalar_type absCosTheta)
{
Dielectric<T> retval;
retval.orientedEta2 = orientedEta2;
retval.absCosTheta = hlsl::abs<T>(absCosTheta);
return retval;
}

T operator()()
{
return Dielectric<T>::__call(orientedEta2, absCosTheta);
}

T orientedEta2;
scalar_type absCosTheta;
};


// gets the sum of all R, T R T, T R^3 T, T R^5 T, ... paths
template<typename T>
struct ThinDielectricInfiniteScatter
{
T operator()(T singleInterfaceReflectance)
{
const T doubleInterfaceReflectance = singleInterfaceReflectance * singleInterfaceReflectance;
return hlsl::mix<T>(hlsl::promote<T>(1.0), (singleInterfaceReflectance - doubleInterfaceReflectance) / (hlsl::promote<T>(1.0) - doubleInterfaceReflectance) * 2.0f, doubleInterfaceReflectance > hlsl::promote<T>(0.9999));
}
};

}

}
291 changes: 291 additions & 0 deletions include/nbl/builtin/hlsl/bxdf/geom_smith.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
// Copyright (C) 2018-2023 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_BXDF_GEOM_INCLUDED_
#define _NBL_BUILTIN_HLSL_BXDF_GEOM_INCLUDED_

#include "nbl/builtin/hlsl/bxdf/ndf.hlsl"

namespace nbl
{
namespace hlsl
{
namespace bxdf
{
namespace smith
{

template<typename NDF>
typename NDF::scalar_type VNDF_pdf_wo_clamps(typename NDF::scalar_type ndf, typename NDF::scalar_type lambda_V, typename NDF::scalar_type maxNdotV, NBL_REF_ARG(typename NDF::scalar_type) onePlusLambda_V)
{
onePlusLambda_V = 1.0 + lambda_V;
ndf::microfacet_to_light_measure_transform<NDF,ndf::REFLECT_BIT> transform = ndf::microfacet_to_light_measure_transform<NDF,ndf::REFLECT_BIT>::create(ndf / onePlusLambda_V, maxNdotV);
return transform();
}

template<typename NDF>
typename NDF::scalar_type VNDF_pdf_wo_clamps(typename NDF::scalar_type ndf, typename NDF::scalar_type lambda_V, typename NDF::scalar_type absNdotV, bool transmitted, typename NDF::scalar_type VdotH, typename NDF::scalar_type LdotH, typename NDF::scalar_type VdotHLdotH, typename NDF::scalar_type orientedEta, typename NDF::scalar_type reflectance, NBL_REF_ARG(typename NDF::scalar_type) onePlusLambda_V)
{
onePlusLambda_V = 1.0 + lambda_V;
ndf::microfacet_to_light_measure_transform<NDF,ndf::REFLECT_REFRACT_BIT> transform
= ndf::microfacet_to_light_measure_transform<NDF,ndf::REFLECT_REFRACT_BIT>::create((transmitted ? (1.0 - reflectance) : reflectance) * ndf / onePlusLambda_V, absNdotV, transmitted, VdotH, LdotH, VdotHLdotH, orientedEta);
return transform();
}

template<typename T NBL_FUNC_REQUIRES(is_scalar_v<T>)
T VNDF_pdf_wo_clamps(T ndf, T G1_over_2NdotV)
{
return ndf * 0.5 * G1_over_2NdotV;
}

template<typename T NBL_FUNC_REQUIRES(is_scalar_v<T>)
T FVNDF_pdf_wo_clamps(T fresnel_ndf, T G1_over_2NdotV, T absNdotV, bool transmitted, T VdotH, T LdotH, T VdotHLdotH, T orientedEta)
{
T FNG = fresnel_ndf * G1_over_2NdotV;
T factor = 0.5;
if (transmitted)
{
const T VdotH_etaLdotH = (VdotH + orientedEta * LdotH);
// VdotHLdotH is negative under transmission, so this factor is negative
factor *= -2.0 * VdotHLdotH / (VdotH_etaLdotH * VdotH_etaLdotH);
}
return FNG * factor;
}

template<typename T NBL_FUNC_REQUIRES(is_scalar_v<T>)
T VNDF_pdf_wo_clamps(T ndf, T G1_over_2NdotV, T absNdotV, bool transmitted, T VdotH, T LdotH, T VdotHLdotH, T orientedEta, T reflectance)
{
T FN = (transmitted ? (1.0 - reflectance) : reflectance) * ndf;
return FVNDF_pdf_wo_clamps<T>(FN, G1_over_2NdotV, absNdotV, transmitted, VdotH, LdotH, VdotHLdotH, orientedEta);
}


template<typename T NBL_PRIMARY_REQUIRES(is_scalar_v<T>)
struct SIsotropicParams
{
using this_t = SIsotropicParams<T>;

static this_t create(T a2, T NdotV2, T NdotL2, T lambdaV_plus_one) // beckmann
{
this_t retval;
retval.a2 = a2;
retval.NdotV2 = NdotV2;
retval.NdotL2 = NdotL2;
retval.lambdaV_plus_one = lambdaV_plus_one;
return retval;
}

static this_t create(T a2, T NdotV, T NdotV2, T NdotL, T NdotL2) // ggx
{
this_t retval;
retval.a2 = a2;
retval.NdotV = NdotV;
retval.NdotV2 = NdotV2;
retval.NdotL = NdotL;
retval.NdotL2 = NdotL2;
retval.one_minus_a2 = 1.0 - a2;
return retval;
}

T a2;
T NdotV;
T NdotL;
T NdotV2;
T NdotL2;
T lambdaV_plus_one;
T one_minus_a2;
};

template<typename T NBL_PRIMARY_REQUIRES(is_scalar_v<T>)
struct SAnisotropicParams
{
using this_t = SAnisotropicParams<T>;

static this_t create(T ax2, T ay2, T TdotV2, T BdotV2, T NdotV2, T TdotL2, T BdotL2, T NdotL2, T lambdaV_plus_one) // beckmann
{
this_t retval;
retval.ax2 = ax2;
retval.ay2 = ay2;
retval.TdotV2 = TdotV2;
retval.BdotV2 = BdotV2;
retval.NdotV2 = NdotV2;
retval.TdotL2 = TdotL2;
retval.BdotL2 = BdotL2;
retval.NdotL2 = NdotL2;
retval.lambdaV_plus_one = lambdaV_plus_one;
return retval;
}

static this_t create(T ax2, T ay2, T NdotV, T TdotV2, T BdotV2, T NdotV2, T NdotL, T TdotL2, T BdotL2, T NdotL2) // ggx
{
this_t retval;
retval.ax2 = ax2;
retval.ay2 = ay2;
retval.NdotL = NdotL;
retval.NdotV = NdotV;
retval.TdotV2 = TdotV2;
retval.BdotV2 = BdotV2;
retval.NdotV2 = NdotV2;
retval.TdotL2 = TdotL2;
retval.BdotL2 = BdotL2;
retval.NdotL2 = NdotL2;
return retval;
}

T ax2;
T ay2;
T NdotV;
T NdotL;
T TdotV2;
T BdotV2;
T NdotV2;
T TdotL2;
T BdotL2;
T NdotL2;
T lambdaV_plus_one;
};


// beckmann
template<typename T NBL_PRIMARY_REQUIRES(is_scalar_v<T>)
struct Beckmann
{
using scalar_type = T;

scalar_type G1(scalar_type lambda)
{
return 1.0 / (1.0 + lambda);
}

scalar_type C2(scalar_type NdotX2, scalar_type a2)
{
return NdotX2 / (a2 * (1.0 - NdotX2));
}

scalar_type C2(scalar_type TdotX2, scalar_type BdotX2, scalar_type NdotX2, scalar_type ax2, scalar_type ay2)
{
return NdotX2 / (TdotX2 * ax2 + BdotX2 * ay2);
}

scalar_type Lambda(scalar_type c2)
{
scalar_type c = sqrt<scalar_type>(c2);
scalar_type nom = 1.0 - 1.259 * c + 0.396 * c2;
scalar_type denom = 2.181 * c2 + 3.535 * c;
return hlsl::mix<scalar_type>(0.0, nom / denom, c < 1.6);
}

scalar_type Lambda(scalar_type NdotX2, scalar_type a2)
{
return Lambda(C2(NdotX2, a2));
}

scalar_type Lambda(scalar_type TdotX2, scalar_type BdotX2, scalar_type NdotX2, scalar_type ax2, scalar_type ay2)
{
return Lambda(C2(TdotX2, BdotX2, NdotX2, ax2, ay2));
}

scalar_type correlated(SIsotropicParams<scalar_type> params)
{
scalar_type c2 = C2(params.NdotV2, params.a2);
scalar_type L_v = Lambda(c2);
c2 = C2(params.NdotL2, params.a2);
scalar_type L_l = Lambda(c2);
return G1(L_v + L_l);
}

scalar_type correlated(SAnisotropicParams<scalar_type> params)
{
scalar_type c2 = C2(params.TdotV2, params.BdotV2, params.NdotV2, params.ax2, params.ay2);
scalar_type L_v = Lambda(c2);
c2 = C2(params.TdotL2, params.BdotL2, params.NdotL2, params.ax2, params.ay2);
scalar_type L_l = Lambda(c2);
return G1(L_v + L_l);
}

scalar_type G2_over_G1(SIsotropicParams<scalar_type> params)
{
scalar_type lambdaL = Lambda(params.NdotL2, params.a2);
return params.lambdaV_plus_one / (params.lambdaV_plus_one + lambdaL);
}

scalar_type G2_over_G1(SAnisotropicParams<scalar_type> params)
{
scalar_type c2 = C2(params.TdotL2, params.BdotL2, params.NdotL2, params.ax2, params.ay2);
scalar_type lambdaL = Lambda(c2);
return params.lambdaV_plus_one / (params.lambdaV_plus_one + lambdaL);
}
};


// ggx
template<typename T NBL_PRIMARY_REQUIRES(is_scalar_v<T>)
struct GGX
{
using scalar_type = T;

scalar_type devsh_part(scalar_type NdotX2, scalar_type a2, scalar_type one_minus_a2)
{
return sqrt(a2 + one_minus_a2 * NdotX2);
}

scalar_type devsh_part(scalar_type TdotX2, scalar_type BdotX2, scalar_type NdotX2, scalar_type ax2, scalar_type ay2)
{
return sqrt(TdotX2 * ax2 + BdotX2 * ay2 + NdotX2);
}

scalar_type G1_wo_numerator(scalar_type NdotX, scalar_type NdotX2, scalar_type a2, scalar_type one_minus_a2)
{
return 1.0 / (NdotX + devsh_part(NdotX2,a2,one_minus_a2));
}

scalar_type G1_wo_numerator(scalar_type NdotX, scalar_type TdotX2, scalar_type BdotX2, scalar_type NdotX2, scalar_type ax2, scalar_type ay2)
{
return 1.0 / (NdotX + devsh_part(TdotX2, BdotX2, NdotX2, ax2, ay2));
}

scalar_type G1_wo_numerator(scalar_type NdotX, scalar_type devsh_part)
{
return 1.0 / (NdotX + devsh_part);
}

scalar_type correlated_wo_numerator(SIsotropicParams<scalar_type> params)
{
scalar_type Vterm = params.NdotL * devsh_part(params.NdotV2, params.a2, params.one_minus_a2);
scalar_type Lterm = params.NdotV * devsh_part(params.NdotL2, params.a2, params.one_minus_a2);
return 0.5 / (Vterm + Lterm);
}

scalar_type correlated_wo_numerator(SAnisotropicParams<scalar_type> params)
{
scalar_type Vterm = params.NdotL * devsh_part(params.TdotV2, params.BdotV2, params.NdotV2, params.ax2, params.ay2);
scalar_type Lterm = params.NdotV * devsh_part(params.TdotL2, params.BdotL2, params.NdotL2, params.ax2, params.ay2);
return 0.5 / (Vterm + Lterm);
}

scalar_type G2_over_G1(SIsotropicParams<scalar_type> params)
{
scalar_type devsh_v = devsh_part(params.NdotV2, params.a2, params.one_minus_a2);
scalar_type G2_over_G1 = params.NdotL * (devsh_v + params.NdotV); // alternative `Vterm+NdotL*NdotV /// NdotL*NdotV could come as a parameter
G2_over_G1 /= params.NdotV * devsh_part(params.NdotL2, params.a2, params.one_minus_a2) + params.NdotL * devsh_v;

return G2_over_G1;
}

scalar_type G2_over_G1(SAnisotropicParams<scalar_type> params)
{
scalar_type devsh_v = devsh_part(params.TdotV2, params.BdotV2, params.NdotV2, params.ax2, params.ay2);
scalar_type G2_over_G1 = params.NdotL * (devsh_v + params.NdotV);
G2_over_G1 /= params.NdotV * devsh_part(params.TdotL2, params.BdotL2, params.NdotL2, params.ax2, params.ay2) + params.NdotL * devsh_v;

return G2_over_G1;
}

};

}
}
}
}

#endif
16 changes: 10 additions & 6 deletions include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl
Original file line number Diff line number Diff line change
@@ -240,13 +240,17 @@ struct mix_helper<T, T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::fMix<T>(e
}
};

template<typename T> NBL_PARTIAL_REQ_TOP(concepts::FloatingPointScalar<T>)
struct mix_helper<T, bool NBL_PARTIAL_REQ_BOT(concepts::FloatingPointScalar<T>) >
template<typename T, typename U>
NBL_PARTIAL_REQ_TOP((concepts::Scalar<T> || concepts::Vectorial<T>) && !concepts::Boolean<T> && concepts::Boolean<U>)
struct mix_helper<T, U NBL_PARTIAL_REQ_BOT((concepts::Scalar<T> || concepts::Vectorial<T>) && !concepts::Boolean<T> && concepts::Boolean<U>) >
{
using return_t = conditional_t<is_vector_v<T>, vector<typename vector_traits<T>::scalar_type, vector_traits<T>::Dimension>, T>;
static inline return_t __call(const T x, const T y, const bool a)
// for a component of a that is false, the corresponding component of x is returned
// for a component of a that is true, the corresponding component of y is returned
// so we make sure this is correct when calling the operation
static inline return_t __call(const T x, const T y, const U a)
{
return a ? x : y;
return spirv::select<T, U>(a, y, x);
}
};

@@ -862,8 +866,8 @@ struct mix_helper<T, T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
};

template<typename T, typename U>
NBL_PARTIAL_REQ_TOP(concepts::Vectorial<T> && concepts::Boolean<U> && vector_traits<T>::Dimension == vector_traits<U>::Dimension)
struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(concepts::Vectorial<T> && concepts::Boolean<U> && vector_traits<T>::Dimension == vector_traits<U>::Dimension) >
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT && concepts::Boolean<U> && vector_traits<T>::Dimension == vector_traits<U>::Dimension)
struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT && concepts::Boolean<U> && vector_traits<T>::Dimension == vector_traits<U>::Dimension) >
{
using return_t = T;
static return_t __call(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(U) a)
14 changes: 14 additions & 0 deletions include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl
Original file line number Diff line number Diff line change
@@ -346,6 +346,20 @@ template<typename T NBL_FUNC_REQUIRES(concepts::UnsignedIntegral<T>)
[[vk::ext_instruction(spv::OpISubBorrow)]]
SubBorrowOutput<T> subBorrow(T operand1, T operand2);


template<typename T NBL_FUNC_REQUIRES(is_integral_v<T> && !is_matrix_v<T>)
[[vk::ext_instruction(spv::OpIEqual)]]
conditional_t<is_vector_v<T>, vector<bool, vector_traits<T>::Dimension>, bool> IEqual(T lhs, T rhs);

template<typename T NBL_FUNC_REQUIRES(is_floating_point_v<T> && !is_matrix_v<T>)
[[vk::ext_instruction(spv::OpFOrdEqual)]]
conditional_t<is_vector_v<T>, vector<bool, vector_traits<T>::Dimension>, bool> FOrdEqual(T lhs, T rhs);


template<typename T, typename U NBL_FUNC_REQUIRES(!is_matrix_v<T> && !is_matrix_v<U> && is_same_v<typename vector_traits<U>::scalar_type, bool>)
[[vk::ext_instruction(spv::OpSelect)]]
T select(U a, T x, T y);

}

#endif
36 changes: 20 additions & 16 deletions include/nbl/builtin/hlsl/spirv_intrinsics/subgroup_arithmetic.hlsl
Original file line number Diff line number Diff line change
@@ -17,25 +17,23 @@ namespace hlsl
namespace spirv
{

template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformIAdd )]]
int32_t groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value);
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformIAdd )]]
uint32_t groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value);
enable_if_t<!is_matrix_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformFAdd )]]
float32_t groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value);
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformIMul )]]
int32_t groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value);
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformIMul )]]
uint32_t groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value);
enable_if_t<!is_matrix_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformFMul )]]
float32_t groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value);
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
@@ -54,25 +52,31 @@ T groupBitwiseXor(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T

// The MIN and MAX operations in SPIR-V have different Ops for each arithmetic type
// so we implement them distinctly
template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformSMin )]]
int32_t groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value);
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformUMin )]]
uint32_t groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value);
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformFMin )]]
float32_t groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value);
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformSMax )]]
int32_t groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value);
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformUMax )]]
uint32_t groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value);
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
template<typename T>
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
[[vk::ext_instruction( spv::OpGroupNonUniformFMax )]]
float32_t groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value);
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);

}
}
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@

#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
#include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
#include "nbl/builtin/hlsl/concepts.hlsl"


namespace nbl
45 changes: 45 additions & 0 deletions include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_


#include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"

#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl"
#include "nbl/builtin/hlsl/concepts.hlsl"


namespace nbl
{
namespace hlsl
{
namespace subgroup2
{

template<typename Config, class BinOp, int32_t _ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
struct ArithmeticParams
{
using config_t = Config;
using binop_t = BinOp;
using scalar_t = typename BinOp::type_t; // BinOp should be with scalar type
using type_t = vector<scalar_t, _ItemsPerInvocation>;

NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation;
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
};

template<typename Params>
struct reduction : impl::reduction<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
template<typename Params>
struct inclusive_scan : impl::inclusive_scan<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
template<typename Params>
struct exclusive_scan : impl::exclusive_scan<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};

}
}
}

#endif
228 changes: 228 additions & 0 deletions include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_

#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"

#include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"

#include "nbl/builtin/hlsl/functional.hlsl"
#include "nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl"

namespace nbl
{
namespace hlsl
{
namespace subgroup2
{

namespace impl
{

// forward declarations
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
struct inclusive_scan;

template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
struct exclusive_scan;

template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
struct reduction;


// BinOp needed to specialize native
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
struct inclusive_scan
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
// assert binop_t == BinOp
using exclusive_scan_op_t = exclusive_scan<Params, binop_t, 1, native>;

// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;

type_t operator()(NBL_CONST_REF_ARG(type_t) value)
{
binop_t binop;
type_t retval;
retval[0] = value[0];
[unroll]
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
retval[i] = binop(retval[i-1], value[i]);

exclusive_scan_op_t op;
scalar_t exclusive = op(retval[ItemsPerInvocation-1]);

[unroll]
for (uint32_t i = 0; i < ItemsPerInvocation; i++)
retval[i] = binop(retval[i], exclusive);
return retval;
}
};

template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
struct exclusive_scan
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
using inclusive_scan_op_t = inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;

// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;

type_t operator()(type_t value)
{
inclusive_scan_op_t op;
value = op(value);

type_t left = glsl::subgroupShuffleUp<type_t>(value,1);

type_t retval;
retval[0] = hlsl::mix(binop_t::identity, left[ItemsPerInvocation-1], bool(glsl::gl_SubgroupInvocationID()));
[unroll]
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
retval[i] = value[i-1];
return retval;
}
};

template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
struct reduction
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
using op_t = reduction<Params, binop_t, 1, native>;

// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;

scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
{
binop_t binop;
op_t op;
scalar_t retval = value[0];
[unroll]
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
retval = binop(retval, value[i]);
return op(retval);
}
};


// specs for N=1 uses subgroup funcs
// specialize native
#define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template<class Params, typename T> struct NAME<Params,BINOP<T>,1,true> \
{ \
using type_t = T; \
\
type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
}

#define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \
SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);

SPECIALIZE_ALL(bit_and,And);
SPECIALIZE_ALL(bit_or,Or);
SPECIALIZE_ALL(bit_xor,Xor);

SPECIALIZE_ALL(plus,Add);
SPECIALIZE_ALL(multiplies,Mul);

SPECIALIZE_ALL(minimum,Min);
SPECIALIZE_ALL(maximum,Max);

#undef SPECIALIZE_ALL
#undef SPECIALIZE

// specialize portability
template<class Params, class BinOp>
struct inclusive_scan<Params, BinOp, 1, false>
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
// assert T == scalar type, binop::type == T
using config_t = typename Params::config_t;

// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;

scalar_t operator()(scalar_t value)
{
return __call(value);
}

static scalar_t __call(scalar_t value)
{
binop_t op;
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();

scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < 1u));

const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
[unroll]
for (uint32_t i = 1; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
{
const uint32_t step = 1u << i;
rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < step));
}
return value;
}
};

template<class Params, class BinOp>
struct exclusive_scan<Params, BinOp, 1, false>
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;

scalar_t operator()(scalar_t value)
{
value = inclusive_scan<Params, BinOp, 1, false>::__call(value);
// can't risk getting short-circuited, need to store to a var
scalar_t left = glsl::subgroupShuffleUp<scalar_t>(value,1);
// the first invocation doesn't have anything in its left so we set to the binop's identity value for exlusive scan
return hlsl::mix(binop_t::identity, left, bool(glsl::gl_SubgroupInvocationID()));
}
};

template<class Params, class BinOp>
struct reduction<Params, BinOp, 1, false>
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
using config_t = typename Params::config_t;

// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;

scalar_t operator()(scalar_t value)
{
binop_t op;

const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
[unroll]
for (uint32_t i = 0; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
value = op(glsl::subgroupShuffleXor<scalar_t>(value,0x1u<<i),value);

return value;
}
};

}

}
}
}

#endif
36 changes: 36 additions & 0 deletions include/nbl/builtin/hlsl/subgroup2/ballot.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_BALLOT_INCLUDED_
#define _NBL_BUILTIN_HLSL_SUBGROUP2_BALLOT_INCLUDED_

namespace nbl
{
namespace hlsl
{
namespace subgroup2
{

template<uint32_t SubgroupSizeLog2>
struct Configuration
{
using mask_t = conditional_t<SubgroupSizeLog2 < 7, conditional_t<SubgroupSizeLog2 < 6, uint32_t1, uint32_t2>, uint32_t4>;

NBL_CONSTEXPR_STATIC_INLINE uint16_t SizeLog2 = uint16_t(SubgroupSizeLog2);
NBL_CONSTEXPR_STATIC_INLINE uint16_t Size = uint16_t(0x1u) << SubgroupSizeLog2;
};

template<class T>
struct is_configuration : bool_constant<false> {};

template<uint32_t N>
struct is_configuration<Configuration<N> > : bool_constant<true> {};

template<typename T>
NBL_CONSTEXPR bool is_configuration_v = is_configuration<T>::value;

}
}
}

#endif
1 change: 1 addition & 0 deletions include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ struct vector_traits<vector<T, DIMENSION> >\
NBL_CONSTEXPR_STATIC_INLINE bool IsVector = true;\
};\

DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(1)
DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(2)
DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(3)
DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(4)
58 changes: 58 additions & 0 deletions include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_INCLUDED_
#define _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_INCLUDED_


#include "nbl/builtin/hlsl/functional.hlsl"
#include "nbl/builtin/hlsl/workgroup/ballot.hlsl"
#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl"
#include "nbl/builtin/hlsl/workgroup2/shared_scan.hlsl"


namespace nbl
{
namespace hlsl
{
namespace workgroup2
{

template<class Config, class BinOp, class device_capabilities=void>
struct reduction
{
template<class DataAccessor, class ScratchAccessor>
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
{
impl::reduce<Config,BinOp,Config::LevelCount,device_capabilities> fn;
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
}
};

template<class Config, class BinOp, class device_capabilities=void>
struct inclusive_scan
{
template<class DataAccessor, class ScratchAccessor>
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
{
impl::scan<Config,BinOp,false,Config::LevelCount,device_capabilities> fn;
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
}
};

template<class Config, class BinOp, class device_capabilities=void>
struct exclusive_scan
{
template<class DataAccessor, class ScratchAccessor>
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
{
impl::scan<Config,BinOp,true,Config::LevelCount,device_capabilities> fn;
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
}
};

}
}
}

#endif
466 changes: 466 additions & 0 deletions include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Large diffs are not rendered by default.