Skip to content

Commit c955449

Browse files
dcherianmax-sixty
authored andcommitted
Drop groups associated with nans in group variable (#3406)
* Drop nans in grouped variable. * Add NaTs * whats-new * fix merge. * fix whats-new * fix test
1 parent 02288b4 commit c955449

File tree

3 files changed

+83
-11
lines changed

3 files changed

+83
-11
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,14 @@ Bug fixes
5555
~~~~~~~~~
5656
- Fix regression introduced in v0.14.0 that would cause a crash if dask is installed
5757
but cloudpickle isn't (:issue:`3401`) by `Rhys Doyle <https://github.com/rdoyle45>`_
58-
59-
- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
58+
- Fix grouping over variables with NaNs. (:issue:`2383`, :pull:`3406`).
59+
By `Deepak Cherian <https://github.com/dcherian>`_.
60+
- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
6061
By `Anderson Banihirwe <https://github.com/andersy005>`_.
61-
6262
- Fix :py:meth:`xarray.core.groupby.DataArrayGroupBy.reduce` and
6363
:py:meth:`xarray.core.groupby.DatasetGroupBy.reduce` when reducing over multiple dimensions.
6464
(:issue:`3402`). By `Deepak Cherian <https://github.com/dcherian/>`_
6565

66-
6766
Documentation
6867
~~~~~~~~~~~~~
6968

xarray/core/groupby.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,13 @@ def __init__(
361361
group_indices = [slice(i, i + 1) for i in group_indices]
362362
unique_coord = group
363363
else:
364+
if group.isnull().any():
365+
# drop any NaN valued groups.
366+
# also drop obj values where group was NaN
367+
# Use where instead of reindex to account for duplicate coordinate labels.
368+
obj = obj.where(group.notnull(), drop=True)
369+
group = group.dropna(group_dim)
370+
364371
# look through group to find the unique values
365372
unique_values, group_indices = unique_value_groups(
366373
safe_cast_to_index(group), sort=(bins is None)

xarray/tests/test_groupby.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import xarray as xr
66
from xarray.core.groupby import _consolidate_slices
77

8-
from . import assert_allclose, assert_identical, raises_regex
8+
from . import assert_allclose, assert_equal, assert_identical, raises_regex
99

1010

1111
@pytest.fixture
@@ -48,14 +48,14 @@ def test_groupby_dims_property(dataset):
4848
def test_multi_index_groupby_apply(dataset):
4949
# regression test for GH873
5050
ds = dataset.isel(z=1, drop=True)[["foo"]]
51-
doubled = 2 * ds
52-
group_doubled = (
51+
expected = 2 * ds
52+
actual = (
5353
ds.stack(space=["x", "y"])
5454
.groupby("space")
5555
.apply(lambda x: 2 * x)
5656
.unstack("space")
5757
)
58-
assert doubled.equals(group_doubled)
58+
assert_equal(expected, actual)
5959

6060

6161
def test_multi_index_groupby_sum():
@@ -66,7 +66,7 @@ def test_multi_index_groupby_sum():
6666
)
6767
expected = ds.sum("z")
6868
actual = ds.stack(space=["x", "y"]).groupby("space").sum("z").unstack("space")
69-
assert expected.equals(actual)
69+
assert_equal(expected, actual)
7070

7171

7272
def test_groupby_da_datetime():
@@ -86,15 +86,15 @@ def test_groupby_da_datetime():
8686
expected = xr.DataArray(
8787
[3, 7], coords=dict(reference_date=reference_dates), dims="reference_date"
8888
)
89-
assert actual.equals(expected)
89+
assert_equal(expected, actual)
9090

9191

9292
def test_groupby_duplicate_coordinate_labels():
9393
# fix for http://stackoverflow.com/questions/38065129
9494
array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])])
9595
expected = xr.DataArray([3, 3], [("x", [1, 2])])
9696
actual = array.groupby("x").sum()
97-
assert expected.equals(actual)
97+
assert_equal(expected, actual)
9898

9999

100100
def test_groupby_input_mutation():
@@ -263,6 +263,72 @@ def test_groupby_repr_datetime(obj):
263263
assert actual == expected
264264

265265

266+
def test_groupby_drops_nans():
267+
# GH2383
268+
# nan in 2D data variable (requires stacking)
269+
ds = xr.Dataset(
270+
{
271+
"variable": (("lat", "lon", "time"), np.arange(60.0).reshape((4, 3, 5))),
272+
"id": (("lat", "lon"), np.arange(12.0).reshape((4, 3))),
273+
},
274+
coords={"lat": np.arange(4), "lon": np.arange(3), "time": np.arange(5)},
275+
)
276+
277+
ds["id"].values[0, 0] = np.nan
278+
ds["id"].values[3, 0] = np.nan
279+
ds["id"].values[-1, -1] = np.nan
280+
281+
grouped = ds.groupby(ds.id)
282+
283+
# non reduction operation
284+
expected = ds.copy()
285+
expected.variable.values[0, 0, :] = np.nan
286+
expected.variable.values[-1, -1, :] = np.nan
287+
expected.variable.values[3, 0, :] = np.nan
288+
actual = grouped.apply(lambda x: x).transpose(*ds.variable.dims)
289+
assert_identical(actual, expected)
290+
291+
# reduction along grouped dimension
292+
actual = grouped.mean()
293+
stacked = ds.stack({"xy": ["lat", "lon"]})
294+
expected = (
295+
stacked.variable.where(stacked.id.notnull()).rename({"xy": "id"}).to_dataset()
296+
)
297+
expected["id"] = stacked.id.values
298+
assert_identical(actual, expected.dropna("id").transpose(*actual.dims))
299+
300+
# reduction operation along a different dimension
301+
actual = grouped.mean("time")
302+
expected = ds.mean("time").where(ds.id.notnull())
303+
assert_identical(actual, expected)
304+
305+
# NaN in non-dimensional coordinate
306+
array = xr.DataArray([1, 2, 3], [("x", [1, 2, 3])])
307+
array["x1"] = ("x", [1, 1, np.nan])
308+
expected = xr.DataArray(3, [("x1", [1])])
309+
actual = array.groupby("x1").sum()
310+
assert_equal(expected, actual)
311+
312+
# NaT in non-dimensional coordinate
313+
array["t"] = (
314+
"x",
315+
[
316+
np.datetime64("2001-01-01"),
317+
np.datetime64("2001-01-01"),
318+
np.datetime64("NaT"),
319+
],
320+
)
321+
expected = xr.DataArray(3, [("t", [np.datetime64("2001-01-01")])])
322+
actual = array.groupby("t").sum()
323+
assert_equal(expected, actual)
324+
325+
# test for repeated coordinate labels
326+
array = xr.DataArray([0, 1, 2, 4, 3, 4], [("x", [np.nan, 1, 1, np.nan, 2, np.nan])])
327+
expected = xr.DataArray([3, 3], [("x", [1, 2])])
328+
actual = array.groupby("x").sum()
329+
assert_equal(expected, actual)
330+
331+
266332
def test_groupby_grouping_errors():
267333
dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]})
268334
with raises_regex(ValueError, "None of the data falls within bins with edges"):

0 commit comments

Comments
 (0)