diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 0b9b5046cb9..5f9fc248e84 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -26,7 +26,7 @@ combine_by_coords, ) from ..core.dataarray import DataArray -from ..core.dataset import Dataset, _maybe_chunk +from ..core.dataset import Dataset, _get_chunk, _maybe_chunk from ..core.utils import close_on_error, is_grib_path, is_remote_uri from .common import AbstractDataStore, ArrayWriter from .locks import _get_scheduler @@ -536,7 +536,7 @@ def maybe_decode_store(store, chunks): k: _maybe_chunk( k, v, - store.get_chunk(k, v, chunks), + _get_chunk(v, chunks), overwrite_encoded_chunks=overwrite_encoded_chunks, ) for k, v in ds.variables.items() diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index 7e4605c42ce..3f162d899b4 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -1,7 +1,8 @@ import os +from ..core.dataset import _get_chunk, _maybe_chunk from ..core.utils import is_remote_uri -from . import plugins, zarr +from . import plugins from .api import ( _autodetect_engine, _get_backend_cls, @@ -10,8 +11,48 @@ ) +def _get_mtime(filename_or_obj): + # if passed an actual file path, augment the token with + # the file modification time + if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj): + mtime = os.path.getmtime(filename_or_obj) + else: + mtime = None + return mtime + + +def _chunk_ds( + backend_ds, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + **extra_tokens, +): + from dask.base import tokenize + + mtime = _get_mtime(filename_or_obj) + token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) + name_prefix = "open_dataset-%s" % token + + variables = {} + for name, var in backend_ds.variables.items(): + var_chunks = _get_chunk(var, chunks) + variables[name] = _maybe_chunk( + name, + var, + var_chunks, + overwrite_encoded_chunks=overwrite_encoded_chunks, + name_prefix=name_prefix, + token=token, + ) + ds = backend_ds._replace(variables) + ds._file_obj = backend_ds._file_obj + return ds + + def dataset_from_backend_dataset( - ds, + backend_ds, filename_or_obj, engine, chunks, @@ -26,50 +67,25 @@ def dataset_from_backend_dataset( "Instead found %s. " % chunks ) - _protect_dataset_variables_inplace(ds, cache) - if chunks is not None and engine != "zarr": - from dask.base import tokenize - - # if passed an actual file path, augment the token with - # the file modification time - if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj): - mtime = os.path.getmtime(filename_or_obj) - else: - mtime = None - token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) - name_prefix = "open_dataset-%s" % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) - - elif engine == "zarr": - - if chunks == "auto": - try: - import dask.array # noqa - except ImportError: - chunks = None - - if chunks is None: - return ds - - if isinstance(chunks, int): - chunks = dict.fromkeys(ds.dims, chunks) - - variables = { - k: zarr.ZarrStore.maybe_chunk(k, v, chunks, overwrite_encoded_chunks) - for k, v in ds.variables.items() - } - ds2 = ds._replace(variables) - + _protect_dataset_variables_inplace(backend_ds, cache) + if chunks is None: + ds = backend_ds else: - ds2 = ds - ds2._file_obj = ds._file_obj + ds = _chunk_ds( + backend_ds, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + **extra_tokens, + ) # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: if isinstance(filename_or_obj, str): - ds2.encoding["source"] = filename_or_obj + ds.encoding["source"] = filename_or_obj - return ds2 + return ds def resolve_decoders_kwargs(decode_cf, engine, **decoders): @@ -236,12 +252,13 @@ def open_dataset( open_backend_dataset = _get_backend_cls(engine, engines=plugins.ENGINES)[ "open_dataset" ] + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} backend_ds = open_backend_dataset( filename_or_obj, drop_variables=drop_variables, **decoders, **backend_kwargs, - **{k: v for k, v in kwargs.items() if v is not None}, + **filtered_kwargs, ) ds = dataset_from_backend_dataset( backend_ds, @@ -253,7 +270,7 @@ def open_dataset( drop_variables=drop_variables, **decoders, **backend_kwargs, - **kwargs, + **filtered_kwargs, ) return ds diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 9827c345239..abf0c7372eb 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,5 +1,3 @@ -import warnings - import numpy as np from .. import coding, conventions @@ -368,53 +366,6 @@ def encode_variable(self, variable): def encode_attribute(self, a): return encode_zarr_attr_value(a) - @staticmethod - def get_chunk(name, var, chunks): - chunk_spec = dict(zip(var.dims, var.encoding.get("chunks"))) - - # Coordinate labels aren't chunked - if var.ndim == 1 and var.dims[0] == name: - return chunk_spec - - if chunks == "auto": - return chunk_spec - - for dim in var.dims: - if dim in chunks: - spec = chunks[dim] - if isinstance(spec, int): - spec = (spec,) - if isinstance(spec, (tuple, list)) and chunk_spec[dim]: - if any(s % chunk_spec[dim] for s in spec): - warnings.warn( - "Specified Dask chunks %r would " - "separate Zarr chunk shape %r for " - "dimension %r. This significantly " - "degrades performance. Consider " - "rechunking after loading instead." - % (chunks[dim], chunk_spec[dim], dim), - stacklevel=2, - ) - chunk_spec[dim] = chunks[dim] - return chunk_spec - - @classmethod - def maybe_chunk(cls, name, var, chunks, overwrite_encoded_chunks): - chunk_spec = cls.get_chunk(name, var, chunks) - - if (var.ndim > 0) and (chunk_spec is not None): - from dask.base import tokenize - - # does this cause any data to be read? - token2 = tokenize(name, var._data, chunks) - name2 = f"xarray-{name}-{token2}" - var = var.chunk(chunk_spec, name=name2, lock=None) - if overwrite_encoded_chunks and var.chunks is not None: - var.encoding["chunks"] = tuple(x[0] for x in var.chunks) - return var - else: - return var - def store( self, variables, @@ -660,6 +611,14 @@ def open_zarr( """ from .api import open_dataset + if chunks == "auto": + try: + import dask.array # noqa + + chunks = {} + except ImportError: + chunks = None + if kwargs: raise TypeError( "open_zarr() got unexpected keyword arguments " + ",".join(kwargs.keys()) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 04974c58113..269f58aca92 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -359,6 +359,59 @@ def _assert_empty(args: tuple, msg: str = "%s") -> None: raise ValueError(msg % args) +def _check_chunks_compatibility(var, chunks, chunk_spec): + for dim in var.dims: + if dim not in chunks or (dim not in chunk_spec): + return + + chunk_spec_dim = chunk_spec.get(dim) + chunks_dim = chunks.get(dim) + + if isinstance(chunks_dim, int): + chunks_dim = (chunks_dim,) + if any(s % chunk_spec_dim for s in chunks_dim): + warnings.warn( + "Specified Dask chunks %r would " + "separate on disks chunk shape %r for " + "dimension %r. This could " + "degrades performance. Consider " + "rechunking after loading instead." % (chunks_dim, chunk_spec_dim, dim), + stacklevel=2, + ) + + +def _get_chunk(var, chunks): + # chunks need to be explicity computed to take correctly into accout + # backend preferred chunking + import dask.array as da + + if isinstance(chunks, int) or (chunks == "auto"): + chunks = dict.fromkeys(var.dims, chunks) + + preferred_chunks_list = var.encoding.get("chunks", {}) + preferred_chunks = dict(zip(var.dims, var.encoding.get("chunks", {}))) + if isinstance(var, IndexVariable): + return {} + + chunks_list = [] + for dim in var.dims: + chunks_dim = chunks.get(dim, None) + preferred_chunks_dim = preferred_chunks.get(dim, None) + chunks_list.append(chunks_dim or preferred_chunks_dim) + + output_chunks_list = da.core.normalize_chunks( + chunks_list, + shape=var.shape, + dtype=var.dtype, + previous_chunks=preferred_chunks_list, + ) + + output_chunks = dict(zip(var.dims, output_chunks_list)) + _check_chunks_compatibility(var, output_chunks, preferred_chunks) + + return output_chunks + + def _maybe_chunk( name, var, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 43bf2de245b..8d8ead147d9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -608,10 +608,6 @@ def test_orthogonal_indexing(self): actual = on_disk.isel(**indexers) assert_identical(expected, actual) - @pytest.mark.xfail( - not has_dask, - reason="the code for indexing without dask handles negative steps in slices incorrectly", - ) def test_vectorized_indexing(self): in_memory = create_test_data() with self.roundtrip(in_memory) as on_disk: @@ -676,6 +672,21 @@ def multiple_indexing(indexers): ] multiple_indexing(indexers) + def test_vectorized_indexing_negative_step_slice(self, open_kwargs=None): + in_memory = create_test_data() + + def multiple_indexing(indexers): + # make sure a sequence of lazy indexings certainly works. + with self.roundtrip(in_memory, open_kwargs=open_kwargs) as on_disk: + actual = on_disk["var3"] + expected = in_memory["var3"] + for ind in indexers: + actual = actual.isel(**ind) + expected = expected.isel(**ind) + # make sure the array is not yet loaded into memory + assert not actual.variable._in_memory + assert_identical(expected, actual.load()) + # with negative step slice. indexers = [ { @@ -1567,7 +1578,7 @@ def roundtrip( if save_kwargs is None: save_kwargs = {} if open_kwargs is None: - open_kwargs = {"chunks": "auto"} + open_kwargs = {} with self.create_zarr_target() as store_target: self.save(data, store_target, **save_kwargs) with self.open(store_target, **open_kwargs) as ds: @@ -1604,7 +1615,7 @@ def test_auto_chunk(self): # there should be no chunks assert v.chunks is None - with self.roundtrip(original, open_kwargs={"chunks": "auto"}) as actual: + with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: for k, v in actual.variables.items(): # only index variables should be in memory assert v._in_memory == (k in actual.dims) @@ -1701,7 +1712,7 @@ def test_deprecate_auto_chunk(self): def test_write_uneven_dask_chunks(self): # regression for GH#2225 original = create_test_data().chunk({"dim1": 3, "dim2": 4, "dim3": 3}) - with self.roundtrip(original, open_kwargs={"chunks": "auto"}) as actual: + with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: for k, v in actual.data_vars.items(): print(k) assert v.chunks == actual[k].chunks @@ -1850,9 +1861,7 @@ def test_write_persistence_modes(self, group): ds.to_zarr(store_target, mode="w", group=group) ds_to_append.to_zarr(store_target, append_dim="time", group=group) original = xr.concat([ds, ds_to_append], dim="time") - actual = xr.open_dataset( - store_target, group=group, chunks="auto", engine="zarr" - ) + actual = xr.open_dataset(store_target, group=group, engine="zarr") assert_identical(original, actual) def test_compressor_encoding(self): @@ -1941,11 +1950,11 @@ def test_check_encoding_is_consistent_after_append(self): encoding = {"da": {"compressor": compressor}} ds.to_zarr(store_target, mode="w", encoding=encoding) ds_to_append.to_zarr(store_target, append_dim="time") - actual_ds = xr.open_dataset(store_target, chunks="auto", engine="zarr") + actual_ds = xr.open_dataset(store_target, engine="zarr") actual_encoding = actual_ds["da"].encoding["compressor"] assert actual_encoding.get_config() == compressor.get_config() assert_identical( - xr.open_dataset(store_target, chunks="auto", engine="zarr").compute(), + xr.open_dataset(store_target, engine="zarr").compute(), xr.concat([ds, ds_to_append], dim="time"), ) @@ -1960,9 +1969,7 @@ def test_append_with_new_variable(self): ds_with_new_var.to_zarr(store_target, mode="a") combined = xr.concat([ds, ds_to_append], dim="time") combined["new_var"] = ds_with_new_var["new_var"] - assert_identical( - combined, xr.open_dataset(store_target, chunks="auto", engine="zarr") - ) + assert_identical(combined, xr.open_dataset(store_target, engine="zarr")) @requires_dask def test_to_zarr_compute_false_roundtrip(self): @@ -2172,6 +2179,10 @@ def test_open_zarr_use_cftime(self): ds_b = xr.open_zarr(store_target, consolidated=True, use_cftime=True) assert xr.coding.times.contains_cftime_datetimes(ds_b.time) + @requires_dask + def test_vectorized_indexing_negative_step_slice(self): + super().test_vectorized_indexing_negative_step_slice(open_kwargs={"chunks": {}}) + @requires_zarr class TestZarrDictStore(ZarrBase): @@ -4803,3 +4814,61 @@ def test_load_single_value_h5netcdf(tmp_path): ds.to_netcdf(tmp_path / "test.nc") with xr.open_dataset(tmp_path / "test.nc", engine="h5netcdf") as ds2: ds2["test"][0].load() + + +@requires_zarr +@requires_dask +@pytest.mark.parametrize( + "chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}] +) +def test_open_dataset_chunking_zarr(chunks, tmp_path): + encoded_chunks = 100 + dask_arr = da.from_array( + np.ones((500, 500), dtype="float64"), chunks=encoded_chunks + ) + ds = xr.Dataset( + { + "test": xr.DataArray( + dask_arr, + dims=("x", "y"), + ) + } + ) + ds["test"].encoding["chunks"] = encoded_chunks + ds.to_zarr(tmp_path / "test.zarr") + + with dask.config.set({"array.chunk-size": "1MiB"}): + expected = ds.chunk(chunks) + actual = xr.open_dataset(tmp_path / "test.zarr", engine="zarr", chunks=chunks) + assert actual == expected + + +@requires_zarr +@requires_dask +@pytest.mark.parametrize( + "chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}] +) +def test_chunking_consintency(chunks, tmp_path): + encoded_chunks = {} + dask_arr = da.from_array( + np.ones((500, 500), dtype="float64"), chunks=encoded_chunks + ) + ds = xr.Dataset( + { + "test": xr.DataArray( + dask_arr, + dims=("x", "y"), + ) + } + ) + ds["test"].encoding["chunks"] = encoded_chunks + ds.to_zarr(tmp_path / "test.zarr") + ds.to_netcdf(tmp_path / "test.nc") + + with dask.config.set({"array.chunk-size": "1MiB"}): + expected = ds.chunk(chunks) + actual = xr.open_dataset(tmp_path / "test.zarr", engine="zarr", chunks=chunks) + xr.testing.assert_chunks_equal(actual, expected) + + actual = xr.open_dataset(tmp_path / "test.nc", chunks=chunks) + xr.testing.assert_chunks_equal(actual, expected)