Skip to content

Switch to fast-array-utils #3598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ classifiers = [
dependencies = [
"anndata>=0.9",
"numpy>=1.25",
"fast-array-utils[accel,sparse]>=1.1",
"matplotlib>=3.7",
"pandas >=2.0",
"scipy>=1.11",
Expand Down
78 changes: 3 additions & 75 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from packaging.version import Version

from .. import logging as logg
from .._compat import CSBase, DaskArray, _CSArray, _CSMatrix, pkg_version
from .._compat import CSBase, DaskArray, _CSArray, pkg_version
from .._settings import settings
from .compute.is_constant import is_constant # noqa: F401

Expand All @@ -53,7 +53,7 @@

from anndata import AnnData
from igraph import Graph
from numpy.typing import ArrayLike, DTypeLike, NDArray
from numpy.typing import ArrayLike, NDArray

from .._compat import CSRBase
from ..neighbors import NeighborsParams, RPForestDict
Expand Down Expand Up @@ -606,7 +606,7 @@ def axis_mul_or_truediv(
@axis_mul_or_truediv.register(CSBase)
def _(
X: CSBase,
scaling_array,
scaling_array: np.ndarray,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
Expand Down Expand Up @@ -746,78 +746,6 @@ def _(X: DaskArray, axis: Literal[0, 1]) -> DaskArray:
)


@overload
def axis_sum(
X: _CSMatrix,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> np.matrix: ...


@overload
def axis_sum(
X: np.ndarray, # TODO: or sparray
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> np.ndarray: ...


@singledispatch
def axis_sum(
X: np.ndarray | CSBase,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> np.ndarray | np.matrix:
return np.sum(X, axis=axis, dtype=dtype)


@axis_sum.register(DaskArray)
def _(
X: DaskArray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> DaskArray:
import dask.array as da

if dtype is None:
dtype = getattr(np.zeros(1, dtype=X.dtype).sum(), "dtype", object)

if isinstance(X._meta, np.ndarray) and not isinstance(X._meta, np.matrix):
return X.sum(axis=axis, dtype=dtype)

def sum_drop_keepdims(*args, **kwargs):
kwargs.pop("computing_meta", None)
# masked operations on sparse produce which numpy matrices gives the same API issues handled here
if isinstance(X._meta, _CSMatrix | np.matrix) or isinstance(
args[0], _CSMatrix | np.matrix
):
kwargs.pop("keepdims", None)
axis = kwargs["axis"]
if isinstance(axis, tuple):
if len(axis) != 1:
msg = f"`axis_sum` can only sum over one axis when `axis` arg is provided but got {axis} instead"
raise ValueError(msg)
kwargs["axis"] = axis[0]
# returns a np.matrix normally, which is undesireable
return np.array(np.sum(*args, dtype=dtype, **kwargs))

def aggregate_sum(*args, **kwargs):
return np.sum(args[0], dtype=dtype, **kwargs)

return da.reduction(
X,
sum_drop_keepdims,
aggregate_sum,
axis=axis,
dtype=dtype,
meta=np.array([], dtype=dtype),
)


@singledispatch
def check_nonnegative_integers(X: _SupportedArray) -> bool | DaskArray:
"""Check values of X to ensure it is count data."""
Expand Down
36 changes: 25 additions & 11 deletions src/scanpy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from warnings import warn

import numpy as np
from fast_array_utils import stats

from .. import logging as logg
from .._compat import CSBase, DaskArray, old_positionals
from .._utils import axis_mul_or_truediv, axis_sum, view_to_actual
from .._utils import axis_mul_or_truediv, view_to_actual
from ..get import _get_obs_rep, _set_obs_rep

try:
Expand All @@ -34,7 +35,16 @@ def _compute_nnz_median(counts: np.ndarray | DaskArray) -> np.floating:
return median


def _normalize_data(X, counts, after=None, *, copy: bool = False):
def _normalize_data(
X: np.ndarray | CSBase,
counts: np.ndarray,
after: float | None = None,
*,
copy: bool = False,
) -> np.ndarray:
if counts.ndim != 1:
msg = "counts must be a 1D array"
raise ValueError(msg)
X = X.copy() if copy else X
if issubclass(X.dtype.type, int | np.integer):
X = X.astype(np.float32) # TODO: Check if float64 should be used
Expand Down Expand Up @@ -90,8 +100,8 @@ def normalize_total( # noqa: PLR0912, PLR0915
call functions that trigger `.compute()` on the :class:`~dask.array.Array` if `exclude_highly_expressed`
is `True`, `layer_norm` is not `None`, or if `key_added` is not `None`.

Params
------
Parameters
----------
adata
The annotated data matrix of shape `n_obs` × `n_vars`.
Rows correspond to cells and columns to genes.
Expand Down Expand Up @@ -211,23 +221,27 @@ def normalize_total( # noqa: PLR0912, PLR0915
gene_subset = None
msg = "normalizing counts per cell"

counts_per_cell = axis_sum(x, axis=1)
counts_per_cell = stats.sum(x, axis=1)
assert counts_per_cell.ndim == 1
if exclude_highly_expressed:
counts_per_cell = np.ravel(counts_per_cell)

# at least one cell as more than max_fraction of counts per cell

gene_subset = axis_sum((x > counts_per_cell[:, None] * max_fraction), axis=0)
gene_subset = np.asarray(np.ravel(gene_subset) == 0)
x = x > counts_per_cell[:, None] * max_fraction
if isinstance(x, np.matrix):
x = x.A
elif isinstance(x, DaskArray) and isinstance(x._meta, np.matrix):
x = x.map_blocks(np.asarray, meta=np.array([], dtype=x.dtype))
gene_subset = stats.sum(x, axis=0) == 0
assert gene_subset.ndim == 1

msg += (
". The following highly-expressed genes are not considered during "
f"normalization factor computation:\n{adata.var_names[~gene_subset].tolist()}"
)
counts_per_cell = axis_sum(x[:, gene_subset], axis=1)
counts_per_cell = stats.sum(x[:, gene_subset], axis=1)
assert counts_per_cell.ndim == 1

start = logg.info(msg)
counts_per_cell = np.ravel(counts_per_cell)

cell_subset = counts_per_cell > 0
if not isinstance(cell_subset, DaskArray) and not np.all(cell_subset):
Expand Down
11 changes: 6 additions & 5 deletions src/scanpy/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import numba
import numpy as np
import pandas as pd
from fast_array_utils import stats
from scipy import sparse

from scanpy.get import _get_obs_rep
from scanpy.preprocessing._distributed import materialize_as_ndarray
from scanpy.preprocessing._utils import _get_mean_var

from .._compat import CSBase, CSRBase, DaskArray, njit
from .._utils import _doc_params, axis_nnz, axis_sum
from .._utils import _doc_params, axis_nnz
from ._docs import (
doc_adata_basic,
doc_expr_reps,
Expand Down Expand Up @@ -101,7 +102,7 @@ def describe_obs( # noqa: PLR0913
obs_metrics[f"log1p_n_{var_type}_by_{expr_type}"] = np.log1p(
obs_metrics[f"n_{var_type}_by_{expr_type}"]
)
obs_metrics[f"total_{expr_type}"] = np.ravel(axis_sum(X, axis=1))
obs_metrics[f"total_{expr_type}"] = stats.sum(X, axis=1)
if log1p:
obs_metrics[f"log1p_total_{expr_type}"] = np.log1p(
obs_metrics[f"total_{expr_type}"]
Expand All @@ -114,8 +115,8 @@ def describe_obs( # noqa: PLR0913
proportions[:, i] * 100
)
for qc_var in qc_vars:
obs_metrics[f"total_{expr_type}_{qc_var}"] = np.ravel(
axis_sum(X[:, adata.var[qc_var].values], axis=1)
obs_metrics[f"total_{expr_type}_{qc_var}"] = stats.sum(
X[:, adata.var[qc_var].values], axis=1
)
if log1p:
obs_metrics[f"log1p_total_{expr_type}_{qc_var}"] = np.log1p(
Expand Down Expand Up @@ -189,7 +190,7 @@ def describe_var(
var_metrics[f"pct_dropout_by_{expr_type}"] = (
1 - var_metrics[f"n_cells_by_{expr_type}"] / X.shape[0]
) * 100
var_metrics[f"total_{expr_type}"] = np.ravel(axis_sum(X, axis=0))
var_metrics[f"total_{expr_type}"] = stats.sum(X, axis=0)
if log1p:
var_metrics[f"log1p_total_{expr_type}"] = np.log1p(
var_metrics[f"total_{expr_type}"]
Expand Down
18 changes: 7 additions & 11 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numba
import numpy as np
from anndata import AnnData
from fast_array_utils import stats
from pandas.api.types import CategoricalDtype
from sklearn.utils import check_array, sparsefuncs

Expand All @@ -22,7 +23,6 @@
from .._utils import (
_check_array_function_arguments,
_resolve_axis,
axis_sum,
is_backed_type,
raise_not_implemented_error_if_backed_type,
renamed_arg,
Expand Down Expand Up @@ -170,17 +170,15 @@ def filter_cells(
X = data # proceed with processing the data matrix
min_number = min_counts if min_genes is None else min_genes
max_number = max_counts if max_genes is None else max_genes
number_per_cell = axis_sum(
number_per_cell = stats.sum(
X if min_genes is None and max_genes is None else X > 0, axis=1
)
if isinstance(number_per_cell, np.matrix):
number_per_cell = number_per_cell.A1
if min_number is not None:
cell_subset = number_per_cell >= min_number
if max_number is not None:
cell_subset = number_per_cell <= max_number

s = axis_sum(~cell_subset)
s = stats.sum(~cell_subset)
if s > 0:
msg = f"filtered out {s} cells that have "
if min_genes is not None or min_counts is not None:
Expand Down Expand Up @@ -288,17 +286,15 @@ def filter_genes(
X = data # proceed with processing the data matrix
min_number = min_counts if min_cells is None else min_cells
max_number = max_counts if max_cells is None else max_cells
number_per_gene = axis_sum(
number_per_gene = stats.sum(
X if min_cells is None and max_cells is None else X > 0, axis=0
)
if isinstance(number_per_gene, np.matrix):
number_per_gene = number_per_gene.A1
if min_number is not None:
gene_subset = number_per_gene >= min_number
if max_number is not None:
gene_subset = number_per_gene <= max_number

s = axis_sum(~gene_subset)
s = stats.sum(~gene_subset)
if s > 0:
msg = f"filtered out {s} genes that are detected "
if min_cells is not None or min_counts is not None:
Expand Down Expand Up @@ -1051,7 +1047,7 @@ def _downsample_per_cell(
original_type = type(X)
if not isinstance(X, CSRBase):
X = X.tocsr()
totals = np.ravel(axis_sum(X, axis=1)) # Faster for csr matrix
totals = stats.sum(X, axis=1) # Faster for csr matrix
under_target = np.nonzero(totals > counts_per_cell)[0]
rows = np.split(X.data, X.indptr[1:-1])
for rowidx in under_target:
Expand All @@ -1067,7 +1063,7 @@ def _downsample_per_cell(
if not issubclass(original_type, CSRBase): # Put it back
X = original_type(X)
else:
totals = np.ravel(axis_sum(X, axis=1))
totals = stats.sum(X, axis=1)
under_target = np.nonzero(totals > counts_per_cell)[0]
for rowidx in under_target:
row = X[rowidx, :]
Expand Down
5 changes: 3 additions & 2 deletions src/scanpy/preprocessing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import numba
import numpy as np
from fast_array_utils import stats
from sklearn.random_projection import sample_without_replacement

from .._compat import CSBase, CSCBase, CSRBase, SpBase, njit
from .._utils import axis_sum, elem_mul
from .._utils import elem_mul

if TYPE_CHECKING:
from typing import Literal
Expand All @@ -22,7 +23,7 @@

@singledispatch
def axis_mean(X: DaskArray, *, axis: Literal[0, 1], dtype: DTypeLike) -> DaskArray:
total = axis_sum(X, axis=axis, dtype=dtype)
total = stats.sum(X, axis=axis, dtype=dtype)
return total / X.shape[axis]


Expand Down
19 changes: 12 additions & 7 deletions tests/test_normalization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import numpy as np
import pytest
from anndata import AnnData
from anndata.tests.helpers import assert_equal
from fast_array_utils import conv, stats
from scipy import sparse

import scanpy as sc
from scanpy._compat import CSBase
from scanpy._utils import axis_sum
from scanpy.preprocessing._normalization import _compute_nnz_median
from testing.scanpy._helpers import (
_check_check_values_warnings,
Expand All @@ -25,14 +26,18 @@
from collections.abc import Callable
from typing import Any

to_ndarray = partial(conv.to_dense, to_cpu_memory=True)

X_total = np.array([[1, 0], [3, 0], [5, 6]])
X_frac = np.array([[1, 0, 1], [3, 0, 1], [5, 6, 1]])


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("dtype", ["float32", "int64"])
@pytest.mark.parametrize("target_sum", [None, 1.0])
@pytest.mark.parametrize("exclude_highly_expressed", [True, False])
@pytest.mark.parametrize("target_sum", [None, 1.0], ids=["no_target_sum", "target_sum"])
@pytest.mark.parametrize(
"exclude_highly_expressed", [True, False], ids=["excl_hi", "no_excl_hi"]
)
def test_normalize_matrix_types(
array_type, dtype, target_sum, exclude_highly_expressed
):
Expand Down Expand Up @@ -63,13 +68,13 @@ def test_normalize_matrix_types(
def test_normalize_total(array_type, dtype):
adata = AnnData(array_type(X_total).astype(dtype))
sc.pp.normalize_total(adata, key_added="n_counts")
assert np.allclose(np.ravel(axis_sum(adata.X, axis=1)), [3.0, 3.0, 3.0])
assert np.allclose(to_ndarray(stats.sum(adata.X, axis=1)), [3.0, 3.0, 3.0])
sc.pp.normalize_total(adata, target_sum=1, key_added="n_counts2")
assert np.allclose(np.ravel(axis_sum(adata.X, axis=1)), [1.0, 1.0, 1.0])
assert np.allclose(to_ndarray(stats.sum(adata.X, axis=1)), [1.0, 1.0, 1.0])

adata = AnnData(array_type(X_frac).astype(dtype))
sc.pp.normalize_total(adata, exclude_highly_expressed=True, max_fraction=0.7)
assert np.allclose(np.ravel(axis_sum(adata.X[:, 1:3], axis=1)), [1.0, 1.0, 1.0])
assert np.allclose(to_ndarray(stats.sum(adata.X[:, 1:3], axis=1)), [1.0, 1.0, 1.0])


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
Expand All @@ -88,7 +93,7 @@ def test_normalize_total_layers(array_type, dtype):
adata.layers["layer"] = adata.X.copy()
with pytest.warns(FutureWarning, match=r".*layers.*deprecated"):
sc.pp.normalize_total(adata, layers=["layer"])
assert np.allclose(axis_sum(adata.layers["layer"], axis=1), [3.0, 3.0, 3.0])
assert np.allclose(stats.sum(adata.layers["layer"], axis=1), [3.0, 3.0, 3.0])


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
Expand Down
Loading
Loading