Skip to content

Add ability to shuffle (and reshuffle) batches #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Classes for iterating through xarray datarrays / datasets in batches."""

import itertools
import random
import warnings
from operator import itemgetter
from typing import Any, Dict, Hashable, Iterator, List, Optional, Sequence, Union
Expand Down Expand Up @@ -45,6 +46,8 @@ class BatchSchema:
preload_batch : bool, optional
If ``True``, each batch will be loaded into memory before reshaping /
processing, triggering any dask arrays to be computed.
shuffle : bool, optional
If ``True``, batches will be in a shuffled order

Notes
-----
Expand All @@ -59,6 +62,7 @@ def __init__(
batch_dims: Optional[Dict[Hashable, int]] = None,
concat_input_bins: bool = True,
preload_batch: bool = True,
shuffle: bool = False,
):
if input_overlap is None:
input_overlap = {}
Expand All @@ -69,6 +73,7 @@ def __init__(
self.batch_dims = dict(batch_dims)
self.concat_input_dims = concat_input_bins
self.preload_batch = preload_batch
self.shuffle = shuffle
# Store helpful information based on arguments
self._duplicate_batch_dims: Dict[Hashable, int] = {
dim: length
Expand Down Expand Up @@ -98,6 +103,9 @@ def _gen_batch_selectors(
"""
# Create an iterator that returns an object usable for .isel in xarray
patch_selectors = self._gen_patch_selectors(ds)
if self.shuffle:
patch_selectors = list(patch_selectors)
random.shuffle(patch_selectors)
# Create the Dict containing batch selectors
if self.concat_input_dims: # Combine the patches into batches
return self._combine_patches_into_batch(ds, patch_selectors)
Expand Down Expand Up @@ -364,6 +372,8 @@ class BatchGenerator:
preload_batch : bool, optional
If ``True``, each batch will be loaded into memory before reshaping /
processing, triggering any dask arrays to be computed.
shuffle : bool, optional
If ``True`` batches will be randomly shuffled

Yields
------
Expand All @@ -379,6 +389,7 @@ def __init__(
batch_dims: Dict[Hashable, int] = {},
concat_input_dims: bool = False,
preload_batch: bool = True,
shuffle: bool = False,
):
self.ds = ds
self._batch_selectors: BatchSchema = BatchSchema(
Expand All @@ -388,6 +399,7 @@ def __init__(
batch_dims=batch_dims,
concat_input_bins=concat_input_dims,
preload_batch=preload_batch,
shuffle=shuffle,
)

@property
Expand All @@ -410,6 +422,15 @@ def concat_input_dims(self):
def preload_batch(self):
return self._batch_selectors.preload_batch

def reshuffle(self):
shuffle_idx = list(self._batch_selectors.selectors)
random.shuffle(shuffle_idx)
self._batch_selectors.selectors = {
idx: self._batch_selectors.selectors[shuffled_idx]
for idx, shuffled_idx in zip(self._batch_selectors.selectors, shuffle_idx)
}


def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]:
for idx in self._batch_selectors.selectors:
yield self[idx]
Expand Down