diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 8dd431c5f62..90d83284632 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -25,7 +25,7 @@ combine_by_coords, ) from ..core.dataarray import DataArray -from ..core.dataset import Dataset, _maybe_chunk +from ..core.dataset import Dataset, _maybe_chunk, _get_chunk from ..core.utils import close_on_error, is_grib_path, is_remote_uri from .common import AbstractDataStore, ArrayWriter from .locks import _get_scheduler @@ -535,7 +535,7 @@ def maybe_decode_store(store, chunks): k: _maybe_chunk( k, v, - store.get_chunk(k, v, chunks), + _get_chunk(k, v, chunks), overwrite_encoded_chunks=overwrite_encoded_chunks, ) for k, v in ds.variables.items() diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index 2f34cc285ff..496b7b3e484 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -1,5 +1,6 @@ import os +from ..core.dataset import _get_chunk, _maybe_chunk from ..core.utils import is_remote_uri from . import cfgrib_, h5netcdf_, zarr from .api import ( @@ -16,8 +17,18 @@ } +def get_mtime(filename_or_obj): + # if passed an actual file path, augment the token with + # the file modification time + if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj): + mtime = os.path.getmtime(filename_or_obj) + else: + mtime = None + return mtime + + def dataset_from_backend_dataset( - ds, + backend_ds, filename_or_obj, engine, chunks, @@ -25,57 +36,39 @@ def dataset_from_backend_dataset( overwrite_encoded_chunks, extra_tokens, ): - if not (isinstance(chunks, (int, dict)) or chunks is None): - if chunks != "auto": - raise ValueError( - "chunks must be an int, dict, 'auto', or None. " - "Instead found %s. " % chunks - ) - _protect_dataset_variables_inplace(ds, cache) - if chunks is not None and engine != "zarr": + _protect_dataset_variables_inplace(backend_ds, cache) + if chunks is None: + ds = backend_ds + else: from dask.base import tokenize - # if passed an actual file path, augment the token with - # the file modification time - if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj): - mtime = os.path.getmtime(filename_or_obj) - else: - mtime = None + mtime = get_mtime(filename_or_obj) token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) name_prefix = "open_dataset-%s" % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) - - elif engine == "zarr": - - if chunks == "auto": - try: - import dask.array # noqa - except ImportError: - chunks = None - - if chunks is None: - return ds - if isinstance(chunks, int): - chunks = dict.fromkeys(ds.dims, chunks) + chunks = dict.fromkeys(backend_ds.dims, chunks) variables = { - k: zarr.ZarrStore.maybe_chunk(k, v, chunks, overwrite_encoded_chunks) - for k, v in ds.variables.items() + name: _maybe_chunk( + name, + var, + _get_chunk(name, var, chunks), + overwrite_encoded_chunks=overwrite_encoded_chunks, + name_prefix=name_prefix, + token=token, + ) + for name, var in backend_ds.variables.items() } - ds2 = ds._replace(variables) + ds = backend_ds._replace(variables) - else: - ds2 = ds - ds2._file_obj = ds._file_obj + ds._file_obj = backend_ds._file_obj - # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: if isinstance(filename_or_obj, str): - ds.encoding["source"] = filename_or_obj + backend_ds.encoding["source"] = filename_or_obj - return ds2 + return ds def open_dataset( @@ -191,6 +184,14 @@ def open_dataset( open_mfdataset """ + if chunks == "auto": + chunks = {} + if not (isinstance(chunks, (int, dict)) or chunks is None): + raise ValueError( + "chunks must be an int, dict, 'auto', or None. " + "Instead found %s. " % chunks + ) + if cache is None: cache = chunks is None diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index b54be09e749..2145685b2b2 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -362,53 +362,6 @@ def encode_variable(self, variable): def encode_attribute(self, a): return encode_zarr_attr_value(a) - @staticmethod - def get_chunk(name, var, chunks): - chunk_spec = dict(zip(var.dims, var.encoding.get("chunks"))) - - # Coordinate labels aren't chunked - if var.ndim == 1 and var.dims[0] == name: - return chunk_spec - - if chunks == "auto": - return chunk_spec - - for dim in var.dims: - if dim in chunks: - spec = chunks[dim] - if isinstance(spec, int): - spec = (spec,) - if isinstance(spec, (tuple, list)) and chunk_spec[dim]: - if any(s % chunk_spec[dim] for s in spec): - warnings.warn( - "Specified Dask chunks %r would " - "separate Zarr chunk shape %r for " - "dimension %r. This significantly " - "degrades performance. Consider " - "rechunking after loading instead." - % (chunks[dim], chunk_spec[dim], dim), - stacklevel=2, - ) - chunk_spec[dim] = chunks[dim] - return chunk_spec - - @classmethod - def maybe_chunk(cls, name, var, chunks, overwrite_encoded_chunks): - chunk_spec = cls.get_chunk(name, var, chunks) - - if (var.ndim > 0) and (chunk_spec is not None): - from dask.base import tokenize - - # does this cause any data to be read? - token2 = tokenize(name, var._data, chunks) - name2 = f"xarray-{name}-{token2}" - var = var.chunk(chunk_spec, name=name2, lock=None) - if overwrite_encoded_chunks and var.chunks is not None: - var.encoding["chunks"] = tuple(x[0] for x in var.chunks) - return var - else: - return var - def store( self, variables, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1ceb5623abd..562a99b9028 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -359,6 +359,40 @@ def _assert_empty(args: tuple, msg: str = "%s") -> None: raise ValueError(msg % args) +def _check_chunks_compatibility(dim, chunks, chunk_spec): + spec = chunks[dim] + if isinstance(spec, int): + spec = (spec,) + if any(s % chunk_spec[dim] for s in spec): + warnings.warn( + "Specified Dask chunks %r would " + "separate on disks chunk shape %r for " + "dimension %r. This could " + "degrades performance. Consider " + "rechunking after loading instead." % (chunks[dim], chunk_spec[dim], dim), + stacklevel=2, + ) + + +def _get_chunk(name, var, chunks): + if chunks == "auto": + chunks = {} + + preferred_chunks = dict(zip(var.dims, var.encoding.get("chunks", {}))) + if var.ndim == 1 and var.dims[0] == name: + return preferred_chunks + + output_chunks = {} + if chunks is not None: + for dim in preferred_chunks: + if dim in chunks: + _check_chunks_compatibility(dim, chunks, preferred_chunks) + output_chunks[dim] = chunks[dim] + else: + output_chunks[dim] = preferred_chunks[dim] + return output_chunks + + def _maybe_chunk( name, var,