diff --git a/flox/aggregations.py b/flox/aggregations.py index 2a1d68d6d..0cf6b259b 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -55,10 +55,7 @@ def generic_aggregate( def _normalize_dtype(dtype, array_dtype, fill_value=None): if dtype is None: - if fill_value is not None and np.isnan(fill_value): - dtype = np.floating - else: - dtype = array_dtype + dtype = array_dtype if dtype is np.floating: # mean, std, var always result in floating # but we preserve the array's dtype if it is floating @@ -68,6 +65,8 @@ def _normalize_dtype(dtype, array_dtype, fill_value=None): dtype = np.dtype("float64") elif not isinstance(dtype, np.dtype): dtype = np.dtype(dtype) + if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]: + dtype = np.result_type(dtype, fill_value) return dtype @@ -465,6 +464,7 @@ def _zip_index(array_, idx_): def _initialize_aggregation( func: str | Aggregation, + dtype, array_dtype, fill_value, min_count: int | None, @@ -484,10 +484,18 @@ def _initialize_aggregation( else: raise ValueError("Bad type for func. Expected str or Aggregation") - agg.dtype[func] = _normalize_dtype(agg.dtype[func], array_dtype, fill_value) + # np.dtype(None) == np.dtype("float64")!!! + # so check for not None + if dtype is not None and not isinstance(dtype, np.dtype): + dtype = np.dtype(dtype) + + agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value) agg.dtype["numpy"] = (agg.dtype[func],) agg.dtype["intermediate"] = [ - _normalize_dtype(dtype, array_dtype) for dtype in agg.dtype["intermediate"] + _normalize_dtype(int_dtype, np.result_type(array_dtype, agg.dtype[func]), int_fv) + if int_dtype is None + else int_dtype + for int_dtype, int_fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"]) ] # Replace sentinel fill values according to dtype diff --git a/flox/core.py b/flox/core.py index 15577d35c..c854ce953 100644 --- a/flox/core.py +++ b/flox/core.py @@ -789,6 +789,7 @@ def _finalize_results( else: finalized["groups"] = squeezed["groups"] + finalized[agg.name] = finalized[agg.name].astype(agg.dtype[agg.name], copy=False) return finalized @@ -1411,6 +1412,7 @@ def groupby_reduce( isbin: T_IsBins = False, axis: T_AxesOpt = None, fill_value=None, + dtype: np.typing.DTypeLike = None, min_count: int | None = None, split_out: int = 1, method: T_Method = "map-reduce", @@ -1444,6 +1446,8 @@ def groupby_reduce( Negative integers are normalized using array.ndim fill_value : Any Value to assign when a label in ``expected_groups`` is not present. + dtype: data-type , optional + DType for the output. Can be anything that is accepted by ``np.dtype``. min_count : int, default: None The required number of valid values to perform the operation. If fewer than min_count non-NA values are present the result will be @@ -1621,7 +1625,7 @@ def groupby_reduce( fill_value = np.nan kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine) - agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs) + agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs) if not has_dask: results = _reduce_blockwise( diff --git a/flox/xarray.py b/flox/xarray.py index 5f87bafe6..3b8ec96e8 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -63,6 +63,7 @@ def xarray_reduce( dim: Dims | ellipsis = None, split_out: int = 1, fill_value=None, + dtype: np.typing.DTypeLike = None, method: str = "map-reduce", engine: str = "numpy", keep_attrs: bool | None = True, @@ -98,6 +99,8 @@ def xarray_reduce( fill_value Value used for missing groups in the output i.e. when one of the labels in ``expected_groups`` is not actually present in ``by``. + dtype: data-type, optional + DType for the output. Can be anything accepted by ``np.dtype``. method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional Strategy for reduction of dask arrays only: * ``"map-reduce"``: @@ -387,7 +390,9 @@ def wrapper(array, *by, func, skipna, **kwargs): exclude_dims=set(dim_tuple), output_core_dims=[group_names], dask="allowed", - dask_gufunc_kwargs=dict(output_sizes=group_sizes), + dask_gufunc_kwargs=dict( + output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None + ), keep_attrs=keep_attrs, kwargs={ "func": func, @@ -403,6 +408,7 @@ def wrapper(array, *by, func, skipna, **kwargs): "expected_groups": tuple(expected_groups), "isbin": isbins, "finalize_kwargs": finalize_kwargs, + "dtype": dtype, }, ) diff --git a/tests/__init__.py b/tests/__init__.py index 9917b41fc..0cd967d11 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -80,28 +80,39 @@ def raise_if_dask_computes(max_computes=0): return dask.config.set(scheduler=scheduler) -def assert_equal(a, b): +def assert_equal(a, b, tolerance=None): __tracebackhide__ = True if isinstance(a, list): a = np.array(a) if isinstance(b, list): b = np.array(b) + if isinstance(a, pd_types) or isinstance(b, pd_types): pd.testing.assert_index_equal(a, b) - elif has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types): + return + if has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types): xr.testing.assert_identical(a, b) - elif has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type): + return + + if tolerance is None and ( + np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64) + ): + tolerance = {"atol": 1e-18, "rtol": 1e-15} + else: + tolerance = {} + + if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type): # sometimes it's nice to see values and shapes # rather than being dropped into some file in dask - np.testing.assert_allclose(a, b) + np.testing.assert_allclose(a, b, **tolerance) # does some validation of the dask graph da.utils.assert_eq(a, b, equal_nan=True) else: if a.dtype != b.dtype: raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})") - np.testing.assert_allclose(a, b, equal_nan=True) + np.testing.assert_allclose(a, b, equal_nan=True, **tolerance) @pytest.fixture(scope="module", params=["flox", "numpy", "numba"]) diff --git a/tests/test_core.py b/tests/test_core.py index 25660e734..05deae2fd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -185,8 +185,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if "var" in func or "std" in func: finalize_kwargs = finalize_kwargs + [{"ddof": 1}, {"ddof": 0}] fill_value = np.nan + tolerance = {"rtol": 1e-14, "atol": 1e-16} else: fill_value = None + tolerance = None for kwargs in finalize_kwargs: flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value) @@ -207,7 +209,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert_equal(actual_group, expect) if "arg" in func: assert actual.dtype.kind == "i" - assert_equal(actual, expected) + assert_equal(actual, expected, tolerance) if not has_dask: continue @@ -216,10 +218,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): continue actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs) for actual_group, expect in zip(groups, expected_groups): - assert_equal(actual_group, expect) + assert_equal(actual_group, expect, tolerance) if "arg" in func: assert actual.dtype.kind == "i" - assert_equal(actual, expected) + assert_equal(actual, expected, tolerance) @requires_dask @@ -466,6 +468,11 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): fill_value = False else: fill_value = 123 + + if "var" in func or "std" in func: + tolerance = {"rtol": 1e-14, "atol": 1e-16} + else: + tolerance = None # tests against the numpy output to make sure dask compute matches by = np.broadcast_to(labels2d, (3, *labels2d.shape)) rng = np.random.default_rng(12345) @@ -484,7 +491,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): kwargs.pop("engine") expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy") assert_equal(expected_npg, expected) - assert_equal(actual, expected) + assert_equal(actual, expected, tolerance) @pytest.mark.parametrize("chunks", [None, (2, 2, 3)]) @@ -1025,3 +1032,14 @@ def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dty method="blockwise", ) assert_equal(expected, actual) + + +@pytest.mark.parametrize("func", ALL_FUNCS) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_dtype(func, dtype, engine): + if "arg" in func or func in ["any", "all"]: + pytest.skip() + arr = np.ones((4, 12), dtype=dtype) + labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) + actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64) + assert actual.dtype == np.dtype("float64") diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 6669830b5..995e1daaa 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -24,6 +24,10 @@ pass +tolerance64 = {"rtol": 1e-15, "atol": 1e-18} +np.random.seed(123) + + @pytest.mark.parametrize("reindex", [None, False, True]) @pytest.mark.parametrize("min_count", [None, 1, 3]) @pytest.mark.parametrize("add_nan", [True, False]) @@ -488,3 +492,76 @@ def test_mixed_grouping(chunk): fill_value=0, ) assert (r.sel(v1=[3, 4, 5]) == 0).all().data + + +@pytest.mark.parametrize("add_nan", [True, False]) +@pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("chunk", (True, False)) +def test_dtype(add_nan, chunk, dtype, dtype_out, engine): + if chunk and not has_dask: + pytest.skip() + + xp = dask.array if chunk else np + data = xp.linspace(0, 1, 48, dtype=dtype).reshape((4, 12)) + + if add_nan: + data[1, ...] = np.nan + data[0, [0, 2]] = np.nan + + arr = xr.DataArray( + data, + dims=("x", "t"), + coords={ + "labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])) + }, + name="arr", + ) + kwargs = dict(func="mean", dtype=dtype_out, engine=engine) + actual = xarray_reduce(arr, "labels", **kwargs) + expected = arr.groupby("labels").mean(dtype="float64") + + assert actual.dtype == np.dtype("float64") + assert actual.compute().dtype == np.dtype("float64") + xr.testing.assert_allclose(expected, actual, **tolerance64) + + actual = xarray_reduce(arr.to_dataset(), "labels", **kwargs) + expected = arr.to_dataset().groupby("labels").mean(dtype="float64") + + assert actual.arr.dtype == np.dtype("float64") + assert actual.compute().arr.dtype == np.dtype("float64") + xr.testing.assert_allclose(expected, actual.transpose("labels", ...), **tolerance64) + + +@pytest.mark.parametrize("chunk", [True, False]) +@pytest.mark.parametrize("use_flox", [True, False]) +def test_dtype_accumulation(use_flox, chunk): + if chunk and not has_dask: + pytest.skip() + + datetimes = pd.date_range("2010-01", "2015-01", freq="6H", inclusive="left") + samples = 10 + np.cos(2 * np.pi * 0.001 * np.arange(len(datetimes))) * 1 + samples += np.random.randn(len(datetimes)) + samples = samples.astype("float32") + + nan_indices = np.random.default_rng().integers(0, len(samples), size=5_000) + samples[nan_indices] = np.nan + + da = xr.DataArray(samples, dims=("time",), coords=[datetimes]) + if chunk: + da = da.chunk(time=1024) + + gb = da.groupby("time.month") + + with xr.set_options(use_flox=use_flox): + expected = gb.reduce(np.nanmean) + actual = gb.mean() + xr.testing.assert_allclose(expected, actual) + assert np.issubdtype(actual.dtype, np.float32) + assert np.issubdtype(actual.compute().dtype, np.float32) + + expected = gb.reduce(np.nanmean, dtype="float64") + actual = gb.mean(dtype="float64") + assert np.issubdtype(actual.dtype, np.float64) + assert np.issubdtype(actual.compute().dtype, np.float64) + xr.testing.assert_allclose(expected, actual, **tolerance64)