Skip to content

Support **kwargs form in .chunk() #6471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ New Features
elements which trigger summarization rather than full repr in (numpy) array
detailed views of the html repr (:pull:`6400`).
By `Benoît Bovy <https://github.com/benbovy>`_.
- Allow passing chunks in **kwargs form to :py:meth:`Dataset.chunk`, :py:meth:`DataArray.chunk`, and
:py:meth:`Variable.chunk`. (:pull:`6471`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
18 changes: 17 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,7 @@ def chunk(
name_prefix: str = "xarray-",
token: str = None,
lock: bool = False,
**chunks_kwargs: Any,
) -> DataArray:
"""Coerce this array's data into a dask arrays with the given chunks.

Expand All @@ -1136,13 +1137,28 @@ def chunk(
lock : optional
Passed on to :py:func:`dask.array.from_array`, if the array is not
already as dask array.
**chunks_kwargs : {dim: chunks, ...}, optional
The keyword arguments form of ``chunks``.
One of chunks or chunks_kwargs must be provided.

Returns
-------
chunked : xarray.DataArray
"""
if isinstance(chunks, (tuple, list)):
if chunks is None:
warnings.warn(
"None value for 'chunks' is deprecated. "
"It will raise an error in the future. Use instead '{}'",
category=FutureWarning,
)
chunks = {}

if isinstance(chunks, (Number, str, int)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have maybe commented here. Number is never imported.

chunks = dict.fromkeys(self.dims, chunks)
elif isinstance(chunks, (tuple, list)):
chunks = dict(zip(self.dims, chunks))
else:
chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")

ds = self._to_temp_dataset().chunk(
chunks, name_prefix=name_prefix, token=token, lock=lock
Expand Down
10 changes: 8 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1994,6 +1994,7 @@ def chunk(
name_prefix: str = "xarray-",
token: str = None,
lock: bool = False,
**chunks_kwargs: Any,
) -> Dataset:
"""Coerce all arrays in this dataset into dask arrays with the given
chunks.
Expand All @@ -2007,7 +2008,7 @@ def chunk(

Parameters
----------
chunks : int, "auto" or mapping of hashable to int, optional
chunks : int, tuple of int, "auto" or mapping of hashable to int, optional
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or
``{"x": 5, "y": 5}``.
name_prefix : str, optional
Expand All @@ -2017,6 +2018,9 @@ def chunk(
lock : optional
Passed on to :py:func:`dask.array.from_array`, if the array is not
already as dask array.
**chunks_kwargs : {dim: chunks, ...}, optional
The keyword arguments form of ``chunks``.
One of chunks or chunks_kwargs must be provided

Returns
-------
Expand All @@ -2028,7 +2032,7 @@ def chunk(
Dataset.chunksizes
xarray.unify_chunks
"""
if chunks is None:
if chunks is None and chunks_kwargs is None:
warnings.warn(
"None value for 'chunks' is deprecated. "
"It will raise an error in the future. Use instead '{}'",
Expand All @@ -2038,6 +2042,8 @@ def chunk(

if isinstance(chunks, (Number, str, int)):
chunks = dict.fromkeys(self.dims, chunks)
else:
chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")

bad_dims = chunks.keys() - self.dims.keys()
if bad_dims:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def either_dict_or_kwargs(
kw_kwargs: Mapping[str, T],
func_name: str,
) -> Mapping[Hashable, T]:
if pos_kwargs is None:
if pos_kwargs is None or pos_kwargs == {}:
# Need an explicit cast to appease mypy due to invariance; see
# https://github.com/python/mypy/issues/6228
return cast(Mapping[Hashable, T], kw_kwargs)
Expand Down
24 changes: 22 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numbers
import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Hashable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Sequence

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1012,7 +1012,19 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]:

_array_counter = itertools.count()

def chunk(self, chunks={}, name=None, lock=False):
def chunk(
self,
chunks: (
int
| Literal["auto"]
| tuple[int, ...]
| tuple[tuple[int, ...], ...]
| Mapping[Any, None | int | tuple[int, ...]]
) = {},
name: str = None,
lock: bool = False,
**chunks_kwargs: Any,
) -> Variable:
"""Coerce this array's data into a dask array with the given chunks.

If this variable is a non-dask array, it will be converted to dask
Expand All @@ -1034,6 +1046,9 @@ def chunk(self, chunks={}, name=None, lock=False):
lock : optional
Passed on to :py:func:`dask.array.from_array`, if the array is not
already as dask array.
**chunks_kwargs : {dim: chunks, ...}, optional
The keyword arguments form of ``chunks``.
One of chunks or chunks_kwargs must be provided.

Returns
-------
Expand All @@ -1049,6 +1064,11 @@ def chunk(self, chunks={}, name=None, lock=False):
)
chunks = {}

if isinstance(chunks, (float, str, int, tuple, list)):
pass # dask.array.from_array can handle these directly
else:
chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")

if utils.is_dict_like(chunks):
chunks = {self.get_axis_num(dim): chunk for dim, chunk in chunks.items()}

Expand Down
5 changes: 5 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,11 @@ def test_chunk(self):
assert isinstance(blocked.data, da.Array)
assert "testname_" in blocked.data.name

# test kwargs form of chunks
blocked = unblocked.chunk(dim_0=3, dim_1=3)
assert blocked.chunks == ((3,), (3, 1))
assert blocked.data.name != first_dask_name

def test_isel(self):
assert_identical(self.dv[0], self.dv.isel(x=0))
assert_identical(self.dv, self.dv.isel(x=slice(None)))
Expand Down
5 changes: 4 additions & 1 deletion xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,9 @@ def test_chunk(self):
expected_chunks = {"dim1": (8,), "dim2": (9,), "dim3": (10,)}
assert reblocked.chunks == expected_chunks

# test kwargs form of chunks
assert data.chunk(**expected_chunks).chunks == expected_chunks

def get_dask_names(ds):
return {k: v.data.name for k, v in ds.items()}

Expand All @@ -947,7 +950,7 @@ def get_dask_names(ds):
new_dask_names = get_dask_names(reblocked)
assert reblocked.chunks == expected_chunks
assert_identical(reblocked, data)
# recuhnking with same chunk sizes should not change names
# rechunking with same chunk sizes should not change names
for k, v in new_dask_names.items():
assert v == orig_dask_names[k]

Expand Down
34 changes: 34 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,6 +2154,40 @@ def test_coarsen_keep_attrs(self, operation="mean"):
class TestVariableWithDask(VariableSubclassobjects):
cls = staticmethod(lambda *args: Variable(*args).chunk())

def test_chunk(self):
unblocked = Variable(["dim_0", "dim_1"], np.ones((3, 4)))
assert unblocked.chunks is None

blocked = unblocked.chunk()
assert blocked.chunks == ((3,), (4,))
first_dask_name = blocked.data.name

blocked = unblocked.chunk(chunks=((2, 1), (2, 2)))
assert blocked.chunks == ((2, 1), (2, 2))
assert blocked.data.name != first_dask_name

blocked = unblocked.chunk(chunks=(3, 3))
assert blocked.chunks == ((3,), (3, 1))
assert blocked.data.name != first_dask_name

# name doesn't change when rechunking by same amount
# this fails if ReprObject doesn't have __dask_tokenize__ defined
assert unblocked.chunk(2).data.name == unblocked.chunk(2).data.name

assert blocked.load().chunks is None

# Check that kwargs are passed
import dask.array as da

blocked = unblocked.chunk(name="testname_")
assert isinstance(blocked.data, da.Array)
assert "testname_" in blocked.data.name

# test kwargs form of chunks
blocked = unblocked.chunk(dim_0=3, dim_1=3)
assert blocked.chunks == ((3,), (3, 1))
assert blocked.data.name != first_dask_name

@pytest.mark.xfail
def test_0d_object_array_with_list(self):
super().test_0d_object_array_with_list()
Expand Down