Skip to content

Allow specifying output dtype #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ def generic_aggregate(

def _normalize_dtype(dtype, array_dtype, fill_value=None):
if dtype is None:
if fill_value is not None and np.isnan(fill_value):
dtype = np.floating
else:
dtype = array_dtype
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
Expand All @@ -68,6 +65,8 @@ def _normalize_dtype(dtype, array_dtype, fill_value=None):
dtype = np.dtype("float64")
elif not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
dtype = np.result_type(dtype, fill_value)
return dtype


Expand Down Expand Up @@ -465,6 +464,7 @@ def _zip_index(array_, idx_):

def _initialize_aggregation(
func: str | Aggregation,
dtype,
array_dtype,
fill_value,
min_count: int | None,
Expand All @@ -484,10 +484,18 @@ def _initialize_aggregation(
else:
raise ValueError("Bad type for func. Expected str or Aggregation")

agg.dtype[func] = _normalize_dtype(agg.dtype[func], array_dtype, fill_value)
# np.dtype(None) == np.dtype("float64")!!!
# so check for not None
if dtype is not None and not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)

agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value)
agg.dtype["numpy"] = (agg.dtype[func],)
agg.dtype["intermediate"] = [
_normalize_dtype(dtype, array_dtype) for dtype in agg.dtype["intermediate"]
_normalize_dtype(int_dtype, np.result_type(array_dtype, agg.dtype[func]), int_fv)
if int_dtype is None
else int_dtype
for int_dtype, int_fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
]

# Replace sentinel fill values according to dtype
Expand Down
6 changes: 5 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ def _finalize_results(
else:
finalized["groups"] = squeezed["groups"]

finalized[agg.name] = finalized[agg.name].astype(agg.dtype[agg.name], copy=False)
return finalized


Expand Down Expand Up @@ -1411,6 +1412,7 @@ def groupby_reduce(
isbin: T_IsBins = False,
axis: T_AxesOpt = None,
fill_value=None,
dtype: np.typing.DTypeLike = None,
min_count: int | None = None,
split_out: int = 1,
method: T_Method = "map-reduce",
Expand Down Expand Up @@ -1444,6 +1446,8 @@ def groupby_reduce(
Negative integers are normalized using array.ndim
fill_value : Any
Value to assign when a label in ``expected_groups`` is not present.
dtype: data-type , optional
DType for the output. Can be anything that is accepted by ``np.dtype``.
min_count : int, default: None
The required number of valid values to perform the operation. If
fewer than min_count non-NA values are present the result will be
Expand Down Expand Up @@ -1621,7 +1625,7 @@ def groupby_reduce(
fill_value = np.nan

kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)

if not has_dask:
results = _reduce_blockwise(
Expand Down
8 changes: 7 additions & 1 deletion flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def xarray_reduce(
dim: Dims | ellipsis = None,
split_out: int = 1,
fill_value=None,
dtype: np.typing.DTypeLike = None,
method: str = "map-reduce",
engine: str = "numpy",
keep_attrs: bool | None = True,
Expand Down Expand Up @@ -98,6 +99,8 @@ def xarray_reduce(
fill_value
Value used for missing groups in the output i.e. when one of the labels
in ``expected_groups`` is not actually present in ``by``.
dtype: data-type, optional
DType for the output. Can be anything accepted by ``np.dtype``.
method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
Strategy for reduction of dask arrays only:
* ``"map-reduce"``:
Expand Down Expand Up @@ -387,7 +390,9 @@ def wrapper(array, *by, func, skipna, **kwargs):
exclude_dims=set(dim_tuple),
output_core_dims=[group_names],
dask="allowed",
dask_gufunc_kwargs=dict(output_sizes=group_sizes),
dask_gufunc_kwargs=dict(
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
),
keep_attrs=keep_attrs,
kwargs={
"func": func,
Expand All @@ -403,6 +408,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
"expected_groups": tuple(expected_groups),
"isbin": isbins,
"finalize_kwargs": finalize_kwargs,
"dtype": dtype,
},
)

Expand Down
21 changes: 16 additions & 5 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,28 +80,39 @@ def raise_if_dask_computes(max_computes=0):
return dask.config.set(scheduler=scheduler)


def assert_equal(a, b):
def assert_equal(a, b, tolerance=None):
__tracebackhide__ = True

if isinstance(a, list):
a = np.array(a)
if isinstance(b, list):
b = np.array(b)

if isinstance(a, pd_types) or isinstance(b, pd_types):
pd.testing.assert_index_equal(a, b)
elif has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types):
return
if has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types):
xr.testing.assert_identical(a, b)
elif has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
return

if tolerance is None and (
np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64)
):
tolerance = {"atol": 1e-18, "rtol": 1e-15}
else:
tolerance = {}

if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
# sometimes it's nice to see values and shapes
# rather than being dropped into some file in dask
np.testing.assert_allclose(a, b)
np.testing.assert_allclose(a, b, **tolerance)
# does some validation of the dask graph
da.utils.assert_eq(a, b, equal_nan=True)
else:
if a.dtype != b.dtype:
raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")

np.testing.assert_allclose(a, b, equal_nan=True)
np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)


@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
Expand Down
26 changes: 22 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
if "var" in func or "std" in func:
finalize_kwargs = finalize_kwargs + [{"ddof": 1}, {"ddof": 0}]
fill_value = np.nan
tolerance = {"rtol": 1e-14, "atol": 1e-16}
else:
fill_value = None
tolerance = None

for kwargs in finalize_kwargs:
flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value)
Expand All @@ -207,7 +209,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
assert_equal(actual_group, expect)
if "arg" in func:
assert actual.dtype.kind == "i"
assert_equal(actual, expected)
assert_equal(actual, expected, tolerance)

if not has_dask:
continue
Expand All @@ -216,10 +218,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
continue
actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs)
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect)
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
assert actual.dtype.kind == "i"
assert_equal(actual, expected)
assert_equal(actual, expected, tolerance)


@requires_dask
Expand Down Expand Up @@ -466,6 +468,11 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
fill_value = False
else:
fill_value = 123

if "var" in func or "std" in func:
tolerance = {"rtol": 1e-14, "atol": 1e-16}
else:
tolerance = None
# tests against the numpy output to make sure dask compute matches
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
rng = np.random.default_rng(12345)
Expand All @@ -484,7 +491,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
kwargs.pop("engine")
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
assert_equal(expected_npg, expected)
assert_equal(actual, expected)
assert_equal(actual, expected, tolerance)


@pytest.mark.parametrize("chunks", [None, (2, 2, 3)])
Expand Down Expand Up @@ -1025,3 +1032,14 @@ def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dty
method="blockwise",
)
assert_equal(expected, actual)


@pytest.mark.parametrize("func", ALL_FUNCS)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_dtype(func, dtype, engine):
if "arg" in func or func in ["any", "all"]:
pytest.skip()
arr = np.ones((4, 12), dtype=dtype)
labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64)
assert actual.dtype == np.dtype("float64")
77 changes: 77 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
pass


tolerance64 = {"rtol": 1e-15, "atol": 1e-18}
np.random.seed(123)


@pytest.mark.parametrize("reindex", [None, False, True])
@pytest.mark.parametrize("min_count", [None, 1, 3])
@pytest.mark.parametrize("add_nan", [True, False])
Expand Down Expand Up @@ -488,3 +492,76 @@ def test_mixed_grouping(chunk):
fill_value=0,
)
assert (r.sel(v1=[3, 4, 5]) == 0).all().data


@pytest.mark.parametrize("add_nan", [True, False])
@pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("chunk", (True, False))
def test_dtype(add_nan, chunk, dtype, dtype_out, engine):
if chunk and not has_dask:
pytest.skip()

xp = dask.array if chunk else np
data = xp.linspace(0, 1, 48, dtype=dtype).reshape((4, 12))

if add_nan:
data[1, ...] = np.nan
data[0, [0, 2]] = np.nan

arr = xr.DataArray(
data,
dims=("x", "t"),
coords={
"labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]))
},
name="arr",
)
kwargs = dict(func="mean", dtype=dtype_out, engine=engine)
actual = xarray_reduce(arr, "labels", **kwargs)
expected = arr.groupby("labels").mean(dtype="float64")

assert actual.dtype == np.dtype("float64")
assert actual.compute().dtype == np.dtype("float64")
xr.testing.assert_allclose(expected, actual, **tolerance64)

actual = xarray_reduce(arr.to_dataset(), "labels", **kwargs)
expected = arr.to_dataset().groupby("labels").mean(dtype="float64")

assert actual.arr.dtype == np.dtype("float64")
assert actual.compute().arr.dtype == np.dtype("float64")
xr.testing.assert_allclose(expected, actual.transpose("labels", ...), **tolerance64)


@pytest.mark.parametrize("chunk", [True, False])
@pytest.mark.parametrize("use_flox", [True, False])
def test_dtype_accumulation(use_flox, chunk):
if chunk and not has_dask:
pytest.skip()

datetimes = pd.date_range("2010-01", "2015-01", freq="6H", inclusive="left")
samples = 10 + np.cos(2 * np.pi * 0.001 * np.arange(len(datetimes))) * 1
samples += np.random.randn(len(datetimes))
samples = samples.astype("float32")

nan_indices = np.random.default_rng().integers(0, len(samples), size=5_000)
samples[nan_indices] = np.nan

da = xr.DataArray(samples, dims=("time",), coords=[datetimes])
if chunk:
da = da.chunk(time=1024)

gb = da.groupby("time.month")

with xr.set_options(use_flox=use_flox):
expected = gb.reduce(np.nanmean)
actual = gb.mean()
xr.testing.assert_allclose(expected, actual)
assert np.issubdtype(actual.dtype, np.float32)
assert np.issubdtype(actual.compute().dtype, np.float32)

expected = gb.reduce(np.nanmean, dtype="float64")
actual = gb.mean(dtype="float64")
assert np.issubdtype(actual.dtype, np.float64)
assert np.issubdtype(actual.compute().dtype, np.float64)
xr.testing.assert_allclose(expected, actual, **tolerance64)