From 05ae2904602c3196efb6ac96c00fc8688946e10d Mon Sep 17 00:00:00 2001 From: "Ulrich J. Herter" Date: Thu, 22 Aug 2019 22:40:46 +0200 Subject: [PATCH 1/7] Make argmin/max work lazy with dask (#3237). --- xarray/core/nanops.py | 25 ++++--------------------- xarray/core/nputils.py | 2 ++ 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 9ba4eae29ae..784a1d0109c 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -91,17 +91,9 @@ def nanargmin(a, axis=None): fill_value = dtypes.get_pos_infinity(a.dtype) if a.dtype.kind == "O": return _nan_argminmax_object("argmin", fill_value, a, axis=axis) - a, mask = _replace_nan(a, fill_value) - if isinstance(a, dask_array_type): - res = dask_array.argmin(a, axis=axis) - else: - res = np.argmin(a, axis=axis) - if mask is not None: - mask = mask.all(axis=axis) - if mask.any(): - raise ValueError("All-NaN slice encountered") - return res + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanargmin(a, axis=axis) def nanargmax(a, axis=None): @@ -109,17 +101,8 @@ def nanargmax(a, axis=None): if a.dtype.kind == "O": return _nan_argminmax_object("argmax", fill_value, a, axis=axis) - a, mask = _replace_nan(a, fill_value) - if isinstance(a, dask_array_type): - res = dask_array.argmax(a, axis=axis) - else: - res = np.argmax(a, axis=axis) - - if mask is not None: - mask = mask.all(axis=axis) - if mask.any(): - raise ValueError("All-NaN slice encountered") - return res + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanargmax(a, axis=axis) def nansum(a, axis=None, dtype=None, out=None, min_count=None): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index a9971e7125a..ee21accd349 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -236,3 +236,5 @@ def f(values, axis=None, **kwargs): nanprod = _create_bottleneck_method("nanprod") nancumsum = _create_bottleneck_method("nancumsum") nancumprod = _create_bottleneck_method("nancumprod") +nanargmin = _create_bottleneck_method("nanargmin") +nanargmax = _create_bottleneck_method("nanargmax") From a4c3622c9ee1edcfef3cb7e9170522150556d948 Mon Sep 17 00:00:00 2001 From: "Ulrich J. Herter" Date: Fri, 23 Aug 2019 13:05:13 +0200 Subject: [PATCH 2/7] dask: Testing number of computes on reduce methods. --- xarray/tests/test_dask.py | 57 ++++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e3fc6f65e0f..83bca6c06ba 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -27,14 +27,43 @@ dd = pytest.importorskip("dask.dataframe") +class CountingScheduler: + """ Simple dask scheduler counting the number of computes. + + Reference: https://stackoverflow.com/questions/53289286/ """ + + def __init__(self, max_computes=0): + self.total_computes = 0 + self.max_computes = max_computes + + def __call__(self, dsk, keys, **kwargs): + self.total_computes += 1 + if self.total_computes > self.max_computes: + raise RuntimeError( + "To many computes. Total: %d > max: %d." + % (self.total_computes, self.max_computes) + ) + return dask.get(dsk, keys, **kwargs) + + +def _set_dask_scheduler(scheduler): + if LooseVersion(dask.__version__) >= LooseVersion("0.18.0"): + return dask.config.set(scheduler=scheduler) + return dask.set_options(get=scheduler) + + +def test_counting_scheduler(): + data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2)) + sched = CountingScheduler(0) + with raises_regex(RuntimeError, "To many computes"): + with _set_dask_scheduler(sched): + data.compute() + assert sched.total_computes == 1 + + class DaskTestCase: def assertLazyAnd(self, expected, actual, test): - - with ( - dask.config.set(scheduler="single-threaded") - if LooseVersion(dask.__version__) >= LooseVersion("0.18.0") - else dask.set_options(get=dask.get) - ): + with _set_dask_scheduler(CountingScheduler(1)): test(actual, expected) if isinstance(actual, Dataset): @@ -172,13 +201,15 @@ def test_pickle(self): def test_reduce(self): u = self.eager_var v = self.lazy_var - self.assertLazyAndAllClose(u.mean(), v.mean()) - self.assertLazyAndAllClose(u.std(), v.std()) - self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x")) - self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) - self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) - with raises_regex(NotImplementedError, "dask"): - v.median() + with _set_dask_scheduler(CountingScheduler(0)): + # None of the methods should trigger compute at this stage. + self.assertLazyAndAllClose(u.mean(), v.mean()) + self.assertLazyAndAllClose(u.std(), v.std()) + self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x")) + self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) + self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) + with raises_regex(NotImplementedError, "dask"): + v.median() def test_missing_values(self): values = np.array([0, 1, np.nan, 3]) From c77145c88dc4ab1e410f4fc4392b5948ed7e19d8 Mon Sep 17 00:00:00 2001 From: "Ulrich J. Herter" Date: Fri, 23 Aug 2019 13:12:11 +0200 Subject: [PATCH 3/7] what's new updated --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 705c54b2d30..37bfe62f7cb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -106,6 +106,8 @@ Bug fixes - Fix error that arises when using open_mfdataset on a series of netcdf files having differing values for a variable attribute of type list. (:issue:`3034`) By `Hasan Ahmad `_. +- :py:meth:`~xarray.DataArray.argmax` and :py:meth:`~xarray.DataArray.argmin` did cause + dask to compute (:issue:`3237`). By `Ulrich Herter `_. .. _whats-new.0.12.3: From 1b4bc3a177a50a5013a3265b8754e3a35f4ce17f Mon Sep 17 00:00:00 2001 From: ulijh Date: Tue, 27 Aug 2019 13:48:00 +0200 Subject: [PATCH 4/7] Fix typo Co-Authored-By: Stephan Hoyer --- xarray/tests/test_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 83bca6c06ba..d9dda4bb518 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -40,7 +40,7 @@ def __call__(self, dsk, keys, **kwargs): self.total_computes += 1 if self.total_computes > self.max_computes: raise RuntimeError( - "To many computes. Total: %d > max: %d." + "Too many computes. Total: %d > max: %d." % (self.total_computes, self.max_computes) ) return dask.get(dsk, keys, **kwargs) From abe66556c0c30ccc635a43b22b4890fb4743dde2 Mon Sep 17 00:00:00 2001 From: ulijh Date: Tue, 27 Aug 2019 13:50:44 +0200 Subject: [PATCH 5/7] Be more explicit. Co-Authored-By: Stephan Hoyer --- xarray/tests/test_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index d9dda4bb518..c301bb0f242 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -54,7 +54,7 @@ def _set_dask_scheduler(scheduler): def test_counting_scheduler(): data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2)) - sched = CountingScheduler(0) + sched = CountingScheduler(max_computes=0) with raises_regex(RuntimeError, "To many computes"): with _set_dask_scheduler(sched): data.compute() From c0d74622ee295b9f844741aada1b3ff571a00deb Mon Sep 17 00:00:00 2001 From: "Ulrich J. Herter" Date: Tue, 27 Aug 2019 15:58:32 +0200 Subject: [PATCH 6/7] More explicit raise_if_dask_computes --- xarray/tests/test_dask.py | 41 ++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c301bb0f242..d105765481e 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -46,24 +46,30 @@ def __call__(self, dsk, keys, **kwargs): return dask.get(dsk, keys, **kwargs) -def _set_dask_scheduler(scheduler): +def _set_dask_scheduler(scheduler=dask.get): + """ Backwards compatible way of setting scheduler. """ if LooseVersion(dask.__version__) >= LooseVersion("0.18.0"): return dask.config.set(scheduler=scheduler) return dask.set_options(get=scheduler) -def test_counting_scheduler(): +def raise_if_dask_computes(max_computes=0): + scheduler = CountingScheduler(max_computes) + return _set_dask_scheduler(scheduler) + + +def test_raise_if_dask_computes(): data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2)) - sched = CountingScheduler(max_computes=0) - with raises_regex(RuntimeError, "To many computes"): - with _set_dask_scheduler(sched): + with raises_regex(RuntimeError, "Too many computes"): + with raise_if_dask_computes(): data.compute() - assert sched.total_computes == 1 class DaskTestCase: def assertLazyAnd(self, expected, actual, test): - with _set_dask_scheduler(CountingScheduler(1)): + with _set_dask_scheduler(dask.get): + # dask.get is the syncronous scheduler, which get's set also by + # dask.config.set(scheduler="syncronous") in current versions. test(actual, expected) if isinstance(actual, Dataset): @@ -201,15 +207,18 @@ def test_pickle(self): def test_reduce(self): u = self.eager_var v = self.lazy_var - with _set_dask_scheduler(CountingScheduler(0)): - # None of the methods should trigger compute at this stage. - self.assertLazyAndAllClose(u.mean(), v.mean()) - self.assertLazyAndAllClose(u.std(), v.std()) - self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x")) - self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) - self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) - with raises_regex(NotImplementedError, "dask"): - v.median() + self.assertLazyAndAllClose(u.mean(), v.mean()) + self.assertLazyAndAllClose(u.std(), v.std()) + with raise_if_dask_computes(): + actual = v.argmax(dim="x") + self.assertLazyAndAllClose(u.argmax(dim="x"), actual) + with raise_if_dask_computes(): + actual = v.argmin(dim="x") + self.assertLazyAndAllClose(u.argmin(dim="x"), actual) + self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) + self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) + with raises_regex(NotImplementedError, "dask"): + v.median() def test_missing_values(self): values = np.array([0, 1, np.nan, 3]) From faa7eab73f2b1189550132fa06c422e3c6b74661 Mon Sep 17 00:00:00 2001 From: "Ulrich J. Herter" Date: Tue, 27 Aug 2019 16:13:35 +0200 Subject: [PATCH 7/7] nanargmin/max: only set fill_value when needed --- xarray/core/nanops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 784a1d0109c..17240faf007 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -88,8 +88,8 @@ def nanmax(a, axis=None, out=None): def nanargmin(a, axis=None): - fill_value = dtypes.get_pos_infinity(a.dtype) if a.dtype.kind == "O": + fill_value = dtypes.get_pos_infinity(a.dtype) return _nan_argminmax_object("argmin", fill_value, a, axis=axis) module = dask_array if isinstance(a, dask_array_type) else nputils @@ -97,8 +97,8 @@ def nanargmin(a, axis=None): def nanargmax(a, axis=None): - fill_value = dtypes.get_neg_infinity(a.dtype) if a.dtype.kind == "O": + fill_value = dtypes.get_neg_infinity(a.dtype) return _nan_argminmax_object("argmax", fill_value, a, axis=axis) module = dask_array if isinstance(a, dask_array_type) else nputils