5
5
import xarray as xr
6
6
from xarray .core .groupby import _consolidate_slices
7
7
8
- from . import assert_allclose , assert_identical , raises_regex
8
+ from . import assert_allclose , assert_equal , assert_identical , raises_regex
9
9
10
10
11
11
@pytest .fixture
@@ -48,14 +48,14 @@ def test_groupby_dims_property(dataset):
48
48
def test_multi_index_groupby_apply (dataset ):
49
49
# regression test for GH873
50
50
ds = dataset .isel (z = 1 , drop = True )[["foo" ]]
51
- doubled = 2 * ds
52
- group_doubled = (
51
+ expected = 2 * ds
52
+ actual = (
53
53
ds .stack (space = ["x" , "y" ])
54
54
.groupby ("space" )
55
55
.apply (lambda x : 2 * x )
56
56
.unstack ("space" )
57
57
)
58
- assert doubled . equals ( group_doubled )
58
+ assert_equal ( expected , actual )
59
59
60
60
61
61
def test_multi_index_groupby_sum ():
@@ -66,7 +66,7 @@ def test_multi_index_groupby_sum():
66
66
)
67
67
expected = ds .sum ("z" )
68
68
actual = ds .stack (space = ["x" , "y" ]).groupby ("space" ).sum ("z" ).unstack ("space" )
69
- assert expected . equals ( actual )
69
+ assert_equal ( expected , actual )
70
70
71
71
72
72
def test_groupby_da_datetime ():
@@ -86,15 +86,15 @@ def test_groupby_da_datetime():
86
86
expected = xr .DataArray (
87
87
[3 , 7 ], coords = dict (reference_date = reference_dates ), dims = "reference_date"
88
88
)
89
- assert actual . equals (expected )
89
+ assert_equal (expected , actual )
90
90
91
91
92
92
def test_groupby_duplicate_coordinate_labels ():
93
93
# fix for http://stackoverflow.com/questions/38065129
94
94
array = xr .DataArray ([1 , 2 , 3 ], [("x" , [1 , 1 , 2 ])])
95
95
expected = xr .DataArray ([3 , 3 ], [("x" , [1 , 2 ])])
96
96
actual = array .groupby ("x" ).sum ()
97
- assert expected . equals ( actual )
97
+ assert_equal ( expected , actual )
98
98
99
99
100
100
def test_groupby_input_mutation ():
@@ -263,6 +263,72 @@ def test_groupby_repr_datetime(obj):
263
263
assert actual == expected
264
264
265
265
266
+ def test_groupby_drops_nans ():
267
+ # GH2383
268
+ # nan in 2D data variable (requires stacking)
269
+ ds = xr .Dataset (
270
+ {
271
+ "variable" : (("lat" , "lon" , "time" ), np .arange (60.0 ).reshape ((4 , 3 , 5 ))),
272
+ "id" : (("lat" , "lon" ), np .arange (12.0 ).reshape ((4 , 3 ))),
273
+ },
274
+ coords = {"lat" : np .arange (4 ), "lon" : np .arange (3 ), "time" : np .arange (5 )},
275
+ )
276
+
277
+ ds ["id" ].values [0 , 0 ] = np .nan
278
+ ds ["id" ].values [3 , 0 ] = np .nan
279
+ ds ["id" ].values [- 1 , - 1 ] = np .nan
280
+
281
+ grouped = ds .groupby (ds .id )
282
+
283
+ # non reduction operation
284
+ expected = ds .copy ()
285
+ expected .variable .values [0 , 0 , :] = np .nan
286
+ expected .variable .values [- 1 , - 1 , :] = np .nan
287
+ expected .variable .values [3 , 0 , :] = np .nan
288
+ actual = grouped .apply (lambda x : x ).transpose (* ds .variable .dims )
289
+ assert_identical (actual , expected )
290
+
291
+ # reduction along grouped dimension
292
+ actual = grouped .mean ()
293
+ stacked = ds .stack ({"xy" : ["lat" , "lon" ]})
294
+ expected = (
295
+ stacked .variable .where (stacked .id .notnull ()).rename ({"xy" : "id" }).to_dataset ()
296
+ )
297
+ expected ["id" ] = stacked .id .values
298
+ assert_identical (actual , expected .dropna ("id" ).transpose (* actual .dims ))
299
+
300
+ # reduction operation along a different dimension
301
+ actual = grouped .mean ("time" )
302
+ expected = ds .mean ("time" ).where (ds .id .notnull ())
303
+ assert_identical (actual , expected )
304
+
305
+ # NaN in non-dimensional coordinate
306
+ array = xr .DataArray ([1 , 2 , 3 ], [("x" , [1 , 2 , 3 ])])
307
+ array ["x1" ] = ("x" , [1 , 1 , np .nan ])
308
+ expected = xr .DataArray (3 , [("x1" , [1 ])])
309
+ actual = array .groupby ("x1" ).sum ()
310
+ assert_equal (expected , actual )
311
+
312
+ # NaT in non-dimensional coordinate
313
+ array ["t" ] = (
314
+ "x" ,
315
+ [
316
+ np .datetime64 ("2001-01-01" ),
317
+ np .datetime64 ("2001-01-01" ),
318
+ np .datetime64 ("NaT" ),
319
+ ],
320
+ )
321
+ expected = xr .DataArray (3 , [("t" , [np .datetime64 ("2001-01-01" )])])
322
+ actual = array .groupby ("t" ).sum ()
323
+ assert_equal (expected , actual )
324
+
325
+ # test for repeated coordinate labels
326
+ array = xr .DataArray ([0 , 1 , 2 , 4 , 3 , 4 ], [("x" , [np .nan , 1 , 1 , np .nan , 2 , np .nan ])])
327
+ expected = xr .DataArray ([3 , 3 ], [("x" , [1 , 2 ])])
328
+ actual = array .groupby ("x" ).sum ()
329
+ assert_equal (expected , actual )
330
+
331
+
266
332
def test_groupby_grouping_errors ():
267
333
dataset = xr .Dataset ({"foo" : ("x" , [1 , 1 , 1 ])}, {"x" : [1 , 2 , 3 ]})
268
334
with raises_regex (ValueError , "None of the data falls within bins with edges" ):
0 commit comments