diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 19bfd7130d5..1d25b1bf88c 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1252,20 +1252,16 @@ def isomorphic( except (TypeError, TreeIsomorphismError): return False - def equals(self, other: DataTree, from_root: bool = True) -> bool: + def equals(self, other: DataTree) -> bool: """ - Two DataTrees are equal if they have isomorphic node structures, with matching node names, - and if they have matching variables and coordinates, all of which are equal. - - By default this method will check the whole tree above the given node. + Two DataTrees are equal if they have isomorphic node structures, with + matching node names, and if they have matching variables and + coordinates, all of which are equal. Parameters ---------- other : DataTree The other tree object to compare to. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the two trees before checking for isomorphism. - If neither tree has a parent then this has no effect. See Also -------- @@ -1273,30 +1269,27 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool: DataTree.isomorphic DataTree.identical """ - if not self.isomorphic(other, from_root=from_root, strict_names=True): + if not self.isomorphic(other, strict_names=True): return False return all( - [ - node.dataset.equals(other_node.dataset) - for node, other_node in zip(self.subtree, other.subtree, strict=True) - ] + node.dataset.equals(other_node.dataset) + for node, other_node in zip(self.subtree, other.subtree, strict=True) ) - def identical(self, other: DataTree, from_root=True) -> bool: - """ - Like equals, but will also check all dataset attributes and the attributes on - all variables and coordinates. + def _inherited_coords_set(self) -> set[str]: + return set(self.parent.coords if self.parent else []) - By default this method will check the whole tree above the given node. + def identical(self, other: DataTree) -> bool: + """ + Like equals, but also checks attributes on all datasets, variables and + coordinates, and requires that any inherited coordinates at the tree + root are also inherited on the other tree. Parameters ---------- other : DataTree The other tree object to compare to. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the two trees before checking for isomorphism. - If neither tree has a parent then this has no effect. See Also -------- @@ -1304,9 +1297,16 @@ def identical(self, other: DataTree, from_root=True) -> bool: DataTree.isomorphic DataTree.equals """ - if not self.isomorphic(other, from_root=from_root, strict_names=True): + if not self.isomorphic(other, strict_names=True): + return False + + if self.name != other.name: + return False + + if self._inherited_coords_set() != other._inherited_coords_set(): return False + # TODO: switch to zip_subtrees, when available return all( node.dataset.identical(other_node.dataset) for node, other_node in zip(self.subtree, other.subtree, strict=True) diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 3a0cc96b20d..f3165bb9a11 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -3,7 +3,6 @@ import functools import warnings from collections.abc import Hashable -from typing import overload import numpy as np import pandas as pd @@ -107,16 +106,8 @@ def maybe_transpose_dims(a, b, check_dim_order: bool): return b -@overload -def assert_equal(a, b): ... - - -@overload -def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ... - - @ensure_warnings -def assert_equal(a, b, from_root=True, check_dim_order: bool = True): +def assert_equal(a, b, check_dim_order: bool = True): """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray objects. @@ -135,10 +126,6 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True): or xarray.core.datatree.DataTree. The first object to compare. b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates or xarray.core.datatree.DataTree. The second object to compare. - from_root : bool, optional, default is True - Only used when comparing DataTree objects. Indicates whether or not to - first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. check_dim_order : bool, optional, default is True Whether dimensions must be in the same order. @@ -159,25 +146,13 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True): elif isinstance(a, Coordinates): assert a.equals(b), formatting.diff_coords_repr(a, b, "equals") elif isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals") + assert a.equals(b), diff_datatree_repr(a, b, "equals") else: raise TypeError(f"{type(a)} not supported by assertion comparison") -@overload -def assert_identical(a, b): ... - - -@overload -def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ... - - @ensure_warnings -def assert_identical(a, b, from_root=True): +def assert_identical(a, b): """Like :py:func:`xarray.testing.assert_equal`, but also matches the objects' names and attributes. @@ -193,12 +168,6 @@ def assert_identical(a, b, from_root=True): The first object to compare. b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The second object to compare. - from_root : bool, optional, default is True - Only used when comparing DataTree objects. Indicates whether or not to - first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. - check_dim_order : bool, optional, default is True - Whether dimensions must be in the same order. See Also -------- @@ -220,13 +189,7 @@ def assert_identical(a, b, from_root=True): elif isinstance(a, Coordinates): assert a.identical(b), formatting.diff_coords_repr(a, b, "identical") elif isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.identical(b, from_root=from_root), diff_datatree_repr( - a, b, "identical" - ) + assert a.identical(b), diff_datatree_repr(a, b, "identical") else: raise TypeError(f"{type(a)} not supported by assertion comparison") diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 9c11cde3bbb..a710fbfafa0 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1538,6 +1538,110 @@ def f(x, tree, y): assert actual is dt and actual.attrs == attrs +class TestEqualsAndIdentical: + + def test_minimal_variations(self): + tree = DataTree.from_dict( + { + "/": Dataset({"x": 1}), + "/child": Dataset({"x": 2}), + } + ) + assert tree.equals(tree) + assert tree.identical(tree) + + child = tree.children["child"] + assert child.equals(child) + assert child.identical(child) + + new_child = DataTree(dataset=Dataset({"x": 2}), name="child") + assert child.equals(new_child) + assert child.identical(new_child) + + anonymous_child = DataTree(dataset=Dataset({"x": 2})) + # TODO: re-enable this after fixing .equals() not to require matching + # names on the root node (i.e., after switching to use zip_subtrees) + # assert child.equals(anonymous_child) + assert not child.identical(anonymous_child) + + different_variables = DataTree.from_dict( + { + "/": Dataset(), + "/other": Dataset({"x": 2}), + } + ) + assert not tree.equals(different_variables) + assert not tree.identical(different_variables) + + different_root_data = DataTree.from_dict( + { + "/": Dataset({"x": 4}), + "/child": Dataset({"x": 2}), + } + ) + assert not tree.equals(different_root_data) + assert not tree.identical(different_root_data) + + different_child_data = DataTree.from_dict( + { + "/": Dataset({"x": 1}), + "/child": Dataset({"x": 3}), + } + ) + assert not tree.equals(different_child_data) + assert not tree.identical(different_child_data) + + different_child_node_attrs = DataTree.from_dict( + { + "/": Dataset({"x": 1}), + "/child": Dataset({"x": 2}, attrs={"foo": "bar"}), + } + ) + assert tree.equals(different_child_node_attrs) + assert not tree.identical(different_child_node_attrs) + + different_child_variable_attrs = DataTree.from_dict( + { + "/": Dataset({"x": 1}), + "/child": Dataset({"x": ((), 2, {"foo": "bar"})}), + } + ) + assert tree.equals(different_child_variable_attrs) + assert not tree.identical(different_child_variable_attrs) + + different_name = DataTree.from_dict( + { + "/": Dataset({"x": 1}), + "/child": Dataset({"x": 2}), + }, + name="different", + ) + # TODO: re-enable this after fixing .equals() not to require matching + # names on the root node (i.e., after switching to use zip_subtrees) + # assert tree.equals(different_name) + assert not tree.identical(different_name) + + def test_differently_inherited_coordinates(self): + root = DataTree.from_dict( + { + "/": Dataset(coords={"x": [1, 2]}), + "/child": Dataset(), + } + ) + child = root.children["child"] + assert child.equals(child) + assert child.identical(child) + + new_child = DataTree(dataset=Dataset(coords={"x": [1, 2]}), name="child") + assert child.equals(new_child) + assert not child.identical(new_child) + + deeper_root = DataTree(children={"root": root}) + grandchild = deeper_root["/root/child"] + assert child.equals(grandchild) + assert child.identical(grandchild) + + class TestSubset: def test_match(self) -> None: # TODO is this example going to cause problems with case sensitivity? @@ -1599,7 +1703,7 @@ def test_isel_siblings(self) -> None: } ) actual = tree.isel(x=-1) - assert_equal(actual, expected) + assert_identical(actual, expected) expected = DataTree.from_dict( { @@ -1608,13 +1712,13 @@ def test_isel_siblings(self) -> None: } ) actual = tree.isel(x=slice(1)) - assert_equal(actual, expected) + assert_identical(actual, expected) actual = tree.isel(x=[0]) - assert_equal(actual, expected) + assert_identical(actual, expected) actual = tree.isel(x=slice(None)) - assert_equal(actual, tree) + assert_identical(actual, tree) def test_isel_inherited(self) -> None: tree = DataTree.from_dict( @@ -1631,7 +1735,7 @@ def test_isel_inherited(self) -> None: } ) actual = tree.isel(x=-1) - assert_equal(actual, expected) + assert_identical(actual, expected) expected = DataTree.from_dict( { @@ -1639,7 +1743,7 @@ def test_isel_inherited(self) -> None: } ) actual = tree.isel(x=-1, drop=True) - assert_equal(actual, expected) + assert_identical(actual, expected) expected = DataTree.from_dict( { @@ -1648,7 +1752,7 @@ def test_isel_inherited(self) -> None: } ) actual = tree.isel(x=[0]) - assert_equal(actual, expected) + assert_identical(actual, expected) actual = tree.isel(x=slice(None)) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 766df76a259..1334468b54d 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -264,7 +264,7 @@ def times_ten(ds): expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] result_tree = times_ten(subtree) - assert_equal(result_tree, expected, from_root=False) + assert_equal(result_tree, expected) def test_skip_empty_nodes_with_attrs(self, create_test_datatree): # inspired by xarray-datatree GH262