Skip to content

WIP: Chunking refactor #4595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
combine_by_coords,
)
from ..core.dataarray import DataArray
from ..core.dataset import Dataset, _maybe_chunk
from ..core.dataset import Dataset, _get_chunk, _maybe_chunk
from ..core.utils import close_on_error, is_grib_path, is_remote_uri
from .common import AbstractDataStore, ArrayWriter
from .locks import _get_scheduler
Expand Down Expand Up @@ -536,7 +536,7 @@ def maybe_decode_store(store, chunks):
k: _maybe_chunk(
k,
v,
store.get_chunk(k, v, chunks),
_get_chunk(v, chunks),
overwrite_encoded_chunks=overwrite_encoded_chunks,
)
for k, v in ds.variables.items()
Expand Down
101 changes: 59 additions & 42 deletions xarray/backends/apiv2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os

from ..core.dataset import _get_chunk, _maybe_chunk
from ..core.utils import is_remote_uri
from . import plugins, zarr
from . import plugins
from .api import (
_autodetect_engine,
_get_backend_cls,
Expand All @@ -10,8 +11,48 @@
)


def _get_mtime(filename_or_obj):
# if passed an actual file path, augment the token with
# the file modification time
if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj):
mtime = os.path.getmtime(filename_or_obj)
else:
mtime = None
return mtime


def _chunk_ds(
backend_ds,
filename_or_obj,
engine,
chunks,
overwrite_encoded_chunks,
**extra_tokens,
):
from dask.base import tokenize

mtime = _get_mtime(filename_or_obj)
token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens)
name_prefix = "open_dataset-%s" % token

variables = {}
for name, var in backend_ds.variables.items():
var_chunks = _get_chunk(var, chunks)
variables[name] = _maybe_chunk(
name,
var,
var_chunks,
overwrite_encoded_chunks=overwrite_encoded_chunks,
name_prefix=name_prefix,
token=token,
)
ds = backend_ds._replace(variables)
ds._file_obj = backend_ds._file_obj
return ds


def dataset_from_backend_dataset(
ds,
backend_ds,
filename_or_obj,
engine,
chunks,
Expand All @@ -26,50 +67,25 @@ def dataset_from_backend_dataset(
"Instead found %s. " % chunks
)

_protect_dataset_variables_inplace(ds, cache)
if chunks is not None and engine != "zarr":
from dask.base import tokenize

# if passed an actual file path, augment the token with
# the file modification time
if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj):
mtime = os.path.getmtime(filename_or_obj)
else:
mtime = None
token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens)
name_prefix = "open_dataset-%s" % token
ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token)

elif engine == "zarr":

if chunks == "auto":
try:
import dask.array # noqa
except ImportError:
chunks = None

if chunks is None:
return ds

if isinstance(chunks, int):
chunks = dict.fromkeys(ds.dims, chunks)

variables = {
k: zarr.ZarrStore.maybe_chunk(k, v, chunks, overwrite_encoded_chunks)
for k, v in ds.variables.items()
}
ds2 = ds._replace(variables)

_protect_dataset_variables_inplace(backend_ds, cache)
if chunks is None:
ds = backend_ds
else:
ds2 = ds
ds2._file_obj = ds._file_obj
ds = _chunk_ds(
backend_ds,
filename_or_obj,
engine,
chunks,
overwrite_encoded_chunks,
**extra_tokens,
)

# Ensure source filename always stored in dataset object (GH issue #2550)
if "source" not in ds.encoding:
if isinstance(filename_or_obj, str):
ds2.encoding["source"] = filename_or_obj
ds.encoding["source"] = filename_or_obj

return ds2
return ds


def resolve_decoders_kwargs(decode_cf, engine, **decoders):
Expand Down Expand Up @@ -236,12 +252,13 @@ def open_dataset(
open_backend_dataset = _get_backend_cls(engine, engines=plugins.ENGINES)[
"open_dataset"
]
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
backend_ds = open_backend_dataset(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**backend_kwargs,
**{k: v for k, v in kwargs.items() if v is not None},
**filtered_kwargs,
)
ds = dataset_from_backend_dataset(
backend_ds,
Expand All @@ -253,7 +270,7 @@ def open_dataset(
drop_variables=drop_variables,
**decoders,
**backend_kwargs,
**kwargs,
**filtered_kwargs,
)

return ds
57 changes: 8 additions & 49 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import numpy as np

from .. import coding, conventions
Expand Down Expand Up @@ -368,53 +366,6 @@ def encode_variable(self, variable):
def encode_attribute(self, a):
return encode_zarr_attr_value(a)

@staticmethod
def get_chunk(name, var, chunks):
chunk_spec = dict(zip(var.dims, var.encoding.get("chunks")))

# Coordinate labels aren't chunked
if var.ndim == 1 and var.dims[0] == name:
return chunk_spec

if chunks == "auto":
return chunk_spec

for dim in var.dims:
if dim in chunks:
spec = chunks[dim]
if isinstance(spec, int):
spec = (spec,)
if isinstance(spec, (tuple, list)) and chunk_spec[dim]:
if any(s % chunk_spec[dim] for s in spec):
warnings.warn(
"Specified Dask chunks %r would "
"separate Zarr chunk shape %r for "
"dimension %r. This significantly "
"degrades performance. Consider "
"rechunking after loading instead."
% (chunks[dim], chunk_spec[dim], dim),
stacklevel=2,
)
chunk_spec[dim] = chunks[dim]
return chunk_spec

@classmethod
def maybe_chunk(cls, name, var, chunks, overwrite_encoded_chunks):
chunk_spec = cls.get_chunk(name, var, chunks)

if (var.ndim > 0) and (chunk_spec is not None):
from dask.base import tokenize

# does this cause any data to be read?
token2 = tokenize(name, var._data, chunks)
name2 = f"xarray-{name}-{token2}"
var = var.chunk(chunk_spec, name=name2, lock=None)
if overwrite_encoded_chunks and var.chunks is not None:
var.encoding["chunks"] = tuple(x[0] for x in var.chunks)
return var
else:
return var

def store(
self,
variables,
Expand Down Expand Up @@ -660,6 +611,14 @@ def open_zarr(
"""
from .api import open_dataset

if chunks == "auto":
try:
import dask.array # noqa

chunks = {}
except ImportError:
chunks = None

if kwargs:
raise TypeError(
"open_zarr() got unexpected keyword arguments " + ",".join(kwargs.keys())
Expand Down
53 changes: 53 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,59 @@ def _assert_empty(args: tuple, msg: str = "%s") -> None:
raise ValueError(msg % args)


def _check_chunks_compatibility(var, chunks, chunk_spec):
for dim in var.dims:
if dim not in chunks or (dim not in chunk_spec):
return

chunk_spec_dim = chunk_spec.get(dim)
chunks_dim = chunks.get(dim)

if isinstance(chunks_dim, int):
chunks_dim = (chunks_dim,)
if any(s % chunk_spec_dim for s in chunks_dim):
warnings.warn(
"Specified Dask chunks %r would "
"separate on disks chunk shape %r for "
"dimension %r. This could "
"degrades performance. Consider "
"rechunking after loading instead." % (chunks_dim, chunk_spec_dim, dim),
stacklevel=2,
)


def _get_chunk(var, chunks):
# chunks need to be explicity computed to take correctly into accout
# backend preferred chunking
import dask.array as da

if isinstance(chunks, int) or (chunks == "auto"):
chunks = dict.fromkeys(var.dims, chunks)

preferred_chunks_list = var.encoding.get("chunks", {})
preferred_chunks = dict(zip(var.dims, var.encoding.get("chunks", {})))
if isinstance(var, IndexVariable):
return {}

chunks_list = []
for dim in var.dims:
chunks_dim = chunks.get(dim, None)
preferred_chunks_dim = preferred_chunks.get(dim, None)
chunks_list.append(chunks_dim or preferred_chunks_dim)

output_chunks_list = da.core.normalize_chunks(
chunks_list,
shape=var.shape,
dtype=var.dtype,
previous_chunks=preferred_chunks_list,
)

output_chunks = dict(zip(var.dims, output_chunks_list))
_check_chunks_compatibility(var, output_chunks, preferred_chunks)

return output_chunks


def _maybe_chunk(
name,
var,
Expand Down
Loading