From cfd8aed2886ad369ee857899b9821a337036d346 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 9 Aug 2022 12:47:06 -0600 Subject: [PATCH 01/21] Allow specifying output dtype Closes https://github.com/pydata/xarray/issues/6902 --- flox/core.py | 8 +++++++- flox/xarray.py | 8 +++++++- tests/test_core.py | 11 +++++++++++ tests/test_xarray.py | 17 +++++++++++++++++ 4 files changed, 42 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index f39a3fe4e..a6118cc3f 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1364,6 +1364,7 @@ def groupby_reduce( isbin: bool = False, axis=None, fill_value=None, + dtype=None, min_count: int | None = None, split_out: int = 1, method: str = "map-reduce", @@ -1566,8 +1567,13 @@ def groupby_reduce( # overwrite than when min_count is set fill_value = np.nan + if dtype is not None and not isinstance(dtype, np.dtype): + dtype = np.dtype(dtype) + kwargs = dict(axis=axis, fill_value=fill_value, engine=engine) - agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs) + agg = _initialize_aggregation( + func, array.dtype if dtype is None else dtype, fill_value, min_count, finalize_kwargs + ) if not has_dask: results = _reduce_blockwise( diff --git a/flox/xarray.py b/flox/xarray.py index 358b57abd..7a8d56d0a 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -60,6 +60,7 @@ def xarray_reduce( dim: Hashable = None, split_out: int = 1, fill_value=None, + dtype=None, method: str = "map-reduce", engine: str = "flox", keep_attrs: bool | None = True, @@ -95,6 +96,8 @@ def xarray_reduce( fill_value Value used for missing groups in the output i.e. when one of the labels in ``expected_groups`` is not actually present in ``by``. + dtype: str + DType for the output (DataArray only). method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional Strategy for reduction of dask arrays only: * ``"map-reduce"``: @@ -341,7 +344,9 @@ def wrapper(array, *by, func, skipna, **kwargs): exclude_dims=set(dim), output_core_dims=[group_names], dask="allowed", - dask_gufunc_kwargs=dict(output_sizes=group_sizes), + dask_gufunc_kwargs=dict( + output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None + ), keep_attrs=keep_attrs, kwargs={ "func": func, @@ -357,6 +362,7 @@ def wrapper(array, *by, func, skipna, **kwargs): "expected_groups": tuple(expected_groups), "isbin": isbin, "finalize_kwargs": finalize_kwargs, + "dtype": dtype, }, ) diff --git a/tests/test_core.py b/tests/test_core.py index 9c0bd6adb..e30db6064 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1009,3 +1009,14 @@ def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dty method="blockwise", ) assert_equal(expected, actual) + + +@pytest.mark.parametrize("func", ALL_FUNCS) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_dtype(func, dtype, engine): + if "arg" in func or func in ["any", "all"]: + pytest.skip() + arr = np.ones((4, 12), dtype=dtype) + labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) + actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64) + assert actual.dtype == np.dtype("float64") diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 0a696b24a..34ae551a7 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -485,3 +485,20 @@ def test_mixed_grouping(chunk): fill_value=0, ) assert (r.sel(v1=[3, 4, 5]) == 0).all().data + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_dtype(dtype, engine): + arr = xr.DataArray( + data=np.ones((4, 12), dtype=dtype), + dims=("x", "t"), + coords={ + "labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])) + }, + ) + actual = xarray_reduce(arr, "labels", func="mean", dtype=np.float64) + assert actual.dtype == np.dtype("float64") + + actual = xarray_reduce(arr.chunk({"x": 1}), arr.labels, func="mean", dtype=np.float64) + assert actual.dtype == np.dtype("float64") + assert actual.compute().dtype == np.dtype("float64") From e9d1bd9c4e1856de922f55bceab31a054bee81c9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 9 Aug 2022 12:48:36 -0600 Subject: [PATCH 02/21] Guard dask test --- tests/test_xarray.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 34ae551a7..c44af57b1 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -499,6 +499,7 @@ def test_dtype(dtype, engine): actual = xarray_reduce(arr, "labels", func="mean", dtype=np.float64) assert actual.dtype == np.dtype("float64") - actual = xarray_reduce(arr.chunk({"x": 1}), arr.labels, func="mean", dtype=np.float64) - assert actual.dtype == np.dtype("float64") - assert actual.compute().dtype == np.dtype("float64") + if has_dask: + actual = xarray_reduce(arr.chunk({"x": 1}), arr.labels, func="mean", dtype=np.float64) + assert actual.dtype == np.dtype("float64") + assert actual.compute().dtype == np.dtype("float64") From a294a879d8dd16d728d12c3124cffe6a8a613366 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 9 Aug 2022 12:51:30 -0600 Subject: [PATCH 03/21] Add dataset test too --- tests/test_xarray.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index c44af57b1..0ff4cb08d 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -488,18 +488,27 @@ def test_mixed_grouping(chunk): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_dtype(dtype, engine): +@pytest.mark.parametrize("chunk", (True, False)) +def test_dtype(chunk, dtype, engine): + if chunk and not has_dask: + pytest.skip() + + if chunk: + data = dask.array.ones((4, 12), dtype=dtype, chunks=(1, -1)) + else: + data = np.ones((4, 12), dtype=dtype) arr = xr.DataArray( - data=np.ones((4, 12), dtype=dtype), + data, dims=("x", "t"), coords={ "labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])) }, + name="arr", ) actual = xarray_reduce(arr, "labels", func="mean", dtype=np.float64) assert actual.dtype == np.dtype("float64") + assert actual.compute().dtype == np.dtype("float64") - if has_dask: - actual = xarray_reduce(arr.chunk({"x": 1}), arr.labels, func="mean", dtype=np.float64) - assert actual.dtype == np.dtype("float64") - assert actual.compute().dtype == np.dtype("float64") + actual = xarray_reduce(arr.to_dataset(), "labels", func="mean", dtype=np.float64) + assert actual.arr.dtype == np.dtype("float64") + assert actual.compute().arr.dtype == np.dtype("float64") From 8e320e302a4ad2f316352ee9a8f97e1a0e6d9e2f Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 9 Aug 2022 12:53:33 -0600 Subject: [PATCH 04/21] More dtype options. --- tests/test_xarray.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 0ff4cb08d..caa29bcf4 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -487,9 +487,10 @@ def test_mixed_grouping(chunk): assert (r.sel(v1=[3, 4, 5]) == 0).all().data +@pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("chunk", (True, False)) -def test_dtype(chunk, dtype, engine): +def test_dtype(chunk, dtype, dtype_out, engine): if chunk and not has_dask: pytest.skip() @@ -497,6 +498,7 @@ def test_dtype(chunk, dtype, engine): data = dask.array.ones((4, 12), dtype=dtype, chunks=(1, -1)) else: data = np.ones((4, 12), dtype=dtype) + arr = xr.DataArray( data, dims=("x", "t"), @@ -505,10 +507,10 @@ def test_dtype(chunk, dtype, engine): }, name="arr", ) - actual = xarray_reduce(arr, "labels", func="mean", dtype=np.float64) + actual = xarray_reduce(arr, "labels", func="mean", dtype=dtype_out) assert actual.dtype == np.dtype("float64") assert actual.compute().dtype == np.dtype("float64") - actual = xarray_reduce(arr.to_dataset(), "labels", func="mean", dtype=np.float64) + actual = xarray_reduce(arr.to_dataset(), "labels", func="mean", dtype=dtype_out) assert actual.arr.dtype == np.dtype("float64") assert actual.compute().arr.dtype == np.dtype("float64") From b6c1693906103060ba01e3e726537c09c39a289e Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 9 Aug 2022 12:59:13 -0600 Subject: [PATCH 05/21] Update docstring --- flox/core.py | 2 ++ flox/xarray.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index a6118cc3f..04d8b0fb4 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1398,6 +1398,8 @@ def groupby_reduce( Negative integers are normalized using array.ndim fill_value : Any Value to assign when a label in ``expected_groups`` is not present. + dtype: data-type , optional + DType for the output. Can be anything that is accepted by ``np.dtype``. min_count : int, default: None The required number of valid values to perform the operation. If fewer than min_count non-NA values are present the result will be diff --git a/flox/xarray.py b/flox/xarray.py index 7a8d56d0a..64ea92c40 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -96,8 +96,8 @@ def xarray_reduce( fill_value Value used for missing groups in the output i.e. when one of the labels in ``expected_groups`` is not actually present in ``by``. - dtype: str - DType for the output (DataArray only). + dtype: data-type, optional + DType for the output. Can be anything accepted by ``np.dtype``. method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional Strategy for reduction of dask arrays only: * ``"map-reduce"``: From 592a2fa2a59d68969c69c0e1a82ac86cc3111553 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 9 Aug 2022 14:14:51 -0600 Subject: [PATCH 06/21] Better testing --- tests/test_xarray.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index caa29bcf4..32ca4f346 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -494,10 +494,8 @@ def test_dtype(chunk, dtype, dtype_out, engine): if chunk and not has_dask: pytest.skip() - if chunk: - data = dask.array.ones((4, 12), dtype=dtype, chunks=(1, -1)) - else: - data = np.ones((4, 12), dtype=dtype) + xp = dask.array if chunk else np + data = xp.linspace(0, 1, 48, dtype=dtype).reshape((4, 12)) arr = xr.DataArray( data, @@ -507,10 +505,17 @@ def test_dtype(chunk, dtype, dtype_out, engine): }, name="arr", ) - actual = xarray_reduce(arr, "labels", func="mean", dtype=dtype_out) + kwargs = dict(func="mean", dtype=dtype_out, engine=engine) + actual = xarray_reduce(arr, "labels", **kwargs) + expected = arr.groupby("labels").mean(dtype="float64") + assert actual.dtype == np.dtype("float64") assert actual.compute().dtype == np.dtype("float64") + assert_equal(expected, actual) + + actual = xarray_reduce(arr.to_dataset(), "labels", **kwargs) + expected = arr.to_dataset().groupby("labels").mean(dtype="float64") - actual = xarray_reduce(arr.to_dataset(), "labels", func="mean", dtype=dtype_out) assert actual.arr.dtype == np.dtype("float64") assert actual.compute().arr.dtype == np.dtype("float64") + assert_equal(expected, actual.transpose("labels", ...)) From c0cb3001d16573bcee3ceb72c7b048f3db5ac1b7 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 9 Aug 2022 15:42:11 -0600 Subject: [PATCH 07/21] Nicer test --- tests/test_xarray.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 32ca4f346..986e91a59 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -487,16 +487,21 @@ def test_mixed_grouping(chunk): assert (r.sel(v1=[3, 4, 5]) == 0).all().data +@pytest.mark.parametrize("add_nan", [True, False]) @pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("chunk", (True, False)) -def test_dtype(chunk, dtype, dtype_out, engine): +def test_dtype(add_nan, chunk, dtype, dtype_out, engine): if chunk and not has_dask: pytest.skip() xp = dask.array if chunk else np data = xp.linspace(0, 1, 48, dtype=dtype).reshape((4, 12)) + if add_nan: + data[1, ...] = np.nan + data[0, [0, 2]] = np.nan + arr = xr.DataArray( data, dims=("x", "t"), @@ -511,11 +516,11 @@ def test_dtype(chunk, dtype, dtype_out, engine): assert actual.dtype == np.dtype("float64") assert actual.compute().dtype == np.dtype("float64") - assert_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) actual = xarray_reduce(arr.to_dataset(), "labels", **kwargs) expected = arr.to_dataset().groupby("labels").mean(dtype="float64") assert actual.arr.dtype == np.dtype("float64") assert actual.compute().arr.dtype == np.dtype("float64") - assert_equal(expected, actual.transpose("labels", ...)) + xr.testing.assert_allclose(expected, actual.transpose("labels", ...)) From dae969c9a4a94c1a30c9f5621531503839a03b35 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 16 Aug 2022 12:00:27 -0600 Subject: [PATCH 08/21] Add dtype type Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- flox/core.py | 2 +- flox/xarray.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index 04d8b0fb4..78c47a812 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1364,7 +1364,7 @@ def groupby_reduce( isbin: bool = False, axis=None, fill_value=None, - dtype=None, + dtype: np.typing.DTypeLike = None, min_count: int | None = None, split_out: int = 1, method: str = "map-reduce", diff --git a/flox/xarray.py b/flox/xarray.py index 64ea92c40..ba834968e 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -60,7 +60,7 @@ def xarray_reduce( dim: Hashable = None, split_out: int = 1, fill_value=None, - dtype=None, + dtype: np.typing.DTypeLike = None, method: str = "map-reduce", engine: str = "flox", keep_attrs: bool | None = True, From 4dab89af64eb7e8dc2bf04242ce8be156dcd3cb3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 24 Sep 2022 09:21:00 +0200 Subject: [PATCH 09/21] Add dtype check for numpy arrays --- tests/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/__init__.py b/tests/__init__.py index fef0a778e..01366aad3 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -98,6 +98,8 @@ def assert_equal(a, b): # does some validation of the dask graph da.utils.assert_eq(a, b, equal_nan=True) else: + if a.dtype != b.dtype: + raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})") np.testing.assert_allclose(a, b, equal_nan=True) From a32a40cb28189b6c9eee85258ba093671e404e71 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Sep 2022 14:16:27 +0200 Subject: [PATCH 10/21] Revert "Add dtype check for numpy arrays" This reverts commit 4dab89af64eb7e8dc2bf04242ce8be156dcd3cb3. --- tests/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 01366aad3..fef0a778e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -98,8 +98,6 @@ def assert_equal(a, b): # does some validation of the dask graph da.utils.assert_eq(a, b, equal_nan=True) else: - if a.dtype != b.dtype: - raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})") np.testing.assert_allclose(a, b, equal_nan=True) From c37a7d9836021778057cef1896f56a704e6fef15 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 12:13:27 -0600 Subject: [PATCH 11/21] stricter test --- tests/test_xarray.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 7280ae970..35d2834e5 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -517,13 +517,15 @@ def test_dtype(add_nan, chunk, dtype, dtype_out, engine): actual = xarray_reduce(arr, "labels", **kwargs) expected = arr.groupby("labels").mean(dtype="float64") + tolerance = {"rtol": 1e-15, "atol": 1e-18} + assert actual.dtype == np.dtype("float64") assert actual.compute().dtype == np.dtype("float64") - xr.testing.assert_allclose(expected, actual) + xr.testing.assert_allclose(expected, actual, **tolerance) actual = xarray_reduce(arr.to_dataset(), "labels", **kwargs) expected = arr.to_dataset().groupby("labels").mean(dtype="float64") assert actual.arr.dtype == np.dtype("float64") assert actual.compute().arr.dtype == np.dtype("float64") - xr.testing.assert_allclose(expected, actual.transpose("labels", ...)) + xr.testing.assert_allclose(expected, actual.transpose("labels", ...), **tolerance) From ff89e59ccd882ca86ecdc8e2f4d415565c3ff831 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 12:27:27 -0600 Subject: [PATCH 12/21] Fix dtype of output with float32 & dask --- flox/core.py | 1 + tests/__init__.py | 8 ++++++-- tests/test_xarray.py | 44 ++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/flox/core.py b/flox/core.py index 00df7ca4e..93a8e67ed 100644 --- a/flox/core.py +++ b/flox/core.py @@ -789,6 +789,7 @@ def _finalize_results( else: finalized["groups"] = squeezed["groups"] + finalized[agg.name] = finalized[agg.name].astype(agg.dtype[agg.name], copy=False) return finalized diff --git a/tests/__init__.py b/tests/__init__.py index 9917b41fc..be0b59dec 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -94,7 +94,11 @@ def assert_equal(a, b): elif has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type): # sometimes it's nice to see values and shapes # rather than being dropped into some file in dask - np.testing.assert_allclose(a, b) + if np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64): + kwargs = {"atol": 1e-18, "rtol": 1e-15} + else: + kwargs = {} + np.testing.assert_allclose(a, b, **kwargs) # does some validation of the dask graph da.utils.assert_eq(a, b, equal_nan=True) else: @@ -104,7 +108,7 @@ def assert_equal(a, b): np.testing.assert_allclose(a, b, equal_nan=True) -@pytest.fixture(scope="module", params=["flox", "numpy", "numba"]) +@pytest.fixture(scope="module", params=["flox", "numpy"]) def engine(request): if request.param == "numba": try: diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 35d2834e5..995e1daaa 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -24,6 +24,10 @@ pass +tolerance64 = {"rtol": 1e-15, "atol": 1e-18} +np.random.seed(123) + + @pytest.mark.parametrize("reindex", [None, False, True]) @pytest.mark.parametrize("min_count", [None, 1, 3]) @pytest.mark.parametrize("add_nan", [True, False]) @@ -517,15 +521,47 @@ def test_dtype(add_nan, chunk, dtype, dtype_out, engine): actual = xarray_reduce(arr, "labels", **kwargs) expected = arr.groupby("labels").mean(dtype="float64") - tolerance = {"rtol": 1e-15, "atol": 1e-18} - assert actual.dtype == np.dtype("float64") assert actual.compute().dtype == np.dtype("float64") - xr.testing.assert_allclose(expected, actual, **tolerance) + xr.testing.assert_allclose(expected, actual, **tolerance64) actual = xarray_reduce(arr.to_dataset(), "labels", **kwargs) expected = arr.to_dataset().groupby("labels").mean(dtype="float64") assert actual.arr.dtype == np.dtype("float64") assert actual.compute().arr.dtype == np.dtype("float64") - xr.testing.assert_allclose(expected, actual.transpose("labels", ...), **tolerance) + xr.testing.assert_allclose(expected, actual.transpose("labels", ...), **tolerance64) + + +@pytest.mark.parametrize("chunk", [True, False]) +@pytest.mark.parametrize("use_flox", [True, False]) +def test_dtype_accumulation(use_flox, chunk): + if chunk and not has_dask: + pytest.skip() + + datetimes = pd.date_range("2010-01", "2015-01", freq="6H", inclusive="left") + samples = 10 + np.cos(2 * np.pi * 0.001 * np.arange(len(datetimes))) * 1 + samples += np.random.randn(len(datetimes)) + samples = samples.astype("float32") + + nan_indices = np.random.default_rng().integers(0, len(samples), size=5_000) + samples[nan_indices] = np.nan + + da = xr.DataArray(samples, dims=("time",), coords=[datetimes]) + if chunk: + da = da.chunk(time=1024) + + gb = da.groupby("time.month") + + with xr.set_options(use_flox=use_flox): + expected = gb.reduce(np.nanmean) + actual = gb.mean() + xr.testing.assert_allclose(expected, actual) + assert np.issubdtype(actual.dtype, np.float32) + assert np.issubdtype(actual.compute().dtype, np.float32) + + expected = gb.reduce(np.nanmean, dtype="float64") + actual = gb.mean(dtype="float64") + assert np.issubdtype(actual.dtype, np.float64) + assert np.issubdtype(actual.compute().dtype, np.float64) + xr.testing.assert_allclose(expected, actual, **tolerance64) From 5e77eda9ba53b8cd1575167da734eb27fb05fed6 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 12:30:32 -0600 Subject: [PATCH 13/21] restore numba test --- tests/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/__init__.py b/tests/__init__.py index be0b59dec..5c168fbe1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -108,7 +108,7 @@ def assert_equal(a, b): np.testing.assert_allclose(a, b, equal_nan=True) -@pytest.fixture(scope="module", params=["flox", "numpy"]) +@pytest.fixture(scope="module", params=["flox", "numpy", "numba"]) def engine(request): if request.param == "numba": try: From 6d5dd46612fffe214f8b7683ed35d509ce39b784 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 13:28:39 -0600 Subject: [PATCH 14/21] Fix dtype promotion by fill_value --- flox/aggregations.py | 17 +++++++++++------ flox/core.py | 7 +------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 2a1d68d6d..c5ffe7ced 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -55,10 +55,7 @@ def generic_aggregate( def _normalize_dtype(dtype, array_dtype, fill_value=None): if dtype is None: - if fill_value is not None and np.isnan(fill_value): - dtype = np.floating - else: - dtype = array_dtype + dtype = array_dtype if dtype is np.floating: # mean, std, var always result in floating # but we preserve the array's dtype if it is floating @@ -68,6 +65,8 @@ def _normalize_dtype(dtype, array_dtype, fill_value=None): dtype = np.dtype("float64") elif not isinstance(dtype, np.dtype): dtype = np.dtype(dtype) + if fill_value is not None: + dtype = np.result_type(dtype, fill_value) return dtype @@ -465,6 +464,7 @@ def _zip_index(array_, idx_): def _initialize_aggregation( func: str | Aggregation, + dtype, array_dtype, fill_value, min_count: int | None, @@ -484,10 +484,15 @@ def _initialize_aggregation( else: raise ValueError("Bad type for func. Expected str or Aggregation") - agg.dtype[func] = _normalize_dtype(agg.dtype[func], array_dtype, fill_value) + # np.dtype(None) == np.dtype("float64")!!! + # so check for not None + if dtype is not None and not isinstance(dtype, np.dtype): + dtype = np.dtype(dtype) + + agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value) agg.dtype["numpy"] = (agg.dtype[func],) agg.dtype["intermediate"] = [ - _normalize_dtype(dtype, array_dtype) for dtype in agg.dtype["intermediate"] + _normalize_dtype(dtype_, dtype) for dtype_ in agg.dtype["intermediate"] ] # Replace sentinel fill values according to dtype diff --git a/flox/core.py b/flox/core.py index 93a8e67ed..c854ce953 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1624,13 +1624,8 @@ def groupby_reduce( # overwrite than when min_count is set fill_value = np.nan - if dtype is not None and not isinstance(dtype, np.dtype): - dtype = np.dtype(dtype) - kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine) - agg = _initialize_aggregation( - func, array.dtype if dtype is None else dtype, fill_value, min_count, finalize_kwargs - ) + agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs) if not has_dask: results = _reduce_blockwise( From 3d99c5a6e7f2f3095a3e0ffc5061300fb1e6e395 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 15:36:41 -0600 Subject: [PATCH 15/21] Cleaner dtype setting for intermediates --- flox/aggregations.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index c5ffe7ced..6928c36f6 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -491,9 +491,7 @@ def _initialize_aggregation( agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value) agg.dtype["numpy"] = (agg.dtype[func],) - agg.dtype["intermediate"] = [ - _normalize_dtype(dtype_, dtype) for dtype_ in agg.dtype["intermediate"] - ] + agg.dtype["intermediate"] = [int_dtype or dtype for int_dtype in agg.dtype["intermediate"]] # Replace sentinel fill values according to dtype agg.fill_value["intermediate"] = tuple( From 5bc675c7d8c841925c759b38149aa15c4b1ba010 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 15:37:00 -0600 Subject: [PATCH 16/21] Loosen tolerance for var, std --- tests/__init__.py | 18 +++++++++++------- tests/test_core.py | 8 +++++--- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 5c168fbe1..8b25f4526 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -80,13 +80,21 @@ def raise_if_dask_computes(max_computes=0): return dask.config.set(scheduler=scheduler) -def assert_equal(a, b): +def assert_equal(a, b, tolerance=None): __tracebackhide__ = True if isinstance(a, list): a = np.array(a) if isinstance(b, list): b = np.array(b) + + if tolerance is None and ( + np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64) + ): + tolerance = {"atol": 1e-18, "rtol": 1e-15} + else: + tolerance = {} + if isinstance(a, pd_types) or isinstance(b, pd_types): pd.testing.assert_index_equal(a, b) elif has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types): @@ -94,18 +102,14 @@ def assert_equal(a, b): elif has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type): # sometimes it's nice to see values and shapes # rather than being dropped into some file in dask - if np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64): - kwargs = {"atol": 1e-18, "rtol": 1e-15} - else: - kwargs = {} - np.testing.assert_allclose(a, b, **kwargs) + np.testing.assert_allclose(a, b, **tolerance) # does some validation of the dask graph da.utils.assert_eq(a, b, equal_nan=True) else: if a.dtype != b.dtype: raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})") - np.testing.assert_allclose(a, b, equal_nan=True) + np.testing.assert_allclose(a, b, equal_nan=True, **tolerance) @pytest.fixture(scope="module", params=["flox", "numpy", "numba"]) diff --git a/tests/test_core.py b/tests/test_core.py index 413a72a75..642e26459 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -185,8 +185,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if "var" in func or "std" in func: finalize_kwargs = finalize_kwargs + [{"ddof": 1}, {"ddof": 0}] fill_value = np.nan + tolerance = {"rtol": 1e-14, "atol": 1e-16} else: fill_value = None + tolerance = None for kwargs in finalize_kwargs: flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value) @@ -207,7 +209,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert_equal(actual_group, expect) if "arg" in func: assert actual.dtype.kind == "i" - assert_equal(actual, expected) + assert_equal(actual, expected, tolerance) if not has_dask: continue @@ -216,10 +218,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): continue actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs) for actual_group, expect in zip(groups, expected_groups): - assert_equal(actual_group, expect) + assert_equal(actual_group, expect, tolerance) if "arg" in func: assert actual.dtype.kind == "i" - assert_equal(actual, expected) + assert_equal(actual, expected, tolerance) @requires_dask From c0c1dc1618b9bab7af3bebf922a6d8541032a5f7 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 15:44:02 -0600 Subject: [PATCH 17/21] bugfix --- flox/aggregations.py | 4 +++- tests/__init__.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 6928c36f6..fff661e3d 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -491,7 +491,9 @@ def _initialize_aggregation( agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value) agg.dtype["numpy"] = (agg.dtype[func],) - agg.dtype["intermediate"] = [int_dtype or dtype for int_dtype in agg.dtype["intermediate"]] + agg.dtype["intermediate"] = [ + int_dtype or agg.dtype[func] for int_dtype in agg.dtype["intermediate"] + ] # Replace sentinel fill values according to dtype agg.fill_value["intermediate"] = tuple( diff --git a/tests/__init__.py b/tests/__init__.py index 8b25f4526..0cd967d11 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -88,6 +88,13 @@ def assert_equal(a, b, tolerance=None): if isinstance(b, list): b = np.array(b) + if isinstance(a, pd_types) or isinstance(b, pd_types): + pd.testing.assert_index_equal(a, b) + return + if has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types): + xr.testing.assert_identical(a, b) + return + if tolerance is None and ( np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64) ): @@ -95,11 +102,7 @@ def assert_equal(a, b, tolerance=None): else: tolerance = {} - if isinstance(a, pd_types) or isinstance(b, pd_types): - pd.testing.assert_index_equal(a, b) - elif has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types): - xr.testing.assert_identical(a, b) - elif has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type): + if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type): # sometimes it's nice to see values and shapes # rather than being dropped into some file in dask np.testing.assert_allclose(a, b, **tolerance) From 3d5ccf9bd91a563fd61d88b216d6d500e950e658 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 15:52:42 -0600 Subject: [PATCH 18/21] Fix more var tests --- tests/test_core.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_core.py b/tests/test_core.py index 642e26459..05deae2fd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -468,6 +468,11 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): fill_value = False else: fill_value = 123 + + if "var" in func or "std" in func: + tolerance = {"rtol": 1e-14, "atol": 1e-16} + else: + tolerance = None # tests against the numpy output to make sure dask compute matches by = np.broadcast_to(labels2d, (3, *labels2d.shape)) rng = np.random.default_rng(12345) @@ -486,7 +491,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): kwargs.pop("engine") expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy") assert_equal(expected_npg, expected) - assert_equal(actual, expected) + assert_equal(actual, expected, tolerance) @pytest.mark.parametrize("chunks", [None, (2, 2, 3)]) From e134fcaac44ab4814e3aaa53f236cd2803022ff8 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 18:49:45 -0600 Subject: [PATCH 19/21] Fix argreductions --- flox/aggregations.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index fff661e3d..1777e317f 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -65,7 +65,7 @@ def _normalize_dtype(dtype, array_dtype, fill_value=None): dtype = np.dtype("float64") elif not isinstance(dtype, np.dtype): dtype = np.dtype(dtype) - if fill_value is not None: + if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]: dtype = np.result_type(dtype, fill_value) return dtype @@ -492,7 +492,8 @@ def _initialize_aggregation( agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value) agg.dtype["numpy"] = (agg.dtype[func],) agg.dtype["intermediate"] = [ - int_dtype or agg.dtype[func] for int_dtype in agg.dtype["intermediate"] + _normalize_dtype(int_dtype, array_dtype, int_fv) + for int_dtype, int_fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"]) ] # Replace sentinel fill values according to dtype From 5db0bca11304c10e5badd97b300d470721e88fd4 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 20:28:50 -0600 Subject: [PATCH 20/21] Another fix? --- flox/aggregations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 1777e317f..c7e6afa69 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -492,7 +492,7 @@ def _initialize_aggregation( agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value) agg.dtype["numpy"] = (agg.dtype[func],) agg.dtype["intermediate"] = [ - _normalize_dtype(int_dtype, array_dtype, int_fv) + _normalize_dtype(int_dtype, np.result_type(agg.dtype[func], array_dtype), int_fv) for int_dtype, int_fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"]) ] From 3bb94fb241c7a1d2cb04682f2e78e6c18ca763f4 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 9 Oct 2022 20:36:20 -0600 Subject: [PATCH 21/21] Another fix. --- flox/aggregations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index c7e6afa69..0cf6b259b 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -492,7 +492,9 @@ def _initialize_aggregation( agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value) agg.dtype["numpy"] = (agg.dtype[func],) agg.dtype["intermediate"] = [ - _normalize_dtype(int_dtype, np.result_type(agg.dtype[func], array_dtype), int_fv) + _normalize_dtype(int_dtype, np.result_type(array_dtype, agg.dtype[func]), int_fv) + if int_dtype is None + else int_dtype for int_dtype, int_fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"]) ]