Skip to content

Commit d0214a6

Browse files
Charlie Kawczynskicharleskawczynski
Charlie Kawczynski
authored andcommitted
Refactor FD shmem index management
1 parent b732f9f commit d0214a6

File tree

3 files changed

+112
-79
lines changed

3 files changed

+112
-79
lines changed

ext/cuda/operators_fd_shmem.jl

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ import ClimaCore.RecursiveApply: ⊟, ⊞
66

77
Base.@propagate_inbounds function fd_operator_shmem(
88
space,
9-
::Val{Nvt},
9+
params,
1010
op::Operators.DivergenceF2C,
1111
args...,
12-
) where {Nvt}
12+
)
1313
# allocate temp output
1414
RT = return_eltype(op, args...)
15-
Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,))
16-
lJu³ = CUDA.CuStaticSharedArray(RT, (1,))
17-
rJu³ = CUDA.CuStaticSharedArray(RT, (1,))
15+
Ju³ = CUDA.CuStaticSharedArray(RT, shmem_size(params))
16+
lJu³ = CUDA.CuStaticSharedArray(RT, boundary_shmem_size())
17+
rJu³ = CUDA.CuStaticSharedArray(RT, boundary_shmem_size())
1818
return (Ju³, lJu³, rJu³)
1919
end
2020

@@ -29,20 +29,21 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
2929
arg,
3030
)
3131
@inbounds begin
32-
vt = threadIdx().x
32+
si = FDShmemIndex()
33+
bi = FDShmemBoundaryIndex()
3334
lg = Geometry.LocalGeometry(space, idx, hidx)
3435
if !on_boundary(idx, space, op)
3536
= Operators.getidx(space, arg, idx, hidx)
36-
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
37+
Ju³[si] = Geometry.Jcontravariant3(u³, lg)
3738
elseif on_left_boundary(idx, space, op)
3839
bloc = Operators.left_boundary_window(space)
3940
bc = Operators.get_boundary(op, bloc)
4041
ub = Operators.getidx(space, bc.val, nothing, hidx)
4142
bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³
4243
if bc isa Operators.SetValue
43-
bJu³[1] = Geometry.Jcontravariant3(ub, lg)
44+
bJu³[bi] = Geometry.Jcontravariant3(ub, lg)
4445
elseif bc isa Operators.SetDivergence
45-
bJu³[1] = ub
46+
bJu³[bi] = ub
4647
elseif bc isa Operators.Extrapolate # no shmem needed
4748
end
4849
elseif on_right_boundary(idx, space, op)
@@ -51,9 +52,9 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
5152
ub = Operators.getidx(space, bc.val, nothing, hidx)
5253
bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³
5354
if bc isa Operators.SetValue
54-
bJu³[1] = Geometry.Jcontravariant3(ub, lg)
55+
bJu³[bi] = Geometry.Jcontravariant3(ub, lg)
5556
elseif bc isa Operators.SetDivergence
56-
bJu³[1] = ub
57+
bJu³[bi] = ub
5758
elseif bc isa Operators.Extrapolate # no shmem needed
5859
end
5960
end
@@ -70,11 +71,12 @@ Base.@propagate_inbounds function fd_operator_evaluate(
7071
arg,
7172
)
7273
@inbounds begin
73-
vt = threadIdx().x
74+
si = FDShmemIndex()
75+
bi = FDShmemBoundaryIndex()
7476
lg = Geometry.LocalGeometry(space, idx, hidx)
7577
if !on_boundary(idx, space, op)
76-
Ju³₋ = Ju³[vt] # corresponds to idx - half
77-
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
78+
Ju³₋ = Ju³[si] # corresponds to idx - half
79+
Ju³₊ = Ju³[si + 1] # corresponds to idx + half
7880
return (Ju³₊ Ju³₋) lg.invJ
7981
else
8082
bloc =
@@ -85,8 +87,8 @@ Base.@propagate_inbounds function fd_operator_evaluate(
8587
@assert bc isa Operators.SetValue || bc isa Operators.SetDivergence
8688
if on_left_boundary(idx, space)
8789
if bc isa Operators.SetValue
88-
Ju³₋ = lJu³[1] # corresponds to idx - half
89-
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
90+
Ju³₋ = lJu³[bi] # corresponds to idx - half
91+
Ju³₊ = Ju³[si + 1] # corresponds to idx + half
9092
return (Ju³₊ Ju³₋) lg.invJ
9193
else
9294
# @assert bc isa Operators.SetDivergence
@@ -95,12 +97,12 @@ Base.@propagate_inbounds function fd_operator_evaluate(
9597
else
9698
@assert on_right_boundary(idx, space)
9799
if bc isa Operators.SetValue
98-
Ju³₋ = Ju³[vt] # corresponds to idx - half
99-
Ju³₊ = rJu³[1] # corresponds to idx + half
100+
Ju³₋ = Ju³[si] # corresponds to idx - half
101+
Ju³₊ = rJu³[bi] # corresponds to idx + half
100102
return (Ju³₊ Ju³₋) lg.invJ
101103
else
102104
@assert bc isa Operators.SetDivergence
103-
return rJu³[1]
105+
return rJu³[bi]
104106
end
105107
end
106108
end
@@ -109,15 +111,15 @@ end
109111

110112
Base.@propagate_inbounds function fd_operator_shmem(
111113
space,
112-
::Val{Nvt},
114+
params,
113115
op::Operators.GradientC2F,
114116
args...,
115-
) where {Nvt}
117+
)
116118
# allocate temp output
117119
RT = return_eltype(op, args...)
118-
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
119-
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
120-
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
120+
u = CUDA.CuStaticSharedArray(RT, shmem_size(params)) # cell centers
121+
lb = CUDA.CuStaticSharedArray(RT, boundary_shmem_size()) # left boundary
122+
rb = CUDA.CuStaticSharedArray(RT, boundary_shmem_size()) # right boundary
121123
return (u, lb, rb)
122124
end
123125

@@ -132,11 +134,12 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
132134
arg,
133135
)
134136
@inbounds begin
137+
si = FDShmemIndex(idx)
138+
bi = FDShmemBoundaryIndex()
135139
is_out_of_bounds(idx, space) && return nothing
136-
vt = threadIdx().x
137140
cov3 = Geometry.Covariant3Vector(1)
138141
if in_domain(idx, arg_space)
139-
u[vt] = cov3 Operators.getidx(space, arg, idx, hidx)
142+
u[si] = cov3 Operators.getidx(space, arg, idx, hidx)
140143
end
141144
if on_any_boundary(idx, space, op)
142145
lloc = Operators.left_boundary_window(space)
@@ -149,10 +152,10 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
149152
ub = Operators.getidx(space, bc.val, nothing, hidx)
150153
bu = on_left_boundary(idx, space) ? lb : rb
151154
if bc isa Operators.SetValue
152-
bu[1] = cov3 ub
155+
bu[bi] = cov3 ub
153156
elseif bc isa Operators.SetGradient
154157
lg = Geometry.LocalGeometry(space, idx, hidx)
155-
bu[1] = Geometry.project(Geometry.Covariant3Axis(), ub, lg)
158+
bu[bi] = Geometry.project(Geometry.Covariant3Axis(), ub, lg)
156159
elseif bc isa Operators.Extrapolate # no shmem needed
157160
end
158161
end
@@ -169,11 +172,12 @@ Base.@propagate_inbounds function fd_operator_evaluate(
169172
args...,
170173
)
171174
@inbounds begin
172-
vt = threadIdx().x
175+
si = FDShmemIndex()
176+
bi = FDShmemBoundaryIndex()
173177
lg = Geometry.LocalGeometry(space, idx, hidx)
174178
if !on_boundary(idx, space, op)
175-
u₋ = u[vt - 1] # corresponds to idx - half
176-
u₊ = u[vt] # corresponds to idx + half
179+
u₋ = u[si - 1] # corresponds to idx - half
180+
u₊ = u[si] # corresponds to idx + half
177181
return u₊ u₋
178182
else
179183
bloc =
@@ -184,15 +188,15 @@ Base.@propagate_inbounds function fd_operator_evaluate(
184188
@assert bc isa Operators.SetValue
185189
if on_left_boundary(idx, space)
186190
if bc isa Operators.SetValue
187-
u₋ = 2 * lb[1] # corresponds to idx - half
188-
u₊ = 2 * u[vt] # corresponds to idx + half
191+
u₋ = 2 * lb[bi] # corresponds to idx - half
192+
u₊ = 2 * u[si] # corresponds to idx + half
189193
return u₊ u₋
190194
end
191195
else
192196
@assert on_right_boundary(idx, space)
193197
if bc isa Operators.SetValue
194-
u₋ = 2 * u[vt - 1] # corresponds to idx - half
195-
u₊ = 2 * rb[1] # corresponds to idx + half
198+
u₋ = 2 * u[si - 1] # corresponds to idx - half
199+
u₊ = 2 * rb[bi] # corresponds to idx + half
196200
return u₊ u₋
197201
end
198202
end
@@ -202,15 +206,15 @@ end
202206

203207
Base.@propagate_inbounds function fd_operator_shmem(
204208
space,
205-
::Val{Nvt},
209+
params,
206210
op::Operators.InterpolateC2F,
207211
args...,
208-
) where {Nvt}
212+
)
209213
# allocate temp output
210214
RT = return_eltype(op, args...)
211-
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
212-
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
213-
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
215+
u = CUDA.CuStaticSharedArray(RT, shmem_size(params)) # cell centers
216+
lb = CUDA.CuStaticSharedArray(RT, boundary_shmem_size()) # left boundary
217+
rb = CUDA.CuStaticSharedArray(RT, boundary_shmem_size()) # right boundary
214218
return (u, lb, rb)
215219
end
216220

@@ -225,10 +229,12 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
225229
arg,
226230
)
227231
@inbounds begin
228-
is_out_of_bounds(idx, space) && return nothing
229232
ᶜidx = get_cent_idx(idx)
233+
si = FDShmemIndex(idx)
234+
bi = FDShmemBoundaryIndex()
235+
is_out_of_bounds(idx, space) && return nothing
230236
if in_domain(idx, arg_space)
231-
u[idx] = Operators.getidx(space, arg, idx, hidx)
237+
u[si] = Operators.getidx(space, arg, idx, hidx)
232238
else
233239
lloc = Operators.left_boundary_window(space)
234240
rloc = Operators.right_boundary_window(space)
@@ -242,16 +248,16 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
242248
bc isa Operators.NullBoundaryCondition
243249
if bc isa Operators.NullBoundaryCondition ||
244250
bc isa Operators.Extrapolate
245-
u[idx] = Operators.getidx(space, arg, idx, hidx)
251+
u[si] = Operators.getidx(space, arg, idx, hidx)
246252
return nothing
247253
end
248254
bu = on_left_boundary(idx, space) ? lb : rb
249255
ub = Operators.getidx(space, bc.val, nothing, hidx)
250256
if bc isa Operators.SetValue
251-
bu[1] = ub
257+
bu[bi] = ub
252258
elseif bc isa Operators.SetGradient
253259
lg = Geometry.LocalGeometry(space, idx, hidx)
254-
bu[1] = Geometry.covariant3(ub, lg)
260+
bu[bi] = Geometry.covariant3(ub, lg)
255261
end
256262
end
257263
end
@@ -267,12 +273,13 @@ Base.@propagate_inbounds function fd_operator_evaluate(
267273
args...,
268274
)
269275
@inbounds begin
270-
vt = threadIdx().x
271-
lg = Geometry.LocalGeometry(space, idx, hidx)
272276
ᶜidx = get_cent_idx(idx)
277+
si = FDShmemIndex(ᶜidx)
278+
bi = FDShmemBoundaryIndex()
279+
lg = Geometry.LocalGeometry(space, idx, hidx)
273280
if !on_boundary(idx, space, op)
274-
u₋ = u[ᶜidx - 1] # corresponds to idx - half
275-
u₊ = u[ᶜidx] # corresponds to idx + half
281+
u₋ = u[si - 1] # corresponds to idx - half
282+
u₊ = u[si] # corresponds to idx + half
276283
return RecursiveApply.rdiv(u₊ u₋, 2)
277284
else
278285
bloc =
@@ -285,26 +292,26 @@ Base.@propagate_inbounds function fd_operator_evaluate(
285292
bc isa Operators.Extrapolate
286293
if on_left_boundary(idx, space)
287294
if bc isa Operators.SetValue
288-
return lb[1]
295+
return lb[bi]
289296
elseif bc isa Operators.SetGradient
290-
u₋ = lb[1] # corresponds to idx - half
291-
u₊ = u[ᶜidx] # corresponds to idx + half
297+
u₋ = lb[bi] # corresponds to idx - half
298+
u₊ = u[si] # corresponds to idx + half
292299
return u₊ RecursiveApply.rdiv(u₋, 2)
293300
else
294301
@assert bc isa Operators.Extrapolate
295-
return u[ᶜidx]
302+
return u[si]
296303
end
297304
else
298305
@assert on_right_boundary(idx, space)
299306
if bc isa Operators.SetValue
300-
return rb[1]
307+
return rb[bi]
301308
elseif bc isa Operators.SetGradient
302-
u₋ = u[ᶜidx - 1] # corresponds to idx - half
303-
u₊ = rb[1] # corresponds to idx + half
309+
u₋ = u[si - 1] # corresponds to idx - half
310+
u₊ = rb[bi] # corresponds to idx + half
304311
return u₋ RecursiveApply.rdiv(u₊, 2)
305312
else
306313
@assert bc isa Operators.Extrapolate
307-
return u[ᶜidx - 1]
314+
return u[si - 1]
308315
end
309316
end
310317
end

ext/cuda/operators_fd_shmem_common.jl

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -209,19 +209,15 @@ Base.@propagate_inbounds function getidx(
209209
end
210210

211211
"""
212-
fd_allocate_shmem(Val(Nvt), b)
212+
fd_allocate_shmem(params, b)
213213
214-
Create a new broadcasted object with necessary share memory allocated,
215-
using `Nvt` nodal points per block.
214+
Create a new broadcasted object with necessary share memory allocated.
216215
"""
217-
@inline function fd_allocate_shmem(::Val{Nvt}, obj) where {Nvt}
216+
@inline function fd_allocate_shmem(params, obj)
218217
obj
219218
end
220-
@inline function fd_allocate_shmem(
221-
::Val{Nvt},
222-
bc::Broadcasted{Style},
223-
) where {Nvt, Style}
224-
Broadcasted{Style}(bc.f, _fd_allocate_shmem(Val(Nvt), bc.args...), bc.axes)
219+
@inline function fd_allocate_shmem(params, bc::Broadcasted{Style}) where {Style}
220+
Broadcasted{Style}(bc.f, _fd_allocate_shmem(params, bc.args...), bc.axes)
225221
end
226222

227223
######### MatrixFields
@@ -236,23 +232,21 @@ end
236232
#########
237233

238234
@inline function fd_allocate_shmem(
239-
::Val{Nvt},
235+
params,
240236
sbc::StencilBroadcasted{Style},
241-
) where {Nvt, Style}
242-
args = _fd_allocate_shmem(Val(Nvt), sbc.args...)
237+
) where {Style}
238+
args = _fd_allocate_shmem(params, sbc.args...)
243239
work = if Operators.fd_shmem_is_supported(sbc)
244-
fd_operator_shmem(sbc.axes, Val(Nvt), sbc.op, args...)
240+
fd_operator_shmem(sbc.axes, params, sbc.op, args...)
245241
else
246242
nothing
247243
end
248244
StencilBroadcasted{Style}(sbc.op, args, sbc.axes, work)
249245
end
250246

251-
@inline _fd_allocate_shmem(::Val{Nvt}) where {Nvt} = ()
252-
@inline _fd_allocate_shmem(::Val{Nvt}, arg, xargs...) where {Nvt} = (
253-
fd_allocate_shmem(Val(Nvt), arg),
254-
_fd_allocate_shmem(Val(Nvt), xargs...)...,
255-
)
247+
@inline _fd_allocate_shmem(params) = ()
248+
@inline _fd_allocate_shmem(params, arg, xargs...) =
249+
(fd_allocate_shmem(params, arg), _fd_allocate_shmem(params, xargs...)...)
256250

257251
"""
258252
fd_shmem_needed_per_column(::Base.Broadcast.Broadcasted)
@@ -378,6 +372,37 @@ Base.@propagate_inbounds fd_resolve_shmem!(bc::Broadcasted, idx, hidx, bds) =
378372
_fd_resolve_shmem!(idx, hidx, bds, bc.args...)
379373
@inline fd_resolve_shmem!(obj, idx, hidx, bds) = nothing
380374

375+
function shmem_size(params)
376+
return (unval(params).Nvt,)
377+
end
378+
function boundary_shmem_size()
379+
return (1,)
380+
end
381+
382+
struct ShmemIndex{T}
383+
v::T
384+
col_id::T
385+
end
386+
function FDShmemIndex()
387+
v = threadIdx().x
388+
return ShmemIndex(v, typeof(v)(1))
389+
end
390+
function FDShmemIndex(v)
391+
return ShmemIndex(v, typeof(v)(1))
392+
end
393+
FDShmemBoundaryIndex() = ShmemIndex(1, 1)
394+
395+
# Base.getindex(a::AbstractArray, si::ShmemIndex) = Base.getindex(a, si.v, si.col_id)
396+
# Base.setindex!(a::AbstractArray, val, si::ShmemIndex) = Base.setindex!(a, val, si.v, si.col_id)
397+
Base.getindex(a::AbstractArray, si::ShmemIndex) = Base.getindex(a, si.v)
398+
Base.setindex!(a::AbstractArray, val, si::ShmemIndex) =
399+
Base.setindex!(a, val, si.v)
400+
401+
@inline Base.:+(si::ShmemIndex{T}, i::Integer) where {T} =
402+
ShmemIndex{T}(si.v + i, si.col_id)
403+
@inline Base.:-(si::ShmemIndex{T}, i::Integer) where {T} =
404+
ShmemIndex{T}(si.v - i, si.col_id)
405+
381406
if hasfield(Method, :recursion_relation)
382407
dont_limit = (args...) -> true
383408
for m in methods(fd_resolve_shmem!)

0 commit comments

Comments
 (0)