Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ Bug fixes
- Fix error that arises when using open_mfdataset on a series of netcdf files
having differing values for a variable attribute of type list. (:issue:`3034`)
By `Hasan Ahmad <https://github.com/HasanAhmadQ7>`_.
- :py:meth:`~xarray.DataArray.argmax` and :py:meth:`~xarray.DataArray.argmin` did cause
dask to compute (:issue:`3237`). By `Ulrich Herter <https://github.com/ulijh>`_.

.. _whats-new.0.12.3:

Expand Down
25 changes: 4 additions & 21 deletions xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,35 +91,18 @@ def nanargmin(a, axis=None):
fill_value = dtypes.get_pos_infinity(a.dtype)
if a.dtype.kind == "O":
return _nan_argminmax_object("argmin", fill_value, a, axis=axis)
a, mask = _replace_nan(a, fill_value)
if isinstance(a, dask_array_type):
res = dask_array.argmin(a, axis=axis)
else:
res = np.argmin(a, axis=axis)

if mask is not None:
mask = mask.all(axis=axis)
if mask.any():
raise ValueError("All-NaN slice encountered")
return res
module = dask_array if isinstance(a, dask_array_type) else nputils
return module.nanargmin(a, axis=axis)


def nanargmax(a, axis=None):
fill_value = dtypes.get_neg_infinity(a.dtype)
if a.dtype.kind == "O":
return _nan_argminmax_object("argmax", fill_value, a, axis=axis)

a, mask = _replace_nan(a, fill_value)
if isinstance(a, dask_array_type):
res = dask_array.argmax(a, axis=axis)
else:
res = np.argmax(a, axis=axis)

if mask is not None:
mask = mask.all(axis=axis)
if mask.any():
raise ValueError("All-NaN slice encountered")
return res
module = dask_array if isinstance(a, dask_array_type) else nputils
return module.nanargmax(a, axis=axis)


def nansum(a, axis=None, dtype=None, out=None, min_count=None):
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,5 @@ def f(values, axis=None, **kwargs):
nanprod = _create_bottleneck_method("nanprod")
nancumsum = _create_bottleneck_method("nancumsum")
nancumprod = _create_bottleneck_method("nancumprod")
nanargmin = _create_bottleneck_method("nanargmin")
nanargmax = _create_bottleneck_method("nanargmax")
57 changes: 44 additions & 13 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,43 @@
dd = pytest.importorskip("dask.dataframe")


class CountingScheduler:
""" Simple dask scheduler counting the number of computes.

Reference: https://stackoverflow.com/questions/53289286/ """

def __init__(self, max_computes=0):
self.total_computes = 0
self.max_computes = max_computes

def __call__(self, dsk, keys, **kwargs):
self.total_computes += 1
if self.total_computes > self.max_computes:
raise RuntimeError(
"To many computes. Total: %d > max: %d."
% (self.total_computes, self.max_computes)
)
return dask.get(dsk, keys, **kwargs)


def _set_dask_scheduler(scheduler):
if LooseVersion(dask.__version__) >= LooseVersion("0.18.0"):
return dask.config.set(scheduler=scheduler)
return dask.set_options(get=scheduler)


def test_counting_scheduler():
data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2))
sched = CountingScheduler(0)
with raises_regex(RuntimeError, "To many computes"):
with _set_dask_scheduler(sched):
data.compute()
assert sched.total_computes == 1


class DaskTestCase:
def assertLazyAnd(self, expected, actual, test):

with (
dask.config.set(scheduler="single-threaded")
if LooseVersion(dask.__version__) >= LooseVersion("0.18.0")
else dask.set_options(get=dask.get)
):
with _set_dask_scheduler(CountingScheduler(1)):
test(actual, expected)

if isinstance(actual, Dataset):
Expand Down Expand Up @@ -172,13 +201,15 @@ def test_pickle(self):
def test_reduce(self):
u = self.eager_var
v = self.lazy_var
self.assertLazyAndAllClose(u.mean(), v.mean())
self.assertLazyAndAllClose(u.std(), v.std())
self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x"))
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
with raises_regex(NotImplementedError, "dask"):
v.median()
with _set_dask_scheduler(CountingScheduler(0)):
# None of the methods should trigger compute at this stage.
self.assertLazyAndAllClose(u.mean(), v.mean())
self.assertLazyAndAllClose(u.std(), v.std())
self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x"))
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
with raises_regex(NotImplementedError, "dask"):
v.median()

def test_missing_values(self):
values = np.array([0, 1, np.nan, 3])
Expand Down