Skip to content

Commit 5bf94a8

Browse files
committed
Drop nans in grouped variable.
1 parent 3f9069b commit 5bf94a8

File tree

2 files changed

+67
-7
lines changed

2 files changed

+67
-7
lines changed

xarray/core/groupby.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ def __init__(
348348
group_indices = [slice(i, i + 1) for i in group_indices]
349349
unique_coord = group
350350
else:
351+
if group.isnull().any():
352+
# drop any NaN valued groups.
353+
# also drop obj values where group was NaN
354+
# Use where instead of reindex to account for duplicate coordinate labels.
355+
obj = obj.where(group.notnull(), drop=True)
356+
group = group.dropna(group_dim)
357+
351358
# look through group to find the unique values
352359
unique_values, group_indices = unique_value_groups(
353360
safe_cast_to_index(group), sort=(bins is None)

xarray/tests/test_groupby.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import xarray as xr
66
from xarray.core.groupby import _consolidate_slices
77

8-
from . import assert_identical, raises_regex
8+
from . import assert_equal, assert_identical, raises_regex
99

1010

1111
def test_consolidate_slices():
@@ -40,14 +40,14 @@ def test_multi_index_groupby_apply():
4040
{"foo": (("x", "y"), np.random.randn(3, 4))},
4141
{"x": ["a", "b", "c"], "y": [1, 2, 3, 4]},
4242
)
43-
doubled = 2 * ds
44-
group_doubled = (
43+
expected = 2 * ds
44+
actual = (
4545
ds.stack(space=["x", "y"])
4646
.groupby("space")
4747
.apply(lambda x: 2 * x)
4848
.unstack("space")
4949
)
50-
assert doubled.equals(group_doubled)
50+
assert_equal(expected, actual)
5151

5252

5353
def test_multi_index_groupby_sum():
@@ -58,7 +58,7 @@ def test_multi_index_groupby_sum():
5858
)
5959
expected = ds.sum("z")
6060
actual = ds.stack(space=["x", "y"]).groupby("space").sum("z").unstack("space")
61-
assert expected.equals(actual)
61+
assert_equal(expected, actual)
6262

6363

6464
def test_groupby_da_datetime():
@@ -78,15 +78,15 @@ def test_groupby_da_datetime():
7878
expected = xr.DataArray(
7979
[3, 7], coords=dict(reference_date=reference_dates), dims="reference_date"
8080
)
81-
assert actual.equals(expected)
81+
assert_equal(expected, actual)
8282

8383

8484
def test_groupby_duplicate_coordinate_labels():
8585
# fix for http://stackoverflow.com/questions/38065129
8686
array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])])
8787
expected = xr.DataArray([3, 3], [("x", [1, 2])])
8888
actual = array.groupby("x").sum()
89-
assert expected.equals(actual)
89+
assert_equal(expected, actual)
9090

9191

9292
def test_groupby_input_mutation():
@@ -255,6 +255,59 @@ def test_groupby_repr_datetime(obj):
255255
assert actual == expected
256256

257257

258+
def test_groupby_drops_nans():
259+
# GH2383
260+
# nan in 2D data variable (requires stacking)
261+
ds = xr.Dataset(
262+
{
263+
"variable": (("lat", "lon", "time"), np.arange(60.0).reshape((4, 3, 5))),
264+
"id": (("lat", "lon"), np.arange(12.0).reshape((4, 3))),
265+
},
266+
coords={"lat": np.arange(4), "lon": np.arange(3), "time": np.arange(5)},
267+
)
268+
269+
ds["id"].values[0, 0] = np.nan
270+
ds["id"].values[3, 0] = np.nan
271+
ds["id"].values[-1, -1] = np.nan
272+
273+
grouped = ds.groupby(ds.id)
274+
275+
# non reduction operation
276+
expected = ds.copy()
277+
expected.variable.values[0, 0, :] = np.nan
278+
expected.variable.values[-1, -1, :] = np.nan
279+
expected.variable.values[3, 0, :] = np.nan
280+
actual = grouped.apply(lambda x: x).transpose(*ds.variable.dims)
281+
assert_identical(actual, expected)
282+
283+
# reduction along grouped dimension
284+
actual = grouped.mean()
285+
stacked = ds.stack({"xy": ["lat", "lon"]})
286+
expected = (
287+
stacked.variable.where(stacked.id.notnull()).rename({"xy": "id"}).to_dataset()
288+
)
289+
expected["id"] = stacked.id.values
290+
assert_identical(actual, expected.dropna("id").transpose(*actual.dims))
291+
292+
# reduction operation along a different dimension
293+
actual = grouped.mean("time")
294+
expected = ds.mean("time").where(ds.id.notnull())
295+
assert_identical(actual, expected)
296+
297+
# NaN in non-dimensional coordinate
298+
array = xr.DataArray([1, 2, 3], [("x", [1, 2, 3])])
299+
array["x1"] = ("x", [1, 1, np.nan])
300+
expected = xr.DataArray(3, [("x1", [1])])
301+
actual = array.groupby("x1").sum()
302+
assert_equal(expected, actual)
303+
304+
# test for repeated coordinate labels
305+
array = xr.DataArray([0, 1, 2, 4, 3, 4], [("x", [np.nan, 1, 1, np.nan, 2, np.nan])])
306+
expected = xr.DataArray([3, 3], [("x", [1, 2])])
307+
actual = array.groupby("x").sum()
308+
assert_equal(expected, actual)
309+
310+
258311
def test_groupby_grouping_errors():
259312
dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]})
260313
with raises_regex(ValueError, "None of the data falls within bins with edges"):

0 commit comments

Comments
 (0)