Skip to content

Commit c105d3e

Browse files
Charlie Kawczynskicharleskawczynski
Charlie Kawczynski
authored andcommitted
Refactor FD shmem index management
1 parent 41c518e commit c105d3e

File tree

2 files changed

+73
-41
lines changed

2 files changed

+73
-41
lines changed

ext/cuda/operators_fd_shmem.jl

+48-41
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,21 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
2929
arg,
3030
)
3131
@inbounds begin
32-
vt = threadIdx().x
32+
si = ShmemIndex()
33+
bi = FDShmemBoundaryIndex()
3334
lg = Geometry.LocalGeometry(space, idx, hidx)
3435
if !on_boundary(idx, space, op)
3536
= Operators.getidx(space, arg, idx, hidx)
36-
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
37+
Ju³[si] = Geometry.Jcontravariant3(u³, lg)
3738
elseif on_left_boundary(idx, space, op)
3839
bloc = Operators.left_boundary_window(space)
3940
bc = Operators.get_boundary(op, bloc)
4041
ub = Operators.getidx(space, bc.val, nothing, hidx)
4142
bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³
4243
if bc isa Operators.SetValue
43-
bJu³[1] = Geometry.Jcontravariant3(ub, lg)
44+
bJu³[bi] = Geometry.Jcontravariant3(ub, lg)
4445
elseif bc isa Operators.SetDivergence
45-
bJu³[1] = ub
46+
bJu³[bi] = ub
4647
elseif bc isa Operators.Extrapolate # no shmem needed
4748
end
4849
elseif on_right_boundary(idx, space, op)
@@ -51,9 +52,9 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
5152
ub = Operators.getidx(space, bc.val, nothing, hidx)
5253
bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³
5354
if bc isa Operators.SetValue
54-
bJu³[1] = Geometry.Jcontravariant3(ub, lg)
55+
bJu³[bi] = Geometry.Jcontravariant3(ub, lg)
5556
elseif bc isa Operators.SetDivergence
56-
bJu³[1] = ub
57+
bJu³[bi] = ub
5758
elseif bc isa Operators.Extrapolate # no shmem needed
5859
end
5960
end
@@ -70,11 +71,12 @@ Base.@propagate_inbounds function fd_operator_evaluate(
7071
arg,
7172
)
7273
@inbounds begin
73-
vt = threadIdx().x
74+
si = ShmemIndex()
75+
bi = FDShmemBoundaryIndex()
7476
lg = Geometry.LocalGeometry(space, idx, hidx)
7577
if !on_boundary(idx, space, op)
76-
Ju³₋ = Ju³[vt] # corresponds to idx - half
77-
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
78+
Ju³₋ = Ju³[si] # corresponds to idx - half
79+
Ju³₊ = Ju³[si + 1] # corresponds to idx + half
7880
return (Ju³₊ Ju³₋) lg.invJ
7981
else
8082
bloc =
@@ -85,22 +87,22 @@ Base.@propagate_inbounds function fd_operator_evaluate(
8587
@assert bc isa Operators.SetValue || bc isa Operators.SetDivergence
8688
if on_left_boundary(idx, space)
8789
if bc isa Operators.SetValue
88-
Ju³₋ = lJu³[1] # corresponds to idx - half
89-
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
90+
Ju³₋ = lJu³[bi] # corresponds to idx - half
91+
Ju³₊ = Ju³[si + 1] # corresponds to idx + half
9092
return (Ju³₊ Ju³₋) lg.invJ
9193
else
9294
# @assert bc isa Operators.SetDivergence
93-
return lJu³[1]
95+
return lJu³[bi]
9496
end
9597
else
9698
@assert on_right_boundary(idx, space)
9799
if bc isa Operators.SetValue
98-
Ju³₋ = Ju³[vt] # corresponds to idx - half
99-
Ju³₊ = rJu³[1] # corresponds to idx + half
100+
Ju³₋ = Ju³[si] # corresponds to idx - half
101+
Ju³₊ = rJu³[bi] # corresponds to idx + half
100102
return (Ju³₊ Ju³₋) lg.invJ
101103
else
102104
@assert bc isa Operators.SetDivergence
103-
return rJu³[1]
105+
return rJu³[bi]
104106
end
105107
end
106108
end
@@ -133,10 +135,11 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
133135
)
134136
@inbounds begin
135137
is_out_of_bounds(idx, space) && return nothing
136-
vt = threadIdx().x
138+
si = ShmemIndex()
139+
bi = FDShmemBoundaryIndex()
137140
cov3 = Geometry.Covariant3Vector(1)
138141
if in_domain(idx, arg_space)
139-
u[vt] = cov3 Operators.getidx(space, arg, idx, hidx)
142+
u[si] = cov3 Operators.getidx(space, arg, idx, hidx)
140143
end
141144
if on_any_boundary(idx, space, op)
142145
lloc = Operators.left_boundary_window(space)
@@ -149,10 +152,10 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
149152
ub = Operators.getidx(space, bc.val, nothing, hidx)
150153
bu = on_left_boundary(idx, space) ? lb : rb
151154
if bc isa Operators.SetValue
152-
bu[1] = cov3 ub
155+
bu[bi] = cov3 ub
153156
elseif bc isa Operators.SetGradient
154157
lg = Geometry.LocalGeometry(space, idx, hidx)
155-
bu[1] = Geometry.project(Geometry.Covariant3Axis(), ub, lg)
158+
bu[bi] = Geometry.project(Geometry.Covariant3Axis(), ub, lg)
156159
elseif bc isa Operators.Extrapolate # no shmem needed
157160
end
158161
end
@@ -169,11 +172,12 @@ Base.@propagate_inbounds function fd_operator_evaluate(
169172
args...,
170173
)
171174
@inbounds begin
172-
vt = threadIdx().x
175+
si = ShmemIndex()
176+
bi = FDShmemBoundaryIndex()
173177
lg = Geometry.LocalGeometry(space, idx, hidx)
174178
if !on_boundary(idx, space, op)
175-
u₋ = u[vt - 1] # corresponds to idx - half
176-
u₊ = u[vt] # corresponds to idx + half
179+
u₋ = u[si - 1] # corresponds to idx - half
180+
u₊ = u[si] # corresponds to idx + half
177181
return u₊ u₋
178182
else
179183
bloc =
@@ -184,15 +188,15 @@ Base.@propagate_inbounds function fd_operator_evaluate(
184188
@assert bc isa Operators.SetValue
185189
if on_left_boundary(idx, space)
186190
if bc isa Operators.SetValue
187-
u₋ = 2 * lb[1] # corresponds to idx - half
188-
u₊ = 2 * u[vt] # corresponds to idx + half
191+
u₋ = 2 * lb[bi] # corresponds to idx - half
192+
u₊ = 2 * u[si] # corresponds to idx + half
189193
return u₊ u₋
190194
end
191195
else
192196
@assert on_right_boundary(idx, space)
193197
if bc isa Operators.SetValue
194-
u₋ = 2 * u[vt - 1] # corresponds to idx - half
195-
u₊ = 2 * rb[1] # corresponds to idx + half
198+
u₋ = 2 * u[si - 1] # corresponds to idx - half
199+
u₊ = 2 * rb[bi] # corresponds to idx + half
196200
return u₊ u₋
197201
end
198202
end
@@ -226,9 +230,10 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
226230
)
227231
@inbounds begin
228232
is_out_of_bounds(idx, space) && return nothing
229-
ᶜidx = get_cent_idx(idx)
233+
si = ShmemIndex(idx)
234+
bi = FDShmemBoundaryIndex()
230235
if in_domain(idx, arg_space)
231-
u[idx] = Operators.getidx(space, arg, idx, hidx)
236+
u[si] = Operators.getidx(space, arg, idx, hidx)
232237
else
233238
lloc = Operators.left_boundary_window(space)
234239
rloc = Operators.right_boundary_window(space)
@@ -242,16 +247,16 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
242247
bc isa Operators.NullBoundaryCondition
243248
if bc isa Operators.NullBoundaryCondition ||
244249
bc isa Operators.Extrapolate
245-
u[idx] = Operators.getidx(space, arg, idx, hidx)
250+
u[si] = Operators.getidx(space, arg, idx, hidx)
246251
return nothing
247252
end
248253
bu = on_left_boundary(idx, space) ? lb : rb
249254
ub = Operators.getidx(space, bc.val, nothing, hidx)
250255
if bc isa Operators.SetValue
251-
bu[1] = ub
256+
bu[bi] = ub
252257
elseif bc isa Operators.SetGradient
253258
lg = Geometry.LocalGeometry(space, idx, hidx)
254-
bu[1] = Geometry.covariant3(ub, lg)
259+
bu[bi] = Geometry.covariant3(ub, lg)
255260
end
256261
end
257262
end
@@ -270,9 +275,11 @@ Base.@propagate_inbounds function fd_operator_evaluate(
270275
vt = threadIdx().x
271276
lg = Geometry.LocalGeometry(space, idx, hidx)
272277
ᶜidx = get_cent_idx(idx)
278+
si = ShmemIndex(ᶜidx)
279+
bi = FDShmemBoundaryIndex()
273280
if !on_boundary(idx, space, op)
274-
u₋ = u[ᶜidx - 1] # corresponds to idx - half
275-
u₊ = u[ᶜidx] # corresponds to idx + half
281+
u₋ = u[si - 1] # corresponds to idx - half
282+
u₊ = u[si] # corresponds to idx + half
276283
return RecursiveApply.rdiv(u₊ u₋, 2)
277284
else
278285
bloc =
@@ -285,26 +292,26 @@ Base.@propagate_inbounds function fd_operator_evaluate(
285292
bc isa Operators.Extrapolate
286293
if on_left_boundary(idx, space)
287294
if bc isa Operators.SetValue
288-
return lb[1]
295+
return lb[bi]
289296
elseif bc isa Operators.SetGradient
290-
u₋ = lb[1] # corresponds to idx - half
291-
u₊ = u[ᶜidx] # corresponds to idx + half
297+
u₋ = lb[bi] # corresponds to idx - half
298+
u₊ = u[si] # corresponds to idx + half
292299
return u₊ RecursiveApply.rdiv(u₋, 2)
293300
else
294301
@assert bc isa Operators.Extrapolate
295-
return u[ᶜidx]
302+
return u[si]
296303
end
297304
else
298305
@assert on_right_boundary(idx, space)
299306
if bc isa Operators.SetValue
300-
return rb[1]
307+
return rb[bi]
301308
elseif bc isa Operators.SetGradient
302-
u₋ = u[ᶜidx - 1] # corresponds to idx - half
303-
u₊ = rb[1] # corresponds to idx + half
309+
u₋ = u[si - 1] # corresponds to idx - half
310+
u₊ = rb[bi] # corresponds to idx + half
304311
return u₋ RecursiveApply.rdiv(u₊, 2)
305312
else
306313
@assert bc isa Operators.Extrapolate
307-
return u[ᶜidx - 1]
314+
return u[si - 1]
308315
end
309316
end
310317
end

ext/cuda/operators_finite_difference.jl

+25
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ struct ShmemParams{Nv} end
2626
interior_size(::ShmemParams{Nv}) where {Nv} = (Nv,)
2727
boundary_size(::ShmemParams{Nv}) where {Nv} = (1,)
2828

29+
struct ShmemIndex{T}
30+
v::T
31+
col_id::T
32+
end
33+
@inline function FDShmemIndex()
34+
v = threadIdx().x
35+
return ShmemIndex(v, typeof(v)(1))
36+
end
37+
@inline function FDShmemIndex(v)
38+
return ShmemIndex(v, typeof(v)(1))
39+
end
40+
@inline FDShmemBoundaryIndex() = ShmemIndex(1, 1)
41+
42+
# Base.getindex(a::AbstractArray, si::ShmemIndex) = Base.getindex(a, si.v, si.col_id)
43+
# Base.setindex!(a::AbstractArray, val, si::ShmemIndex) = Base.setindex!(a, val, si.v, si.col_id)
44+
Base.@propagate_inbounds Base.getindex(a::AbstractArray, si::ShmemIndex) =
45+
Base.getindex(a, si.v)
46+
Base.@propagate_inbounds Base.setindex!(a::AbstractArray, val, si::ShmemIndex) =
47+
Base.setindex!(a, val, si.v)
48+
49+
@inline Base.:+(si::ShmemIndex{T}, i::Integer) where {T} =
50+
ShmemIndex{T}(si.v + T(i), si.col_id)
51+
@inline Base.:-(si::ShmemIndex{T}, i::Integer) where {T} =
52+
ShmemIndex{T}(si.v - T(i), si.col_id)
53+
2954
function Base.copyto!(
3055
out::Field,
3156
bc::Union{

0 commit comments

Comments
 (0)