diff --git a/docs/api-reference.md b/docs/api-reference.md index 38d0d26..f1e0cb3 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -19,6 +19,7 @@ nunique one_hot pad + quantile setdiff1d sinc ``` diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index ddfc715..5610eb8 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import isclose, one_hot, pad +from ._delegation import isclose, one_hot, pad, quantile from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -36,6 +36,7 @@ "nunique", "one_hot", "pad", + "quantile", "setdiff1d", "sinc", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 756841c..9bca87a 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -5,6 +5,7 @@ from typing import Literal from ._lib import _funcs +from ._lib._quantile import quantile as _quantile from ._lib._utils._compat import ( array_namespace, is_cupy_namespace, @@ -18,7 +19,7 @@ from ._lib._utils._helpers import asarrays from ._lib._utils._typing import Array, DType -__all__ = ["isclose", "one_hot", "pad"] +__all__ = ["isclose", "one_hot", "pad", "quantile"] def isclose( @@ -247,3 +248,99 @@ def pad( return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) + + +def quantile( + x: Array, + q: Array | float, + /, + *, + axis: int | None = None, + keepdims: bool | None = None, + method: str = "linear", + xp: ModuleType | None = None, +) -> Array: + """ + Compute the q-th quantile(s) of the data along the specified axis. + + Parameters + ---------- + x : array of real numbers + Data array. + q : array of float + Probability or sequence of probabilities of the quantiles to compute. + Values must be between 0 and 1 (inclusive). Must have length 1 along + `axis` unless ``keepdims=True``. + axis : int or None, default: None + Axis along which the quantiles are computed. ``None`` ravels both `x` + and `q` before performing the calculation. + keepdims : bool or None, default: None + By default, the axis will be reduced away if possible + (i.e. if there is exactly one element of `q` per axis-slice of `x`). + If `keepdims` is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result will + broadcast correctly against the original array `x`. + If `keepdims` is set to False, the axis will be reduced away if possible, + and an error will be raised otherwise. + method : str, default: 'linear' + The method to use for estimating the quantile. The available options are: + 'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', + 'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default), + 'median_unbiased', 'normal_unbiased'. + xp : array_namespace, optional + The standard-compatible namespace for `x` and `q`. Default: infer. + + Returns + ------- + array + An array with the quantiles of the data. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([[10, 8, 7, 5, 4], [0, 1, 2, 3, 5]]) + >>> xpx.quantile(x, 0.5, axis=-1) + Array([7., 2.], dtype=array_api_strict.float64) + >>> xpx.quantile(x, [0.25, 0.75], axis=-1) + Array([[5., 8.], + [1., 3.]], dtype=array_api_strict.float64) + """ + # We only support a subset of the methods supported by scipy.stats.quantile. + # So we need to perform the validation here. + methods = { + "inverted_cdf", + "averaged_inverted_cdf", + "closest_observation", + "hazen", + "interpolated_inverted_cdf", + "linear", + "median_unbiased", + "normal_unbiased", + "weibull", + } + if method not in methods: + msg = f"`method` must be one of {methods}" + raise ValueError(msg) + + xp = array_namespace(x, q) if xp is None else xp + + if is_dask_namespace(xp): + return xp.quantile(x, q, axis=axis, keepdims=keepdims, method=method) + + try: + import scipy # type: ignore[import-untyped] + from packaging import version + + # The quantile function in scipy 1.16 supports array API directly, no need + # to delegate + if version.parse(scipy.__version__) >= version.parse("1.17"): # pyright: ignore[reportUnknownArgumentType] + from scipy.stats import ( # type: ignore[import-untyped] + quantile as scipy_quantile, + ) + + return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method) + except (ImportError, AttributeError): + pass + + return _quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py new file mode 100644 index 0000000..9670d4d --- /dev/null +++ b/src/array_api_extra/_lib/_quantile.py @@ -0,0 +1,149 @@ +"""Quantile implementation.""" + +from types import ModuleType +from typing import cast + +from ._at import at +from ._utils import _compat +from ._utils._compat import array_namespace +from ._utils._typing import Array + + +def quantile( + x: Array, + q: Array | float, + /, + *, + axis: int | None = None, + keepdims: bool | None = None, + method: str = "linear", + xp: ModuleType | None = None, +) -> Array: # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" + if xp is None: + xp = array_namespace(x, q) + + q_is_scalar = isinstance(q, int | float) + if q_is_scalar: + q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x)) + q_arr = cast(Array, q) + + if not xp.isdtype(x.dtype, ("integral", "real floating")): + msg = "`x` must have real dtype." + raise ValueError(msg) + if not xp.isdtype(q_arr.dtype, "real floating"): + msg = "`q` must have real floating dtype." + raise ValueError(msg) + + # Promote to common dtype + x = xp.astype(x, xp.float64) + q_arr = xp.asarray(q_arr, dtype=xp.float64, device=_compat.device(x)) + + dtype = x.dtype + axis_none = axis is None + ndim = max(x.ndim, q_arr.ndim) + + if axis_none: + x = xp.reshape(x, (-1,)) + q_arr = xp.reshape(q_arr, (-1,)) + axis = 0 + elif not isinstance(axis, int): # pyright: ignore[reportUnnecessaryIsInstance] + msg = "`axis` must be an integer or None." + raise ValueError(msg) + elif axis >= ndim or axis < -ndim: + msg = "`axis` is not compatible with the shapes of the inputs." + raise ValueError(msg) + else: + axis = int(axis) + + if keepdims not in {None, True, False}: + msg = "If specified, `keepdims` must be True or False." + raise ValueError(msg) + + if x.shape[axis] == 0: + shape = list(x.shape) + shape[axis] = 1 + x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x)) + + y = xp.sort(x, axis=axis) + + # Move axis to the end for easier processing + y = xp.moveaxis(y, axis, -1) + if not (q_is_scalar or q_arr.ndim == 0): + q_arr = xp.moveaxis(q_arr, axis, -1) + + n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) + + # Validate that q values are in the range [0, 1] + if xp.any((q_arr < 0) | (q_arr > 1)): + msg = "`q` must contain values between 0 and 1 inclusive." + raise ValueError(msg) + + res = _quantile_hf(y, q_arr, n, method, xp) + + # Reshape per axis/keepdims + if axis_none and keepdims: + shape = (1,) * (ndim - 1) + res.shape + res = xp.reshape(res, shape) + axis = -1 + + # Move axis back to original position + res = xp.moveaxis(res, -1, axis) + + if not keepdims and res.shape[axis] == 1: + res = xp.squeeze(res, axis=axis) + + if res.ndim == 0: + return res[()] + return res + + +def _quantile_hf( + y: Array, p: Array, n: Array, method: str, xp: ModuleType +) -> Array: # numpydoc ignore=PR01,RT01 + """Helper function for Hyndman-Fan quantile method.""" + ms: dict[str, Array | int | float] = { + "inverted_cdf": 0, + "averaged_inverted_cdf": 0, + "closest_observation": -0.5, + "interpolated_inverted_cdf": 0, + "hazen": 0.5, + "weibull": p, + "linear": 1 - p, + "median_unbiased": p / 3 + 1 / 3, + "normal_unbiased": p / 4 + 3 / 8, + } + m = ms[method] + + jg = p * n + m - 1 + # Convert both to integers, the type of j and n must be the same + # for us to be able to `xp.clip` them. + j = xp.astype(jg // 1, xp.int64) + n = xp.astype(n, xp.int64) + g = jg % 1 + + if method == "inverted_cdf": + g = xp.astype((g > 0), jg.dtype) + elif method == "averaged_inverted_cdf": + g = (1 + xp.astype((g > 0), jg.dtype)) / 2 + elif method == "closest_observation": + g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype) + if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}: + g = xp.asarray(g) + g = at(g, jg < 0).set(0) + g = at(g, j < 0).set(0) + j = xp.clip(j, 0, n - 1) + jp1 = xp.clip(j + 1, 0, n - 1) + + # Broadcast indices to match y shape except for the last axis + if y.ndim > 1: + # Create broadcast shape for indices + broadcast_shape = [*y.shape[:-1], 1] + j = xp.broadcast_to(j, broadcast_shape) + jp1 = xp.broadcast_to(jp1, broadcast_shape) + g = xp.broadcast_to(g, broadcast_shape) + + res = (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis( + y, jp1, axis=-1 + ) + return res # noqa: RET504 diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 769b411..d7d6bc5 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -24,6 +24,7 @@ nunique, one_hot, pad, + quantile, setdiff1d, sinc, ) @@ -43,6 +44,7 @@ lazy_xp_function(nunique) lazy_xp_function(one_hot) lazy_xp_function(pad) +lazy_xp_function(quantile) # FIXME calls in1d which calls xp.unique_values without size lazy_xp_function(setdiff1d, jax_jit=False) lazy_xp_function(sinc) @@ -1162,3 +1164,73 @@ def test_device(self, xp: ModuleType, device: Device): def test_xp(self, xp: ModuleType): xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) + + +class TestQuantile: + def test_basic(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, 0.5) + expect = xp.asarray(3.0) + xp_assert_close(actual, expect) + + def test_multiple_quantiles(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, xp.asarray([0.25, 0.5, 0.75])) + expect = xp.asarray([2.0, 3.0, 4.0]) + xp_assert_close(actual, expect) + + def test_2d_axis(self, xp: ModuleType): + x = xp.asarray([[1, 2, 3], [4, 5, 6]]) + actual = quantile(x, 0.5, axis=0) + expect = xp.asarray([2.5, 3.5, 4.5]) + xp_assert_close(actual, expect) + + def test_2d_axis_keepdims(self, xp: ModuleType): + x = xp.asarray([[1, 2, 3], [4, 5, 6]]) + actual = quantile(x, 0.5, axis=0, keepdims=True) + expect = xp.asarray([[2.5, 3.5, 4.5]]) + xp_assert_close(actual, expect) + + def test_methods(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + methods = ["linear", "hazen", "weibull"] + for method in methods: + actual = quantile(x, 0.5, method=method) + # All methods should give reasonable results + assert 2.5 <= float(actual) <= 3.5 + + def test_edge_cases(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + # q = 0 should give minimum + actual = quantile(x, 0.0) + expect = xp.asarray(1.0) + xp_assert_close(actual, expect) + + # q = 1 should give maximum + actual = quantile(x, 1.0) + expect = xp.asarray(5.0) + xp_assert_close(actual, expect) + + def test_invalid_q(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + # q > 1 should raise + with pytest.raises( + ValueError, match="`q` must contain values between 0 and 1 inclusive" + ): + _ = quantile(x, 1.5) + + with pytest.raises( + ValueError, match="`q` must contain values between 0 and 1 inclusive" + ): + _ = quantile(x, -0.5) + + def test_device(self, xp: ModuleType, device: Device): + x = xp.asarray([1, 2, 3, 4, 5], device=device) + actual = quantile(x, 0.5) + assert get_device(actual) == device + + def test_xp(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, 0.5, xp=xp) + expect = xp.asarray(3.0) + xp_assert_close(actual, expect)