Skip to content

WIP: DataTree slice generator function #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/requirements/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- sphinx-autosummary-accessors
- sphinx-copybutton
- xarray
- xarray-datatree
# For examples
- adlfs
- ipykernel
Expand Down
1 change: 1 addition & 0 deletions ci/requirements/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- pytest-cov
- pytorch
- tensorflow
- xarray-datatree
- zarr
# Style checks
- black
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"xarray",
]
[project.optional-dependencies]
datatree = ["xarray-datatree"]
torch = [
"torch",
]
Expand All @@ -44,6 +45,7 @@ dev = [
"pytest-cov",
"tensorflow",
"torch",
"xarray-datatree",
]
[project.urls]
documentation = "https://xbatcher.readthedocs.io/en/latest/"
Expand All @@ -58,7 +60,7 @@ fallback_version = "999"

[tool.isort]
profile = "black"
known_third_party = ["numpy", "pandas", "pytest", "sphinx_autosummary_accessors", "torch", "xarray"]
known_third_party = ["datatree", "numpy", "pandas", "pytest", "sphinx_autosummary_accessors", "torch", "xarray"]

[tool.pytest.ini_options]
log_cli = true
Expand Down
6 changes: 5 additions & 1 deletion xbatcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from . import testing # noqa: F401
from .accessors import BatchAccessor # noqa: F401
from .generators import BatchGenerator, BatchSchema # noqa: F401
from .generators import ( # noqa: F401
BatchGenerator,
BatchSchema,
datatree_slice_generator,
)
from .util.print_versions import show_versions # noqa: F401

try:
Expand Down
71 changes: 71 additions & 0 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import xarray as xr
from datatree import DataTree

PatchGenerator = Iterator[Dict[Hashable, slice]]
BatchSelector = List[Dict[Hashable, slice]]
Expand Down Expand Up @@ -462,3 +463,73 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:
)
else:
raise IndexError("list index out of range")


def datatree_slice_generator(
data_obj: DataTree,
dim_strides: Dict[str, int],
ref_node: str,
**kwargs,
) -> DataTree:
"""
Generator for iterating through an Xarray DataTree.

Parameters
----------
data_obj : ``datatree.DataTree``
The multi-dimensional xarray data object to iterate over.

dim_strides : dict
A dictionary specifying the size of the stride in each dimension to
slice along, e.g. ``{'longitude': 30, 'latitude': 30}``. These are the
dimensions the machine library will see.

ref_node : str
The reference node in the xarray DataTree object whose minimum and
maximum coordinate bounds will be used to iteratively slice over.

kwargs : dict
Extra keyword arguments to pass into datatree.Datatree.sel

Yields
------
xr_slice : : ``datatree.DataTree``
A single slice of a multi-dimensional xarray data object.
"""
# Get the coordinate positions to slice on for each dimension, e.g.
# {'y': [(45.0, 25.0), (25.0, 5.0)],
# 'x': [(95.0, 115.0), (115.0, 135.0), (135.0, 155.0)]}
slice_positions: Dict[str, list] = {}
for dim, stride in dim_strides.items():
first_coord: np.ndarray = data_obj[ref_node].isel({dim: 0})[dim].data
final_coord: np.ndarray = data_obj[ref_node].isel({dim: -1})[dim].data
slices: int = int(np.floor((final_coord - first_coord) / stride))

# Obtain slicing positions. Need to minus and plus half a pixel because
# xarray uses centre-based coordinates
resolution: float = (final_coord - first_coord) / slices
start_positions: np.ndarray = np.linspace(
start=first_coord - (resolution / 2),
stop=final_coord - (resolution / 2),
num=slices + 1,
)
stop_positions: np.ndarray = np.linspace(
start=first_coord + (resolution / 2),
stop=final_coord + (resolution / 2),
num=slices + 1,
)
slice_positions[dim] = [tup for tup in zip(start_positions, stop_positions)]
assert slice_positions.keys() == dim_strides.keys()

# Iterate over combinations of slice positions across many dimensions, e.g.
# ((45.0, 25.0), (95.0, 115.0))
# ((45.0, 25.0), (115.0, 135.0))
# ((45.0, 25.0), (135.0, 155.0))
# ((25.0, 5.0), (95.0, 115.0))
# ((25.0, 5.0), (115.0, 135.0))
# ((25.0, 5.0), (135.0, 155.0))
for slice_pos in itertools.product(*slice_positions.values()):
indexers: Dict[str, slice] = {}
for i, dim in enumerate(dim_strides):
indexers[dim] = slice(*slice_pos[i])
yield data_obj.sel(indexers=indexers, **kwargs)
80 changes: 80 additions & 0 deletions xbatcher/tests/test_datatree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import datatree.testing as dtt
import numpy as np
import pytest
import xarray as xr
from datatree import DataTree

from xbatcher import datatree_slice_generator


@pytest.fixture(scope="module")
def sample_datatree() -> DataTree:
"""
Sample multi-resolution DataTree for testing.

DataTree('None', parent=None)
├── DataTree('grid10m')
│ Dimensions: (y: 4, x: 6)
│ Coordinates:
│ * y (y) float64 40.0 30.0 20.0 10.0
│ * x (x) float64 100.0 110.0 120.0 130.0 140.0 150.0
│ Data variables:
│ grid10m (y, x) int64 10 11 12 13 14 15 16 17 18 ... 26 27 28 29 30 31 32 33
└── DataTree('grid20m')
Dimensions: (y: 2, x: 3)
Coordinates:
* y (y) float64 35.0 15.0
* x (x) float64 105.0 125.0 145.0
Data variables:
grid20m (y, x) int64 0 1 2 3 4 5
"""
grid20m = xr.DataArray(
data=np.arange(0, 6).reshape(2, 3),
dims=("y", "x"),
coords={"y": np.linspace(35.0, 15.0, num=2), "x": np.linspace(105, 145, num=3)},
name="grid20m",
)
grid10m = xr.DataArray(
data=np.arange(10, 34).reshape(4, 6),
dims=("y", "x"),
coords={"y": np.linspace(40.0, 10.0, num=4), "x": np.linspace(100, 150, num=6)},
name="grid10m",
)
dt = DataTree.from_dict(d={"grid10m": grid10m, "grid20m": grid20m})
return dt


@pytest.fixture(scope="module")
def expected_datatree() -> DataTree:
""" """
expected_datatree = DataTree.from_dict(
d={
"grid10m": xr.DataArray(
data=[[26, 27], [32, 33]],
dims=("y", "x"),
coords={"y": [20.0, 10.0], "x": [140.0, 150.0]},
name="grid10m",
),
"grid20m": xr.DataArray(
data=[[5]],
dims=("y", "x"),
coords={"y": [15.0], "x": [145.0]},
name="grid20m",
),
}
)
return expected_datatree


def test_datatree(sample_datatree, expected_datatree):
"""
Test slicing through a multi-resolution DataTree.
"""
generator = datatree_slice_generator(
data_obj=sample_datatree, dim_strides={"y": -20, "x": 20}, ref_node="grid20m"
)
for i, chip in enumerate(generator):
pass

assert i + 1 == 6 # number of chips
dtt.assert_identical(a=chip, b=expected_datatree) # check returned DataTree