Skip to content

Handle reading multi-band datasetes #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -2,4 +2,7 @@
max-line-length = 120
exclude =
.pyi
typings
typings
ignore =
E203 # whitespace before ':'
W503 # line break before binary operator
34 changes: 34 additions & 0 deletions stackstac/nodata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Tuple, Union
import re

import numpy as np
from rasterio.windows import Window

State = Tuple[np.dtype, Union[int, float]]


def nodata_for_window(
ndim: int, window: Window, fill_value: Union[int, float], dtype: np.dtype
):
return np.full((ndim, window.height, window.width), fill_value, dtype)


def exception_matches(e: Exception, patterns: Tuple[Exception, ...]) -> bool:
"""
Whether an exception matches one of the pattern exceptions

Parameters
----------
e:
The exception to check
patterns:
Instances of an Exception type to catch, where ``str(exception_pattern)``
is a regex pattern to match against ``str(e)``.
"""
e_type = type(e)
e_msg = str(e)
for pattern in patterns:
if issubclass(e_type, type(pattern)):
if re.match(str(pattern), e_msg):
return True
return False
65 changes: 0 additions & 65 deletions stackstac/nodata_reader.py

This file was deleted.

45 changes: 41 additions & 4 deletions stackstac/prepare.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,16 @@
from .stac_types import ItemSequence
from . import accumulate_metadata, geom_utils

ASSET_TABLE_DT = np.dtype([("url", object), ("bounds", "float64", 4)])
ASSET_TABLE_DT = np.dtype(
[("url", object), ("bounds", "float64", 4), ("bands", object)]
)
# ^ NOTE: `bands` should be a `Sequence[int]` of _1-indexed_ bands to fetch from the asset.
# We support specifying a sequence of band indices (rather than just the number of bands,
# and always doing a `read()` of all bands) for future optimizations to support fetching
# (and possibly reordering?) a subset of bands per asset. This could be done either via
# another argument to `stack` (please no!) or a custom Dask optimization, akin to column
# projection for DataFrames.
# But at the moment, `bands == list(range(1, ds.count + 1))`.


class Mimetype(NamedTuple):
@@ -64,7 +73,7 @@ def prepare_items(
bounds: Optional[Bbox] = None,
bounds_latlon: Optional[Bbox] = None,
snap_bounds: bool = True,
) -> Tuple[np.ndarray, RasterSpec, List[str], ItemSequence]:
) -> Tuple[np.ndarray, RasterSpec, List[str], ItemSequence, tuple[int, ...]]:

if bounds is not None and bounds_latlon is not None:
raise ValueError(
@@ -119,6 +128,7 @@ def prepare_items(
asset_ids = assets

asset_table = np.full((len(items), len(asset_ids)), None, dtype=ASSET_TABLE_DT)
nbands_per_asset: list[int | None] = [None] * len(asset_ids)

# TODO support item-assets https://github.com/radiantearth/stac-spec/tree/master/extensions/item-assets

@@ -321,7 +331,25 @@ def prepare_items(
)

# Phew, we figured out all the spatial stuff! Now actually store the information we care about.
asset_table[item_i, asset_i] = (asset["href"], asset_bbox_proj)

bands: Optional[Sequence[int]] = None
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this is where we'd actually figure out band counts from STAC metadata

# ^ TODO actually determine this from `eo:bands` or `raster:bands`
# https://github.com/gjoseph92/stackstac/issues/62

nbands = 1 if bands is None else len(bands)
prev_nbands = nbands_per_asset[asset_i]
if prev_nbands is None:
nbands_per_asset[asset_i] = nbands
else:
if prev_nbands != nbands:
raise ValueError(
f"The asset {id!r} has {nbands} band(s) on item {item_i} {item['id']!r}, "
f"but on all previous items, it had {prev_nbands}."
# TODO improve this error message with something actionable
# (it's probably a data provider issue), once multi-band is actually supported.
)

asset_table[item_i, asset_i] = (asset["href"], asset_bbox_proj, bands)
# ^ NOTE: If `asset_bbox_proj` is None, NumPy automatically converts it to NaNs

# At this point, everything has been set (or there was as error)
@@ -346,9 +374,18 @@ def prepare_items(
if item_isnan.any() or asset_id_isnan.any():
asset_table = asset_table[np.ix_(~item_isnan, ~asset_id_isnan)]
asset_ids = [id for id, isnan in zip(asset_ids, asset_id_isnan) if not isnan]
nbands_per_asset = [
id for id, isnan in zip(nbands_per_asset, asset_id_isnan) if not isnan
]
items = [item for item, isnan in zip(items, item_isnan) if not isnan]

return asset_table, spec, asset_ids, items
# Being for the benefit of mr. typechecker
nbpa = tuple(x for x in nbands_per_asset if x is not None)
assert len(nbpa) == len(
nbands_per_asset
), f"Some `nbands_per_asset` are None: {nbands_per_asset}"

return asset_table, spec, asset_ids, items, nbpa


def to_coords(
27 changes: 22 additions & 5 deletions stackstac/reader_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from __future__ import annotations
from typing import Optional, Protocol, Tuple, Type, TYPE_CHECKING, TypeVar, Union
from typing import (
Optional,
Protocol,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)

import numpy as np

@@ -30,6 +39,7 @@ def __init__(
self,
*,
url: str,
bands: Optional[Sequence[int]],
spec: RasterSpec,
resampling: Resampling,
dtype: np.dtype,
@@ -45,6 +55,9 @@ def __init__(
----------
url:
Fetch data from the asset at this URL.
bands:
List of (one-indexed!) band indices to read, or None for all bands.
If None, the asset must have exactly one band.
spec:
Reproject data to match this georeferencing information.
resampling:
@@ -69,7 +82,6 @@ def __init__(
where ``str(exception_pattern)`` is a regex pattern to match against
``str(raised_exception)``.
"""
# TODO colormaps?

def read(self, window: Window) -> np.ndarray:
"""
@@ -87,7 +99,7 @@ def read(self, window: Window) -> np.ndarray:

Returns
-------
array: The window of data read
array: The window of data read from all bands, as a 3D array
"""
...

@@ -113,11 +125,16 @@ class FakeReader:
or inherent to the dask graph.
"""

def __init__(self, *, dtype: np.dtype, **kwargs) -> None:
def __init__(
self, *, bands: Optional[Sequence[int]], dtype: np.dtype, **kwargs
) -> None:
self.dtype = dtype
self.ndim = len(bands) if bands is not None else 1
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof, ndim is not the right name for this, since this is just the length of one dimension... it was a late night.


def read(self, window: Window, **kwargs) -> np.ndarray:
return np.random.random((window.height, window.width)).astype(self.dtype)
return np.random.random((self.ndim, window.height, window.width)).astype(
self.dtype
)

def close(self) -> None:
pass
97 changes: 82 additions & 15 deletions stackstac/rio_reader.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,16 @@
import logging
import threading
import warnings
from typing import TYPE_CHECKING, Optional, Protocol, Tuple, Type, TypedDict, Union
from typing import (
TYPE_CHECKING,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypedDict,
Union,
)

import numpy as np
import rasterio as rio
@@ -13,7 +22,7 @@
from .timer import time
from .reader_protocol import Reader
from .raster_spec import RasterSpec
from .nodata_reader import NodataReader, exception_matches, nodata_for_window
from .nodata import exception_matches, nodata_for_window

if TYPE_CHECKING:
from rasterio.enums import Resampling
@@ -70,7 +79,7 @@ def _curthread():
class ThreadsafeRioDataset(Protocol):
scale_offset: Tuple[float, float]

def read(self, window: Window, **kwargs) -> np.ndarray:
def read(self, bands: Sequence[int], window: Window, **kwargs) -> np.ndarray:
...

def close(self) -> None:
@@ -99,11 +108,11 @@ def __init__(

self._lock = threading.Lock()

def read(self, window: Window, **kwargs) -> np.ndarray:
def read(self, bands: Sequence[int], window: Window, **kwargs) -> np.ndarray:
"Acquire the lock, then read from the dataset"
reader = self.vrt or self.ds
with self._lock, self.env.read:
return reader.read(1, window=window, **kwargs)
return reader.read(bands, window=window, **kwargs)

def close(self) -> None:
"Acquire the lock, then close the dataset"
@@ -220,11 +229,11 @@ def dataset(self) -> Union[SelfCleaningDatasetReader, WarpedVRT]:
except AttributeError:
return self._open()

def read(self, window: Window, **kwargs) -> np.ndarray:
def read(self, bands: Sequence[int], window: Window, **kwargs) -> np.ndarray:
"Read from the current thread's dataset, opening a new copy of the dataset on first access from each thread."
with time(f"Read {self._url!r} in {_curthread()}: {{t}}"):
with self._env.read:
return self.dataset.read(1, window=window, **kwargs)
return self.dataset.read(bands, window=window, **kwargs)

def close(self) -> None:
"""
@@ -274,8 +283,29 @@ def __del__(self):
self.close()


class Nodataset:
"`ThreadsafeRioDataset` that returns a constant (nodata) value for all reads"
scale_offset = (1.0, 0.0)

def __init__(
self,
*,
dtype: np.dtype,
fill_value: Union[int, float],
) -> None:
self.dtype = dtype
self.fill_value = fill_value

def read(self, bands: Sequence[int], window: Window, **kwargs) -> np.ndarray:
return nodata_for_window(len(bands), window, self.fill_value, self.dtype)

def close(self) -> None:
pass


class PickleState(TypedDict):
url: str
bands: Optional[Sequence[int]]
spec: RasterSpec
resampling: Resampling
dtype: np.dtype
@@ -295,10 +325,22 @@ class AutoParallelRioReader:
for non-thread-safe drivers.
"""

url: str
bands: Sequence[int]
exactly_one_band: bool
spec: RasterSpec
resampling: Resampling
dtype: np.dtype
fill_value: Union[int, float]
rescale: bool
gdal_env: LayeredEnv
errors_as_nodata: Tuple[Exception, ...]

def __init__(
self,
*,
url: str,
bands: Optional[Sequence[int]],
spec: RasterSpec,
resampling: Resampling,
dtype: np.dtype,
@@ -308,6 +350,8 @@ def __init__(
errors_as_nodata: Tuple[Exception, ...] = (),
) -> None:
self.url = url
self.bands = bands if bands is not None else (1,)
self.exactly_one_band = bands is None
self.spec = spec
self.resampling = resampling
self.dtype = dtype
@@ -330,17 +374,34 @@ def _open(self) -> ThreadsafeRioDataset:
msg = f"Error opening {self.url!r}: {e!r}"
if exception_matches(e, self.errors_as_nodata):
warnings.warn(msg)
return NodataReader(
dtype=self.dtype, fill_value=self.fill_value
return Nodataset(
dtype=self.dtype,
fill_value=self.fill_value,
)

raise RuntimeError(msg) from e
if ds.count != 1:

if self.exactly_one_band:
# Unknown band count. If the asset actually has 3 bands, we don't want to
# silently read just the first one.
if ds.count != 1:
ds.close()
raise RuntimeError(
f"Assets must have exactly 1 band, but file {self.url!r} has {ds.count}. "
"We can't currently handle multi-band rasters (each band has to be "
"a separate STAC asset), so you'll need to exclude this asset from your analysis."
# TODO change this error message once we actually determine band counts from STAC metadata.
# Then, this should mention that the asset was missing `eo:bands` and `raster:bands` metadata,
# so the expected band count was unknown and defaults to 1.
# Alternatively, we could get rid of this bands==None codepath entirely, and always require
# STAC metadata to specify `eo:bands` or `raster:bands` (allowing you to explicitly provide
# values for them if they're missing?).
)
elif ds.count < len(self.bands):
ds.close()
raise RuntimeError(
f"Assets must have exactly 1 band, but file {self.url!r} has {ds.count}. "
"We can't currently handle multi-band rasters (each band has to be "
"a separate STAC asset), so you'll need to exclude this asset from your analysis."
f"Expected to read {len(self.bands)} {tuple(self.bands)}, but there are only "
f"{ds.count} band(s) in the asset at {self.url!r}."
)

# Only make a VRT if the dataset doesn't match the spatial spec we want
@@ -375,7 +436,7 @@ def _open(self) -> ThreadsafeRioDataset:
return SingleThreadedRioDataset(self.gdal_env, ds, vrt=vrt)

@property
def dataset(self):
def dataset(self) -> ThreadsafeRioDataset:
with self._dataset_lock:
if self._dataset is None:
self._dataset = self._open()
@@ -385,6 +446,7 @@ def read(self, window: Window, **kwargs) -> np.ndarray:
reader = self.dataset
try:
result = reader.read(
self.bands,
window=window,
masked=True,
# ^ NOTE: we always do a masked array, so we can safely apply scales and offsets
@@ -395,10 +457,14 @@ def read(self, window: Window, **kwargs) -> np.ndarray:
msg = f"Error reading {window} from {self.url!r}: {e!r}"
if exception_matches(e, self.errors_as_nodata):
warnings.warn(msg)
return nodata_for_window(window, self.fill_value, self.dtype)
return nodata_for_window(
len(self.bands), window, self.fill_value, self.dtype
)

raise RuntimeError(msg) from e

# TODO scale and offset might not apply to all bands.
# Should probably just remove this.
if self.rescale:
scale, offset = reader.scale_offset
if scale != 1 and offset != 0:
@@ -430,6 +496,7 @@ def __getstate__(
) -> PickleState:
return {
"url": self.url,
"bands": None if self.exactly_one_band else self.bands,
"spec": self.spec,
"resampling": self.resampling,
"dtype": self.dtype,
5 changes: 3 additions & 2 deletions stackstac/stack.py
Original file line number Diff line number Diff line change
@@ -276,7 +276,7 @@ def stack(
reverse=sortby_date == "desc",
)

asset_table, spec, asset_ids, plain_items = prepare_items(
asset_table, spec, asset_ids, plain_items, nbands_per_asset = prepare_items(
plain_items,
assets=assets,
epsg=epsg,
@@ -289,6 +289,7 @@ def stack(
asset_table,
spec,
chunksize=chunksize,
nbands_per_asset=nbands_per_asset,
dtype=dtype,
resampling=resampling,
fill_value=fill_value,
@@ -309,5 +310,5 @@ def stack(
band_coords=band_coords,
),
attrs=to_attrs(spec),
name="stackstac-" + dask.base.tokenize(arr)
name="stackstac-" + dask.base.tokenize(arr),
)
56 changes: 54 additions & 2 deletions stackstac/tests/test_to_dask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations
import itertools
from threading import Lock
from typing import ClassVar

from hypothesis import given, settings, strategies as st
from hypothesis import given, note, settings, strategies as st
import hypothesis.extra.numpy as st_np
import numpy as np
import pytest
from rasterio import windows
import dask.core
import dask.threaded
@@ -16,6 +18,7 @@
ChunksParam,
items_to_dask,
normalize_chunks,
process_multiband_chunks,
window_from_bounds,
)
from stackstac.testing import strategies as st_stc
@@ -194,9 +197,58 @@ def __setstate__(self, state):
def test_normalize_chunks(
chunksize: ChunksParam, shape: tuple[int, int, int, int], dtype: np.dtype
):
chunks = normalize_chunks(chunksize, shape, dtype)
nbands_per_asset = (1,) * shape[1] # not testing this here, keep it simple
chunks, asset_table_band_chunks = normalize_chunks(
chunksize, shape, nbands_per_asset, dtype
)
numblocks = tuple(map(len, chunks))
assert len(numblocks) == 4
assert all(x >= 1 for t in chunks for x in t)
if isinstance(chunksize, int) or isinstance(chunks, tuple) and len(chunks) == 2:
assert numblocks[:2] == shape[:2]


@given(st.data(), st.lists(st.integers(1, 5), max_size=5).map(tuple))
def test_process_multiband_chunks(
data: st.DataObject, nbands_per_asset: tuple[int, ...]
):
total_bands = sum(nbands_per_asset)
chunks: list[int] = []
remaining = total_bands
while remaining:
c = data.draw(st.integers(1, remaining))
remaining -= c
assert remaining >= 0
chunks.append(c)

note(f"{nbands_per_asset=}")
note(f" {chunks=}")

# Expand chunks form into 1-elem-per-band form. This is a simpler but less efficient way to validate.
# Ex: [2, 4, 1, 1] -> [0, 0, 1, 1, 1, 1, 2, 3]
physical_layout = [
x for i, n in enumerate(nbands_per_asset) for x in itertools.repeat(i, n)
]
requested_layout = [x for i, n in enumerate(chunks) for x in itertools.repeat(i, n)]
assert len(physical_layout) == len(requested_layout)

invalid = False
for i in range(1, len(requested_layout)):
if requested_layout[i - 1] != requested_layout[i]:
# Wherever the asset we're pulling from changes in the requested layout,
# it must also change in the physical layout.
if physical_layout[i - 1] == physical_layout[i]:
invalid = True
break

note(f" {physical_layout=}")
note(f"{requested_layout=}")

if invalid:
with pytest.raises(NotImplementedError):
process_multiband_chunks(tuple(chunks), nbands_per_asset)
else:
asset_table_band_chunks = process_multiband_chunks(
tuple(chunks), nbands_per_asset
)
assert len(asset_table_band_chunks) == len(chunks)
141 changes: 128 additions & 13 deletions stackstac/to_dask.py
Original file line number Diff line number Diff line change
@@ -18,13 +18,17 @@
from .reader_protocol import Reader

ChunkVal = Union[int, Literal["auto"], str, None]
ChunksParam = Union[ChunkVal, Tuple[ChunkVal, ...], Dict[int, ChunkVal]]
ChunksParam = Union[
ChunkVal, Tuple[Union[ChunkVal, Tuple[ChunkVal, ...]], ...], Dict[int, ChunkVal]
]
TBYXChunks = Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]


def items_to_dask(
asset_table: np.ndarray,
spec: RasterSpec,
chunksize: ChunksParam,
nbands_per_asset: tuple[int, ...],
resampling: Resampling = Resampling.nearest,
dtype: np.dtype = np.dtype("float64"),
fill_value: Union[int, float] = np.nan,
@@ -42,8 +46,11 @@ def items_to_dask(
f"Either use `dtype={np.array(fill_value).dtype.name!r}`, or pick a different `fill_value`."
)

chunks = normalize_chunks(chunksize, asset_table.shape + spec.shape, dtype)
chunks_tb, chunks_yx = chunks[:2], chunks[2:]
chunks, asset_table_band_chunks = normalize_chunks(
chunksize, asset_table.shape + spec.shape, nbands_per_asset, dtype
)
chunks_tb = chunks[:1] + asset_table_band_chunks
chunks_yx = chunks[2:]

# The overall strategy in this function is to materialize the outer two dimensions (items, assets)
# as one dask array (the "asset table"), then map a function over it which opens each URL as a `Reader`
@@ -56,7 +63,7 @@ def items_to_dask(
# make URLs into dask array, chunked as requested for the time,band dimensions
asset_table_dask = da.from_array(
asset_table,
chunks=chunks_tb,
chunks=chunks_tb, # type: ignore
inline_array=True,
name="asset-table-" + dask.base.tokenize(asset_table),
)
@@ -97,6 +104,8 @@ def items_to_dask(
None,
fill_value,
None,
nbands_per_asset,
None,
numblocks={reader_table.name: reader_table.numblocks}, # ugh
)
dsk = HighLevelGraph.from_collections(name, lyr, [reader_table])
@@ -136,6 +145,7 @@ def asset_table_to_reader_and_window(
entry: ReaderTableEntry = (
reader(
url=url,
bands=asset_entry["bands"],
spec=spec,
resampling=resampling,
dtype=dtype,
@@ -155,24 +165,34 @@ def fetch_raster_window(
slices: Tuple[slice, slice],
dtype: np.dtype,
fill_value: Union[int, float],
nbands_per_asset: tuple[int, ...],
) -> np.ndarray:
"Do a spatially-windowed read of raster data from all the Readers in the table."
assert len(slices) == 2, slices
current_window = windows.Window.from_slices(*slices)

assert reader_table.size, f"Empty reader_table: {reader_table.shape=}"
assert (
len(nbands_per_asset) == reader_table.shape[1]
), f"{nbands_per_asset=}, {reader_table.shape[1]=}"
# Start with an empty output array, using the broadcast trick for even fewer memz.
# If none of the assets end up actually existing, or overlapping the current window,
# or containing data, we'll just return this 1-element array that's been broadcast
# to look like a full-size array.
output = np.broadcast_to(
np.array(fill_value, dtype),
reader_table.shape + (current_window.height, current_window.width),
(
reader_table.shape[0],
sum(nbands_per_asset),
current_window.height,
current_window.width,
),
)

asset_i_to_band = np.cumsum(nbands_per_asset)
all_empty: bool = True
entry: ReaderTableEntry
for index, entry in np.ndenumerate(reader_table):
for (time_i, asset_i), entry in np.ndenumerate(reader_table):
if entry:
reader, asset_window = entry
# Only read if the window we're fetching actually overlaps with the asset
@@ -183,6 +203,9 @@ def fetch_raster_window(

# TODO when the Reader won't be rescaling, support passing `output` to avoid the copy?
data = reader.read(current_window)
assert (
data.shape[0] == nbands_per_asset[asset_i]
), f"Band count mismatch: {nbands_per_asset[asset_i]=}, {data.shape[0]=}"

if all_empty:
# Turn `output` from a broadcast-trick array to a real array, so it's writeable
@@ -196,36 +219,128 @@ def fetch_raster_window(
output = np.array(output)
all_empty = False

output[index] = data
band_i = asset_i_to_band[asset_i]
output[time_i, band_i : band_i + data.shape[0]] = data

return output


def normalize_chunks(
chunks: ChunksParam, shape: Tuple[int, int, int, int], dtype: np.dtype
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
chunks: ChunksParam,
shape: Tuple[int, int, int, int],
nbands_per_asset: tuple[int, ...],
dtype: np.dtype,
) -> tuple[TBYXChunks, tuple[int, ...]]:
"""
Normalize chunks to tuple of tuples, assuming 1D and 2D chunks only apply to spatial coordinates
If only 1 or 2 chunks are given, assume they're for the ``y, x`` coordinates,
and that the ``time, band`` coordinates should be chunksize 1.
If "auto" is given for bands, uses ``nbands_per_asset``.
Returns
-------
chunks:
Normalized chunks
asset_table_band_chunks:
Band chunks to apply to the asset table (see `process_multiband_chunks`)
"""
# TODO implement our own auto-chunking that makes the time,band coordinates
# >1 if the spatial chunking would create too many tasks?
if isinstance(chunks, int):
chunks = (1, 1, chunks, chunks)
chunks = (1, nbands_per_asset, chunks, chunks)
elif isinstance(chunks, tuple) and len(chunks) == 2:
chunks = (1, 1) + chunks
chunks = (1, nbands_per_asset) + chunks
elif isinstance(chunks, tuple) and len(chunks) == 4 and chunks[1] == "auto":
chunks = (chunks[0], nbands_per_asset, chunks[2], chunks[3])

return da.core.normalize_chunks(
norm: TBYXChunks = da.core.normalize_chunks(
chunks,
shape,
dtype=dtype,
previous_chunks=((1,) * shape[0], (1,) * shape[1], (shape[2],), (shape[3],)),
previous_chunks=((1,) * shape[0], nbands_per_asset, (shape[2],), (shape[3],)),
# ^ Give dask some hint of the physical layout of the data, so it prefers widening
# the spatial chunks over bundling together items/assets. This isn't totally accurate.
)

# Ensure we aren't trying to split apart multi-band assets. This would require rewriting
# the asset table (adding duplicate columns) and is generally not what you want, assuming
# that in multi-band assets, the bands are stored interleaved, so reading one requires reading
# them all anyway.
asset_table_band_chunks = process_multiband_chunks(norm[1], nbands_per_asset)
return norm, asset_table_band_chunks


def process_multiband_chunks(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a probably an easier/more legible way to implement this function, but this turned out to be the crux of implementing multi-band. We really don't want two different chunks to source data from the same multi-band asset (different bands to different chunks), because this would no longer be a blockwise operation (from asset table -> full array), therefore would require rewriting the asset table, opening the same dataset twice, and probably be bad performance anyway when bands are interleaved. It seemed easier to just validate that this situation doesn't occur than to support it.

If you have an asset with lots of bands which aren't stored interleaved, I could definitely see wanting different chunks per band. I'm just not sure how common that is. It would be helpful to know if this seems like a common use-case @TomAugspurger.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docs describe the three methods for organizing multiple bands, two of which use interleaving: https://desktop.arcgis.com/en/arcmap/10.3/manage-data/raster-and-images/bil-bip-and-bsq-raster-files.htm

And a shorter synopsis of each: https://www.l3harrisgeospatial.com/docs/enviimagefiles.html
BIP is common for hyperspectral datasets like ASTER
BIL I think is the most common format. Landsat comes in BIL
BSQ seems to be less common, legacy sats like SPOT were distributed in BSQ: https://www.loc.gov/preservation/digital/formats/fdd/fdd000306.shtml but I'm not certain on how common BSQ is.

chunks: tuple[int, ...], nbands_per_asset: tuple[int, ...]
) -> tuple[int, ...]:
"""
Validate that the bands chunks don't try to split apart any multi-band assets.
Returns
-------
asset_table_band_chunks:
Band chunks to apply to the asset table (so that assets are combined into single chunks as necessary).
``len(asset_table_band_chunks) == len(chunks)``. In other words, along the bands, we'll have the same
``numblocks`` in the asset table as ``numblocks`` in the final array. But each block in the final array
may be longer (have more bands) than the number of assets (when they're multi-band assets).
"""
n_chunks = len(chunks)
n_assets = len(nbands_per_asset)

final_msg = (
f"Requested bands chunks: {chunks}\n"
f"Physical bands chunks: {nbands_per_asset}\n"
"This would entail splitting apart multi-band assets. This typically (but not always) has "
"much worse performance, since GeoTIFF bands are generally interleaved (so reading one "
"band from a file requires reading them all).\n"
"If you have a use-case for this, please discuss on https://github.com/gjoseph92/stackstac/issues."
)

if n_chunks > n_assets:
raise NotImplementedError(
f"Refusing to make {n_chunks} chunk(s) for the bands when there are only {n_assets} bands asset(s).\n"
+ final_msg
)
elif n_chunks == n_assets:
if chunks != nbands_per_asset:
raise NotImplementedError(final_msg)
return chunks
else:
# Trying to combine multiple assets into one chunk; must be whole multiples.
# n_chunks < n_assets
asset_table_band_chunks: list[int] = []
i = nbands_so_far = nbands_requested = n_assets_so_far = 0
for nb in nbands_per_asset:
if nbands_requested == 0:
if i == n_chunks:
raise ValueError(
f"Invalid chunks for {sum(nbands_per_asset)} band(s): only {sum(chunks)} band(s) used.\n"
f"Requested bands chunks: {chunks}\n"
f"Physical bands chunks: {nbands_per_asset}\n"
)
nbands_requested = chunks[i]

nbands_so_far += nb
n_assets_so_far += 1
if nbands_so_far < nbands_requested:
continue
elif nbands_so_far == nbands_requested:
# nailed it
i += 1
nbands_so_far = 0
nbands_requested = 0
asset_table_band_chunks.append(n_assets_so_far)
n_assets_so_far = 0
else:
# `nbands_so_far > nbands_requested`
raise NotImplementedError(
f"Specified chunks do not evenly combine multi-band assets: chunk {i} would split one apart.\n"
+ final_msg
)
return tuple(asset_table_band_chunks)


# FIXME remove this once rasterio bugs are fixed
def window_from_bounds(bounds: Bbox, transform: Affine) -> windows.Window: