Skip to content

support chunks in open_groups and open_datatree #9660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
fe95b16
support chunking and default values in `open_groups`
keewis Oct 22, 2024
3bfbc3a
same for `open_datatree`
keewis Oct 22, 2024
f4abb01
use `group_subtrees` instead of `map_over_datasets`
keewis Oct 22, 2024
b0458aa
check that `chunks` on `open_datatree` works
keewis Oct 22, 2024
4dbd91e
specify the chunksizes when opening from disk
keewis Oct 23, 2024
11850fd
check that `open_groups` with chunks works, too
keewis Oct 23, 2024
a71f5e2
require dask for `test_open_groups_chunks`
TomNicholas Oct 23, 2024
6d3deed
protect variables from write operations
keewis Oct 23, 2024
7f770cf
copy over `_close` from the backend tree
keewis Oct 23, 2024
05efaf6
copy a lot of the docstring from `open_dataset`
keewis Oct 23, 2024
f9fee40
same for `open_groups`
keewis Oct 23, 2024
2e10bdc
Merge branch 'main' into open_datatree-dask
keewis Oct 23, 2024
a4e99c6
reuse `_protect_dataset_variables_inplace`
keewis Oct 23, 2024
3e8b80c
final missing `requires_dask`
keewis Oct 23, 2024
cf1a6b0
typing for the test utils
keewis Oct 24, 2024
114c4dc
type hints for `_protect_datatree_variables_inplace`
keewis Oct 24, 2024
9eac19d
type hints for `_protect_dataset_variables_inplace`
keewis Oct 24, 2024
446a53d
copy over the name of the backend tree
keewis Oct 24, 2024
5b36701
typo
keewis Oct 24, 2024
66616f7
swap the order of arguments to `assert_identical`
keewis Oct 24, 2024
843b2fc
try explicitly typing `data`
keewis Oct 24, 2024
8950841
typo
keewis Oct 24, 2024
4d93ada
use `Hashable` for variable names
keewis Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 173 additions & 3 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
)
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk
from xarray.core.datatree import DataTree
from xarray.core.indexes import Index
from xarray.core.treenode import group_subtrees
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes
from xarray.core.utils import is_remote_uri
from xarray.namedarray.daskmanager import DaskManager
Expand Down Expand Up @@ -74,7 +76,6 @@
T_NetcdfTypes = Literal[
"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC"
]
from xarray.core.datatree import DataTree

DATAARRAY_NAME = "__xarray_dataarray_name__"
DATAARRAY_VARIABLE = "__xarray_dataarray_variable__"
Expand Down Expand Up @@ -414,6 +415,56 @@ def _dataset_from_backend_dataset(
return ds


def _datatree_from_backend_datatree(
backend_tree,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
**extra_tokens,
):
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
)

# _protect_datatree_variables_inplace(backend_tree, cache)
if chunks is None:
tree = backend_tree
else:
tree = DataTree.from_dict(
{
path: _chunk_ds(
node.dataset,
filename_or_obj,
engine,
chunks,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
**extra_tokens,
)
for path, [node] in group_subtrees(backend_tree)
}
)

# ds.set_close(backend_ds._close)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backend_tree should have been created using datatree_from_dict_with_io_cleanup, so one way to handle this could be just to copy over the _close attribute from every node of backend_tree?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the question is, do we even need that here? I copied this from open_dataset where this is explicitly set, but since datatree_from_dict_with_io_cleanup does this already we might be able to just remove it?

The only reason why I kept the commented-out line is to discuss whether the shift in paradigm (have the backend set _close vs. do it for all backends the same way) is intentional, and if we should do the same for open_dataset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it would be nice to remove this, I'm just worried that mapping over the each .dataset might not properly propagate ._close (does it? should it?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not (I think), so I'm explicitly copying it over. So far that doesn't appear to cause anything to break.


# Ensure source filename always stored in dataset object
if "source" not in tree.encoding:
path = getattr(filename_or_obj, "path", filename_or_obj)

if isinstance(path, str | os.PathLike):
tree.encoding["source"] = _normalize_path(path)

return tree


def open_dataset(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
Expand Down Expand Up @@ -838,7 +889,22 @@ def open_dataarray(

def open_datatree(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
cache: bool | None = None,
decode_cf: bool | None = None,
mask_and_scale: bool | Mapping[str, bool] | None = None,
decode_times: bool | Mapping[str, bool] | None = None,
decode_timedelta: bool | Mapping[str, bool] | None = None,
use_cftime: bool | Mapping[str, bool] | None = None,
concat_characters: bool | Mapping[str, bool] | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> DataTree:
"""
Expand All @@ -856,17 +922,75 @@ def open_datatree(
-------
xarray.DataTree
"""
if cache is None:
cache = chunks is None

if backend_kwargs is not None:
kwargs.update(backend_kwargs)

if engine is None:
engine = plugins.guess_engine(filename_or_obj)

if from_array_kwargs is None:
from_array_kwargs = {}

backend = plugins.get_backend(engine)

return backend.open_datatree(filename_or_obj, **kwargs)
decoders = _resolve_decoders_kwargs(
decode_cf,
open_backend_dataset_parameters=(),
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
)
overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)

backend_tree = backend.open_datatree(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)

tree = _datatree_from_backend_datatree(
backend_tree,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
**decoders,
**kwargs,
)

return tree


def open_groups(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
cache: bool | None = None,
decode_cf: bool | None = None,
mask_and_scale: bool | Mapping[str, bool] | None = None,
decode_times: bool | Mapping[str, bool] | None = None,
decode_timedelta: bool | Mapping[str, bool] | None = None,
use_cftime: bool | Mapping[str, bool] | None = None,
concat_characters: bool | Mapping[str, bool] | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> dict[str, Dataset]:
"""
Expand All @@ -893,12 +1017,58 @@ def open_groups(
open_datatree()
DataTree.from_dict()
"""
if cache is None:
cache = chunks is None

if backend_kwargs is not None:
kwargs.update(backend_kwargs)

if engine is None:
engine = plugins.guess_engine(filename_or_obj)

if from_array_kwargs is None:
from_array_kwargs = {}

backend = plugins.get_backend(engine)

return backend.open_groups_as_dict(filename_or_obj, **kwargs)
decoders = _resolve_decoders_kwargs(
decode_cf,
open_backend_dataset_parameters=(),
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
)
overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)

backend_groups = backend.open_groups_as_dict(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)

groups = {
name: _dataset_from_backend_dataset(
backend_ds,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
**decoders,
**kwargs,
)
for name, backend_ds in backend_groups.items()
}

return groups


def open_mfdataset(
Expand Down
78 changes: 78 additions & 0 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from xarray.core.datatree import DataTree
from xarray.testing import assert_equal, assert_identical
from xarray.tests import (
requires_dask,
requires_h5netcdf,
requires_netCDF4,
requires_zarr,
Expand All @@ -25,6 +26,43 @@
pass


def diff_chunks(comparison, tree1, tree2):
mismatching_variables = [loc for loc, equals in comparison.items() if not equals]

variable_messages = [
"\n".join(
[
f"L {path}:{name}: {tree1[path].variables[name].chunksizes}",
f"R {path}:{name}: {tree2[path].variables[name].chunksizes}",
]
)
for path, name in mismatching_variables
]
return "\n".join(["Differing chunk sizes:"] + variable_messages)


def assert_chunks_equal(actual, expected, enforce_dask=False):
__tracebackhide__ = True

from xarray.namedarray.pycompat import array_type

dask_array_type = array_type("dask")

comparison = {
(path, name): (
(
not enforce_dask
or isinstance(node1.variables[name].data, dask_array_type)
)
and node1.variables[name].chunksizes == node2.variables[name].chunksizes
)
for path, (node1, node2) in xr.group_subtrees(actual, expected)
for name in node1.variables.keys()
}

assert all(comparison.values()), diff_chunks(comparison, actual, expected)


@pytest.fixture(scope="module")
def unaligned_datatree_nc(tmp_path_factory):
"""Creates a test netCDF4 file with the following unaligned structure, writes it to a /tmp directory
Expand Down Expand Up @@ -170,6 +208,26 @@
):
open_datatree(unaligned_datatree_nc)

@requires_dask
def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None:
filepath = tmpdir / "test.nc"

root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
original_tree = DataTree.from_dict(
{
"/": root_data.chunk({"x": 2, "y": 1}),
"/group1": set1_data.chunk({"x": 1, "y": 2}),
"/group2": set2_data.chunk({"x": 2, "y": 3}),
}
)
original_tree.to_netcdf(filepath, engine="netcdf4")

with open_datatree(filepath, engine="netcdf4", chunks={}) as tree:
xr.testing.assert_identical(tree, original_tree)
assert_chunks_equal(tree, original_tree, enforce_dask=True)

Check failure on line 229 in xarray/tests/test_backends_datatree.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

TestNetCDF4DatatreeIO.test_open_datatree_chunks AssertionError: Differing chunk sizes: L .:a: {'y': (3,)} R .:a: {'y': (1, 1, 1)} L group1:a: {'y': (3,)} R group1:a: {'y': (2, 1)} L group1:b: {'x': (2,)} R group1:b: {'x': (1, 1)} assert False + where False = all(dict_values([False, True, False, False, True, True])) + where dict_values([False, True, False, False, True, True]) = <built-in method values of dict object at 0x14b78dcc0>() + where <built-in method values of dict object at 0x14b78dcc0> = {('.', 'a'): False, ('.', 'set0'): True, ('group1', 'a'): False, ('group1', 'b'): False, ...}.values

Check failure on line 229 in xarray/tests/test_backends_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 min-all-deps

TestNetCDF4DatatreeIO.test_open_datatree_chunks AssertionError: Differing chunk sizes: L .:a: {'y': (3,)} R .:a: {'y': (1, 1, 1)} L group1:a: {'y': (3,)} R group1:a: {'y': (2, 1)} L group1:b: {'x': (2,)} R group1:b: {'x': (1, 1)} assert False + where False = all(dict_values([False, True, False, False, True, True])) + where dict_values([False, True, False, False, True, True]) = <built-in method values of dict object at 0x7fd955b9f180>() + where <built-in method values of dict object at 0x7fd955b9f180> = {('.', 'a'): False, ('.', 'set0'): True, ('group1', 'a'): False, ('group1', 'b'): False, ...}.values

Check failure on line 229 in xarray/tests/test_backends_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12 all-but-numba

TestNetCDF4DatatreeIO.test_open_datatree_chunks AssertionError: Differing chunk sizes: L .:a: {'y': (3,)} R .:a: {'y': (1, 1, 1)} L group1:a: {'y': (3,)} R group1:a: {'y': (2, 1)} L group1:b: {'x': (2,)} R group1:b: {'x': (1, 1)} assert False + where False = all(dict_values([False, True, False, False, True, True])) + where dict_values([False, True, False, False, True, True]) = <built-in method values of dict object at 0x7f569a654240>() + where <built-in method values of dict object at 0x7f569a654240> = {('.', 'a'): False, ('.', 'set0'): True, ('group1', 'a'): False, ('group1', 'b'): False, ...}.values

Check failure on line 229 in xarray/tests/test_backends_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

TestNetCDF4DatatreeIO.test_open_datatree_chunks AssertionError: Differing chunk sizes: L .:a: {'y': (3,)} R .:a: {'y': (1, 1, 1)} L group1:a: {'y': (3,)} R group1:a: {'y': (2, 1)} L group1:b: {'x': (2,)} R group1:b: {'x': (1, 1)} assert False + where False = all(dict_values([False, True, False, False, True, True])) + where dict_values([False, True, False, False, True, True]) = <built-in method values of dict object at 0x7f501ca35d00>() + where <built-in method values of dict object at 0x7f501ca35d00> = {('.', 'a'): False, ('.', 'set0'): True, ('group1', 'a'): False, ('group1', 'b'): False, ...}.values

Check failure on line 229 in xarray/tests/test_backends_datatree.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.12

TestNetCDF4DatatreeIO.test_open_datatree_chunks AssertionError: Differing chunk sizes: L .:a: {'y': (3,)} R .:a: {'y': (1, 1, 1)} L group1:a: {'y': (3,)} R group1:a: {'y': (2, 1)} L group1:b: {'x': (2,)} R group1:b: {'x': (1, 1)} assert False + where False = all(dict_values([False, True, False, False, True, True])) + where dict_values([False, True, False, False, True, True]) = <built-in method values of dict object at 0x148ca4040>() + where <built-in method values of dict object at 0x148ca4040> = {('.', 'a'): False, ('.', 'set0'): True, ('group1', 'a'): False, ('group1', 'b'): False, ...}.values

Check failure on line 229 in xarray/tests/test_backends_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

TestNetCDF4DatatreeIO.test_open_datatree_chunks AssertionError: Differing chunk sizes: L .:a: {'y': (3,)} R .:a: {'y': (1, 1, 1)} L group1:a: {'y': (3,)} R group1:a: {'y': (2, 1)} L group1:b: {'x': (2,)} R group1:b: {'x': (1, 1)} assert False + where False = all(dict_values([False, True, False, False, True, True])) + where dict_values([False, True, False, False, True, True]) = <built-in method values of dict object at 0x7facbe2b1c40>() + where <built-in method values of dict object at 0x7facbe2b1c40> = {('.', 'a'): False, ('.', 'set0'): True, ('group1', 'a'): False, ('group1', 'b'): False, ...}.values

def test_open_groups(self, unaligned_datatree_nc) -> None:
"""Test `open_groups` with a netCDF4 file with an unaligned group hierarchy."""
unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc)
Expand Down Expand Up @@ -348,6 +406,26 @@
):
open_datatree(unaligned_datatree_zarr, engine="zarr")

@requires_dask
def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None:
filepath = tmpdir / "test.zarr"

root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
original_tree = DataTree.from_dict(
{
"/": root_data.chunk({"x": 2, "y": 1}),
"/group1": set1_data.chunk({"x": 1, "y": 2}),
"/group2": set2_data.chunk({"x": 2, "y": 3}),
}
)
original_tree.to_zarr(filepath)

with open_datatree(filepath, engine="zarr", chunks={}) as tree:
xr.testing.assert_identical(original_tree, tree)
assert_chunks_equal(tree, original_tree, enforce_dask=True)

def test_open_groups(self, unaligned_datatree_zarr) -> None:
"""Test `open_groups` with a zarr store of an unaligned group hierarchy."""

Expand Down
Loading