Skip to content

Commit 881fa27

Browse files
committed
Merge branch 'main' into allow-dtype
* main: Add a dtype check for numpy arrays in assert_equal (#158)
2 parents 08dff04 + 031979d commit 881fa27

File tree

3 files changed

+63
-32
lines changed

3 files changed

+63
-32
lines changed

flox/core.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,7 @@
66
import operator
77
from collections import namedtuple
88
from functools import partial, reduce
9-
from typing import (
10-
TYPE_CHECKING,
11-
Any,
12-
Callable,
13-
Dict,
14-
Iterable,
15-
Literal,
16-
Mapping,
17-
Sequence,
18-
Union,
19-
)
9+
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Mapping, Sequence, Union
2010

2111
import numpy as np
2212
import numpy_groupies as npg
@@ -37,8 +27,11 @@
3727
if TYPE_CHECKING:
3828
import dask.array.Array as DaskArray
3929

30+
T_ExpectedGroups = Union[Sequence, np.ndarray, pd.Index]
31+
T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
4032
T_Func = Union[str, Callable]
4133
T_Funcs = Union[T_Func, Sequence[T_Func]]
34+
T_Agg = Union[str, Aggregation]
4235
T_Axis = int
4336
T_Axes = tuple[T_Axis, ...]
4437
T_AxesOpt = Union[T_Axis, T_Axes, None]
@@ -60,14 +53,20 @@
6053
DUMMY_AXIS = -2
6154

6255

63-
def _is_arg_reduction(func: str | Aggregation) -> bool:
56+
def _is_arg_reduction(func: T_Agg) -> bool:
6457
if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]:
6558
return True
6659
if isinstance(func, Aggregation) and func.reduction_type == "argreduce":
6760
return True
6861
return False
6962

7063

64+
def _is_minmax_reduction(func: T_Agg) -> bool:
65+
return not _is_arg_reduction(func) and (
66+
isinstance(func, str) and ("max" in func or "min" in func)
67+
)
68+
69+
7170
def _get_expected_groups(by, sort: bool) -> pd.Index:
7271
if is_duck_dask_array(by):
7372
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
@@ -1027,7 +1026,16 @@ def split_blocks(applied, split_out, expected_groups, split_name):
10271026

10281027

10291028
def _reduce_blockwise(
1030-
array, by, agg, *, axis: T_Axes, expected_groups, fill_value, engine: T_Engine, sort, reindex
1029+
array,
1030+
by,
1031+
agg: Aggregation,
1032+
*,
1033+
axis: T_Axes,
1034+
expected_groups,
1035+
fill_value,
1036+
engine: T_Engine,
1037+
sort,
1038+
reindex,
10311039
) -> FinalResultsDict:
10321040
"""
10331041
Blockwise groupby reduction that produces the final result. This code path is
@@ -1335,7 +1343,7 @@ def _assert_by_is_aligned(shape, by):
13351343

13361344

13371345
def _convert_expected_groups_to_index(
1338-
expected_groups: Iterable, isbin: Sequence[bool], sort: bool
1346+
expected_groups: T_ExpectedGroups, isbin: Sequence[bool], sort: bool
13391347
) -> tuple[pd.Index | None, ...]:
13401348
out: list[pd.Index | None] = []
13411349
for ex, isbin_ in zip(expected_groups, isbin):
@@ -1397,8 +1405,8 @@ def _factorize_multiple(by, expected_groups, by_is_dask, reindex):
13971405
def groupby_reduce(
13981406
array: np.ndarray | DaskArray,
13991407
*by: np.ndarray | DaskArray,
1400-
func: str | Aggregation,
1401-
expected_groups: Sequence | np.ndarray | None = None,
1408+
func: T_Agg,
1409+
expected_groups: T_ExpectedGroupsOpt = None,
14021410
sort: bool = True,
14031411
isbin: T_IsBins = False,
14041412
axis: T_AxesOpt = None,
@@ -1523,7 +1531,8 @@ def groupby_reduce(
15231531

15241532
if not is_duck_array(array):
15251533
array = np.asarray(array)
1526-
array = array.astype(int) if np.issubdtype(array.dtype, bool) else array
1534+
is_bool_array = np.issubdtype(array.dtype, bool)
1535+
array = array.astype(int) if is_bool_array else array
15271536

15281537
if isinstance(isbin, Sequence):
15291538
isbins = isbin
@@ -1717,4 +1726,7 @@ def groupby_reduce(
17171726
result, from_=groups[0], to=expected_groups, fill_value=fill_value
17181727
).reshape(result.shape[:-1] + grp_shape)
17191728
groups = final_groups
1729+
1730+
if _is_minmax_reduction(func) and is_bool_array:
1731+
result = result.astype(bool)
17201732
return (result, *groups)

tests/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
dask_array_type = da.Array
1616
except ImportError:
17-
dask_array_type = ()
17+
dask_array_type = () # type: ignore
1818

1919

2020
try:
2121
import xarray as xr
2222

2323
xr_types = (xr.DataArray, xr.Dataset)
2424
except ImportError:
25-
xr_types = ()
25+
xr_types = () # type: ignore
2626

2727

2828
def _importorskip(modname, minversion=None):
@@ -98,6 +98,9 @@ def assert_equal(a, b):
9898
# does some validation of the dask graph
9999
da.utils.assert_eq(a, b, equal_nan=True)
100100
else:
101+
if a.dtype != b.dtype:
102+
raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")
103+
101104
np.testing.assert_allclose(a, b, equal_nan=True)
102105

103106

tests/test_core.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from __future__ import annotations
2+
13
from functools import reduce
4+
from typing import TYPE_CHECKING
25

36
import numpy as np
47
import pandas as pd
@@ -63,6 +66,9 @@ def dask_array_ones(*args):
6366
pytest.param("nanmedian", marks=(pytest.mark.skip,)),
6467
)
6568

69+
if TYPE_CHECKING:
70+
from flox.core import T_Engine, T_ExpectedGroupsOpt, T_Func2
71+
6672

6773
def test_alignment_error():
6874
da = np.ones((12,))
@@ -101,21 +107,29 @@ def test_alignment_error():
101107
],
102108
)
103109
def test_groupby_reduce(
104-
array, by, expected, func, expected_groups, chunk, split_out, dtype, engine
105-
):
110+
engine: T_Engine,
111+
func: T_Func2,
112+
array: np.ndarray,
113+
by: np.ndarray,
114+
expected: list[float],
115+
expected_groups: T_ExpectedGroupsOpt,
116+
chunk: bool,
117+
split_out: int,
118+
dtype: np.typing.DTypeLike,
119+
) -> None:
106120
array = array.astype(dtype)
107121
if chunk:
108122
if not has_dask or expected_groups is None:
109123
pytest.skip()
110124
array = da.from_array(array, chunks=(3,) if array.ndim == 1 else (1, 3))
111125
by = da.from_array(by, chunks=(3,) if by.ndim == 1 else (1, 3))
112126

113-
if "mean" in func:
114-
expected = np.array(expected, dtype=float)
127+
if func == "mean" or func == "nanmean":
128+
expected_result = np.array(expected, dtype=float)
115129
elif func == "sum":
116-
expected = np.array(expected, dtype=dtype)
130+
expected_result = np.array(expected, dtype=dtype)
117131
elif func == "count":
118-
expected = np.array(expected, dtype=int)
132+
expected_result = np.array(expected, dtype=int)
119133

120134
result, groups, = groupby_reduce(
121135
array,
@@ -126,8 +140,10 @@ def test_groupby_reduce(
126140
split_out=split_out,
127141
engine=engine,
128142
)
129-
assert_equal(groups, [0, 1, 2])
130-
assert_equal(expected, result)
143+
g_dtype = by.dtype if expected_groups is None else np.asarray(expected_groups).dtype
144+
145+
assert_equal(groups, np.array([0, 1, 2], g_dtype))
146+
assert_equal(expected_result, result)
131147

132148

133149
def gen_array_by(size, func):
@@ -843,16 +859,16 @@ def test_bool_reductions(func, engine):
843859

844860

845861
@requires_dask
846-
def test_map_reduce_blockwise_mixed():
862+
def test_map_reduce_blockwise_mixed() -> None:
847863
t = pd.date_range("2000-01-01", "2000-12-31", freq="D").to_series()
848864
data = t.dt.dayofyear
849-
actual = groupby_reduce(
865+
actual, _ = groupby_reduce(
850866
dask.array.from_array(data.values, chunks=365),
851867
t.dt.month,
852868
func="mean",
853869
method="split-reduce",
854870
)
855-
expected = groupby_reduce(data, t.dt.month, func="mean")
871+
expected, _ = groupby_reduce(data, t.dt.month, func="mean")
856872
assert_equal(expected, actual)
857873

858874

@@ -908,7 +924,7 @@ def test_factorize_values_outside_bins():
908924
assert_equal(expected, actual)
909925

910926

911-
def test_multiple_groupers():
927+
def test_multiple_groupers() -> None:
912928
actual, *_ = groupby_reduce(
913929
np.ones((5, 2)),
914930
np.arange(10).reshape(5, 2),
@@ -921,7 +937,7 @@ def test_multiple_groupers():
921937
reindex=True,
922938
func="count",
923939
)
924-
expected = np.eye(5, 5)
940+
expected = np.eye(5, 5, dtype=int)
925941
assert_equal(expected, actual)
926942

927943

0 commit comments

Comments
 (0)