Skip to content

Commit 41c518e

Browse files
Generalize shmem size (#2313)
1 parent bb9aa81 commit 41c518e

File tree

3 files changed

+42
-33
lines changed

3 files changed

+42
-33
lines changed

ext/cuda/operators_fd_shmem.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ import ClimaCore.RecursiveApply: ⊟, ⊞
66

77
Base.@propagate_inbounds function fd_operator_shmem(
88
space,
9-
::Val{Nvt},
9+
shmem_params,
1010
op::Operators.DivergenceF2C,
1111
args...,
12-
) where {Nvt}
12+
)
1313
# allocate temp output
1414
RT = return_eltype(op, args...)
15-
Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,))
16-
lJu³ = CUDA.CuStaticSharedArray(RT, (1,))
17-
rJu³ = CUDA.CuStaticSharedArray(RT, (1,))
15+
Ju³ = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params))
16+
lJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params))
17+
rJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params))
1818
return (Ju³, lJu³, rJu³)
1919
end
2020

@@ -109,15 +109,15 @@ end
109109

110110
Base.@propagate_inbounds function fd_operator_shmem(
111111
space,
112-
::Val{Nvt},
112+
shmem_params,
113113
op::Operators.GradientC2F,
114114
args...,
115-
) where {Nvt}
115+
)
116116
# allocate temp output
117117
RT = return_eltype(op, args...)
118-
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
119-
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
120-
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
118+
u = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params)) # cell centers
119+
lb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # left boundary
120+
rb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # right boundary
121121
return (u, lb, rb)
122122
end
123123

@@ -202,15 +202,15 @@ end
202202

203203
Base.@propagate_inbounds function fd_operator_shmem(
204204
space,
205-
::Val{Nvt},
205+
shmem_params,
206206
op::Operators.InterpolateC2F,
207207
args...,
208-
) where {Nvt}
208+
)
209209
# allocate temp output
210210
RT = return_eltype(op, args...)
211-
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
212-
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
213-
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
211+
u = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params)) # cell centers
212+
lb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # left boundary
213+
rb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # right boundary
214214
return (u, lb, rb)
215215
end
216216

ext/cuda/operators_fd_shmem_common.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,19 +209,23 @@ Base.@propagate_inbounds function getidx(
209209
end
210210

211211
"""
212-
fd_allocate_shmem(Val(Nvt), b)
212+
fd_allocate_shmem(shmem_params, b)
213213
214214
Create a new broadcasted object with necessary share memory allocated,
215-
using `Nvt` nodal points per block.
215+
using `params` nodal points per block.
216216
"""
217-
@inline function fd_allocate_shmem(::Val{Nvt}, obj) where {Nvt}
217+
@inline function fd_allocate_shmem(::ShmemParams, obj)
218218
obj
219219
end
220220
@inline function fd_allocate_shmem(
221-
::Val{Nvt},
221+
shmem_params::ShmemParams,
222222
bc::Broadcasted{Style},
223-
) where {Nvt, Style}
224-
Broadcasted{Style}(bc.f, _fd_allocate_shmem(Val(Nvt), bc.args...), bc.axes)
223+
) where {Style}
224+
Broadcasted{Style}(
225+
bc.f,
226+
_fd_allocate_shmem(shmem_params, bc.args...),
227+
bc.axes,
228+
)
225229
end
226230

227231
######### MatrixFields
@@ -236,22 +240,22 @@ end
236240
#########
237241

238242
@inline function fd_allocate_shmem(
239-
::Val{Nvt},
243+
shmem_params::ShmemParams,
240244
sbc::StencilBroadcasted{Style},
241-
) where {Nvt, Style}
242-
args = _fd_allocate_shmem(Val(Nvt), sbc.args...)
245+
) where {Style}
246+
args = _fd_allocate_shmem(shmem_params, sbc.args...)
243247
work = if Operators.fd_shmem_is_supported(sbc)
244-
fd_operator_shmem(sbc.axes, Val(Nvt), sbc.op, args...)
248+
fd_operator_shmem(sbc.axes, shmem_params, sbc.op, args...)
245249
else
246250
nothing
247251
end
248252
StencilBroadcasted{Style}(sbc.op, args, sbc.axes, work)
249253
end
250254

251-
@inline _fd_allocate_shmem(::Val{Nvt}) where {Nvt} = ()
252-
@inline _fd_allocate_shmem(::Val{Nvt}, arg, xargs...) where {Nvt} = (
253-
fd_allocate_shmem(Val(Nvt), arg),
254-
_fd_allocate_shmem(Val(Nvt), xargs...)...,
255+
@inline _fd_allocate_shmem(::ShmemParams) = ()
256+
@inline _fd_allocate_shmem(shmem_params::ShmemParams, arg, xargs...) = (
257+
fd_allocate_shmem(shmem_params, arg),
258+
_fd_allocate_shmem(shmem_params, xargs...)...,
255259
)
256260

257261
"""

ext/cuda/operators_finite_difference.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ Base.Broadcast.BroadcastStyle(
2222

2323
include("operators_fd_shmem_is_supported.jl")
2424

25+
struct ShmemParams{Nv} end
26+
interior_size(::ShmemParams{Nv}) where {Nv} = (Nv,)
27+
boundary_size(::ShmemParams{Nv}) where {Nv} = (1,)
28+
2529
function Base.copyto!(
2630
out::Field,
2731
bc::Union{
@@ -56,6 +60,7 @@ function Base.copyto!(
5660
mask isa NoMask &&
5761
enough_shmem &&
5862
Operators.use_fd_shmem()
63+
shmem_params = ShmemParams{n_face_levels}()
5964
p = fd_shmem_stencil_partition(us, n_face_levels)
6065
args = (
6166
strip_space(out, space),
@@ -64,7 +69,7 @@ function Base.copyto!(
6469
bounds,
6570
us,
6671
mask,
67-
Val(p.Nvthreads),
72+
shmem_params,
6873
)
6974
auto_launch!(
7075
copyto_stencil_kernel_shmem!,
@@ -153,8 +158,8 @@ function copyto_stencil_kernel_shmem!(
153158
bds,
154159
us,
155160
mask,
156-
::Val{Nvt},
157-
) where {Nvt}
161+
shmem_params::ShmemParams,
162+
)
158163
@inbounds begin
159164
out_fv = Fields.field_values(out)
160165
us = DataLayouts.UniversalSize(out_fv)
@@ -165,7 +170,7 @@ function copyto_stencil_kernel_shmem!(
165170
hidx = (i, j, h)
166171
idx = v - 1 + li
167172
bc = Operators.reconstruct_placeholder_broadcasted(space, bc′)
168-
bc_shmem = fd_allocate_shmem(Val(Nvt), bc) # allocates shmem
173+
bc_shmem = fd_allocate_shmem(shmem_params, bc) # allocates shmem
169174

170175
fd_resolve_shmem!(bc_shmem, idx, hidx, bds) # recursively fills shmem
171176
CUDA.sync_threads()

0 commit comments

Comments
 (0)