diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index d483372..5e5e9d2 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -15,6 +15,7 @@ dependencies: - sphinx-autosummary-accessors - sphinx-copybutton - xarray + - xarray-datatree # For examples - adlfs - ipykernel diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 3e0b102..7ca4663 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -16,6 +16,7 @@ dependencies: - pytest-cov - pytorch - tensorflow + - xarray-datatree - zarr # Style checks - black diff --git a/pyproject.toml b/pyproject.toml index 11ab076..99141fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "xarray", ] [project.optional-dependencies] +datatree = ["xarray-datatree"] torch = [ "torch", ] @@ -44,6 +45,7 @@ dev = [ "pytest-cov", "tensorflow", "torch", + "xarray-datatree", ] [project.urls] documentation = "https://xbatcher.readthedocs.io/en/latest/" @@ -58,7 +60,7 @@ fallback_version = "999" [tool.isort] profile = "black" -known_third_party = ["numpy", "pandas", "pytest", "sphinx_autosummary_accessors", "torch", "xarray"] +known_third_party = ["datatree", "numpy", "pandas", "pytest", "sphinx_autosummary_accessors", "torch", "xarray"] [tool.pytest.ini_options] log_cli = true diff --git a/xbatcher/__init__.py b/xbatcher/__init__.py index 6fb8d75..19f32ad 100644 --- a/xbatcher/__init__.py +++ b/xbatcher/__init__.py @@ -3,7 +3,11 @@ from . import testing # noqa: F401 from .accessors import BatchAccessor # noqa: F401 -from .generators import BatchGenerator, BatchSchema # noqa: F401 +from .generators import ( # noqa: F401 + BatchGenerator, + BatchSchema, + datatree_slice_generator, +) from .util.print_versions import show_versions # noqa: F401 try: diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 58799ba..07c3a1b 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -7,6 +7,7 @@ import numpy as np import xarray as xr +from datatree import DataTree PatchGenerator = Iterator[Dict[Hashable, slice]] BatchSelector = List[Dict[Hashable, slice]] @@ -462,3 +463,73 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: ) else: raise IndexError("list index out of range") + + +def datatree_slice_generator( + data_obj: DataTree, + dim_strides: Dict[str, int], + ref_node: str, + **kwargs, +) -> DataTree: + """ + Generator for iterating through an Xarray DataTree. + + Parameters + ---------- + data_obj : ``datatree.DataTree`` + The multi-dimensional xarray data object to iterate over. + + dim_strides : dict + A dictionary specifying the size of the stride in each dimension to + slice along, e.g. ``{'longitude': 30, 'latitude': 30}``. These are the + dimensions the machine library will see. + + ref_node : str + The reference node in the xarray DataTree object whose minimum and + maximum coordinate bounds will be used to iteratively slice over. + + kwargs : dict + Extra keyword arguments to pass into datatree.Datatree.sel + + Yields + ------ + xr_slice : : ``datatree.DataTree`` + A single slice of a multi-dimensional xarray data object. + """ + # Get the coordinate positions to slice on for each dimension, e.g. + # {'y': [(45.0, 25.0), (25.0, 5.0)], + # 'x': [(95.0, 115.0), (115.0, 135.0), (135.0, 155.0)]} + slice_positions: Dict[str, list] = {} + for dim, stride in dim_strides.items(): + first_coord: np.ndarray = data_obj[ref_node].isel({dim: 0})[dim].data + final_coord: np.ndarray = data_obj[ref_node].isel({dim: -1})[dim].data + slices: int = int(np.floor((final_coord - first_coord) / stride)) + + # Obtain slicing positions. Need to minus and plus half a pixel because + # xarray uses centre-based coordinates + resolution: float = (final_coord - first_coord) / slices + start_positions: np.ndarray = np.linspace( + start=first_coord - (resolution / 2), + stop=final_coord - (resolution / 2), + num=slices + 1, + ) + stop_positions: np.ndarray = np.linspace( + start=first_coord + (resolution / 2), + stop=final_coord + (resolution / 2), + num=slices + 1, + ) + slice_positions[dim] = [tup for tup in zip(start_positions, stop_positions)] + assert slice_positions.keys() == dim_strides.keys() + + # Iterate over combinations of slice positions across many dimensions, e.g. + # ((45.0, 25.0), (95.0, 115.0)) + # ((45.0, 25.0), (115.0, 135.0)) + # ((45.0, 25.0), (135.0, 155.0)) + # ((25.0, 5.0), (95.0, 115.0)) + # ((25.0, 5.0), (115.0, 135.0)) + # ((25.0, 5.0), (135.0, 155.0)) + for slice_pos in itertools.product(*slice_positions.values()): + indexers: Dict[str, slice] = {} + for i, dim in enumerate(dim_strides): + indexers[dim] = slice(*slice_pos[i]) + yield data_obj.sel(indexers=indexers, **kwargs) diff --git a/xbatcher/tests/test_datatree.py b/xbatcher/tests/test_datatree.py new file mode 100644 index 0000000..65168d3 --- /dev/null +++ b/xbatcher/tests/test_datatree.py @@ -0,0 +1,80 @@ +import datatree.testing as dtt +import numpy as np +import pytest +import xarray as xr +from datatree import DataTree + +from xbatcher import datatree_slice_generator + + +@pytest.fixture(scope="module") +def sample_datatree() -> DataTree: + """ + Sample multi-resolution DataTree for testing. + + DataTree('None', parent=None) + ├── DataTree('grid10m') + │ Dimensions: (y: 4, x: 6) + │ Coordinates: + │ * y (y) float64 40.0 30.0 20.0 10.0 + │ * x (x) float64 100.0 110.0 120.0 130.0 140.0 150.0 + │ Data variables: + │ grid10m (y, x) int64 10 11 12 13 14 15 16 17 18 ... 26 27 28 29 30 31 32 33 + └── DataTree('grid20m') + Dimensions: (y: 2, x: 3) + Coordinates: + * y (y) float64 35.0 15.0 + * x (x) float64 105.0 125.0 145.0 + Data variables: + grid20m (y, x) int64 0 1 2 3 4 5 + """ + grid20m = xr.DataArray( + data=np.arange(0, 6).reshape(2, 3), + dims=("y", "x"), + coords={"y": np.linspace(35.0, 15.0, num=2), "x": np.linspace(105, 145, num=3)}, + name="grid20m", + ) + grid10m = xr.DataArray( + data=np.arange(10, 34).reshape(4, 6), + dims=("y", "x"), + coords={"y": np.linspace(40.0, 10.0, num=4), "x": np.linspace(100, 150, num=6)}, + name="grid10m", + ) + dt = DataTree.from_dict(d={"grid10m": grid10m, "grid20m": grid20m}) + return dt + + +@pytest.fixture(scope="module") +def expected_datatree() -> DataTree: + """ """ + expected_datatree = DataTree.from_dict( + d={ + "grid10m": xr.DataArray( + data=[[26, 27], [32, 33]], + dims=("y", "x"), + coords={"y": [20.0, 10.0], "x": [140.0, 150.0]}, + name="grid10m", + ), + "grid20m": xr.DataArray( + data=[[5]], + dims=("y", "x"), + coords={"y": [15.0], "x": [145.0]}, + name="grid20m", + ), + } + ) + return expected_datatree + + +def test_datatree(sample_datatree, expected_datatree): + """ + Test slicing through a multi-resolution DataTree. + """ + generator = datatree_slice_generator( + data_obj=sample_datatree, dim_strides={"y": -20, "x": 20}, ref_node="grid20m" + ) + for i, chip in enumerate(generator): + pass + + assert i + 1 == 6 # number of chips + dtt.assert_identical(a=chip, b=expected_datatree) # check returned DataTree