diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 277c32b1016..d2da3f40938 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -48,6 +48,10 @@ Breaking changes :ref:`weather-climate` (:pull:`2844`, :issue:`3689`) - remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull:`4725`). By `Aureliana Barghini `_. +- As a result of :pull:`4911` the output from calling :py:meth:`DataArray.sum` + or :py:meth:`DataArray.prod` on an integer array with ``skipna=True`` and a + non-None value for ``min_count`` will now be a float array rather than an + integer array. Deprecations ~~~~~~~~~~~~ @@ -125,6 +129,12 @@ Bug fixes By `Leif Denby `_. - Fix time encoding bug associated with using cftime versions greater than 1.4.0 with xarray (:issue:`4870`, :pull:`4871`). By `Spencer Clark `_. +- Stop :py:meth:`DataArray.sum` and :py:meth:`DataArray.prod` computing lazy + arrays when called with a ``min_count`` parameter (:issue:`4898`, :pull:`4911`). + By `Blair Bonnett `_. +- Fix bug preventing the ``min_count`` parameter to :py:meth:`DataArray.sum` and + :py:meth:`DataArray.prod` working correctly when calculating over all axes of + a float64 array (:issue:`4898`, :pull:`4911`). By `Blair Bonnett `_. - Fix decoding of vlen strings using h5py versions greater than 3.0.0 with h5netcdf backend (:issue:`4570`, :pull:`4893`). By `Kai Mühlbauer `_. diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 167f00fa932..898e7e650b3 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -78,7 +78,7 @@ def maybe_promote(dtype): return np.dtype(dtype), fill_value -NAT_TYPES = (np.datetime64("NaT"), np.timedelta64("NaT")) +NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} def get_fill_value(dtype): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 5eb88bcd096..1cfd66103a2 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -3,7 +3,14 @@ import numpy as np from . import dtypes, nputils, utils -from .duck_array_ops import _dask_or_eager_func, count, fillna, isnull, where_method +from .duck_array_ops import ( + _dask_or_eager_func, + count, + fillna, + isnull, + where, + where_method, +) from .pycompat import dask_array_type try: @@ -28,18 +35,14 @@ def _maybe_null_out(result, axis, mask, min_count=1): """ xarray version of pandas.core.nanops._maybe_null_out """ - if axis is not None and getattr(result, "ndim", False): null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0 - if null_mask.any(): - dtype, fill_value = dtypes.maybe_promote(result.dtype) - result = result.astype(dtype) - result[null_mask] = fill_value + dtype, fill_value = dtypes.maybe_promote(result.dtype) + result = where(null_mask, fill_value, result.astype(dtype)) elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: null_mask = mask.size - mask.sum() - if null_mask < min_count: - result = np.nan + result = where(null_mask < min_count, np.nan, result) return result diff --git a/xarray/core/ops.py b/xarray/core/ops.py index d56b0d59df0..1c899115a5b 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -114,9 +114,12 @@ _MINCOUNT_DOCSTRING = """ min_count : int, default: None - The required number of valid values to perform the operation. - If fewer than min_count non-NA values are present the result will - be NA. New in version 0.10.8: Added with the default being None.""" + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. New in version 0.10.8: Added with the default being + None. Changed in version 0.17.0: if specified on an integer array + and skipna=True, the result will be a float array.""" _COARSEN_REDUCE_DOCSTRING_TEMPLATE = """\ Coarsen this object by applying `{name}` along its dimensions. diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 5ad1a6355e6..53ed2c87133 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -90,3 +90,9 @@ def test_maybe_promote(kind, expected): actual = dtypes.maybe_promote(np.dtype(kind)) assert actual[0] == expected[0] assert str(actual[1]) == expected[1] + + +def test_nat_types_membership(): + assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES + assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES + assert np.float64 not in dtypes.NAT_TYPES diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 1342950f3e5..90e742dee62 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -34,6 +34,7 @@ assert_array_equal, has_dask, has_scipy, + raise_if_dask_computes, raises_regex, requires_cftime, requires_dask, @@ -587,7 +588,10 @@ def test_min_count(dim_num, dtype, dask, func, aggdim, contains_nan, skipna): da = construct_dataarray(dim_num, dtype, contains_nan=contains_nan, dask=dask) min_count = 3 - actual = getattr(da, func)(dim=aggdim, skipna=skipna, min_count=min_count) + # If using Dask, the function call should be lazy. + with raise_if_dask_computes(): + actual = getattr(da, func)(dim=aggdim, skipna=skipna, min_count=min_count) + expected = series_reduce(da, func, skipna=skipna, dim=aggdim, min_count=min_count) assert_allclose(actual, expected) assert_dask_array(actual, dask) @@ -603,7 +607,13 @@ def test_min_count_nd(dtype, dask, func): min_count = 3 dim_num = 3 da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) - actual = getattr(da, func)(dim=["x", "y", "z"], skipna=True, min_count=min_count) + + # If using Dask, the function call should be lazy. + with raise_if_dask_computes(): + actual = getattr(da, func)( + dim=["x", "y", "z"], skipna=True, min_count=min_count + ) + # Supplying all dims is equivalent to supplying `...` or `None` expected = getattr(da, func)(dim=..., skipna=True, min_count=min_count) @@ -611,6 +621,48 @@ def test_min_count_nd(dtype, dask, func): assert_dask_array(actual, dask) +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("func", ["sum", "prod"]) +@pytest.mark.parametrize("dim", [None, "a", "b"]) +def test_min_count_specific(dask, func, dim): + if dask and not has_dask: + pytest.skip("requires dask") + + # Simple array with four non-NaN values. + da = DataArray(np.ones((6, 6), dtype=np.float64) * np.nan, dims=("a", "b")) + da[0][0] = 2 + da[0][3] = 2 + da[3][0] = 2 + da[3][3] = 2 + if dask: + da = da.chunk({"a": 3, "b": 3}) + + # Expected result if we set min_count to the number of non-NaNs in a + # row/column/the entire array. + if dim: + min_count = 2 + expected = DataArray( + [4.0, np.nan, np.nan] * 2, dims=("a" if dim == "b" else "b",) + ) + else: + min_count = 4 + expected = DataArray(8.0 if func == "sum" else 16.0) + + # Check for that min_count. + with raise_if_dask_computes(): + actual = getattr(da, func)(dim, skipna=True, min_count=min_count) + assert_dask_array(actual, dask) + assert_allclose(actual, expected) + + # With min_count being one higher, should get all NaN. + min_count += 1 + expected *= np.nan + with raise_if_dask_computes(): + actual = getattr(da, func)(dim, skipna=True, min_count=min_count) + assert_dask_array(actual, dask) + assert_allclose(actual, expected) + + @pytest.mark.parametrize("func", ["sum", "prod"]) def test_min_count_dataset(func): da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False) @@ -655,9 +707,12 @@ def test_docs(): have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None - The required number of valid values to perform the operation. - If fewer than min_count non-NA values are present the result will - be NA. New in version 0.10.8: Added with the default being None. + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. New in version 0.10.8: Added with the default being + None. Changed in version 0.17.0: if specified on an integer array + and skipna=True, the result will be a float array. keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be