5
5
import xarray as xr
6
6
from xarray .core .groupby import _consolidate_slices
7
7
8
- from . import assert_identical , raises_regex
8
+ from . import assert_equal , assert_identical , raises_regex
9
9
10
10
11
11
def test_consolidate_slices ():
@@ -40,14 +40,14 @@ def test_multi_index_groupby_apply():
40
40
{"foo" : (("x" , "y" ), np .random .randn (3 , 4 ))},
41
41
{"x" : ["a" , "b" , "c" ], "y" : [1 , 2 , 3 , 4 ]},
42
42
)
43
- doubled = 2 * ds
44
- group_doubled = (
43
+ expected = 2 * ds
44
+ actual = (
45
45
ds .stack (space = ["x" , "y" ])
46
46
.groupby ("space" )
47
47
.apply (lambda x : 2 * x )
48
48
.unstack ("space" )
49
49
)
50
- assert doubled . equals ( group_doubled )
50
+ assert_equal ( expected , actual )
51
51
52
52
53
53
def test_multi_index_groupby_sum ():
@@ -58,7 +58,7 @@ def test_multi_index_groupby_sum():
58
58
)
59
59
expected = ds .sum ("z" )
60
60
actual = ds .stack (space = ["x" , "y" ]).groupby ("space" ).sum ("z" ).unstack ("space" )
61
- assert expected . equals ( actual )
61
+ assert_equal ( expected , actual )
62
62
63
63
64
64
def test_groupby_da_datetime ():
@@ -78,15 +78,15 @@ def test_groupby_da_datetime():
78
78
expected = xr .DataArray (
79
79
[3 , 7 ], coords = dict (reference_date = reference_dates ), dims = "reference_date"
80
80
)
81
- assert actual . equals (expected )
81
+ assert_equal (expected , actual )
82
82
83
83
84
84
def test_groupby_duplicate_coordinate_labels ():
85
85
# fix for http://stackoverflow.com/questions/38065129
86
86
array = xr .DataArray ([1 , 2 , 3 ], [("x" , [1 , 1 , 2 ])])
87
87
expected = xr .DataArray ([3 , 3 ], [("x" , [1 , 2 ])])
88
88
actual = array .groupby ("x" ).sum ()
89
- assert expected . equals ( actual )
89
+ assert_equal ( expected , actual )
90
90
91
91
92
92
def test_groupby_input_mutation ():
@@ -255,6 +255,59 @@ def test_groupby_repr_datetime(obj):
255
255
assert actual == expected
256
256
257
257
258
+ def test_groupby_drops_nans ():
259
+ # GH2383
260
+ # nan in 2D data variable (requires stacking)
261
+ ds = xr .Dataset (
262
+ {
263
+ "variable" : (("lat" , "lon" , "time" ), np .arange (60.0 ).reshape ((4 , 3 , 5 ))),
264
+ "id" : (("lat" , "lon" ), np .arange (12.0 ).reshape ((4 , 3 ))),
265
+ },
266
+ coords = {"lat" : np .arange (4 ), "lon" : np .arange (3 ), "time" : np .arange (5 )},
267
+ )
268
+
269
+ ds ["id" ].values [0 , 0 ] = np .nan
270
+ ds ["id" ].values [3 , 0 ] = np .nan
271
+ ds ["id" ].values [- 1 , - 1 ] = np .nan
272
+
273
+ grouped = ds .groupby (ds .id )
274
+
275
+ # non reduction operation
276
+ expected = ds .copy ()
277
+ expected .variable .values [0 , 0 , :] = np .nan
278
+ expected .variable .values [- 1 , - 1 , :] = np .nan
279
+ expected .variable .values [3 , 0 , :] = np .nan
280
+ actual = grouped .apply (lambda x : x ).transpose (* ds .variable .dims )
281
+ assert_identical (actual , expected )
282
+
283
+ # reduction along grouped dimension
284
+ actual = grouped .mean ()
285
+ stacked = ds .stack ({"xy" : ["lat" , "lon" ]})
286
+ expected = (
287
+ stacked .variable .where (stacked .id .notnull ()).rename ({"xy" : "id" }).to_dataset ()
288
+ )
289
+ expected ["id" ] = stacked .id .values
290
+ assert_identical (actual , expected .dropna ("id" ).transpose (* actual .dims ))
291
+
292
+ # reduction operation along a different dimension
293
+ actual = grouped .mean ("time" )
294
+ expected = ds .mean ("time" ).where (ds .id .notnull ())
295
+ assert_identical (actual , expected )
296
+
297
+ # NaN in non-dimensional coordinate
298
+ array = xr .DataArray ([1 , 2 , 3 ], [("x" , [1 , 2 , 3 ])])
299
+ array ["x1" ] = ("x" , [1 , 1 , np .nan ])
300
+ expected = xr .DataArray (3 , [("x1" , [1 ])])
301
+ actual = array .groupby ("x1" ).sum ()
302
+ assert_equal (expected , actual )
303
+
304
+ # test for repeated coordinate labels
305
+ array = xr .DataArray ([0 , 1 , 2 , 4 , 3 , 4 ], [("x" , [np .nan , 1 , 1 , np .nan , 2 , np .nan ])])
306
+ expected = xr .DataArray ([3 , 3 ], [("x" , [1 , 2 ])])
307
+ actual = array .groupby ("x" ).sum ()
308
+ assert_equal (expected , actual )
309
+
310
+
258
311
def test_groupby_grouping_errors ():
259
312
dataset = xr .Dataset ({"foo" : ("x" , [1 , 1 , 1 ])}, {"x" : [1 , 2 , 3 ]})
260
313
with raises_regex (ValueError , "None of the data falls within bins with edges" ):
0 commit comments