From 9acb3b74550e805eaaf3b04752f7ed5ed84cea2c Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 8 Mar 2020 01:51:12 +0100 Subject: [PATCH 01/18] allow passing a callable as compat to diff_{dataset,array}_repr --- xarray/core/formatting.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 89246ff228d..4f8141ce6ec 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -525,7 +525,10 @@ def extra_items_repr(extra_keys, mapping, ab_side): for k in a_keys & b_keys: try: # compare xarray variable - compatible = getattr(a_mapping[k], compat)(b_mapping[k]) + if not callable(compat): + compatible = getattr(a_mapping[k], compat)(b_mapping[k]) + else: + compatible = compat(a_mapping[k], b_mapping[k]) is_variable = True except AttributeError: # compare attribute value @@ -582,8 +585,13 @@ def extra_items_repr(extra_keys, mapping, ab_side): def _compat_to_str(compat): + if callable(compat): + compat = compat.__name__ + if compat == "equals": return "equal" + elif compat == "allclose": + return "close" else: return compat @@ -597,8 +605,12 @@ def diff_array_repr(a, b, compat): ] summary.append(diff_dim_summary(a, b)) + if callable(compat): + equiv = compat + else: + equiv = array_equiv - if not array_equiv(a.data, b.data): + if not equiv(a.data, b.data): temp = [wrap_indent(short_numpy_repr(obj), start=" ") for obj in (a, b)] diff_data_repr = [ ab_side + "\n" + ab_data_repr From 660edd38d33f56b29cbf90b54f7e7f0d6e25d194 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 8 Mar 2020 01:54:29 +0100 Subject: [PATCH 02/18] rewrite assert_allclose to provide a failure summary --- xarray/testing.py | 43 ++++++++++++++++++++---------------- xarray/tests/test_testing.py | 24 ++++++++++++++++++++ 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/xarray/testing.py b/xarray/testing.py index ac189f7e023..2f4fd2abb53 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -1,10 +1,11 @@ """Testing functions exposed to the user API""" +import functools from typing import Hashable, Set, Union import numpy as np import pandas as pd -from xarray.core import duck_array_ops, formatting +from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.indexes import default_indexes @@ -121,29 +122,33 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): -------- assert_identical, assert_equal, numpy.testing.assert_allclose """ + # todo: + # - one assert statement per type using utils.dict_equiv with "_data_allclose_or_equiv" + # - add the possibility to pass callables as compat to diff_{array,dataset}_repr __tracebackhide__ = True assert type(a) == type(b) - kwargs = dict(rtol=rtol, atol=atol, decode_bytes=decode_bytes) + + equiv = functools.partial( + _data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes + ) + equiv.__name__ = "allclose" + + def compat_variable(a, b): + return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) + if isinstance(a, Variable): - assert a.dims == b.dims - allclose = _data_allclose_or_equiv(a.values, b.values, **kwargs) - assert allclose, f"{a.values}\n{b.values}" + allclose = compat_variable(a, b) + assert allclose, formatting.diff_array_repr(a, b, compat=equiv) elif isinstance(a, DataArray): - assert_allclose(a.variable, b.variable, **kwargs) - assert set(a.coords) == set(b.coords) - for v in a.coords.variables: - # can't recurse with this function as coord is sometimes a - # DataArray, so call into _data_allclose_or_equiv directly - allclose = _data_allclose_or_equiv( - a.coords[v].values, b.coords[v].values, **kwargs - ) - assert allclose, "{}\n{}".format(a.coords[v].values, b.coords[v].values) + allclose = utils.dict_equiv( + a.coords, b.coords, compat=compat_variable + ) and compat_variable(a.variable, b.variable) + assert allclose, formatting.diff_array_repr(a, b, compat=equiv) elif isinstance(a, Dataset): - assert set(a.data_vars) == set(b.data_vars) - assert set(a.coords) == set(b.coords) - for k in list(a.variables) + list(a.coords): - assert_allclose(a[k], b[k], **kwargs) - + allclose = a._coord_names == b._coord_names and utils.dict_equiv( + a.variables, b.variables, compat=compat_variable + ) + assert allclose, formatting.diff_dataset_repr(a, b, compat=equiv) else: raise TypeError("{} not supported by assertion comparison".format(type(a))) diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index 041b7341ade..27f917deb86 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -1,3 +1,5 @@ +import pytest + import xarray as xr @@ -5,3 +7,25 @@ def test_allclose_regression(): x = xr.DataArray(1.01) y = xr.DataArray(1.02) xr.testing.assert_allclose(x, y, atol=0.01) + + +@pytest.mark.parametrize( + "obj1,obj2", + ( + pytest.param( + xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable", + ), + pytest.param( + xr.DataArray([1e-17, 2], dims="x"), + xr.DataArray([0, 3], dims="x"), + id="DataArray", + ), + pytest.param( + xr.Dataset({"a": ("x", [1e-17, 2]), "b": ("y", [-2e-18, 2])}), + xr.Dataset({"a": ("x", [0, 2]), "b": ("y", [0, 1])}), + id="Dataset", + ), + ), +) +def test_assert_allclose(obj1, obj2): + xr.testing.assert_allclose(obj1, obj2) From fd1ca50ff0774e3511955cadf7ca53e635450266 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 8 Mar 2020 13:41:58 +0100 Subject: [PATCH 03/18] make sure we're comparing variables --- xarray/testing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/testing.py b/xarray/testing.py index 2f4fd2abb53..636c5ac1651 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -134,6 +134,9 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): equiv.__name__ = "allclose" def compat_variable(a, b): + a = getattr(a, "variable", a) + b = getattr(b, "variable", b) + return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) if isinstance(a, Variable): From 94f20e7743f69e91d625c474cc08c2964f4b8b61 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 8 Mar 2020 13:45:04 +0100 Subject: [PATCH 04/18] remove spurious comments --- xarray/testing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xarray/testing.py b/xarray/testing.py index 636c5ac1651..4897b958113 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -122,9 +122,6 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): -------- assert_identical, assert_equal, numpy.testing.assert_allclose """ - # todo: - # - one assert statement per type using utils.dict_equiv with "_data_allclose_or_equiv" - # - add the possibility to pass callables as compat to diff_{array,dataset}_repr __tracebackhide__ = True assert type(a) == type(b) From 0d5d00f5c29bde2249f31a0163ac61951252a0f8 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 8 Mar 2020 22:50:01 +0100 Subject: [PATCH 05/18] override test_aggregate_complex with a test compatible with pint --- xarray/tests/test_units.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 9f63ebb1d42..0744375b70f 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1404,6 +1404,14 @@ def test_aggregation(self, func, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + def test_aggregate_complex(self): + variable = xr.Variable("x", [1, 2j, np.nan] * unit_registry.m) + expected = xr.Variable((), (0.5 + 1j) * unit_registry.m) + actual = variable.mean() + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) + @pytest.mark.parametrize( "func", ( From 983a54545f11e82339ea8f4f527c4f1bded9f579 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 9 Mar 2020 01:20:34 +0100 Subject: [PATCH 06/18] expect the asserts to raise --- xarray/tests/test_testing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index 27f917deb86..f4961af58e9 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -28,4 +28,5 @@ def test_allclose_regression(): ), ) def test_assert_allclose(obj1, obj2): - xr.testing.assert_allclose(obj1, obj2) + with pytest.raises(AssertionError): + xr.testing.assert_allclose(obj1, obj2) From a724df70625ac66e42e10394c62c8b481b6249f0 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 9 Mar 2020 12:32:06 +0100 Subject: [PATCH 07/18] xfail the tests failing due to isclose not accepting non-quantity tolerances --- xarray/tests/test_units.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 0744375b70f..43c5fc67d9e 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1367,6 +1367,13 @@ def example_1d_objects(self): ]: yield (self.cls("x", data), data) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + ) + def test_real_and_imag(self): + super().test_real_and_imag() + @pytest.mark.parametrize( "func", ( @@ -1404,6 +1411,10 @@ def test_aggregation(self, func, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + ) def test_aggregate_complex(self): variable = xr.Variable("x", [1, 2j, np.nan] * unit_registry.m) expected = xr.Variable((), (0.5 + 1j) * unit_registry.m) @@ -1412,6 +1423,10 @@ def test_aggregate_complex(self): assert_units_equal(expected, actual) xr.testing.assert_allclose(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + ) @pytest.mark.parametrize( "func", ( From ccd20fe0a997822e547170ca8fb4c2008cc8c87b Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 9 Mar 2020 13:19:57 +0100 Subject: [PATCH 08/18] mark top-level function tests as xfailing if they use assert_allclose --- xarray/tests/test_units.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 17e6c7b5f02..111a81ae95d 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -425,6 +425,10 @@ def test_apply_ufunc_dataset(dtype): assert_identical(expected, actual) +# TODO: remove once pint==0.12 has been released +@pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" +) @pytest.mark.parametrize( "unit,error", ( @@ -512,6 +516,10 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype): assert_allclose(expected_b, actual_b) +# TODO: remove once pint==0.12 has been released +@pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" +) @pytest.mark.parametrize( "unit,error", ( @@ -929,6 +937,10 @@ def test_concat_dataset(variant, unit, error, dtype): assert_identical(expected, actual) +# TODO: remove once pint==0.12 has been released +@pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" +) @pytest.mark.parametrize( "unit,error", ( @@ -1036,6 +1048,10 @@ def test_merge_dataarray(variant, unit, error, dtype): assert_allclose(expected, actual) +# TODO: remove once pint==0.12 has been released +@pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" +) @pytest.mark.parametrize( "unit,error", ( From 29a6a1f0033c4236d1dcfdbce1aec04d3691773b Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 9 Mar 2020 13:24:56 +0100 Subject: [PATCH 09/18] mark test_1d_math as runnable but xfail it --- xarray/tests/test_units.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 111a81ae95d..2fee2df174c 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1401,7 +1401,6 @@ def wrapper(cls): "test_datetime64_conversion", "test_timedelta64_conversion", "test_pandas_period_index", - "test_1d_math", "test_1d_reduce", "test_array_interface", "test___array__", @@ -1787,6 +1786,10 @@ def test_isel(self, indices, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + ) @pytest.mark.parametrize( "unit,error", ( From 2c52f86ffe3bb1dda6d384101c60f4a5403831a8 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 19 Mar 2020 19:07:24 +0100 Subject: [PATCH 10/18] bump dask and distributed --- ci/requirements/py36-min-all-deps.yml | 4 ++-- ci/requirements/py36-min-nep18.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/requirements/py36-min-all-deps.yml b/ci/requirements/py36-min-all-deps.yml index 86540197dcc..a72cd000680 100644 --- a/ci/requirements/py36-min-all-deps.yml +++ b/ci/requirements/py36-min-all-deps.yml @@ -15,8 +15,8 @@ dependencies: - cfgrib=0.9 - cftime=1.0 - coveralls - - dask=2.2 - - distributed=2.2 + - dask=2.5 + - distributed=2.5 - flake8 - h5netcdf=0.7 - h5py=2.9 # Policy allows for 2.10, but it's a conflict-fest diff --git a/ci/requirements/py36-min-nep18.yml b/ci/requirements/py36-min-nep18.yml index a5eded49cd4..a2245e89b41 100644 --- a/ci/requirements/py36-min-nep18.yml +++ b/ci/requirements/py36-min-nep18.yml @@ -6,8 +6,8 @@ dependencies: # require drastically newer packages than everything else - python=3.6 - coveralls - - dask=2.4 - - distributed=2.4 + - dask=2.5 + - distributed=2.5 - msgpack-python=0.6 # remove once distributed is bumped. distributed GH3491 - numpy=1.17 - pandas=0.25 From 1932a52000e809ea90dffbeb4d69259ca6d07787 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 19 Mar 2020 19:11:06 +0100 Subject: [PATCH 11/18] entry to whats-new.rst --- doc/whats-new.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8140288f350..f14979ea22a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -53,7 +53,8 @@ New Features :py:meth:`core.groupby.DatasetGroupBy.quantile`, :py:meth:`core.groupby.DataArrayGroupBy.quantile` (:issue:`3843`, :pull:`3844`) By `Aaron Spring `_. - +- Add a diff summary for `testing.assert_allclose`. (:issue:`3617`, :pull:`3847`) + By `Justus Magin `_. Bug fixes ~~~~~~~~~ From f3b9b13c5ec22feec4df27b7987410148c055bb7 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 26 Mar 2020 23:59:42 +0100 Subject: [PATCH 12/18] attempt to fix the failing py36-min-all-deps and py36-min-nep18 CI --- xarray/testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/testing.py b/xarray/testing.py index 4897b958113..e1284690c47 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -134,7 +134,7 @@ def compat_variable(a, b): a = getattr(a, "variable", a) b = getattr(b, "variable", b) - return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) + return a.dims == b.dims and (a._data is b._data or equiv(b.data, a.data)) if isinstance(a, Variable): allclose = compat_variable(a, b) From 0286c0278c18c86ff566fb0a0382b1ab0ba2dc60 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 13 May 2020 17:08:44 +0200 Subject: [PATCH 13/18] conditionally xfail tests using assert_allclose with pint < 0.12 --- xarray/tests/test_units.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 33cd77a5b28..fca1db4fa7f 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -2266,6 +2266,10 @@ def test_repr(self, func, variant, dtype): # warnings or errors, but does not check the result func(data_array) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose", + ) @pytest.mark.parametrize( "func", ( @@ -2277,7 +2281,7 @@ def test_repr(self, func, variant, dtype): function("mean"), pytest.param( function("median"), - marks=pytest.mark.xfail( + marks=pytest.mark.skip( reason="median does not work with dataarrays yet" ), ), From c5ce18d3ee705c24af7ef6708b3c335c26281144 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 13 May 2020 19:31:46 +0200 Subject: [PATCH 14/18] xfail more tests depending on which pint version is used --- xarray/tests/test_units.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index fca1db4fa7f..6f4f9f768d9 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -3329,6 +3329,10 @@ def test_head_tail_thin(self, func, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + ) @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( "func", @@ -3402,6 +3406,10 @@ def test_interp_reindex_indexing(self, func, unit, error, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + ) @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( "func", @@ -3604,6 +3612,10 @@ def test_computation(self, func, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + ) @pytest.mark.parametrize( "func", ( From 1d6bfe84b21e995ce4fb2ddac4025a5f9fdf488d Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 13 May 2020 20:46:33 +0200 Subject: [PATCH 15/18] try using numpy.testing.assert_allclose instead --- xarray/testing.py | 2 +- xarray/tests/test_duck_array_ops.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/testing.py b/xarray/testing.py index e1284690c47..4897b958113 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -134,7 +134,7 @@ def compat_variable(a, b): a = getattr(a, "variable", a) b = getattr(b, "variable", b) - return a.dims == b.dims and (a._data is b._data or equiv(b.data, a.data)) + return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) if isinstance(a, Variable): allclose = compat_variable(a, b) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index e61881cfce3..feedcd27164 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -384,7 +384,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): actual = getattr(da, func)(skipna=skipna, dim=aggdim) assert_dask_array(actual, dask) - assert np.allclose( + np.testing.assert_allclose( actual.values, np.array(expected), rtol=1.0e-4, equal_nan=True ) except (TypeError, AttributeError, ZeroDivisionError): From 22f1fc3e5a1a27618e6217694188ec49d7f9b704 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 27 May 2020 19:38:40 +0200 Subject: [PATCH 16/18] try computing if the dask version is too old and dask.array[bool] --- xarray/core/duck_array_ops.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 1340b456cf2..ebc851db447 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -6,6 +6,7 @@ import contextlib import inspect import warnings +from distutils.version import LooseVersion from functools import partial import numpy as np @@ -199,6 +200,17 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): """ arr1 = asarray(arr1) arr2 = asarray(arr2) + # TODO: remove after we require dask > 2.9.1 + sufficient_dask_version = ( + dask_array is not None and LooseVersion(dask_array.__version__) >= "2.9.1" + ) + if sufficient_dask_version and any(arr.dtype.kind == "b" for arr in [arr1, arr2]): + if isinstance(arr1, dask_array_type): + arr1 = arr1.compute() + + if isinstance(arr2, dask_array_type): + arr2 = arr2.compute() + lazy_equiv = lazy_array_equiv(arr1, arr2) if lazy_equiv is None: return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) From 137a4eef16c94a47a72b8389b4fee6b614830da3 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 27 May 2020 20:01:13 +0200 Subject: [PATCH 17/18] fix the dask version checking --- xarray/core/duck_array_ops.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index ebc851db447..a1c53b5f5b3 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -21,6 +21,14 @@ except ImportError: dask_array = None # type: ignore +# TODO: remove after we stop supporting dask < 2.9.1 +try: + import dask + + dask_version = dask.__version__ +except ImportError: + dask_version = None + def _dask_or_eager_func( name, @@ -200,19 +208,22 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): """ arr1 = asarray(arr1) arr2 = asarray(arr2) - # TODO: remove after we require dask > 2.9.1 - sufficient_dask_version = ( - dask_array is not None and LooseVersion(dask_array.__version__) >= "2.9.1" - ) - if sufficient_dask_version and any(arr.dtype.kind == "b" for arr in [arr1, arr2]): - if isinstance(arr1, dask_array_type): - arr1 = arr1.compute() - - if isinstance(arr2, dask_array_type): - arr2 = arr2.compute() lazy_equiv = lazy_array_equiv(arr1, arr2) if lazy_equiv is None: + # TODO: remove after we require dask >= 2.9.1 + sufficient_dask_version = ( + dask_version is not None and LooseVersion(dask_version) >= "2.9.1" + ) + if sufficient_dask_version and any( + arr.dtype.kind == "b" for arr in [arr1, arr2] + ): + if isinstance(arr1, dask_array_type): + arr1 = arr1.compute() + + if isinstance(arr2, dask_array_type): + arr2 = arr2.compute() + return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) else: return lazy_equiv From 1b268516ae29fd3bc9f017a3b5340b95bcf1eb3e Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 27 May 2020 20:39:58 +0200 Subject: [PATCH 18/18] convert all dask arrays to numpy when using a insufficient dask version --- xarray/core/duck_array_ops.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index a1c53b5f5b3..76719699168 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -215,14 +215,11 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): sufficient_dask_version = ( dask_version is not None and LooseVersion(dask_version) >= "2.9.1" ) - if sufficient_dask_version and any( - arr.dtype.kind == "b" for arr in [arr1, arr2] + if not sufficient_dask_version and any( + isinstance(arr, dask_array_type) for arr in [arr1, arr2] ): - if isinstance(arr1, dask_array_type): - arr1 = arr1.compute() - - if isinstance(arr2, dask_array_type): - arr2 = arr2.compute() + arr1 = np.array(arr1) + arr2 = np.array(arr2) return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) else: