diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 22cb03fd11c..e6f0fe83858 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2907,32 +2907,40 @@ def from_iris(cls, cube: "iris_Cube") -> "DataArray": return from_iris(cube) - def _all_compat(self, other: "DataArray", compat_str: str) -> bool: + def _all_compat( + self, other: "DataArray", compat_str: str, check_dtype: bool = False + ) -> bool: """Helper function for equals, broadcast_equals, and identical""" def compat(x, y): - return getattr(x.variable, compat_str)(y.variable) + return getattr(x.variable, compat_str)(y.variable, check_dtype=check_dtype) return utils.dict_equiv(self.coords, other.coords, compat=compat) and compat( self, other ) - def broadcast_equals(self, other: "DataArray") -> bool: + def broadcast_equals(self, other: "DataArray", check_dtype: bool = False) -> bool: """Two DataArrays are broadcast equal if they are equal after broadcasting them against each other such that they have the same dimensions. + Parameters + ---------- + check_dtype : bool, default: False + Whether to check if the objects' dtypes are identical. Compares the + dtypes of the data and the coords. + See Also -------- DataArray.equals DataArray.identical """ try: - return self._all_compat(other, "broadcast_equals") + return self._all_compat(other, "broadcast_equals", check_dtype=check_dtype) except (TypeError, AttributeError): return False - def equals(self, other: "DataArray") -> bool: + def equals(self, other: "DataArray", check_dtype: bool = False) -> bool: """True if two DataArrays have the same dimensions, coordinates and values; otherwise False. @@ -2942,27 +2950,41 @@ def equals(self, other: "DataArray") -> bool: This method is necessary because `v1 == v2` for ``DataArray`` does element-wise comparisons (like numpy.ndarrays). + Parameters + ---------- + check_dtype : bool, default: False + Whether to check if the objects' dtypes are identical. Compares the + dtypes of the data and the coords. + See Also -------- DataArray.broadcast_equals DataArray.identical """ try: - return self._all_compat(other, "equals") + return self._all_compat(other, "equals", check_dtype=check_dtype) except (TypeError, AttributeError): return False - def identical(self, other: "DataArray") -> bool: + def identical(self, other: "DataArray", check_dtype: bool = False) -> bool: """Like equals, but also checks the array name and attributes, and attributes on all coordinates. + Parameters + ---------- + check_dtype : bool, default: False + Whether to check if the objects' dtypes are identical. Compares the + dtypes of the data and the coords. + See Also -------- DataArray.broadcast_equals DataArray.equals """ try: - return self.name == other.name and self._all_compat(other, "identical") + return self.name == other.name and self._all_compat( + other, "identical", check_dtype=check_dtype + ) except (TypeError, AttributeError): return False diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f59b9b6bea5..2807caa4b6b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1537,19 +1537,19 @@ def __delitem__(self, key: Hashable) -> None: # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore[assignment] - def _all_compat(self, other: "Dataset", compat_str: str) -> bool: + def _all_compat(self, other: "Dataset", compat_str: str, check_dtype: bool) -> bool: """Helper function for equals and identical""" # some stores (e.g., scipy) do not seem to preserve order, so don't # require matching order for equality def compat(x: Variable, y: Variable) -> bool: - return getattr(x, compat_str)(y) + return getattr(x, compat_str)(y, check_dtype=check_dtype) return self._coord_names == other._coord_names and utils.dict_equiv( self._variables, other._variables, compat=compat ) - def broadcast_equals(self, other: "Dataset") -> bool: + def broadcast_equals(self, other: "Dataset", check_dtype: bool = False) -> bool: """Two Datasets are broadcast equal if they are equal after broadcasting all variables against each other. @@ -1557,17 +1557,23 @@ def broadcast_equals(self, other: "Dataset") -> bool: the other dataset can still be broadcast equal if the the non-scalar variable is a constant. + Parameters + ---------- + check_dtype : bool, default: False + Whether to check if the objects' dtypes are identical. Compares the + dtypes of all data variables and coords. + See Also -------- Dataset.equals Dataset.identical """ try: - return self._all_compat(other, "broadcast_equals") + return self._all_compat(other, "broadcast_equals", check_dtype=check_dtype) except (TypeError, AttributeError): return False - def equals(self, other: "Dataset") -> bool: + def equals(self, other: "Dataset", check_dtype: bool = False) -> bool: """Two Datasets are equal if they have matching variables and coordinates, all of which are equal. @@ -1577,20 +1583,32 @@ def equals(self, other: "Dataset") -> bool: This method is necessary because `v1 == v2` for ``Dataset`` does element-wise comparisons (like numpy.ndarrays). + Parameters + ---------- + check_dtype : bool, default: False + Whether to check if the objects' dtypes are identical. Compares the + dtypes of all data variables and coords. + See Also -------- Dataset.broadcast_equals Dataset.identical """ try: - return self._all_compat(other, "equals") + return self._all_compat(other, "equals", check_dtype=check_dtype) except (TypeError, AttributeError): return False - def identical(self, other: "Dataset") -> bool: + def identical(self, other: "Dataset", check_dtype: bool = False) -> bool: """Like equals, but also checks all dataset attributes and the attributes on all variables and coordinates. + Parameters + ---------- + check_dtype : bool, default: False + Whether to check if the objects' dtypes are identical. Compares the + dtypes of all data variables and coords. + See Also -------- Dataset.broadcast_equals @@ -1598,7 +1616,7 @@ def identical(self, other: "Dataset") -> bool: """ try: return utils.dict_equiv(self.attrs, other.attrs) and self._all_compat( - other, "identical" + other, "identical", check_dtype=check_dtype ) except (TypeError, AttributeError): return False diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index e32fd4be376..bdeac6c883b 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -218,10 +218,11 @@ def as_shared_dtype(scalars_or_arrays): return [x.astype(out_type, copy=False) for x in arrays] -def lazy_array_equiv(arr1, arr2): +def lazy_array_equiv(arr1, arr2, check_dtype=False): """Like array_equal, but doesn't actually compare values. Returns True when arr1, arr2 identical or their dask tokens are equal. Returns False when shapes are not equal. + Returns False if dtype does not match and check_dtype is True. Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays; or their dask tokens are not equal """ @@ -231,6 +232,9 @@ def lazy_array_equiv(arr1, arr2): arr2 = asarray(arr2) if arr1.shape != arr2.shape: return False + # "is False" needed -> should not return on None + if check_dtype and same_dtype(arr1, arr2, lazy=True) is False: + return False if dask_array and is_duck_dask_array(arr1) and is_duck_dask_array(arr2): # GH3068, GH4221 if tokenize(arr1) == tokenize(arr2): @@ -240,13 +244,32 @@ def lazy_array_equiv(arr1, arr2): return None -def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): +def same_dtype(arr1, arr2, lazy): + + # object dask arrays can change dtype -> need to compute them + if arr1.dtype == object and is_duck_dask_array(arr1): + if lazy: + return None + # arr.compute() can return a scalar -> wrap in an array + arr1 = asarray(arr1.compute()) + + if arr2.dtype == object and is_duck_dask_array(arr2): + if lazy: + return None + arr2 = asarray(arr2.compute()) + + return arr1.dtype == arr2.dtype + + +def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8, check_dtype=False): """Like np.allclose, but also allows values to be NaN in both arrays""" arr1 = asarray(arr1) arr2 = asarray(arr2) - lazy_equiv = lazy_array_equiv(arr1, arr2) + lazy_equiv = lazy_array_equiv(arr1, arr2, check_dtype=check_dtype) if lazy_equiv is None: + if check_dtype and not same_dtype(arr1, arr2, lazy=False): + return False with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) @@ -254,12 +277,14 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): return lazy_equiv -def array_equiv(arr1, arr2): +def array_equiv(arr1, arr2, check_dtype=False): """Like np.array_equal, but also allows values to be NaN in both arrays""" arr1 = asarray(arr1) arr2 = asarray(arr2) - lazy_equiv = lazy_array_equiv(arr1, arr2) + lazy_equiv = lazy_array_equiv(arr1, arr2, check_dtype=check_dtype) if lazy_equiv is None: + if check_dtype and not same_dtype(arr1, arr2, lazy=False): + return False with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) @@ -268,14 +293,16 @@ def array_equiv(arr1, arr2): return lazy_equiv -def array_notnull_equiv(arr1, arr2): +def array_notnull_equiv(arr1, arr2, check_dtype=False): """Like np.array_equal, but also allows values to be NaN in either or both arrays """ arr1 = asarray(arr1) arr2 = asarray(arr2) - lazy_equiv = lazy_array_equiv(arr1, arr2) + lazy_equiv = lazy_array_equiv(arr1, arr2, check_dtype=check_dtype) if lazy_equiv is None: + if check_dtype and not same_dtype(arr1, arr2, lazy=False): + return False with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index c3e7187abed..8960c03dc73 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -566,7 +566,9 @@ def diff_dim_summary(a, b): return "" -def _diff_mapping_repr(a_mapping, b_mapping, compat, title, summarizer, col_width=None): +def _diff_mapping_repr( + a_mapping, b_mapping, compat, title, summarizer, col_width=None, check_dtype=False +): def extra_items_repr(extra_keys, mapping, ab_side): extra_repr = [summarizer(k, mapping[k], col_width) for k in extra_keys] if extra_repr: @@ -586,9 +588,15 @@ def extra_items_repr(extra_keys, mapping, ab_side): try: # compare xarray variable if not callable(compat): - compatible = getattr(a_mapping[k], compat)(b_mapping[k]) + compatible = getattr(a_mapping[k], compat)( + b_mapping[k], check_dtype=check_dtype + ) else: - compatible = compat(a_mapping[k], b_mapping[k]) + compatible = compat( + a_mapping[k], + b_mapping[k], + check_dtype=check_dtype, + ) is_variable = True except AttributeError: # compare attribute value @@ -620,8 +628,9 @@ def extra_items_repr(extra_keys, mapping, ab_side): diff_items += [ab_side + s[1:] for ab_side, s in zip(("L", "R"), temp)] + maybe_dtype = " (values and/ or dtype)" if check_dtype else "" if diff_items: - summary += [f"Differing {title.lower()}:"] + diff_items + summary += [f"Differing {title.lower()}{maybe_dtype}:"] + diff_items summary += extra_items_repr(a_keys - b_keys, a_mapping, "left") summary += extra_items_repr(b_keys - a_keys, b_mapping, "right") @@ -656,7 +665,7 @@ def _compat_to_str(compat): return compat -def diff_array_repr(a, b, compat): +def diff_array_repr(a, b, compat, check_dtype=False): # used for DataArray, Variable and IndexVariable summary = [ "Left and right {} objects are not {}".format( @@ -670,18 +679,22 @@ def diff_array_repr(a, b, compat): else: equiv = array_equiv - if not equiv(a.data, b.data): + maybe_dtype = " or dtype" if check_dtype else "" + + if not equiv(a.data, b.data, check_dtype=check_dtype): temp = [wrap_indent(short_numpy_repr(obj), start=" ") for obj in (a, b)] diff_data_repr = [ ab_side + "\n" + ab_data_repr for ab_side, ab_data_repr in zip(("L", "R"), temp) ] - summary += ["Differing values:"] + diff_data_repr + summary += [f"Differing values{maybe_dtype}:"] + diff_data_repr if hasattr(a, "coords"): col_width = _calculate_col_width(set(a.coords) | set(b.coords)) summary.append( - diff_coords_repr(a.coords, b.coords, compat, col_width=col_width) + diff_coords_repr( + a.coords, b.coords, compat, col_width=col_width, check_dtype=check_dtype + ) ) if compat == "identical": @@ -690,7 +703,7 @@ def diff_array_repr(a, b, compat): return "\n".join(summary) -def diff_dataset_repr(a, b, compat): +def diff_dataset_repr(a, b, compat, check_dtype=False): summary = [ "Left and right {} objects are not {}".format( type(a).__name__, _compat_to_str(compat) @@ -702,9 +715,19 @@ def diff_dataset_repr(a, b, compat): ) summary.append(diff_dim_summary(a, b)) - summary.append(diff_coords_repr(a.coords, b.coords, compat, col_width=col_width)) summary.append( - diff_data_vars_repr(a.data_vars, b.data_vars, compat, col_width=col_width) + diff_coords_repr( + a.coords, b.coords, compat, col_width=col_width, check_dtype=check_dtype + ) + ) + summary.append( + diff_data_vars_repr( + a.data_vars, + b.data_vars, + compat, + col_width=col_width, + check_dtype=check_dtype, + ) ) if compat == "identical": diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1abaad49cdf..6640aa3275a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1842,7 +1842,7 @@ def concat( return cls(dims, data, attrs, encoding) - def equals(self, other, equiv=duck_array_ops.array_equiv): + def equals(self, other, equiv=duck_array_ops.array_equiv, check_dtype=False): """True if two Variables have the same dimensions and values; otherwise False. @@ -1855,12 +1855,15 @@ def equals(self, other, equiv=duck_array_ops.array_equiv): other = getattr(other, "variable", other) try: return self.dims == other.dims and ( - self._data is other._data or equiv(self.data, other.data) + self._data is other._data + or equiv(self.data, other.data, check_dtype=check_dtype) ) except (TypeError, AttributeError): return False - def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv): + def broadcast_equals( + self, other, equiv=duck_array_ops.array_equiv, check_dtype=False + ): """True if two Variables have the values after being broadcast against each other; otherwise False. @@ -1871,25 +1874,27 @@ def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv): self, other = broadcast_variables(self, other) except (ValueError, AttributeError): return False - return self.equals(other, equiv=equiv) + return self.equals(other, equiv=equiv, check_dtype=check_dtype) - def identical(self, other, equiv=duck_array_ops.array_equiv): + def identical(self, other, equiv=duck_array_ops.array_equiv, check_dtype=False): """Like equals, but also checks attributes.""" try: return utils.dict_equiv(self.attrs, other.attrs) and self.equals( - other, equiv=equiv + other, equiv=equiv, check_dtype=check_dtype ) except (TypeError, AttributeError): return False - def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv): + def no_conflicts( + self, other, equiv=duck_array_ops.array_notnull_equiv, check_dtype=False + ): """True if the intersection of two Variable's non-null data is equal; otherwise false. Variables can thus still be equal if there are locations where either, or both, contain NaN values. """ - return self.broadcast_equals(other, equiv=equiv) + return self.broadcast_equals(other, equiv=equiv, check_dtype=check_dtype) def quantile( self, q, dim=None, interpolation="linear", keep_attrs=None, skipna=True @@ -2674,20 +2679,26 @@ def copy(self, deep=True, data=None): ) return self._replace(data=data) - def equals(self, other, equiv=None): + def equals(self, other, equiv=None, check_dtype=False): # if equiv is specified, super up if equiv is not None: - return super().equals(other, equiv) + return super().equals(other, equiv, check_dtype=check_dtype) # otherwise use the native index equals, rather than looking at _data other = getattr(other, "variable", other) try: - return self.dims == other.dims and self._data_equals(other) + return self.dims == other.dims and self._data_equals( + other, check_dtype=check_dtype + ) except (TypeError, AttributeError): return False - def _data_equals(self, other): - return self.to_index().equals(other.to_index()) + def _data_equals(self, other, check_dtype=False): + + if check_dtype and self.dtype != other.dtype: + return False + + return bool(self.to_index().equals(other.to_index())) def to_index_variable(self): """Return this variable as an xarray.IndexVariable""" diff --git a/xarray/testing.py b/xarray/testing.py index 40ca12852b9..769af0592d9 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -42,19 +42,23 @@ def _decode_string_data(data): return data -def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=True): +def _data_allclose_or_equiv( + arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=True, check_dtype=False +): if any(arr.dtype.kind == "S" for arr in [arr1, arr2]) and decode_bytes: arr1 = _decode_string_data(arr1) arr2 = _decode_string_data(arr2) exact_dtypes = ["M", "m", "O", "S", "U"] if any(arr.dtype.kind in exact_dtypes for arr in [arr1, arr2]): - return duck_array_ops.array_equiv(arr1, arr2) + return duck_array_ops.array_equiv(arr1, arr2, check_dtype=check_dtype) else: - return duck_array_ops.allclose_or_equiv(arr1, arr2, rtol=rtol, atol=atol) + return duck_array_ops.allclose_or_equiv( + arr1, arr2, rtol=rtol, atol=atol, check_dtype=check_dtype + ) @ensure_warnings -def assert_equal(a, b): +def assert_equal(a, b, check_dtype=False): """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray objects. @@ -69,6 +73,9 @@ def assert_equal(a, b): The first object to compare. b : xarray.Dataset, xarray.DataArray or xarray.Variable The second object to compare. + check_dtype : bool, default: False + Whether to check if the objects' dtypes are identical. Compares the + dtypes of all data variables and coords. See Also -------- @@ -78,15 +85,19 @@ def assert_equal(a, b): __tracebackhide__ = True assert type(a) == type(b) if isinstance(a, (Variable, DataArray)): - assert a.equals(b), formatting.diff_array_repr(a, b, "equals") + assert a.equals(b, check_dtype=check_dtype), formatting.diff_array_repr( + a, b, "equals", check_dtype=check_dtype + ) elif isinstance(a, Dataset): - assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals") + assert a.equals(b, check_dtype=check_dtype), formatting.diff_dataset_repr( + a, b, "equals", check_dtype=check_dtype + ) else: raise TypeError("{} not supported by assertion comparison".format(type(a))) @ensure_warnings -def assert_identical(a, b): +def assert_identical(a, b, check_dtype=False): """Like :py:func:`xarray.testing.assert_equal`, but also matches the objects' names and attributes. @@ -98,6 +109,9 @@ def assert_identical(a, b): The first object to compare. b : xarray.Dataset, xarray.DataArray or xarray.Variable The second object to compare. + check_dtype : bool, default: False + Whether to check if the objects' dtypes are identical. Compares the + dtypes of all data variables and coords. See Also -------- @@ -106,18 +120,24 @@ def assert_identical(a, b): __tracebackhide__ = True assert type(a) == type(b) if isinstance(a, Variable): - assert a.identical(b), formatting.diff_array_repr(a, b, "identical") + assert a.identical(b, check_dtype=check_dtype), formatting.diff_array_repr( + a, b, "identical", check_dtype=check_dtype + ) elif isinstance(a, DataArray): assert a.name == b.name - assert a.identical(b), formatting.diff_array_repr(a, b, "identical") - elif isinstance(a, (Dataset, Variable)): - assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical") + assert a.identical(b, check_dtype=check_dtype), formatting.diff_array_repr( + a, b, "identical", check_dtype=check_dtype + ) + elif isinstance(a, Dataset): + assert a.identical(b, check_dtype=check_dtype), formatting.diff_dataset_repr( + a, b, "identical", check_dtype=check_dtype + ) else: raise TypeError("{} not supported by assertion comparison".format(type(a))) @ensure_warnings -def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): +def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True, check_dtype=False): """Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects. Raises an AssertionError if two objects are not equal up to desired @@ -137,6 +157,9 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): Whether byte dtypes should be decoded to strings as UTF-8 or not. This is useful for testing serialization methods on Python 3 that return saved strings as bytes. + check_dtype : bool, default: False + Whether to check if the objects' dtypes are identical. Compares the + dtypes of all data variables and coords. See Also -------- @@ -146,7 +169,11 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): assert type(a) == type(b) equiv = functools.partial( - _data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes + _data_allclose_or_equiv, + rtol=rtol, + atol=atol, + decode_bytes=decode_bytes, + check_dtype=check_dtype, ) equiv.__name__ = "allclose" @@ -158,17 +185,23 @@ def compat_variable(a, b): if isinstance(a, Variable): allclose = compat_variable(a, b) - assert allclose, formatting.diff_array_repr(a, b, compat=equiv) + assert allclose, formatting.diff_array_repr( + a, b, compat=equiv, check_dtype=check_dtype + ) elif isinstance(a, DataArray): allclose = utils.dict_equiv( a.coords, b.coords, compat=compat_variable ) and compat_variable(a.variable, b.variable) - assert allclose, formatting.diff_array_repr(a, b, compat=equiv) + assert allclose, formatting.diff_array_repr( + a, b, compat=equiv, check_dtype=check_dtype + ) elif isinstance(a, Dataset): allclose = a._coord_names == b._coord_names and utils.dict_equiv( a.variables, b.variables, compat=compat_variable ) - assert allclose, formatting.diff_dataset_repr(a, b, compat=equiv) + assert allclose, formatting.diff_dataset_repr( + a, b, compat=equiv, check_dtype=check_dtype + ) else: raise TypeError("{} not supported by assertion comparison".format(type(a))) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 2797db3cf8b..56548d24498 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -181,22 +181,22 @@ def source_ndarray(array): # invariants -def assert_equal(a, b): +def assert_equal(a, b, check_dtype=False): __tracebackhide__ = True - xarray.testing.assert_equal(a, b) + xarray.testing.assert_equal(a, b, check_dtype=check_dtype) xarray.testing._assert_internal_invariants(a) xarray.testing._assert_internal_invariants(b) -def assert_identical(a, b): +def assert_identical(a, b, check_dtype=False): __tracebackhide__ = True - xarray.testing.assert_identical(a, b) + xarray.testing.assert_identical(a, b, check_dtype=check_dtype) xarray.testing._assert_internal_invariants(a) xarray.testing._assert_internal_invariants(b) -def assert_allclose(a, b, **kwargs): +def assert_allclose(a, b, check_dtype=False, **kwargs): __tracebackhide__ = True - xarray.testing.assert_allclose(a, b, **kwargs) + xarray.testing.assert_allclose(a, b, check_dtype=check_dtype, **kwargs) xarray.testing._assert_internal_invariants(a) xarray.testing._assert_internal_invariants(b) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8012fad18d0..b54777db318 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -494,6 +494,43 @@ def test_broadcast_equals(self): assert not a.broadcast_equals(c) assert not c.broadcast_equals(a) + @pytest.mark.parametrize( + "dtype1, dtype2, expected", + [ + [float, float, True], + [float, int, False], + [object, int, False], + ], + ) + def test_equals_check_dtype(self, dtype1, dtype2, expected): + data1 = np.array([1], dtype=dtype1) + data2 = np.array([1], dtype=dtype2) + + da1 = DataArray(data1) + da2 = DataArray(data2) + + assert da1.equals(da2, check_dtype=True) is expected + assert da1.identical(da2, check_dtype=True) is expected + assert da1.broadcast_equals(da2, check_dtype=True) is expected + + # check the default (check_dtype=False) + assert da1.equals(da2) + assert da1.identical(da2) + assert da1.broadcast_equals(da2) + + # it also checks the dtype of the coords + da1 = DataArray([1], dims="x", coords={"x": data1}) + da2 = DataArray([1], dims="x", coords={"x": data2}) + + assert da1.equals(da2, check_dtype=True) is expected + assert da1.identical(da2, check_dtype=True) is expected + assert da1.broadcast_equals(da2, check_dtype=True) is expected + + # check the default (check_dtype=False) + assert da1.equals(da2) + assert da1.identical(da2) + assert da1.broadcast_equals(da2) + def test_getitem(self): # strings pull out dataarrays assert_identical(self.dv, self.ds["foo"]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b8e1cd4b03b..9e65b595214 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -925,6 +925,43 @@ def test_broadcast_equals(self): assert not data1.equals(data2) assert not data1.identical(data2) + @pytest.mark.parametrize( + "dtype1, dtype2, expected", + [ + [float, float, True], + [float, int, False], + [object, int, False], + ], + ) + def test_equals_check_dtype(self, dtype1, dtype2, expected): + data1 = np.array([1], dtype=dtype1) + data2 = np.array([1], dtype=dtype2) + + ds1 = Dataset({"data": ("x", data1)}) + ds2 = Dataset({"data": ("x", data2)}) + + assert ds1.equals(ds2, check_dtype=True) is expected + assert ds1.identical(ds2, check_dtype=True) is expected + assert ds1.broadcast_equals(ds2, check_dtype=True) is expected + + # check the default (check_dtype=False) + assert ds1.equals(ds2) + assert ds1.identical(ds2) + assert ds1.broadcast_equals(ds2) + + # ensure the check also fails if the coords have different dtype + ds1 = Dataset(coords={"x": data1}) + ds2 = Dataset(coords={"x": data2}) + + assert ds1.equals(ds2, check_dtype=True) is expected + assert ds1.identical(ds2, check_dtype=True) is expected + assert ds1.broadcast_equals(ds2, check_dtype=True) is expected + + # check the default (check_dtype=False) + assert ds1.equals(ds2) + assert ds1.identical(ds2) + assert ds1.broadcast_equals(ds2) + def test_attrs(self): data = create_test_data(seed=42) data.attrs = {"foobar": "baz"} diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 0eb007259bb..a60be8c3e53 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -10,18 +10,22 @@ from xarray import DataArray, Dataset, cftime_range, concat from xarray.core import dtypes, duck_array_ops from xarray.core.duck_array_ops import ( + allclose_or_equiv, + array_equiv, array_notnull_equiv, concatenate, count, first, gradient, last, + lazy_array_equiv, least_squares, mean, np_timedelta64_to_float, pd_timedelta_to_float, push, py_timedelta_to_float, + same_dtype, stack, timedelta_to_numeric, where, @@ -216,6 +220,114 @@ def test_types(self, val1, val2, val3, null): assert array_notnull_equiv(arr1, arr2) +@pytest.mark.parametrize( + "dtype1, dtype2, expected_lazy, expected_eager", + [ + [bool, bool, True, True], + [int, int, True, True], + [float, float, True, True], + [str, str, True, True], + [bytes, bytes, True, True], + [object, object, None, True], + [bool, int, False, False], + ["U1", "U2", False, False], + ["S1", "S2", False, False], + [float, int, False, False], + [object, int, None, False], + [int, object, None, False], + [object, float, None, False], + [float, object, None, False], + ], +) +def test_same_dtype(dtype1, dtype2, expected_lazy, expected_eager): + + ar1 = np.array(1, dtype=dtype1) + ar2 = np.array(1, dtype=dtype2) + + # lazy makes no difference when not passing a numpy array + assert same_dtype(ar1, ar2, lazy=True) is expected_eager + assert same_dtype(ar1, ar2, lazy=False) is expected_eager + + if has_dask: + + import dask.array as da + + ar1 = da.array(1, dtype=dtype1) + ar2 = da.array(1, dtype=dtype2) + + # lazy differs for dask arrays with object dtype + # because these can change dtype on compute + with raise_if_dask_computes(): + assert same_dtype(ar1, ar2, lazy=True) is expected_lazy + + assert same_dtype(ar1, ar2, lazy=False) is expected_eager + + +@requires_dask +@pytest.mark.parametrize( + "dtype_numpy, dtype_dask, expected", + [ + [float, float, None], + [object, object, None], + [float, int, False], + [object, int, False], + [int, object, None], + ], +) +def test_lazy_array_equiv_dask(dtype_numpy, dtype_dask, expected): + + import dask.array as da + + # one needs to be a numpy array, else "tokenize" says it's the same array + arr_numpy = np.array(1, dtype=dtype_numpy) + + arr_dask = da.array(1, dtype=dtype_dask) + + with raise_if_dask_computes(): + assert lazy_array_equiv(arr_numpy, arr_dask, check_dtype=True) is expected + assert lazy_array_equiv(arr_dask, arr_numpy, check_dtype=True) is expected + + +@pytest.mark.parametrize( + "dtype1, dtype2, expected", + [ + [float, float, True], + [float, int, False], + [object, int, False], + ], +) +def test_equiv_check_dtype(dtype1, dtype2, expected): + + ar1 = np.array(1, dtype=dtype1) + ar2 = np.array(1, dtype=dtype2) + + assert allclose_or_equiv(ar1, ar2, check_dtype=True) is expected + assert array_equiv(ar1, ar2, check_dtype=True) is expected + assert array_notnull_equiv(ar1, ar2, check_dtype=True) is expected + + # np.allclose does not work for object array + if dtype1 is not object: + assert allclose_or_equiv(ar1, ar2) + assert array_equiv(ar1, ar2) + assert array_notnull_equiv(ar1, ar2) + + if has_dask: + import dask.array as da + + ar1 = da.array(1, dtype=dtype1) + ar2 = da.array(1, dtype=dtype2) + + assert allclose_or_equiv(ar1, ar2, check_dtype=True) is expected + assert array_equiv(ar1, ar2, check_dtype=True) is expected + assert array_notnull_equiv(ar1, ar2, check_dtype=True) is expected + + # np.allclose does not work for object array + if dtype1 is not object: + assert allclose_or_equiv(ar1, ar2) + assert array_equiv(ar1, ar2) + assert array_notnull_equiv(ar1, ar2) + + def construct_dataarray(dim_num, dtype, contains_nan, dask): # dimnum <= 3 rng = np.random.RandomState(0) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index e823a67f2db..4384a6eb8b7 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -290,6 +290,63 @@ def test_diff_array_repr(self): except AssertionError: assert actual == expected.replace(", dtype=int64", "") + # test with check_dtype=True + + da_a_dtype = xr.DataArray( + np.array([1], int), dims="x", coords=dict(x=np.array([1], int)) + ) + da_b_dtype = xr.DataArray( + np.array([1], float), dims="x", coords=dict(x=np.array([1], float)) + ) + + expected = dedent( + """\ + Left and right DataArray objects are not identical + + Differing values or dtype: + L + array([1], dtype=int64) + R + array([1.], dtype=float64) + Differing coordinates (values and/ or dtype): + L * x (x) int64 1 + R * x (x) float64 1.0 + """ + ) + + actual = formatting.diff_array_repr( + da_a_dtype, da_b_dtype, "identical", check_dtype=True + ) + + try: + assert actual == expected + except AssertionError: + expected = expected.replace(", dtype=int64", "") + expected = expected.replace(", dtype=float64", "") + assert actual == expected + + expected = dedent( + """\ + Left and right Variable objects are not identical + + Differing values or dtype: + L + array([1]) + R + array([1.]) + """ + ) + + actual = formatting.diff_array_repr( + da_a_dtype.variable, da_b_dtype.variable, "identical", check_dtype=True + ) + try: + assert actual == expected + except AssertionError: + expected = expected.replace(", dtype=int64", "") + expected = expected.replace(", dtype=float64", "") + assert actual == expected + @pytest.mark.filterwarnings("error") def test_diff_attrs_repr_with_array(self): attrs_a = {"attr": np.array([0, 1])} @@ -380,6 +437,33 @@ def test_diff_dataset_repr(self): actual = formatting.diff_dataset_repr(ds_a, ds_b, "identical") assert actual == expected + ds_a_dtype = xr.Dataset( + data_vars={"a": ("x", np.array([1], int))}, + coords=dict(x=np.array([1], int)), + ) + ds_b_dtype = xr.Dataset( + data_vars={"a": ("x", np.array([1], float))}, + coords=dict(x=np.array([1], float)), + ) + + expected = dedent( + """\ + Left and right Dataset objects are not identical + + Differing coordinates (values and/ or dtype): + L * x (x) int64 1 + R * x (x) float64 1.0 + Differing data variables (values and/ or dtype): + L a (x) int64 1 + R a (x) float64 1.0 + """ + ) + + actual = formatting.diff_dataset_repr( + ds_a_dtype, ds_b_dtype, "identical", check_dtype=True + ) + assert actual == expected + def test_array_repr(self): ds = xr.Dataset(coords={"foo": [1, 2, 3], "bar": [1, 2, 3]}) ds[(1, 2)] = xr.DataArray([0], dims="test") diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 1e0dff45dd2..06bd624f320 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -420,6 +420,32 @@ def test_equals_all_dtypes(self): assert v[:2].identical(v2[:2]) assert v[:2].no_conflicts(v2[:2]) + @pytest.mark.parametrize( + "dtype1, dtype2, expected", + [ + [float, float, True], + [float, int, False], + ], + ) + def test_equals_check_dtype(self, dtype1, dtype2, expected): + + data1 = np.array([1], dtype=dtype1) + data2 = np.array([1], dtype=dtype2) + + v1 = self.cls(["data"], data1) + v2 = self.cls(["data"], data2) + + assert v1.equals(v2, check_dtype=True) is expected + assert v1.identical(v2, check_dtype=True) is expected + assert v1.broadcast_equals(v2, check_dtype=True) is expected + assert v1.no_conflicts(v2, check_dtype=True) is expected + + # test the default (check_dtype=False) + assert v1.equals(v2) + assert v1.identical(v2) + assert v1.broadcast_equals(v2) + assert v1.no_conflicts(v2) + def test_eq_all_dtypes(self): # ensure that we don't choke on comparisons for which numpy returns # scalars