Skip to content

Commit 39870ee

Browse files
authored
Fix func count for dtype O with numpy and numba (#138)
1 parent b21b040 commit 39870ee

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

flox/xarray.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,10 @@ def wrapper(array, *by, func, skipna, **kwargs):
296296
if "nan" not in func and func not in ["all", "any", "count"]:
297297
func = f"nan{func}"
298298

299-
requires_numeric = func not in ["count", "any", "all"]
299+
# Flox's count works with non-numeric and its faster than converting.
300+
requires_numeric = func not in ["count", "any", "all"] or (
301+
func == "count" and engine != "flox"
302+
)
300303
if requires_numeric:
301304
is_npdatetime = array.dtype.kind in "Mm"
302305
is_cftime = _contains_cftime_datetimes(array)
@@ -311,7 +314,8 @@ def wrapper(array, *by, func, skipna, **kwargs):
311314

312315
result, *groups = groupby_reduce(array, *by, func=func, **kwargs)
313316

314-
if requires_numeric:
317+
# Output of count has an int dtype.
318+
if requires_numeric and func != "count":
315319
if is_npdatetime:
316320
return result.astype(dtype) + offset
317321
elif is_cftime:

tests/test_xarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,15 @@ def test_cache():
420420

421421
@pytest.mark.parametrize("use_cftime", [True, False])
422422
@pytest.mark.parametrize("func", ["count", "mean"])
423-
def test_datetime_array_reduce(use_cftime, func):
423+
def test_datetime_array_reduce(use_cftime, func, engine):
424424

425425
time = xr.DataArray(
426426
xr.date_range("2009-01-01", "2012-12-31", use_cftime=use_cftime),
427427
dims=("time",),
428428
name="time",
429429
)
430430
expected = getattr(time.resample(time="YS"), func)()
431-
actual = resample_reduce(time.resample(time="YS"), func=func, engine="flox")
431+
actual = resample_reduce(time.resample(time="YS"), func=func, engine=engine)
432432
assert_equal(expected, actual)
433433

434434

0 commit comments

Comments
 (0)