Skip to content

Allow specifying output dtype #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,7 @@ def groupby_reduce(
isbin: bool = False,
axis=None,
fill_value=None,
dtype=None,
min_count: int | None = None,
split_out: int = 1,
method: str = "map-reduce",
Expand Down Expand Up @@ -1397,6 +1398,8 @@ def groupby_reduce(
Negative integers are normalized using array.ndim
fill_value : Any
Value to assign when a label in ``expected_groups`` is not present.
dtype: data-type , optional
DType for the output. Can be anything that is accepted by ``np.dtype``.
min_count : int, default: None
The required number of valid values to perform the operation. If
fewer than min_count non-NA values are present the result will be
Expand Down Expand Up @@ -1566,8 +1569,13 @@ def groupby_reduce(
# overwrite than when min_count is set
fill_value = np.nan

if dtype is not None and not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)

kwargs = dict(axis=axis, fill_value=fill_value, engine=engine)
agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs)
agg = _initialize_aggregation(
func, array.dtype if dtype is None else dtype, fill_value, min_count, finalize_kwargs
)

if not has_dask:
results = _reduce_blockwise(
Expand Down
8 changes: 7 additions & 1 deletion flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def xarray_reduce(
dim: Hashable = None,
split_out: int = 1,
fill_value=None,
dtype=None,
method: str = "map-reduce",
engine: str = "flox",
keep_attrs: bool | None = True,
Expand Down Expand Up @@ -95,6 +96,8 @@ def xarray_reduce(
fill_value
Value used for missing groups in the output i.e. when one of the labels
in ``expected_groups`` is not actually present in ``by``.
dtype: data-type, optional
DType for the output. Can be anything accepted by ``np.dtype``.
method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
Strategy for reduction of dask arrays only:
* ``"map-reduce"``:
Expand Down Expand Up @@ -341,7 +344,9 @@ def wrapper(array, *by, func, skipna, **kwargs):
exclude_dims=set(dim),
output_core_dims=[group_names],
dask="allowed",
dask_gufunc_kwargs=dict(output_sizes=group_sizes),
dask_gufunc_kwargs=dict(
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
),
keep_attrs=keep_attrs,
kwargs={
"func": func,
Expand All @@ -357,6 +362,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
"expected_groups": tuple(expected_groups),
"isbin": isbin,
"finalize_kwargs": finalize_kwargs,
"dtype": dtype,
},
)

Expand Down
11 changes: 11 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,3 +1009,14 @@ def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dty
method="blockwise",
)
assert_equal(expected, actual)


@pytest.mark.parametrize("func", ALL_FUNCS)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_dtype(func, dtype, engine):
if "arg" in func or func in ["any", "all"]:
pytest.skip()
arr = np.ones((4, 12), dtype=dtype)
labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64)
assert actual.dtype == np.dtype("float64")
29 changes: 29 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,32 @@ def test_mixed_grouping(chunk):
fill_value=0,
)
assert (r.sel(v1=[3, 4, 5]) == 0).all().data


@pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("chunk", (True, False))
def test_dtype(chunk, dtype, dtype_out, engine):
if chunk and not has_dask:
pytest.skip()

if chunk:
data = dask.array.ones((4, 12), dtype=dtype, chunks=(1, -1))
else:
data = np.ones((4, 12), dtype=dtype)

arr = xr.DataArray(
data,
dims=("x", "t"),
coords={
"labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]))
},
name="arr",
)
actual = xarray_reduce(arr, "labels", func="mean", dtype=dtype_out)
assert actual.dtype == np.dtype("float64")
assert actual.compute().dtype == np.dtype("float64")

actual = xarray_reduce(arr.to_dataset(), "labels", func="mean", dtype=dtype_out)
assert actual.arr.dtype == np.dtype("float64")
assert actual.compute().arr.dtype == np.dtype("float64")