Skip to content

Commit 21ef762

Browse files
Backport PR #2034: fix: no cache indptr zarr dask (#2045)
Co-authored-by: Ilan Gold <[email protected]>
1 parent 642357b commit 21ef762

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

benchmarks/benchmarks/sparse_dataset.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
import numpy as np
66
import zarr
7+
from dask.array.core import Array as DaskArray
78
from scipy import sparse
89

910
from anndata import AnnData
1011
from anndata._core.sparse_dataset import sparse_dataset
1112
from anndata._io.specs import write_elem
13+
from anndata.experimental import read_elem_lazy
1214

1315

1416
def make_alternating_mask(n):
@@ -37,27 +39,36 @@ class SparseCSRContiguousSlice:
3739
# (10_000, 500)
3840
],
3941
_slices.keys(),
42+
[True, False],
4043
)
41-
param_names = ("shape", "slice")
44+
param_names = ("shape", "slice", "use_dask")
4245

43-
def setup(self, shape: tuple[int, int], slice: str):
46+
def setup(self, shape: tuple[int, int], slice: str, use_dask: bool): # noqa: FBT001
4447
X = sparse.random(
4548
*shape, density=0.01, format="csr", random_state=np.random.default_rng(42)
4649
)
4750
self.slice = self._slices[slice]
4851
g = zarr.group()
4952
write_elem(g, "X", X)
50-
self.x = sparse_dataset(g["X"])
53+
self.x = read_elem_lazy(g["X"]) if use_dask else sparse_dataset(g["X"])
5154
self.adata = AnnData(self.x)
5255

5356
def time_getitem(self, *_):
54-
self.x[self.slice]
57+
res = self.x[self.slice]
58+
if isinstance(res, DaskArray):
59+
res.compute()
5560

5661
def peakmem_getitem(self, *_):
57-
self.x[self.slice]
62+
res = self.x[self.slice]
63+
if isinstance(res, DaskArray):
64+
res.compute()
5865

5966
def time_getitem_adata(self, *_):
60-
self.adata[self.slice]
67+
res = self.adata[self.slice]
68+
if isinstance(res, DaskArray):
69+
res.compute()
6170

6271
def peakmem_getitem_adata(self, *_):
63-
self.adata[self.slice]
72+
res = self.adata[self.slice]
73+
if isinstance(res, DaskArray):
74+
res.compute()

src/anndata/_core/sparse_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ def _offsets(
165165
def _get_contiguous_compressed_slice(
166166
self, s: slice
167167
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
168-
new_indptr = self.indptr[s.start : s.stop + 1].copy()
168+
new_indptr = self.indptr[s.start : s.stop + 1]
169+
# If indptr is cached, we need to make a copy of the subset
170+
# so as not to alter the underlying cached data.
171+
if isinstance(self.indptr, np.ndarray):
172+
new_indptr = new_indptr.copy()
169173

170174
start = new_indptr[0]
171175
stop = new_indptr[-1]

src/anndata/_io/specs/lazy_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def read_sparse_as_dask(
132132
path_or_sparse_dataset = (
133133
Path(filename(elem))
134134
if isinstance(elem, H5Group)
135-
else ad.io.sparse_dataset(elem)
135+
else ad.io.sparse_dataset(elem, should_cache_indptr=False)
136136
)
137137
elem_name = get_elem_name(elem)
138138
shape: tuple[int, int] = tuple(elem.attrs["shape"])

0 commit comments

Comments
 (0)