-
Notifications
You must be signed in to change notification settings - Fork 52
Handle reading multi-band datasetes #146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Tuple, Union | ||
import re | ||
|
||
import numpy as np | ||
from rasterio.windows import Window | ||
|
||
State = Tuple[np.dtype, Union[int, float]] | ||
|
||
|
||
def nodata_for_window( | ||
ndim: int, window: Window, fill_value: Union[int, float], dtype: np.dtype | ||
): | ||
return np.full((ndim, window.height, window.width), fill_value, dtype) | ||
|
||
|
||
def exception_matches(e: Exception, patterns: Tuple[Exception, ...]) -> bool: | ||
""" | ||
Whether an exception matches one of the pattern exceptions | ||
|
||
Parameters | ||
---------- | ||
e: | ||
The exception to check | ||
patterns: | ||
Instances of an Exception type to catch, where ``str(exception_pattern)`` | ||
is a regex pattern to match against ``str(e)``. | ||
""" | ||
e_type = type(e) | ||
e_msg = str(e) | ||
for pattern in patterns: | ||
if issubclass(e_type, type(pattern)): | ||
if re.match(str(pattern), e_msg): | ||
return True | ||
return False |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,14 @@ | ||
from __future__ import annotations | ||
from typing import Optional, Protocol, Tuple, Type, TYPE_CHECKING, TypeVar, Union | ||
from typing import ( | ||
Optional, | ||
Protocol, | ||
Sequence, | ||
Tuple, | ||
Type, | ||
TYPE_CHECKING, | ||
TypeVar, | ||
Union, | ||
) | ||
|
||
import numpy as np | ||
|
||
|
@@ -30,6 +39,7 @@ def __init__( | |
self, | ||
*, | ||
url: str, | ||
bands: Optional[Sequence[int]], | ||
spec: RasterSpec, | ||
resampling: Resampling, | ||
dtype: np.dtype, | ||
|
@@ -45,6 +55,9 @@ def __init__( | |
---------- | ||
url: | ||
Fetch data from the asset at this URL. | ||
bands: | ||
List of (one-indexed!) band indices to read, or None for all bands. | ||
If None, the asset must have exactly one band. | ||
spec: | ||
Reproject data to match this georeferencing information. | ||
resampling: | ||
|
@@ -69,7 +82,6 @@ def __init__( | |
where ``str(exception_pattern)`` is a regex pattern to match against | ||
``str(raised_exception)``. | ||
""" | ||
# TODO colormaps? | ||
|
||
def read(self, window: Window) -> np.ndarray: | ||
""" | ||
|
@@ -87,7 +99,7 @@ def read(self, window: Window) -> np.ndarray: | |
|
||
Returns | ||
------- | ||
array: The window of data read | ||
array: The window of data read from all bands, as a 3D array | ||
""" | ||
... | ||
|
||
|
@@ -113,11 +125,16 @@ class FakeReader: | |
or inherent to the dask graph. | ||
""" | ||
|
||
def __init__(self, *, dtype: np.dtype, **kwargs) -> None: | ||
def __init__( | ||
self, *, bands: Optional[Sequence[int]], dtype: np.dtype, **kwargs | ||
) -> None: | ||
self.dtype = dtype | ||
self.ndim = len(bands) if bands is not None else 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oof, |
||
|
||
def read(self, window: Window, **kwargs) -> np.ndarray: | ||
return np.random.random((window.height, window.width)).astype(self.dtype) | ||
return np.random.random((self.ndim, window.height, window.width)).astype( | ||
self.dtype | ||
) | ||
|
||
def close(self) -> None: | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,13 +18,17 @@ | |
from .reader_protocol import Reader | ||
|
||
ChunkVal = Union[int, Literal["auto"], str, None] | ||
ChunksParam = Union[ChunkVal, Tuple[ChunkVal, ...], Dict[int, ChunkVal]] | ||
ChunksParam = Union[ | ||
ChunkVal, Tuple[Union[ChunkVal, Tuple[ChunkVal, ...]], ...], Dict[int, ChunkVal] | ||
] | ||
TBYXChunks = Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]] | ||
|
||
|
||
def items_to_dask( | ||
asset_table: np.ndarray, | ||
spec: RasterSpec, | ||
chunksize: ChunksParam, | ||
nbands_per_asset: tuple[int, ...], | ||
resampling: Resampling = Resampling.nearest, | ||
dtype: np.dtype = np.dtype("float64"), | ||
fill_value: Union[int, float] = np.nan, | ||
|
@@ -42,8 +46,11 @@ def items_to_dask( | |
f"Either use `dtype={np.array(fill_value).dtype.name!r}`, or pick a different `fill_value`." | ||
) | ||
|
||
chunks = normalize_chunks(chunksize, asset_table.shape + spec.shape, dtype) | ||
chunks_tb, chunks_yx = chunks[:2], chunks[2:] | ||
chunks, asset_table_band_chunks = normalize_chunks( | ||
chunksize, asset_table.shape + spec.shape, nbands_per_asset, dtype | ||
) | ||
chunks_tb = chunks[:1] + asset_table_band_chunks | ||
chunks_yx = chunks[2:] | ||
|
||
# The overall strategy in this function is to materialize the outer two dimensions (items, assets) | ||
# as one dask array (the "asset table"), then map a function over it which opens each URL as a `Reader` | ||
|
@@ -56,7 +63,7 @@ def items_to_dask( | |
# make URLs into dask array, chunked as requested for the time,band dimensions | ||
asset_table_dask = da.from_array( | ||
asset_table, | ||
chunks=chunks_tb, | ||
chunks=chunks_tb, # type: ignore | ||
inline_array=True, | ||
name="asset-table-" + dask.base.tokenize(asset_table), | ||
) | ||
|
@@ -97,6 +104,8 @@ def items_to_dask( | |
None, | ||
fill_value, | ||
None, | ||
nbands_per_asset, | ||
None, | ||
numblocks={reader_table.name: reader_table.numblocks}, # ugh | ||
) | ||
dsk = HighLevelGraph.from_collections(name, lyr, [reader_table]) | ||
|
@@ -136,6 +145,7 @@ def asset_table_to_reader_and_window( | |
entry: ReaderTableEntry = ( | ||
reader( | ||
url=url, | ||
bands=asset_entry["bands"], | ||
spec=spec, | ||
resampling=resampling, | ||
dtype=dtype, | ||
|
@@ -155,24 +165,34 @@ def fetch_raster_window( | |
slices: Tuple[slice, slice], | ||
dtype: np.dtype, | ||
fill_value: Union[int, float], | ||
nbands_per_asset: tuple[int, ...], | ||
) -> np.ndarray: | ||
"Do a spatially-windowed read of raster data from all the Readers in the table." | ||
assert len(slices) == 2, slices | ||
current_window = windows.Window.from_slices(*slices) | ||
|
||
assert reader_table.size, f"Empty reader_table: {reader_table.shape=}" | ||
assert ( | ||
len(nbands_per_asset) == reader_table.shape[1] | ||
), f"{nbands_per_asset=}, {reader_table.shape[1]=}" | ||
# Start with an empty output array, using the broadcast trick for even fewer memz. | ||
# If none of the assets end up actually existing, or overlapping the current window, | ||
# or containing data, we'll just return this 1-element array that's been broadcast | ||
# to look like a full-size array. | ||
output = np.broadcast_to( | ||
np.array(fill_value, dtype), | ||
reader_table.shape + (current_window.height, current_window.width), | ||
( | ||
reader_table.shape[0], | ||
sum(nbands_per_asset), | ||
current_window.height, | ||
current_window.width, | ||
), | ||
) | ||
|
||
asset_i_to_band = np.cumsum(nbands_per_asset) | ||
all_empty: bool = True | ||
entry: ReaderTableEntry | ||
for index, entry in np.ndenumerate(reader_table): | ||
for (time_i, asset_i), entry in np.ndenumerate(reader_table): | ||
if entry: | ||
reader, asset_window = entry | ||
# Only read if the window we're fetching actually overlaps with the asset | ||
|
@@ -183,6 +203,9 @@ def fetch_raster_window( | |
|
||
# TODO when the Reader won't be rescaling, support passing `output` to avoid the copy? | ||
data = reader.read(current_window) | ||
assert ( | ||
data.shape[0] == nbands_per_asset[asset_i] | ||
), f"Band count mismatch: {nbands_per_asset[asset_i]=}, {data.shape[0]=}" | ||
|
||
if all_empty: | ||
# Turn `output` from a broadcast-trick array to a real array, so it's writeable | ||
|
@@ -196,36 +219,128 @@ def fetch_raster_window( | |
output = np.array(output) | ||
all_empty = False | ||
|
||
output[index] = data | ||
band_i = asset_i_to_band[asset_i] | ||
output[time_i, band_i : band_i + data.shape[0]] = data | ||
|
||
return output | ||
|
||
|
||
def normalize_chunks( | ||
chunks: ChunksParam, shape: Tuple[int, int, int, int], dtype: np.dtype | ||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]: | ||
chunks: ChunksParam, | ||
shape: Tuple[int, int, int, int], | ||
nbands_per_asset: tuple[int, ...], | ||
dtype: np.dtype, | ||
) -> tuple[TBYXChunks, tuple[int, ...]]: | ||
""" | ||
Normalize chunks to tuple of tuples, assuming 1D and 2D chunks only apply to spatial coordinates | ||
If only 1 or 2 chunks are given, assume they're for the ``y, x`` coordinates, | ||
and that the ``time, band`` coordinates should be chunksize 1. | ||
If "auto" is given for bands, uses ``nbands_per_asset``. | ||
Returns | ||
------- | ||
chunks: | ||
Normalized chunks | ||
asset_table_band_chunks: | ||
Band chunks to apply to the asset table (see `process_multiband_chunks`) | ||
""" | ||
# TODO implement our own auto-chunking that makes the time,band coordinates | ||
# >1 if the spatial chunking would create too many tasks? | ||
if isinstance(chunks, int): | ||
chunks = (1, 1, chunks, chunks) | ||
chunks = (1, nbands_per_asset, chunks, chunks) | ||
elif isinstance(chunks, tuple) and len(chunks) == 2: | ||
chunks = (1, 1) + chunks | ||
chunks = (1, nbands_per_asset) + chunks | ||
elif isinstance(chunks, tuple) and len(chunks) == 4 and chunks[1] == "auto": | ||
chunks = (chunks[0], nbands_per_asset, chunks[2], chunks[3]) | ||
|
||
return da.core.normalize_chunks( | ||
norm: TBYXChunks = da.core.normalize_chunks( | ||
chunks, | ||
shape, | ||
dtype=dtype, | ||
previous_chunks=((1,) * shape[0], (1,) * shape[1], (shape[2],), (shape[3],)), | ||
previous_chunks=((1,) * shape[0], nbands_per_asset, (shape[2],), (shape[3],)), | ||
# ^ Give dask some hint of the physical layout of the data, so it prefers widening | ||
# the spatial chunks over bundling together items/assets. This isn't totally accurate. | ||
) | ||
|
||
# Ensure we aren't trying to split apart multi-band assets. This would require rewriting | ||
# the asset table (adding duplicate columns) and is generally not what you want, assuming | ||
# that in multi-band assets, the bands are stored interleaved, so reading one requires reading | ||
# them all anyway. | ||
asset_table_band_chunks = process_multiband_chunks(norm[1], nbands_per_asset) | ||
return norm, asset_table_band_chunks | ||
|
||
|
||
def process_multiband_chunks( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a probably an easier/more legible way to implement this function, but this turned out to be the crux of implementing multi-band. We really don't want two different chunks to source data from the same multi-band asset (different bands to different chunks), because this would no longer be a blockwise operation (from asset table -> full array), therefore would require rewriting the asset table, opening the same dataset twice, and probably be bad performance anyway when bands are interleaved. It seemed easier to just validate that this situation doesn't occur than to support it. If you have an asset with lots of bands which aren't stored interleaved, I could definitely see wanting different chunks per band. I'm just not sure how common that is. It would be helpful to know if this seems like a common use-case @TomAugspurger. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These docs describe the three methods for organizing multiple bands, two of which use interleaving: https://desktop.arcgis.com/en/arcmap/10.3/manage-data/raster-and-images/bil-bip-and-bsq-raster-files.htm And a shorter synopsis of each: https://www.l3harrisgeospatial.com/docs/enviimagefiles.html |
||
chunks: tuple[int, ...], nbands_per_asset: tuple[int, ...] | ||
) -> tuple[int, ...]: | ||
""" | ||
Validate that the bands chunks don't try to split apart any multi-band assets. | ||
Returns | ||
------- | ||
asset_table_band_chunks: | ||
Band chunks to apply to the asset table (so that assets are combined into single chunks as necessary). | ||
``len(asset_table_band_chunks) == len(chunks)``. In other words, along the bands, we'll have the same | ||
``numblocks`` in the asset table as ``numblocks`` in the final array. But each block in the final array | ||
may be longer (have more bands) than the number of assets (when they're multi-band assets). | ||
""" | ||
n_chunks = len(chunks) | ||
n_assets = len(nbands_per_asset) | ||
|
||
final_msg = ( | ||
f"Requested bands chunks: {chunks}\n" | ||
f"Physical bands chunks: {nbands_per_asset}\n" | ||
"This would entail splitting apart multi-band assets. This typically (but not always) has " | ||
"much worse performance, since GeoTIFF bands are generally interleaved (so reading one " | ||
"band from a file requires reading them all).\n" | ||
"If you have a use-case for this, please discuss on https://github.com/gjoseph92/stackstac/issues." | ||
) | ||
|
||
if n_chunks > n_assets: | ||
raise NotImplementedError( | ||
f"Refusing to make {n_chunks} chunk(s) for the bands when there are only {n_assets} bands asset(s).\n" | ||
+ final_msg | ||
) | ||
elif n_chunks == n_assets: | ||
if chunks != nbands_per_asset: | ||
raise NotImplementedError(final_msg) | ||
return chunks | ||
else: | ||
# Trying to combine multiple assets into one chunk; must be whole multiples. | ||
# n_chunks < n_assets | ||
asset_table_band_chunks: list[int] = [] | ||
i = nbands_so_far = nbands_requested = n_assets_so_far = 0 | ||
for nb in nbands_per_asset: | ||
if nbands_requested == 0: | ||
if i == n_chunks: | ||
raise ValueError( | ||
f"Invalid chunks for {sum(nbands_per_asset)} band(s): only {sum(chunks)} band(s) used.\n" | ||
f"Requested bands chunks: {chunks}\n" | ||
f"Physical bands chunks: {nbands_per_asset}\n" | ||
) | ||
nbands_requested = chunks[i] | ||
|
||
nbands_so_far += nb | ||
n_assets_so_far += 1 | ||
if nbands_so_far < nbands_requested: | ||
continue | ||
elif nbands_so_far == nbands_requested: | ||
# nailed it | ||
i += 1 | ||
nbands_so_far = 0 | ||
nbands_requested = 0 | ||
asset_table_band_chunks.append(n_assets_so_far) | ||
n_assets_so_far = 0 | ||
else: | ||
# `nbands_so_far > nbands_requested` | ||
raise NotImplementedError( | ||
f"Specified chunks do not evenly combine multi-band assets: chunk {i} would split one apart.\n" | ||
+ final_msg | ||
) | ||
return tuple(asset_table_band_chunks) | ||
|
||
|
||
# FIXME remove this once rasterio bugs are fixed | ||
def window_from_bounds(bounds: Bbox, transform: Affine) -> windows.Window: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, this is where we'd actually figure out band counts from STAC metadata