Skip to content

Distributed regridding v2 - source data on distributed space #1175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 169 additions & 35 deletions lib/ClimaCoreTempestRemap/src/onlineremap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using MPI


"""
LinearMap{T, S, M, C, V, M}
LinearMap{S, T, W, I, V, M}

stores information on the TempestRemap map and the source and target data:

Expand All @@ -17,6 +17,7 @@ where:
- `col_indices` are the source column indices from TempestRemap. (length = number of overlap-mesh nodes)
- `row_indices` are the target row indices from TempestRemap. (length = number of overlap-mesh nodes)
- `out_type` string that defines the output type
- `source_global_elem_lidx` is a mapping from global to local indices on the source space
- `target_global_elem_lidx` is a mapping from global to local indices on the target space

"""
Expand All @@ -29,6 +30,7 @@ struct LinearMap{S, T, W, I, V, M} # make consistent with / move to regridding.j
col_indices::V
row_indices::V
out_type::String
source_global_elem_lidx::M
target_global_elem_lidx::M
end

Expand Down Expand Up @@ -73,20 +75,36 @@ function remap!(
return target
end

# This version of this function is used for distributed remapping
function remap!(target::Fields.Field, R::LinearMap, source::Fields.Field)
# Serial remapping case
if Spaces.topology(axes(target)).context isa
ClimaComms.SingletonCommsContext &&
Spaces.topology(axes(source)).context isa
ClimaComms.SingletonCommsContext
@assert axes(source) == R.source_space
@assert axes(target) == R.target_space
# we use the tempestremap cgll representation
# it will set the redundant nodes to zero
remap!(Fields.field_values(target), R, Fields.field_values(source))
return target
# Mixed serial/distributed case - error
elseif Spaces.topology(axes(target)).context isa
ClimaComms.SingletonCommsContext ||
Spaces.topology(axes(source)).context isa
ClimaComms.SingletonCommsContext
error(
"Remapping is only possible between two serial spaces or two distributed spaces.",
)
# Distributed remapping case
else
# For now, the source data must be on a non-distributed space
@assert Spaces.topology(axes(source)).context isa
ClimaComms.SingletonCommsContext
@assert !(
Spaces.topology(axes(source)).context isa
ClimaComms.SingletonCommsContext
)
@assert !(
Spaces.topology(axes(target)).context isa
ClimaComms.SingletonCommsContext
)

target_array = parent(target)
source_array = parent(source)
Expand All @@ -99,28 +117,40 @@ function remap!(target::Fields.Field, R::LinearMap, source::Fields.Field)
for (n, wt) in enumerate(R.weights)
# choose all global source indices
# (for simple distr. remapping with broadcasted source data, no halo exchange)
is, js, es = map(
x -> x[1],
(
view(R.source_idxs[1], n),
view(R.source_idxs[2], n),
view(R.source_idxs[3], n),
),
)
# TODO check: when sending all source data to all processes, this should give same result as below selection
# is, js, es = map(
# x -> x[1],
# (
# view(R.source_idxs[1], n),
# view(R.source_idxs[2], n),
# view(R.source_idxs[3], n),
# ),
# )

# choose only the target inds local to this process (skip inds from other processes)
target_elem_gidx = view(R.target_idxs[3], n)[1]
if !(target_elem_gidx in keys(R.target_global_elem_lidx))
# choose only the source inds local to this process (skip inds from other processes)
source_elem_gidx = view(R.source_idxs[3], n)[1]
if !(source_elem_gidx in keys(R.source_global_elem_lidx))
continue
else
# get global i, j inds but local elem index e
it = view(R.target_idxs[1], n)[1]
jt = view(R.target_idxs[2], n)[1]
et = R.target_global_elem_lidx[target_elem_gidx]
is = view(R.source_idxs[1], n)[1]
js = view(R.source_idxs[2], n)[1]
es = R.source_global_elem_lidx[source_elem_gidx]

# choose only the target inds local to this process (skip inds from other processes)
target_elem_gidx = view(R.target_idxs[3], n)[1]
if !(target_elem_gidx in keys(R.target_global_elem_lidx))
continue
else
# get global i, j inds but local elem index e
it = view(R.target_idxs[1], n)[1]
jt = view(R.target_idxs[2], n)[1]
et = R.target_global_elem_lidx[target_elem_gidx]

for f in 1:Nf
target_array[it, jt, f, et] +=
wt * source_array[is, js, f, es]
for f in 1:Nf
target_array[it, jt, f, et] +=
wt * source_array[is, js, f, es]
end
end
end

Expand Down Expand Up @@ -153,31 +183,32 @@ end


"""
generate_map(target_space, source_space; in_type="cgll", out_type="cgll")
generate_map(comms_ctx, target_space, source_space; in_type="cgll", out_type="cgll")

Generate the remapping weights from TempestRemap, returning a `LinearMap` object. This should only be called once.
"""
# TODO change order of target, source args
function generate_map(
comms_ctx::ClimaCommsMPI.MPICommsContext,
target_space::Spaces.SpectralElementSpace2D,
source_space::Spaces.SpectralElementSpace2D;
target_space_distr = nothing,
source_space_distr = nothing,
meshfile_source = tempname(),
meshfile_target = tempname(),
meshfile_overlap = tempname(),
weightfile = tempname(),
in_type = "cgll",
out_type = "cgll",
)
if (target_space_distr != nothing)
comms_ctx = target_space_distr.topology.context
else
comms_ctx = target_space.topology.context
end
# TODO change all target_space, source_space uses to _distr so we can remove serial spaces
@assert target_space_distr.topology.context == comms_ctx
@assert source_space_distr.topology.context == comms_ctx

if ClimaComms.iamroot(comms_ctx)
# write meshes and generate weights on root process (using global indices)
write_exodus(meshfile_source, source_space.topology)
write_exodus(meshfile_target, target_space.topology)
write_exodus(meshfile_source, source_space_distr.topology)
write_exodus(meshfile_target, target_space_distr.topology)
overlap_mesh(meshfile_overlap, meshfile_source, meshfile_target)
remap_weights(
weightfile,
Expand All @@ -186,11 +217,11 @@ function generate_map(
meshfile_overlap;
in_type = in_type,
in_np = Spaces.Quadratures.degrees_of_freedom(
source_space.quadrature_style,
source_space_distr.quadrature_style,
),
out_type = out_type,
out_np = Spaces.Quadratures.degrees_of_freedom(
target_space.quadrature_style,
target_space_distr.quadrature_style,
),
)

Expand All @@ -210,12 +241,13 @@ function generate_map(


# we need to be able to look up the indices of unique nodes
# TODO extend unique_nodes for distributed spaces
source_unique_idxs =
in_type == "cgll" ? collect(Spaces.unique_nodes(source_space)) :
collect(Spaces.all_nodes(source_space))
collect(Spaces.all_nodes(source_space_distr))
target_unique_idxs =
out_type == "cgll" ? collect(Spaces.unique_nodes(target_space)) :
collect(Spaces.all_nodes(target_space))
collect(Spaces.all_nodes(target_space_distr))

# re-order to avoid unnecessary allocations
source_unique_idxs_i =
Expand Down Expand Up @@ -255,7 +287,17 @@ function generate_map(
end
ClimaComms.barrier(comms_ctx)

# Create mapping from global to local element indices (for distributed remapping)
# Create mappings from global to local element indices (for distributed remapping)
if (source_space_distr != nothing)
source_local_elem_gidx = source_space_distr.topology.local_elem_gidx # gidx = local_elem_gidx[lidx]
source_global_elem_lidx = Dict{Int, Int}() # inverse of local_elem_gidx: lidx = global_elem_lidx[gidx]
for (lidx, gidx) in enumerate(source_local_elem_gidx)
source_global_elem_lidx[gidx] = lidx
end
else
source_global_elem_lidx = nothing
end

if (target_space_distr != nothing)
target_local_elem_gidx = target_space_distr.topology.local_elem_gidx # gidx = local_elem_gidx[lidx]
target_global_elem_lidx = Dict{Int, Int}() # inverse of local_elem_gidx: lidx = global_elem_lidx[gidx]
Expand All @@ -266,6 +308,97 @@ function generate_map(
target_global_elem_lidx = nothing
end

return LinearMap(
source_space_distr,
target_space_distr,
weights,
source_unique_idxs,
target_unique_idxs,
col_indices,
row_indices,
out_type,
source_global_elem_lidx,
target_global_elem_lidx,
)
end

"""
generate_map(comms_ctx, target_space, source_space; in_type="cgll", out_type="cgll")

Generate the remapping weights from TempestRemap, returning a `LinearMap` object. This should only be called once.
"""
function generate_map(
comms_ctx::ClimaComms.SingletonCommsContext,
target_space::Spaces.SpectralElementSpace2D,
source_space::Spaces.SpectralElementSpace2D;
meshfile_source = tempname(),
meshfile_target = tempname(),
meshfile_overlap = tempname(),
weightfile = tempname(),
in_type = "cgll",
out_type = "cgll",
)
# write meshes and generate weights on root process (using global indices)
write_exodus(meshfile_source, source_space.topology)
write_exodus(meshfile_target, target_space.topology)
overlap_mesh(meshfile_overlap, meshfile_source, meshfile_target)
remap_weights(
weightfile,
meshfile_source,
meshfile_target,
meshfile_overlap;
in_type = in_type,
in_np = Spaces.Quadratures.degrees_of_freedom(
source_space.quadrature_style,
),
out_type = out_type,
out_np = Spaces.Quadratures.degrees_of_freedom(
target_space.quadrature_style,
),
)

# read weight data
weights, col_indices, row_indices = NCDataset(weightfile, "r") do ds_wt
(Array(ds_wt["S"]), Array(ds_wt["col"]), Array(ds_wt["row"]))
end
# TempestRemap exports in CSR format (i.e. row_indices is sorted)

# TODO: add in extra rows to avoid DSS step
# - for each unique node, we would add extra rows for all the duplicates, with the same column
# - ideally keep in CSR format
# e.g. iterate by row (i.e. target node):
# - if new node, then append it
# - if not new, copy previous entries
# - requires a mechanism to query whether find the first common node of a given node

# we need to be able to look up the indices of unique nodes
source_unique_idxs =
in_type == "cgll" ? collect(Spaces.unique_nodes(source_space)) :
collect(Spaces.all_nodes(source_space))
target_unique_idxs =
out_type == "cgll" ? collect(Spaces.unique_nodes(target_space)) :
collect(Spaces.all_nodes(target_space))

# re-order to avoid unnecessary allocations
source_unique_idxs_i =
map(col -> source_unique_idxs[col][1][1], col_indices)
source_unique_idxs_j =
map(col -> source_unique_idxs[col][1][2], col_indices)
source_unique_idxs_e = map(col -> source_unique_idxs[col][2], col_indices)
target_unique_idxs_i =
map(row -> target_unique_idxs[row][1][1], row_indices)
target_unique_idxs_j =
map(row -> target_unique_idxs[row][1][2], row_indices)
target_unique_idxs_e = map(row -> target_unique_idxs[row][2], row_indices)

source_unique_idxs =
(source_unique_idxs_i, source_unique_idxs_j, source_unique_idxs_e)
target_unique_idxs =
(target_unique_idxs_i, target_unique_idxs_j, target_unique_idxs_e)

source_global_elem_lidx = nothing
target_global_elem_lidx = nothing

return LinearMap(
source_space,
target_space,
Expand All @@ -275,6 +408,7 @@ function generate_map(
col_indices,
row_indices,
out_type,
source_global_elem_lidx,
target_global_elem_lidx,
)
end
13 changes: 9 additions & 4 deletions lib/ClimaCoreTempestRemap/test/mpi_tests/online_remap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ using Test
Spaces.SpectralElementSpace2D(topology_o_singleton, quad_o)

# generate test data in the Field format (on non-distributed topo. for non-halo exchange version)
field_i_singleton = sind.(Fields.coordinate_field(space_i_singleton).long)
# field_i_singleton = sind.(Fields.coordinate_field(space_i_singleton).long)
field_i_distr = sind.(Fields.coordinate_field(space_i_distr).long)

# global exchange (no buffer fill/send here) - convert to superhalo in next implementation
# TODO update
root_pid = 0
ClimaComms.gather(comms_ctx, parent(field_i_singleton))
field_i_singleton =
Expand All @@ -64,10 +66,11 @@ using Test
space_o_singleton,
space_i_singleton,
target_space_distr = space_o_distr,
source_space_distr = space_i_distr,
)

if ClimaComms.iamroot(comms_ctx)
# remap without MPI (for testing comparison) and plot solution
# remap without MPI (for testing comparison)
field_o_singleton = Fields.zeros(space_o_singleton)
CCTR.remap!(field_o_singleton, R, field_i_singleton)
end
Expand All @@ -76,7 +79,7 @@ using Test
field_o_distr = Fields.zeros(space_o_distr)

# apply the remapping to field_i_singleton and store the result in field_o_distr
CCTR.remap!(field_o_distr, R, field_i_singleton)
CCTR.remap!(field_o_distr, R, field_i_distr)

# compute analytical solution for comparison
field_ref = sind.(Fields.coordinate_field(space_o_distr).long)
Expand Down Expand Up @@ -129,8 +132,10 @@ using Test
# savefig(field_ref_fig, OUTPUT_DIR * "/target_data_soln")
# savefig(field_o_distr_fig, OUTPUT_DIR * "/target_data_mpi")

# compare distributed and serial solutions
# compare distributed solution to serial and analytical solutions
@test parent(restart_field_o_distr) ≈ parent(field_o_singleton) atol =
1e-20
@test parent(restart_field_o_distr) ≈ parent(restart_field_ref) atol =
1e-20
end
end