diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 4bd9c24a9..03f3c2094 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -207,7 +207,7 @@ def nansum_of_squares(group_idx, array, *, axis=-1, size=None, fill_value=None, def nanlen(group_idx, array, *args, **kwargs): - return sum(group_idx, (notnull(array)).astype(int), *args, **kwargs) + return sum(group_idx, (notnull(array)).view(np.int8), *args, **kwargs) def mean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): @@ -215,14 +215,16 @@ def mean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): fill_value = 0 out = sum(group_idx, array, axis=axis, size=size, dtype=dtype, fill_value=fill_value) with np.errstate(invalid="ignore", divide="ignore"): - out /= nanlen(group_idx, array, size=size, axis=axis, fill_value=0) + out /= nanlen(group_idx, array, size=size, axis=axis, fill_value=0, dtype=np.intp) return out def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): if fill_value is None: fill_value = 0 - out = nansum(group_idx, array, size=size, axis=axis, dtype=dtype, fill_value=fill_value) + mask = notnull(array) + masked = np.where(mask, array, 0) + out = sum(group_idx, masked, size=size, axis=axis, dtype=dtype, fill_value=fill_value) with np.errstate(invalid="ignore", divide="ignore"): - out /= nanlen(group_idx, array, size=size, axis=axis, fill_value=0) + out /= sum(group_idx, mask.view(np.int8), size=size, axis=axis, fill_value=0, dtype=np.intp) return out