From f7cac0147f2820e4336051b743581ffc4d07b328 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 13:46:40 +0200 Subject: [PATCH 01/20] Add delegation for `quantile` This makes quantile available when the version of Scipy is not new enough to support array API inputs. --- docs/api-reference.md | 1 + src/array_api_extra/__init__.py | 3 +- src/array_api_extra/_delegation.py | 69 +++++++++++- src/array_api_extra/_lib/_funcs.py | 164 +++++++++++++++++++++++++++++ tests/test_funcs.py | 69 ++++++++++++ 5 files changed, 304 insertions(+), 2 deletions(-) diff --git a/docs/api-reference.md b/docs/api-reference.md index 38d0d26e..f1e0cb3d 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -19,6 +19,7 @@ nunique one_hot pad + quantile setdiff1d sinc ``` diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index ddfc715e..5610eb8d 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import isclose, one_hot, pad +from ._delegation import isclose, one_hot, pad, quantile from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -36,6 +36,7 @@ "nunique", "one_hot", "pad", + "quantile", "setdiff1d", "sinc", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 756841c8..dc122faf 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -18,7 +18,7 @@ from ._lib._utils._helpers import asarrays from ._lib._utils._typing import Array, DType -__all__ = ["isclose", "one_hot", "pad"] +__all__ = ["isclose", "one_hot", "pad", "quantile"] def isclose( @@ -247,3 +247,70 @@ def pad( return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) + + +def quantile( + x: Array, + q: Array | float, + /, + *, + axis: int | None = None, + keepdims: bool = False, + method: str = "linear", + xp: ModuleType | None = None, +) -> Array: + """ + Compute the q-th quantile(s) of the data along the specified axis. + + Parameters + ---------- + x : array of real numbers + Data array. + q : array of float + Probability or sequence of probabilities of the quantiles to compute. + Values must be between 0 and 1 (inclusive). Must have length 1 along + `axis` unless ``keepdims=True``. + method : str, default: 'linear' + The method to use for estimating the quantile. The available options are: + 'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', + 'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default), + 'median_unbiased', 'normal_unbiased', 'harrell-davis'. + axis : int or None, default: None + Axis along which the quantiles are computed. ``None`` ravels both `x` + and `q` before performing the calculation. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result will + broadcast correctly against the original array `x`. + xp : array_namespace, optional + The standard-compatible namespace for `x` and `q`. Default: infer. + + Returns + ------- + array + An array with the quantiles of the data. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([[10, 8, 7, 5, 4], [0, 1, 2, 3, 5]]) + >>> xpx.quantile(x, 0.5, axis=-1) + Array([7., 2.], dtype=array_api_strict.float64) + >>> xpx.quantile(x, [0.25, 0.75], axis=-1) + Array([[5., 8.], + [1., 3.]], dtype=array_api_strict.float64) + """ + xp = array_namespace(x, q) if xp is None else xp + + try: + import scipy + from packaging import version + # The quantile function in scipy 1.17 supports array API directly, no need to delegate + if version.parse(scipy.__version__) >= version.parse("1.17"): + from scipy.stats import quantile as scipy_quantile + return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method) + except (ImportError, AttributeError): + pass + + return _funcs.quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index cf1a894a..7a7de597 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -28,6 +28,7 @@ "kron", "nunique", "pad", + "quantile", "setdiff1d", "sinc", ] @@ -988,3 +989,166 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), ) return xp.sin(y) / y + + +def quantile( + x: Array, + q: Array | float, + /, + *, + axis: int | None = None, + keepdims: bool = False, + method: str = "linear", + xp: ModuleType | None = None, +) -> Array: + """See docstring in `array_api_extra._delegation.py`.""" + if xp is None: + xp = array_namespace(x, q) + + # Convert q to array if it's a scalar + q_is_scalar = isinstance(q, (int, float)) + if q_is_scalar: + q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x)) + + # Validate inputs + if not xp.isdtype(x.dtype, ("integral", "real floating")): + raise ValueError("`x` must have real dtype.") + if not xp.isdtype(q.dtype, "real floating"): + raise ValueError("`q` must have real floating dtype.") + + # Promote to common dtype + x = xp.astype(x, xp.float64) + q = xp.astype(q, xp.float64) + q = xp.asarray(q, device=_compat.device(x)) + + dtype = x.dtype + axis_none = axis is None + ndim = max(x.ndim, q.ndim) + + if axis_none: + x = xp.reshape(x, (-1,)) + q = xp.reshape(q, (-1,)) + axis = 0 + elif not isinstance(axis, int): + raise ValueError("`axis` must be an integer or None.") + elif axis >= ndim or axis < -ndim: + raise ValueError("`axis` is not compatible with the shapes of the inputs.") + else: + axis = int(axis) + + # Validate method + methods = { + 'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', + 'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased', + 'normal_unbiased', 'weibull', 'harrell-davis' + } + if method not in methods: + raise ValueError(f"`method` must be one of {methods}") + + # Handle keepdims parameter + if keepdims not in {None, True, False}: + raise ValueError("If specified, `keepdims` must be True or False.") + + # Handle empty arrays + if x.shape[axis] == 0: + shape = list(x.shape) + shape[axis] = 1 + x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x)) + + # Sort the data + y = xp.sort(x, axis=axis) + + # Move axis to the end for easier processing + y = xp.moveaxis(y, axis, -1) + if not (q_is_scalar or q.ndim == 0): + q = xp.moveaxis(q, axis, -1) + + # Get the number of elements along the axis + n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) + + # Apply quantile calculation based on method + if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', + 'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased', + 'normal_unbiased', 'weibull'}: + res = _quantile_hf(y, q, n, method, xp) + elif method == 'harrell-davis': + res = _quantile_hd(y, q, n, xp) + else: + raise ValueError(f"Unknown method: {method}") + + # Handle NaN output for invalid q values + p_mask = (q > 1) | (q < 0) | xp.isnan(q) + if xp.any(p_mask): + res = xp.asarray(res, copy=True) + res = at(res, p_mask).set(xp.nan) + + # Reshape per axis/keepdims + if axis_none and keepdims: + shape = (1,) * (ndim - 1) + res.shape + res = xp.reshape(res, shape) + axis = -1 + + # Move axis back to original position + res = xp.moveaxis(res, -1, axis) + + # Handle keepdims + if not keepdims and res.shape[axis] == 1: + res = xp.squeeze(res, axis=axis) + + # For scalar q, ensure we return a scalar result + if q_is_scalar: + if hasattr(res, 'shape') and res.shape != (): + res = res[()] + + return res + + +def _quantile_hf(y: Array, p: Array, n: Array, method: str, xp: ModuleType) -> Array: + """Helper function for Hyndman-Fan quantile methods.""" + ms = { + 'inverted_cdf': 0, + 'averaged_inverted_cdf': 0, + 'closest_observation': -0.5, + 'interpolated_inverted_cdf': 0, + 'hazen': 0.5, + 'weibull': p, + 'linear': 1 - p, + 'median_unbiased': p/3 + 1/3, + 'normal_unbiased': p/4 + 3/8 + } + m = ms[method] + + jg = p * n + m - 1 + j = xp.astype(jg // 1, xp.int64) # Convert to integer + g = jg % 1 + + if method == 'inverted_cdf': + g = xp.astype((g > 0), jg.dtype) + elif method == 'averaged_inverted_cdf': + g = (1 + xp.astype((g > 0), jg.dtype)) / 2 + elif method == 'closest_observation': + g = (1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype)) + if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation'}: + g = xp.asarray(g) + g = at(g, jg < 0).set(0) + g = at(g, j < 0).set(0) + j = xp.clip(j, 0, n - 1) + jp1 = xp.clip(j + 1, 0, n - 1) + + # Broadcast indices to match y shape except for the last axis + if y.ndim > 1: + # Create broadcast shape for indices + broadcast_shape = list(y.shape[:-1]) + [1] + j = xp.broadcast_to(j, broadcast_shape) + jp1 = xp.broadcast_to(jp1, broadcast_shape) + g = xp.broadcast_to(g, broadcast_shape) + + return ((1 - g) * xp.take_along_axis(y, j, axis=-1) + + g * xp.take_along_axis(y, jp1, axis=-1)) + + +def _quantile_hd(y: Array, p: Array, n: Array, xp: ModuleType) -> Array: + """Helper function for Harrell-Davis quantile method.""" + # For now, implement a simplified version that falls back to linear method + # since betainc is not available in the array API standard + return _quantile_hf(y, p, n, "linear", xp) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index dc0658ab..d3b389fa 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -24,6 +24,7 @@ nunique, one_hot, pad, + quantile, setdiff1d, sinc, ) @@ -43,6 +44,7 @@ lazy_xp_function(nunique) lazy_xp_function(one_hot) lazy_xp_function(pad) +lazy_xp_function(quantile) # FIXME calls in1d which calls xp.unique_values without size lazy_xp_function(setdiff1d, jax_jit=False) lazy_xp_function(sinc) @@ -1162,3 +1164,70 @@ def test_device(self, xp: ModuleType, device: Device): def test_xp(self, xp: ModuleType): xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) + + +class TestQuantile: + def test_basic(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, 0.5) + expect = xp.asarray(3.0) + xp_assert_close(actual, expect) + + def test_multiple_quantiles(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, xp.asarray([0.25, 0.5, 0.75])) + expect = xp.asarray([2.0, 3.0, 4.0]) + xp_assert_close(actual, expect) + + def test_2d_axis(self, xp: ModuleType): + x = xp.asarray([[1, 2, 3], [4, 5, 6]]) + actual = quantile(x, 0.5, axis=0) + expect = xp.asarray([2.5, 3.5, 4.5]) + xp_assert_close(actual, expect) + + def test_2d_axis_keepdims(self, xp: ModuleType): + x = xp.asarray([[1, 2, 3], [4, 5, 6]]) + actual = quantile(x, 0.5, axis=0, keepdims=True) + expect = xp.asarray([[2.5, 3.5, 4.5]]) + xp_assert_close(actual, expect) + + def test_methods(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + methods = ['linear', 'hazen', 'weibull'] + for method in methods: + actual = quantile(x, 0.5, method=method) + # All methods should give reasonable results + assert 2.5 <= float(actual) <= 3.5 + + def test_edge_cases(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + # q = 0 should give minimum + actual = quantile(x, 0.0) + expect = xp.asarray(1.0) + xp_assert_close(actual, expect) + + # q = 1 should give maximum + actual = quantile(x, 1.0) + expect = xp.asarray(5.0) + xp_assert_close(actual, expect) + + def test_invalid_q(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + # q > 1 should return NaN + actual = quantile(x, 1.5) + assert xp.isnan(actual) + + # q < 0 should return NaN + actual = quantile(x, -0.5) + assert xp.isnan(actual) + + def test_device(self, xp: ModuleType, device: Device): + x = xp.asarray([1, 2, 3, 4, 5], device=device) + actual = quantile(x, 0.5) + assert get_device(actual) == device + + def test_xp(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, 0.5, xp=xp) + expect = xp.asarray(3.0) + xp_assert_close(actual, expect) From 9577f119d70db3e489caab50eb0539244f38858b Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 14:19:58 +0200 Subject: [PATCH 02/20] Formatting --- src/array_api_extra/_delegation.py | 15 ++-- src/array_api_extra/_lib/_funcs.py | 129 ++++++++++++++++------------- tests/test_funcs.py | 6 +- 3 files changed, 84 insertions(+), 66 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index dc122faf..10f8ca55 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -270,11 +270,6 @@ def quantile( Probability or sequence of probabilities of the quantiles to compute. Values must be between 0 and 1 (inclusive). Must have length 1 along `axis` unless ``keepdims=True``. - method : str, default: 'linear' - The method to use for estimating the quantile. The available options are: - 'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', - 'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default), - 'median_unbiased', 'normal_unbiased', 'harrell-davis'. axis : int or None, default: None Axis along which the quantiles are computed. ``None`` ravels both `x` and `q` before performing the calculation. @@ -282,6 +277,11 @@ def quantile( If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array `x`. + method : str, default: 'linear' + The method to use for estimating the quantile. The available options are: + 'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', + 'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default), + 'median_unbiased', 'normal_unbiased', 'harrell-davis'. xp : array_namespace, optional The standard-compatible namespace for `x` and `q`. Default: infer. @@ -306,9 +306,12 @@ def quantile( try: import scipy from packaging import version - # The quantile function in scipy 1.17 supports array API directly, no need to delegate + + # The quantile function in scipy 1.17 supports array API directly, no need + # to delegate if version.parse(scipy.__version__) >= version.parse("1.17"): from scipy.stats import quantile as scipy_quantile + return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method) except (ImportError, AttributeError): pass diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 7a7de597..464dd84c 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -1004,147 +1004,162 @@ def quantile( """See docstring in `array_api_extra._delegation.py`.""" if xp is None: xp = array_namespace(x, q) - + # Convert q to array if it's a scalar - q_is_scalar = isinstance(q, (int, float)) + q_is_scalar = isinstance(q, int | float) if q_is_scalar: q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x)) - + # Validate inputs if not xp.isdtype(x.dtype, ("integral", "real floating")): - raise ValueError("`x` must have real dtype.") + raise ValueError("`x` must have real dtype.") # noqa: EM101 if not xp.isdtype(q.dtype, "real floating"): - raise ValueError("`q` must have real floating dtype.") - + raise ValueError("`q` must have real floating dtype.") # noqa: EM101 + # Promote to common dtype x = xp.astype(x, xp.float64) q = xp.astype(q, xp.float64) q = xp.asarray(q, device=_compat.device(x)) - + dtype = x.dtype axis_none = axis is None ndim = max(x.ndim, q.ndim) - + if axis_none: x = xp.reshape(x, (-1,)) q = xp.reshape(q, (-1,)) axis = 0 elif not isinstance(axis, int): - raise ValueError("`axis` must be an integer or None.") + raise ValueError("`axis` must be an integer or None.") # noqa: EM101 elif axis >= ndim or axis < -ndim: - raise ValueError("`axis` is not compatible with the shapes of the inputs.") + raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101 else: axis = int(axis) - + # Validate method methods = { - 'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', - 'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased', - 'normal_unbiased', 'weibull', 'harrell-davis' + "inverted_cdf", + "averaged_inverted_cdf", + "closest_observation", + "hazen", + "interpolated_inverted_cdf", + "linear", + "median_unbiased", + "normal_unbiased", + "weibull", + "harrell-davis", } if method not in methods: - raise ValueError(f"`method` must be one of {methods}") - + raise ValueError(f"`method` must be one of {methods}") # noqa: EM102 + # Handle keepdims parameter if keepdims not in {None, True, False}: - raise ValueError("If specified, `keepdims` must be True or False.") - + raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101 + # Handle empty arrays if x.shape[axis] == 0: shape = list(x.shape) shape[axis] = 1 x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x)) - + # Sort the data y = xp.sort(x, axis=axis) - + # Move axis to the end for easier processing y = xp.moveaxis(y, axis, -1) if not (q_is_scalar or q.ndim == 0): q = xp.moveaxis(q, axis, -1) - + # Get the number of elements along the axis n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) - + # Apply quantile calculation based on method - if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', - 'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased', - 'normal_unbiased', 'weibull'}: + if method in { + "inverted_cdf", + "averaged_inverted_cdf", + "closest_observation", + "hazen", + "interpolated_inverted_cdf", + "linear", + "median_unbiased", + "normal_unbiased", + "weibull", + }: res = _quantile_hf(y, q, n, method, xp) - elif method == 'harrell-davis': + elif method == "harrell-davis": res = _quantile_hd(y, q, n, xp) else: - raise ValueError(f"Unknown method: {method}") - + raise ValueError(f"Unknown method: {method}") # noqa: EM102 + # Handle NaN output for invalid q values p_mask = (q > 1) | (q < 0) | xp.isnan(q) if xp.any(p_mask): res = xp.asarray(res, copy=True) res = at(res, p_mask).set(xp.nan) - + # Reshape per axis/keepdims if axis_none and keepdims: shape = (1,) * (ndim - 1) + res.shape res = xp.reshape(res, shape) axis = -1 - + # Move axis back to original position res = xp.moveaxis(res, -1, axis) - + # Handle keepdims if not keepdims and res.shape[axis] == 1: res = xp.squeeze(res, axis=axis) - + # For scalar q, ensure we return a scalar result - if q_is_scalar: - if hasattr(res, 'shape') and res.shape != (): - res = res[()] - + if q_is_scalar and hasattr(res, "shape") and res.shape != (): + res = res[()] + return res def _quantile_hf(y: Array, p: Array, n: Array, method: str, xp: ModuleType) -> Array: """Helper function for Hyndman-Fan quantile methods.""" ms = { - 'inverted_cdf': 0, - 'averaged_inverted_cdf': 0, - 'closest_observation': -0.5, - 'interpolated_inverted_cdf': 0, - 'hazen': 0.5, - 'weibull': p, - 'linear': 1 - p, - 'median_unbiased': p/3 + 1/3, - 'normal_unbiased': p/4 + 3/8 + "inverted_cdf": 0, + "averaged_inverted_cdf": 0, + "closest_observation": -0.5, + "interpolated_inverted_cdf": 0, + "hazen": 0.5, + "weibull": p, + "linear": 1 - p, + "median_unbiased": p / 3 + 1 / 3, + "normal_unbiased": p / 4 + 3 / 8, } m = ms[method] - + jg = p * n + m - 1 j = xp.astype(jg // 1, xp.int64) # Convert to integer g = jg % 1 - - if method == 'inverted_cdf': + + if method == "inverted_cdf": g = xp.astype((g > 0), jg.dtype) - elif method == 'averaged_inverted_cdf': + elif method == "averaged_inverted_cdf": g = (1 + xp.astype((g > 0), jg.dtype)) / 2 - elif method == 'closest_observation': - g = (1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype)) - if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation'}: + elif method == "closest_observation": + g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype) + if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}: g = xp.asarray(g) g = at(g, jg < 0).set(0) g = at(g, j < 0).set(0) j = xp.clip(j, 0, n - 1) jp1 = xp.clip(j + 1, 0, n - 1) - + # Broadcast indices to match y shape except for the last axis if y.ndim > 1: # Create broadcast shape for indices - broadcast_shape = list(y.shape[:-1]) + [1] + broadcast_shape = list(y.shape[:-1]) + [1] # noqa: RUF005 j = xp.broadcast_to(j, broadcast_shape) jp1 = xp.broadcast_to(jp1, broadcast_shape) g = xp.broadcast_to(g, broadcast_shape) - - return ((1 - g) * xp.take_along_axis(y, j, axis=-1) + - g * xp.take_along_axis(y, jp1, axis=-1)) + + return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis( + y, jp1, axis=-1 + ) def _quantile_hd(y: Array, p: Array, n: Array, xp: ModuleType) -> Array: diff --git a/tests/test_funcs.py b/tests/test_funcs.py index d3b389fa..84ebbaf7 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1193,7 +1193,7 @@ def test_2d_axis_keepdims(self, xp: ModuleType): def test_methods(self, xp: ModuleType): x = xp.asarray([1, 2, 3, 4, 5]) - methods = ['linear', 'hazen', 'weibull'] + methods = ["linear", "hazen", "weibull"] for method in methods: actual = quantile(x, 0.5, method=method) # All methods should give reasonable results @@ -1205,7 +1205,7 @@ def test_edge_cases(self, xp: ModuleType): actual = quantile(x, 0.0) expect = xp.asarray(1.0) xp_assert_close(actual, expect) - + # q = 1 should give maximum actual = quantile(x, 1.0) expect = xp.asarray(5.0) @@ -1216,7 +1216,7 @@ def test_invalid_q(self, xp: ModuleType): # q > 1 should return NaN actual = quantile(x, 1.5) assert xp.isnan(actual) - + # q < 0 should return NaN actual = quantile(x, -0.5) assert xp.isnan(actual) From e8a7d30bfcf999f70848f1d73adc7d93a22b939e Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 14:52:22 +0200 Subject: [PATCH 03/20] Fix scipy version --- src/array_api_extra/_delegation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 10f8ca55..3798108e 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -307,9 +307,9 @@ def quantile( import scipy from packaging import version - # The quantile function in scipy 1.17 supports array API directly, no need + # The quantile function in scipy 1.16 supports array API directly, no need # to delegate - if version.parse(scipy.__version__) >= version.parse("1.17"): + if version.parse(scipy.__version__) >= version.parse("1.16"): from scipy.stats import quantile as scipy_quantile return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method) From 10b9ec85e88dcaab63d5af0d9ffd9d5f8f29c525 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 14:56:45 +0200 Subject: [PATCH 04/20] Formatting --- src/array_api_extra/_lib/_funcs.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 464dd84c..5d2661b9 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -1000,7 +1000,7 @@ def quantile( keepdims: bool = False, method: str = "linear", xp: ModuleType | None = None, -) -> Array: +) -> Array: # numpydoc ignore=PR01,RT01 """See docstring in `array_api_extra._delegation.py`.""" if xp is None: xp = array_namespace(x, q) @@ -1117,7 +1117,9 @@ def quantile( return res -def _quantile_hf(y: Array, p: Array, n: Array, method: str, xp: ModuleType) -> Array: +def _quantile_hf( + y: Array, p: Array, n: Array, method: str, xp: ModuleType +) -> Array: # numpydoc ignore=PR01,RT01 """Helper function for Hyndman-Fan quantile methods.""" ms = { "inverted_cdf": 0, @@ -1162,7 +1164,9 @@ def _quantile_hf(y: Array, p: Array, n: Array, method: str, xp: ModuleType) -> A ) -def _quantile_hd(y: Array, p: Array, n: Array, xp: ModuleType) -> Array: +def _quantile_hd( + y: Array, p: Array, n: Array, xp: ModuleType +) -> Array: # numpydoc ignore=PR01,RT01 """Helper function for Harrell-Davis quantile method.""" # For now, implement a simplified version that falls back to linear method # since betainc is not available in the array API standard From 98570f68c1347375ced96a7915e28922b58acd94 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 14:59:14 +0200 Subject: [PATCH 05/20] Remove superfluous comments --- src/array_api_extra/_lib/_funcs.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 5d2661b9..101a1ad1 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -1005,12 +1005,10 @@ def quantile( if xp is None: xp = array_namespace(x, q) - # Convert q to array if it's a scalar q_is_scalar = isinstance(q, int | float) if q_is_scalar: q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x)) - # Validate inputs if not xp.isdtype(x.dtype, ("integral", "real floating")): raise ValueError("`x` must have real dtype.") # noqa: EM101 if not xp.isdtype(q.dtype, "real floating"): @@ -1036,7 +1034,6 @@ def quantile( else: axis = int(axis) - # Validate method methods = { "inverted_cdf", "averaged_inverted_cdf", @@ -1052,17 +1049,14 @@ def quantile( if method not in methods: raise ValueError(f"`method` must be one of {methods}") # noqa: EM102 - # Handle keepdims parameter if keepdims not in {None, True, False}: raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101 - # Handle empty arrays if x.shape[axis] == 0: shape = list(x.shape) shape[axis] = 1 x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x)) - # Sort the data y = xp.sort(x, axis=axis) # Move axis to the end for easier processing @@ -1070,10 +1064,8 @@ def quantile( if not (q_is_scalar or q.ndim == 0): q = xp.moveaxis(q, axis, -1) - # Get the number of elements along the axis n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) - # Apply quantile calculation based on method if method in { "inverted_cdf", "averaged_inverted_cdf", From 470f8b4713272dcbda49e13527ee0bcd7d04afc7 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 15:39:52 +0200 Subject: [PATCH 06/20] Remove unsupported method We can add it back later if we need it --- src/array_api_extra/_delegation.py | 18 +++++++++++++++++- src/array_api_extra/_lib/_funcs.py | 28 +--------------------------- 2 files changed, 18 insertions(+), 28 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 3798108e..c9742fad 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -281,7 +281,7 @@ def quantile( The method to use for estimating the quantile. The available options are: 'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', 'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default), - 'median_unbiased', 'normal_unbiased', 'harrell-davis'. + 'median_unbiased', 'normal_unbiased'. xp : array_namespace, optional The standard-compatible namespace for `x` and `q`. Default: infer. @@ -301,6 +301,22 @@ def quantile( Array([[5., 8.], [1., 3.]], dtype=array_api_strict.float64) """ + # We only support a subset of the methods supported by scipy.stats.quantile. + # So we need to perform the validation here. + methods = { + "inverted_cdf", + "averaged_inverted_cdf", + "closest_observation", + "hazen", + "interpolated_inverted_cdf", + "linear", + "median_unbiased", + "normal_unbiased", + "weibull", + } + if method not in methods: + raise ValueError(f"`method` must be one of {methods}") # noqa: EM102 + xp = array_namespace(x, q) if xp is None else xp try: diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 101a1ad1..9b720167 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -1034,21 +1034,6 @@ def quantile( else: axis = int(axis) - methods = { - "inverted_cdf", - "averaged_inverted_cdf", - "closest_observation", - "hazen", - "interpolated_inverted_cdf", - "linear", - "median_unbiased", - "normal_unbiased", - "weibull", - "harrell-davis", - } - if method not in methods: - raise ValueError(f"`method` must be one of {methods}") # noqa: EM102 - if keepdims not in {None, True, False}: raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101 @@ -1078,8 +1063,6 @@ def quantile( "weibull", }: res = _quantile_hf(y, q, n, method, xp) - elif method == "harrell-davis": - res = _quantile_hd(y, q, n, xp) else: raise ValueError(f"Unknown method: {method}") # noqa: EM102 @@ -1112,7 +1095,7 @@ def quantile( def _quantile_hf( y: Array, p: Array, n: Array, method: str, xp: ModuleType ) -> Array: # numpydoc ignore=PR01,RT01 - """Helper function for Hyndman-Fan quantile methods.""" + """Helper function for Hyndman-Fan quantile method.""" ms = { "inverted_cdf": 0, "averaged_inverted_cdf": 0, @@ -1154,12 +1137,3 @@ def _quantile_hf( return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis( y, jp1, axis=-1 ) - - -def _quantile_hd( - y: Array, p: Array, n: Array, xp: ModuleType -) -> Array: # numpydoc ignore=PR01,RT01 - """Helper function for Harrell-Davis quantile method.""" - # For now, implement a simplified version that falls back to linear method - # since betainc is not available in the array API standard - return _quantile_hf(y, p, n, "linear", xp) From a2eefa0d90c5e01be39d4db9df4195e722ce3491 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 15:46:47 +0200 Subject: [PATCH 07/20] More noqa --- src/array_api_extra/_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index c9742fad..7aa6ac7b 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -303,7 +303,7 @@ def quantile( """ # We only support a subset of the methods supported by scipy.stats.quantile. # So we need to perform the validation here. - methods = { + methods = { # pylint: disable=duplicate-code "inverted_cdf", "averaged_inverted_cdf", "closest_observation", From 37acd5b759b6c5040ea3a122ff86d2187c695215 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 15:48:36 +0200 Subject: [PATCH 08/20] yet more noqa --- src/array_api_extra/_lib/_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 9b720167..5a7a10ff 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -1051,7 +1051,7 @@ def quantile( n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) - if method in { + if method in { # pylint: disable=duplicate-code "inverted_cdf", "averaged_inverted_cdf", "closest_observation", From c3501e8c2da358b8c2d26ff8c62d27337625ba9b Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 15:58:42 +0200 Subject: [PATCH 09/20] Move quantile implementation to new file _funcs.py was getting too long. --- src/array_api_extra/_delegation.py | 3 +- src/array_api_extra/_lib/_funcs.py | 149 ------------------------ src/array_api_extra/_lib/_quantile.py | 156 ++++++++++++++++++++++++++ 3 files changed, 158 insertions(+), 150 deletions(-) create mode 100644 src/array_api_extra/_lib/_quantile.py diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 7aa6ac7b..b83cdd4b 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -5,6 +5,7 @@ from typing import Literal from ._lib import _funcs +from ._lib._quantile import quantile as _quantile from ._lib._utils._compat import ( array_namespace, is_cupy_namespace, @@ -332,4 +333,4 @@ def quantile( except (ImportError, AttributeError): pass - return _funcs.quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp) + return _quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 5a7a10ff..cf1a894a 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -28,7 +28,6 @@ "kron", "nunique", "pad", - "quantile", "setdiff1d", "sinc", ] @@ -989,151 +988,3 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), ) return xp.sin(y) / y - - -def quantile( - x: Array, - q: Array | float, - /, - *, - axis: int | None = None, - keepdims: bool = False, - method: str = "linear", - xp: ModuleType | None = None, -) -> Array: # numpydoc ignore=PR01,RT01 - """See docstring in `array_api_extra._delegation.py`.""" - if xp is None: - xp = array_namespace(x, q) - - q_is_scalar = isinstance(q, int | float) - if q_is_scalar: - q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x)) - - if not xp.isdtype(x.dtype, ("integral", "real floating")): - raise ValueError("`x` must have real dtype.") # noqa: EM101 - if not xp.isdtype(q.dtype, "real floating"): - raise ValueError("`q` must have real floating dtype.") # noqa: EM101 - - # Promote to common dtype - x = xp.astype(x, xp.float64) - q = xp.astype(q, xp.float64) - q = xp.asarray(q, device=_compat.device(x)) - - dtype = x.dtype - axis_none = axis is None - ndim = max(x.ndim, q.ndim) - - if axis_none: - x = xp.reshape(x, (-1,)) - q = xp.reshape(q, (-1,)) - axis = 0 - elif not isinstance(axis, int): - raise ValueError("`axis` must be an integer or None.") # noqa: EM101 - elif axis >= ndim or axis < -ndim: - raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101 - else: - axis = int(axis) - - if keepdims not in {None, True, False}: - raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101 - - if x.shape[axis] == 0: - shape = list(x.shape) - shape[axis] = 1 - x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x)) - - y = xp.sort(x, axis=axis) - - # Move axis to the end for easier processing - y = xp.moveaxis(y, axis, -1) - if not (q_is_scalar or q.ndim == 0): - q = xp.moveaxis(q, axis, -1) - - n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) - - if method in { # pylint: disable=duplicate-code - "inverted_cdf", - "averaged_inverted_cdf", - "closest_observation", - "hazen", - "interpolated_inverted_cdf", - "linear", - "median_unbiased", - "normal_unbiased", - "weibull", - }: - res = _quantile_hf(y, q, n, method, xp) - else: - raise ValueError(f"Unknown method: {method}") # noqa: EM102 - - # Handle NaN output for invalid q values - p_mask = (q > 1) | (q < 0) | xp.isnan(q) - if xp.any(p_mask): - res = xp.asarray(res, copy=True) - res = at(res, p_mask).set(xp.nan) - - # Reshape per axis/keepdims - if axis_none and keepdims: - shape = (1,) * (ndim - 1) + res.shape - res = xp.reshape(res, shape) - axis = -1 - - # Move axis back to original position - res = xp.moveaxis(res, -1, axis) - - # Handle keepdims - if not keepdims and res.shape[axis] == 1: - res = xp.squeeze(res, axis=axis) - - # For scalar q, ensure we return a scalar result - if q_is_scalar and hasattr(res, "shape") and res.shape != (): - res = res[()] - - return res - - -def _quantile_hf( - y: Array, p: Array, n: Array, method: str, xp: ModuleType -) -> Array: # numpydoc ignore=PR01,RT01 - """Helper function for Hyndman-Fan quantile method.""" - ms = { - "inverted_cdf": 0, - "averaged_inverted_cdf": 0, - "closest_observation": -0.5, - "interpolated_inverted_cdf": 0, - "hazen": 0.5, - "weibull": p, - "linear": 1 - p, - "median_unbiased": p / 3 + 1 / 3, - "normal_unbiased": p / 4 + 3 / 8, - } - m = ms[method] - - jg = p * n + m - 1 - j = xp.astype(jg // 1, xp.int64) # Convert to integer - g = jg % 1 - - if method == "inverted_cdf": - g = xp.astype((g > 0), jg.dtype) - elif method == "averaged_inverted_cdf": - g = (1 + xp.astype((g > 0), jg.dtype)) / 2 - elif method == "closest_observation": - g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype) - if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}: - g = xp.asarray(g) - g = at(g, jg < 0).set(0) - g = at(g, j < 0).set(0) - j = xp.clip(j, 0, n - 1) - jp1 = xp.clip(j + 1, 0, n - 1) - - # Broadcast indices to match y shape except for the last axis - if y.ndim > 1: - # Create broadcast shape for indices - broadcast_shape = list(y.shape[:-1]) + [1] # noqa: RUF005 - j = xp.broadcast_to(j, broadcast_shape) - jp1 = xp.broadcast_to(jp1, broadcast_shape) - g = xp.broadcast_to(g, broadcast_shape) - - return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis( - y, jp1, axis=-1 - ) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py new file mode 100644 index 00000000..03d316ef --- /dev/null +++ b/src/array_api_extra/_lib/_quantile.py @@ -0,0 +1,156 @@ +"""Quantile implementation.""" + +from types import ModuleType + +from ._at import at +from ._utils import _compat +from ._utils._compat import array_namespace +from ._utils._typing import Array + + +def quantile( + x: Array, + q: Array | float, + /, + *, + axis: int | None = None, + keepdims: bool = False, + method: str = "linear", + xp: ModuleType | None = None, +) -> Array: # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" + if xp is None: + xp = array_namespace(x, q) + + q_is_scalar = isinstance(q, int | float) + if q_is_scalar: + q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x)) + + if not xp.isdtype(x.dtype, ("integral", "real floating")): + raise ValueError("`x` must have real dtype.") # noqa: EM101 + if not xp.isdtype(q.dtype, "real floating"): + raise ValueError("`q` must have real floating dtype.") # noqa: EM101 + + # Promote to common dtype + x = xp.astype(x, xp.float64) + q = xp.astype(q, xp.float64) + q = xp.asarray(q, device=_compat.device(x)) + + dtype = x.dtype + axis_none = axis is None + ndim = max(x.ndim, q.ndim) + + if axis_none: + x = xp.reshape(x, (-1,)) + q = xp.reshape(q, (-1,)) + axis = 0 + elif not isinstance(axis, int): + raise ValueError("`axis` must be an integer or None.") # noqa: EM101 + elif axis >= ndim or axis < -ndim: + raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101 + else: + axis = int(axis) + + if keepdims not in {None, True, False}: + raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101 + + if x.shape[axis] == 0: + shape = list(x.shape) + shape[axis] = 1 + x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x)) + + y = xp.sort(x, axis=axis) + + # Move axis to the end for easier processing + y = xp.moveaxis(y, axis, -1) + if not (q_is_scalar or q.ndim == 0): + q = xp.moveaxis(q, axis, -1) + + n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) + + if method in { # pylint: disable=duplicate-code + "inverted_cdf", + "averaged_inverted_cdf", + "closest_observation", + "hazen", + "interpolated_inverted_cdf", + "linear", + "median_unbiased", + "normal_unbiased", + "weibull", + }: + res = _quantile_hf(y, q, n, method, xp) + else: + raise ValueError(f"Unknown method: {method}") # noqa: EM102 + + # Handle NaN output for invalid q values + p_mask = (q > 1) | (q < 0) | xp.isnan(q) + if xp.any(p_mask): + res = xp.asarray(res, copy=True) + res = at(res, p_mask).set(xp.nan) + + # Reshape per axis/keepdims + if axis_none and keepdims: + shape = (1,) * (ndim - 1) + res.shape + res = xp.reshape(res, shape) + axis = -1 + + # Move axis back to original position + res = xp.moveaxis(res, -1, axis) + + # Handle keepdims + if not keepdims and res.shape[axis] == 1: + res = xp.squeeze(res, axis=axis) + + # For scalar q, ensure we return a scalar result + if q_is_scalar and hasattr(res, "shape") and res.shape != (): + res = res[()] + + return res + + +def _quantile_hf( + y: Array, p: Array, n: Array, method: str, xp: ModuleType +) -> Array: # numpydoc ignore=PR01,RT01 + """Helper function for Hyndman-Fan quantile method.""" + ms = { + "inverted_cdf": 0, + "averaged_inverted_cdf": 0, + "closest_observation": -0.5, + "interpolated_inverted_cdf": 0, + "hazen": 0.5, + "weibull": p, + "linear": 1 - p, + "median_unbiased": p / 3 + 1 / 3, + "normal_unbiased": p / 4 + 3 / 8, + } + m = ms[method] + + jg = p * n + m - 1 + j = xp.astype(jg // 1, xp.int64) # Convert to integer + g = jg % 1 + + if method == "inverted_cdf": + g = xp.astype((g > 0), jg.dtype) + elif method == "averaged_inverted_cdf": + g = (1 + xp.astype((g > 0), jg.dtype)) / 2 + elif method == "closest_observation": + g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype) + if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}: + g = xp.asarray(g) + g = at(g, jg < 0).set(0) + g = at(g, j < 0).set(0) + j = xp.clip(j, 0, n - 1) + jp1 = xp.clip(j + 1, 0, n - 1) + + # Broadcast indices to match y shape except for the last axis + if y.ndim > 1: + # Create broadcast shape for indices + broadcast_shape = list(y.shape[:-1]) + [1] # noqa: RUF005 + j = xp.broadcast_to(j, broadcast_shape) + jp1 = xp.broadcast_to(jp1, broadcast_shape) + g = xp.broadcast_to(g, broadcast_shape) + + return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis( + y, jp1, axis=-1 + ) From 7a7934cbf7ecd6edc4cc5f7ceca615fe79e367f2 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jun 2025 18:02:54 +0200 Subject: [PATCH 10/20] Remove duplicated code --- src/array_api_extra/_delegation.py | 4 ++-- src/array_api_extra/_lib/_quantile.py | 17 ++--------------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b83cdd4b..eab96f9d 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -256,7 +256,7 @@ def quantile( /, *, axis: int | None = None, - keepdims: bool = False, + keepdims: bool = None, # noqa: RUF013 method: str = "linear", xp: ModuleType | None = None, ) -> Array: @@ -304,7 +304,7 @@ def quantile( """ # We only support a subset of the methods supported by scipy.stats.quantile. # So we need to perform the validation here. - methods = { # pylint: disable=duplicate-code + methods = { "inverted_cdf", "averaged_inverted_cdf", "closest_observation", diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 03d316ef..15b59f40 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -14,7 +14,7 @@ def quantile( /, *, axis: int | None = None, - keepdims: bool = False, + keepdims: bool = None, # noqa: RUF013 method: str = "linear", xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 @@ -68,20 +68,7 @@ def quantile( n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) - if method in { # pylint: disable=duplicate-code - "inverted_cdf", - "averaged_inverted_cdf", - "closest_observation", - "hazen", - "interpolated_inverted_cdf", - "linear", - "median_unbiased", - "normal_unbiased", - "weibull", - }: - res = _quantile_hf(y, q, n, method, xp) - else: - raise ValueError(f"Unknown method: {method}") # noqa: EM102 + res = _quantile_hf(y, q, n, method, xp) # Handle NaN output for invalid q values p_mask = (q > 1) | (q < 0) | xp.isnan(q) From 3095889fcf19648be779b14ff6a1acdeae19b99b Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Mon, 30 Jun 2025 17:57:34 +0100 Subject: [PATCH 11/20] docstring keepdims --- src/array_api_extra/_delegation.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index eab96f9d..1b552398 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -256,7 +256,7 @@ def quantile( /, *, axis: int | None = None, - keepdims: bool = None, # noqa: RUF013 + keepdims: bool | None = None, method: str = "linear", xp: ModuleType | None = None, ) -> Array: @@ -274,10 +274,14 @@ def quantile( axis : int or None, default: None Axis along which the quantiles are computed. ``None`` ravels both `x` and `q` before performing the calculation. - keepdims : bool, optional - If this is set to True, the axes which are reduced are left in the + keepdims : bool or None, default: None + By default, the axis will be reduced away if possible + (i.e. if there is exactly one element of `q` per axis-slice of `x`). + If `keepdims` is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array `x`. + If `keepdims` is set to False, the axis will be reduced away if possible, + and an error will be raised otherwise. method : str, default: 'linear' The method to use for estimating the quantile. The available options are: 'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', From bd55318178228f08398ed57deecb1adb9b3fa7b6 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Mon, 30 Jun 2025 18:31:09 +0100 Subject: [PATCH 12/20] lint --- .pre-commit-config.yaml | 14 +- pixi.lock | 248 +++++++++++++++++++------- pyproject.toml | 49 ++--- src/array_api_extra/_delegation.py | 8 +- src/array_api_extra/_lib/_funcs.py | 2 +- src/array_api_extra/_lib/_quantile.py | 26 +-- src/array_api_extra/testing.py | 6 +- tests/test_helpers.py | 4 +- tests/test_testing.py | 2 +- 9 files changed, 237 insertions(+), 122 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7be7a1ca..46bea2df 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ exclude: ^.cruft.json|.copier-answers.yml$ repos: - repo: https://github.com/adamchainz/blacken-docs - rev: "1.18.0" + rev: "1.19.1" hooks: - id: blacken-docs additional_dependencies: [black==24.*] @@ -35,21 +35,21 @@ repos: - id: rst-inline-touching-normal - repo: https://github.com/rbubley/mirrors-prettier - rev: "v3.4.2" + rev: "v3.6.2" hooks: - id: prettier types_or: [yaml, markdown, html, css, scss, javascript, json] args: [--prose-wrap=always] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.8.2" + rev: "v0.12.1" hooks: - id: ruff-format - id: ruff args: ["--fix", "--show-fixes"] - repo: https://github.com/codespell-project/codespell - rev: "v2.3.0" + rev: "v2.4.1" hooks: - id: codespell exclude: pixi.lock @@ -68,17 +68,17 @@ repos: exclude: .pre-commit-config.yaml - repo: https://github.com/abravalheri/validate-pyproject - rev: "v0.23" + rev: "v0.24.1" hooks: - id: validate-pyproject additional_dependencies: ["validate-pyproject-schema-store[all]"] - repo: https://github.com/python-jsonschema/check-jsonschema - rev: "0.30.0" + rev: "0.33.1" hooks: - id: check-github-workflows - repo: https://github.com/numpy/numpydoc - rev: "v1.8.0" + rev: "v1.9.0" hooks: - id: numpydoc-validation diff --git a/pixi.lock b/pixi.lock index 115a6c0f..34566632 100644 --- a/pixi.lock +++ b/pixi.lock @@ -84,7 +84,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h2b53caa_26.conda - - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34438-hfd919c2_26.conda + - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_26.conda - pypi: ./ dev: channels: @@ -1613,8 +1613,8 @@ environments: - conda: https://prefix.dev/conda-forge/osx-64/libcxx-20.1.6-hf95d169_0.conda - conda: https://prefix.dev/conda-forge/osx-64/libexpat-2.7.0-h240833e_0.conda - conda: https://prefix.dev/conda-forge/osx-64/libffi-3.4.6-h281671d_1.conda - - conda: https://prefix.dev/conda-forge/osx-64/libgfortran-14.2.0-hef36b68_105.conda - - conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h58528f3_105.conda + - conda: https://prefix.dev/conda-forge/osx-64/libgfortran-5.0.0-14_2_0_h51e75f0_103.conda + - conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h51e75f0_103.conda - conda: https://prefix.dev/conda-forge/osx-64/liblapack-3.9.0-31_h236ab99_openblas.conda - conda: https://prefix.dev/conda-forge/osx-64/liblzma-5.8.1-hd471939_1.conda - conda: https://prefix.dev/conda-forge/osx-64/libmpdec-4.0.0-h6e16a3a_0.conda @@ -1700,8 +1700,8 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/libcxx-20.1.6-ha82da77_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.7.0-h286801f_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libffi-3.4.6-h1da3d7d_1.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-14.2.0-heb5dd2a_105.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-14.2.0-h2c44a93_105.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-5.0.0-14_2_0_h6c33f7e_103.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-14.2.0-h6c33f7e_103.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblapack-3.9.0-31_hc9a63f6_openblas.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblzma-5.8.1-h39f12f2_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libmpdec-4.0.0-h5505292_0.conda @@ -1840,8 +1840,8 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.4.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h2b53caa_26.conda - - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34438-hfd919c2_26.conda - - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34438-h7142326_26.conda + - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_26.conda + - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.44.35208-h38c0c73_26.conda - conda: https://prefix.dev/conda-forge/noarch/win_inet_pton-1.1.0-pyh7428d3b_8.conda - conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/zipp-3.22.0-pyhd8ed1ab_0.conda @@ -1862,7 +1862,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.10-py313h78bf25f_0.conda - conda: https://prefix.dev/conda-forge/noarch/attrs-25.3.0-pyh71513ae_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.17.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/noarch/basedpyright-1.29.4-pyhe01879c_0.conda + - conda: https://prefix.dev/conda-forge/noarch/basedpyright-1.29.5-pyhe01879c_0.conda - conda: https://prefix.dev/conda-forge/linux-64/brotli-python-1.1.0-py313h46c70d0_2.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/noarch/ca-certificates-2025.4.26-hbd8a1cb_0.conda @@ -1915,14 +1915,14 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/markupsafe-3.0.2-py313h8060acc_1.conda - conda: https://prefix.dev/conda-forge/noarch/mccabe-0.7.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/linux-64/mypy-1.16.0-py313h536fd9c_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/mypy-1.16.1-py313h536fd9c_0.conda - conda: https://prefix.dev/conda-forge/noarch/mypy_extensions-1.1.0-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda - conda: https://prefix.dev/conda-forge/noarch/nodeenv-1.9.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/nodejs-22.13.0-hf235a45_0.conda - - conda: https://prefix.dev/conda-forge/noarch/nodejs-wheel-22.16.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/nodejs-wheel-22.17.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/numpy-2.3.0-py313h17eae1a_0.conda - - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.8.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.9.0-pyhe01879c_1.conda - conda: https://prefix.dev/conda-forge/linux-64/openssl-3.5.0-h7b32b05_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -1953,10 +1953,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-qthelp-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/tabulate-0.9.0-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_hd72426e_102.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.2-pyha770c72_1.conda + - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.3-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.14.0-h32cad80_0.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.14.0-pyhe01879c_0.conda @@ -1976,7 +1975,7 @@ environments: - conda: https://prefix.dev/conda-forge/osx-64/astroid-3.3.10-py313habf4b1d_0.conda - conda: https://prefix.dev/conda-forge/noarch/attrs-25.3.0-pyh71513ae_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.17.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/noarch/basedpyright-1.29.4-pyhe01879c_0.conda + - conda: https://prefix.dev/conda-forge/noarch/basedpyright-1.29.5-pyhe01879c_0.conda - conda: https://prefix.dev/conda-forge/osx-64/brotli-python-1.1.0-py313h9ea2907_2.conda - conda: https://prefix.dev/conda-forge/osx-64/bzip2-1.0.8-hfdf4475_7.conda - conda: https://prefix.dev/conda-forge/noarch/ca-certificates-2025.4.26-hbd8a1cb_0.conda @@ -2011,8 +2010,8 @@ environments: - conda: https://prefix.dev/conda-forge/osx-64/libcxx-20.1.6-hf95d169_0.conda - conda: https://prefix.dev/conda-forge/osx-64/libexpat-2.7.0-h240833e_0.conda - conda: https://prefix.dev/conda-forge/osx-64/libffi-3.4.6-h281671d_1.conda - - conda: https://prefix.dev/conda-forge/osx-64/libgfortran-14.2.0-hef36b68_105.conda - - conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h58528f3_105.conda + - conda: https://prefix.dev/conda-forge/osx-64/libgfortran-5.0.0-14_2_0_h51e75f0_103.conda + - conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h51e75f0_103.conda - conda: https://prefix.dev/conda-forge/osx-64/liblapack-3.9.0-31_h236ab99_openblas.conda - conda: https://prefix.dev/conda-forge/osx-64/liblzma-5.8.1-hd471939_1.conda - conda: https://prefix.dev/conda-forge/osx-64/libmpdec-4.0.0-h6e16a3a_0.conda @@ -2024,14 +2023,14 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/osx-64/markupsafe-3.0.2-py313h717bdf5_1.conda - conda: https://prefix.dev/conda-forge/noarch/mccabe-0.7.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/osx-64/mypy-1.16.0-py313h63b0ddb_0.conda + - conda: https://prefix.dev/conda-forge/osx-64/mypy-1.16.1-py313h63b0ddb_0.conda - conda: https://prefix.dev/conda-forge/noarch/mypy_extensions-1.1.0-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/osx-64/ncurses-6.5-h0622a9a_3.conda - conda: https://prefix.dev/conda-forge/noarch/nodeenv-1.9.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-64/nodejs-22.13.0-hffbc63d_0.conda - - conda: https://prefix.dev/conda-forge/noarch/nodejs-wheel-22.16.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/nodejs-wheel-22.17.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-64/numpy-2.3.0-py313hc518a0f_0.conda - - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.8.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.9.0-pyhe01879c_1.conda - conda: https://prefix.dev/conda-forge/osx-64/openssl-3.5.0-hc426f3f_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -2062,10 +2061,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-qthelp-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/tabulate-0.9.0-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/osx-64/tk-8.6.13-hf689a15_2.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.2-pyha770c72_1.conda + - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.3-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.14.0-h32cad80_0.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.14.0-pyhe01879c_0.conda @@ -2085,7 +2083,7 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.10-py313h8f79df9_0.conda - conda: https://prefix.dev/conda-forge/noarch/attrs-25.3.0-pyh71513ae_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.17.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/noarch/basedpyright-1.29.4-pyhe01879c_0.conda + - conda: https://prefix.dev/conda-forge/noarch/basedpyright-1.29.5-pyhe01879c_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/brotli-python-1.1.0-py313h3579c5c_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/noarch/ca-certificates-2025.4.26-hbd8a1cb_0.conda @@ -2120,8 +2118,8 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/libcxx-20.1.6-ha82da77_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.7.0-h286801f_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libffi-3.4.6-h1da3d7d_1.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-14.2.0-heb5dd2a_105.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-14.2.0-h2c44a93_105.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-5.0.0-14_2_0_h6c33f7e_103.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-14.2.0-h6c33f7e_103.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblapack-3.9.0-31_hc9a63f6_openblas.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblzma-5.8.1-h39f12f2_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libmpdec-4.0.0-h5505292_0.conda @@ -2133,14 +2131,14 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/osx-arm64/markupsafe-3.0.2-py313ha9b7d5b_1.conda - conda: https://prefix.dev/conda-forge/noarch/mccabe-0.7.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/mypy-1.16.0-py313h90d716c_0.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/mypy-1.16.1-py313h90d716c_0.conda - conda: https://prefix.dev/conda-forge/noarch/mypy_extensions-1.1.0-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ncurses-6.5-h5e97a16_3.conda - conda: https://prefix.dev/conda-forge/noarch/nodeenv-1.9.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/nodejs-22.13.0-h02a13b7_0.conda - - conda: https://prefix.dev/conda-forge/noarch/nodejs-wheel-22.16.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/nodejs-wheel-22.17.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/numpy-2.3.0-py313h41a2e72_0.conda - - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.8.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.9.0-pyhe01879c_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/openssl-3.5.0-h81ee809_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -2171,10 +2169,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-qthelp-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/tabulate-0.9.0-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h892fb3f_2.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.2-pyha770c72_1.conda + - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.3-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.14.0-h32cad80_0.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.14.0-pyhe01879c_0.conda @@ -2194,7 +2191,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.10-py313hfa70ccb_0.conda - conda: https://prefix.dev/conda-forge/noarch/attrs-25.3.0-pyh71513ae_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.17.0-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/noarch/basedpyright-1.29.4-pyhe01879c_0.conda + - conda: https://prefix.dev/conda-forge/noarch/basedpyright-1.29.5-pyhe01879c_0.conda - conda: https://prefix.dev/conda-forge/win-64/brotli-python-1.1.0-py313h5813708_2.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/noarch/ca-certificates-2025.4.26-h4c7d964_0.conda @@ -2241,13 +2238,13 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/markupsafe-3.0.2-py313hb4c8b1a_1.conda - conda: https://prefix.dev/conda-forge/noarch/mccabe-0.7.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/mkl-2024.2.2-h66d3029_15.conda - - conda: https://prefix.dev/conda-forge/win-64/mypy-1.16.0-py313ha7868ed_0.conda + - conda: https://prefix.dev/conda-forge/win-64/mypy-1.16.1-py313ha7868ed_0.conda - conda: https://prefix.dev/conda-forge/noarch/mypy_extensions-1.1.0-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/nodeenv-1.9.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/nodejs-22.13.0-hfeaa22a_0.conda - - conda: https://prefix.dev/conda-forge/noarch/nodejs-wheel-22.16.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/nodejs-wheel-22.17.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/numpy-2.3.0-py313hefb8edb_0.conda - - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.8.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.9.0-pyhe01879c_1.conda - conda: https://prefix.dev/conda-forge/win-64/openssl-3.5.0-ha4e3fda_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -2277,11 +2274,10 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-qthelp-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/tabulate-0.9.0-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/win-64/tbb-2021.13.0-h62715c5_1.conda - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h2c6b04d_2.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.2-pyha770c72_1.conda + - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.3-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.14.0-h32cad80_0.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.14.0-pyhe01879c_0.conda @@ -2290,9 +2286,9 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/ukkonen-1.0.1-py313h1ec8472_5.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.4.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h2b53caa_26.conda - - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34438-hfd919c2_26.conda + - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_26.conda - conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.31.2-pyhd8ed1ab_0.conda - - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34438-h7142326_26.conda + - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.44.35208-h38c0c73_26.conda - conda: https://prefix.dev/conda-forge/noarch/win_inet_pton-1.1.0-pyh7428d3b_8.conda - conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/zipp-3.22.0-pyhd8ed1ab_0.conda @@ -2372,8 +2368,8 @@ environments: - conda: https://prefix.dev/conda-forge/osx-64/libcxx-20.1.6-hf95d169_0.conda - conda: https://prefix.dev/conda-forge/osx-64/libexpat-2.7.0-h240833e_0.conda - conda: https://prefix.dev/conda-forge/osx-64/libffi-3.4.6-h281671d_1.conda - - conda: https://prefix.dev/conda-forge/osx-64/libgfortran-14.2.0-hef36b68_105.conda - - conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h58528f3_105.conda + - conda: https://prefix.dev/conda-forge/osx-64/libgfortran-5.0.0-14_2_0_h51e75f0_103.conda + - conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h51e75f0_103.conda - conda: https://prefix.dev/conda-forge/osx-64/liblapack-3.9.0-31_h236ab99_openblas.conda - conda: https://prefix.dev/conda-forge/osx-64/liblzma-5.8.1-hd471939_1.conda - conda: https://prefix.dev/conda-forge/osx-64/libmpdec-4.0.0-h6e16a3a_0.conda @@ -2417,8 +2413,8 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/libcxx-20.1.6-ha82da77_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.7.0-h286801f_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libffi-3.4.6-h1da3d7d_1.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-14.2.0-heb5dd2a_105.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-14.2.0-h2c44a93_105.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-5.0.0-14_2_0_h6c33f7e_103.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-14.2.0-h6c33f7e_103.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblapack-3.9.0-31_hc9a63f6_openblas.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblzma-5.8.1-h39f12f2_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libmpdec-4.0.0-h5505292_0.conda @@ -2491,7 +2487,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h2b53caa_26.conda - - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34438-hfd919c2_26.conda + - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_26.conda - pypi: ./ tests-backends: channels: @@ -4001,8 +3997,8 @@ environments: - conda: https://prefix.dev/conda-forge/osx-64/libcxx-20.1.6-hf95d169_0.conda - conda: https://prefix.dev/conda-forge/osx-64/libexpat-2.7.0-h240833e_0.conda - conda: https://prefix.dev/conda-forge/osx-64/libffi-3.4.6-h281671d_1.conda - - conda: https://prefix.dev/conda-forge/osx-64/libgfortran-14.2.0-hef36b68_105.conda - - conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h58528f3_105.conda + - conda: https://prefix.dev/conda-forge/osx-64/libgfortran-5.0.0-14_2_0_h51e75f0_103.conda + - conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h51e75f0_103.conda - conda: https://prefix.dev/conda-forge/osx-64/liblapack-3.9.0-31_h236ab99_openblas.conda - conda: https://prefix.dev/conda-forge/osx-64/liblzma-5.8.1-hd471939_1.conda - conda: https://prefix.dev/conda-forge/osx-64/libmpdec-4.0.0-h6e16a3a_0.conda @@ -4046,8 +4042,8 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/libcxx-20.1.6-ha82da77_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.7.0-h286801f_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libffi-3.4.6-h1da3d7d_1.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-14.2.0-heb5dd2a_105.conda - - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-14.2.0-h2c44a93_105.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-5.0.0-14_2_0_h6c33f7e_103.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-14.2.0-h6c33f7e_103.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblapack-3.9.0-31_hc9a63f6_openblas.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblzma-5.8.1-h39f12f2_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libmpdec-4.0.0-h5505292_0.conda @@ -4120,7 +4116,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h2b53caa_26.conda - - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34438-hfd919c2_26.conda + - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_26.conda - pypi: ./ packages: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 @@ -4196,7 +4192,7 @@ packages: - pypi: ./ name: array-api-extra version: 0.8.1.dev0 - sha256: 8236f595888133a239f0b7b6a6ee07a83c143c6bb09bb8b26ffff0309336f27b + sha256: d7383cc1b3a9010b2857ae4ac3ddd995308b28189244bccc20e0c12ade5a0b18 requires_dist: - array-api-compat>=1.12.0,<2 requires_python: '>=3.10' @@ -4373,6 +4369,18 @@ packages: - pkg:pypi/basedpyright?source=hash-mapping size: 8225473 timestamp: 1749712419694 +- conda: https://prefix.dev/conda-forge/noarch/basedpyright-1.29.5-pyhe01879c_0.conda + sha256: 75ad1b095f3bed6962d6fea8c197a718df6751dfa0c92016647fe341d9426793 + md5: f493f7a303f65470586f4ad0f74bc4e7 + depends: + - python >=3.9 + - nodejs-wheel >=20.13.1 + - python + license: MIT AND Apache-2.0 + purls: + - pkg:pypi/basedpyright?source=hash-mapping + size: 8225085 + timestamp: 1751294452624 - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.13.4-pyha770c72_0.conda sha256: ddb0df12fd30b2d36272f5daf6b6251c7625d6a99414d7ea930005bbaecad06d md5: 9f07c4fc992adb2d6c30da7fab3959a7 @@ -6755,6 +6763,16 @@ packages: purls: [] size: 155635 timestamp: 1743911593527 +- conda: https://prefix.dev/conda-forge/osx-64/libgfortran-5.0.0-14_2_0_h51e75f0_103.conda + sha256: 124dcd89508bd16f562d9d3ce6a906336a7f18e963cd14f2877431adee14028e + md5: 090b3c9ae1282c8f9b394ac9e4773b10 + depends: + - libgfortran5 14.2.0 h51e75f0_103 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + purls: [] + size: 156202 + timestamp: 1743862427451 - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-14.2.0-heb5dd2a_105.conda sha256: 6ca48762c330d1cdbdaa450f197ccc16ffb7181af50d112b4ccf390223d916a1 md5: ad35937216e65cfeecd828979ee5e9e6 @@ -6765,6 +6783,16 @@ packages: purls: [] size: 155474 timestamp: 1743913530958 +- conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-5.0.0-14_2_0_h6c33f7e_103.conda + sha256: 8628746a8ecd311f1c0d14bb4f527c18686251538f7164982ccbe3b772de58b5 + md5: 044a210bc1d5b8367857755665157413 + depends: + - libgfortran5 14.2.0 h6c33f7e_103 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + purls: [] + size: 156291 + timestamp: 1743863532821 - conda: https://prefix.dev/conda-forge/linux-64/libgfortran5-15.1.0-hcea5267_2.conda sha256: be23750f3ca1a5cb3ada858c4f633effe777487d1ea35fddca04c0965c073350 md5: 01de444988ed960031dbe84cf4f9b1fc @@ -6778,6 +6806,18 @@ packages: purls: [] size: 1569986 timestamp: 1746642212331 +- conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h51e75f0_103.conda + sha256: d2ac5e09587e5b21b7bb5795d24f33257e44320749c125448611211088ef8795 + md5: 6183f7e9cd1e7ba20118ff0ca20a05e5 + depends: + - llvm-openmp >=8.0.0 + constrains: + - libgfortran 5.0.0 14_2_0_*_103 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + purls: [] + size: 1225013 + timestamp: 1743862382377 - conda: https://prefix.dev/conda-forge/osx-64/libgfortran5-14.2.0-h58528f3_105.conda sha256: 02fc48106e1ca65cf7de15f58ec567f866f6e8e9dcced157d0cff89f0768bb59 md5: 94560312ff3c78225bed62ab59854c31 @@ -6802,6 +6842,18 @@ packages: purls: [] size: 806283 timestamp: 1743913488925 +- conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-14.2.0-h6c33f7e_103.conda + sha256: 8599453990bd3a449013f5fa3d72302f1c68f0680622d419c3f751ff49f01f17 + md5: 69806c1e957069f1d515830dcc9f6cbb + depends: + - llvm-openmp >=8.0.0 + constrains: + - libgfortran 5.0.0 14_2_0_*_103 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + purls: [] + size: 806566 + timestamp: 1743863491726 - conda: https://prefix.dev/conda-forge/linux-64/libgomp-15.1.0-h767d61c_2.conda sha256: 05fff3dc7e80579bc28de13b511baec281c4343d703c406aefd54389959154fb md5: fbe7d535ff9d3a168c148e07358cd5b1 @@ -8413,9 +8465,9 @@ packages: - pkg:pypi/mypy?source=hash-mapping size: 18162508 timestamp: 1748547897549 -- conda: https://prefix.dev/conda-forge/linux-64/mypy-1.16.0-py313h536fd9c_0.conda - sha256: 542234e8707c1784acd744d6234d36ce3b5943acf998a2d453feff319266676e - md5: 3ffbe69221d685e9afa221ef0e2a71c6 +- conda: https://prefix.dev/conda-forge/linux-64/mypy-1.16.1-py313h536fd9c_0.conda + sha256: 01f9acea3bc0fcdfc17acbe9ac003e18c4cccdaad3cdef7c3595e5c996b74324 + md5: 5446d84e248f2ac04f88af2c393383c6 depends: - __glibc >=2.17,<3.0.a0 - libgcc >=13 @@ -8429,8 +8481,8 @@ packages: license_family: MIT purls: - pkg:pypi/mypy?source=hash-mapping - size: 17269093 - timestamp: 1748547494947 + size: 17242074 + timestamp: 1750118260507 - conda: https://prefix.dev/conda-forge/osx-64/mypy-1.16.0-py310hbb8c376_0.conda sha256: d272291c6e59725d88826cb140768e0bd6047f17ffe094ca054f51e9153f1a2c md5: f23ad57327906f61b85a0ce6b1a26194 @@ -8449,9 +8501,9 @@ packages: - pkg:pypi/mypy?source=hash-mapping size: 11951514 timestamp: 1748547572039 -- conda: https://prefix.dev/conda-forge/osx-64/mypy-1.16.0-py313h63b0ddb_0.conda - sha256: 83ba7f072c565ebdb7f0753078951d7e06c52f9f60511ca8481255406caadc5d - md5: ec5f77fa3f5136668cbcb6bcbfb1ef83 +- conda: https://prefix.dev/conda-forge/osx-64/mypy-1.16.1-py313h63b0ddb_0.conda + sha256: 49cbef241c24b6e4f15b5cce30104fbe41151988456381d1b3037574c5014c7e + md5: 9d3e25c02eeea1904392d24df67ec9dc depends: - __osx >=10.13 - mypy_extensions >=1.0.0 @@ -8464,8 +8516,8 @@ packages: license_family: MIT purls: - pkg:pypi/mypy?source=hash-mapping - size: 11232466 - timestamp: 1748547240219 + size: 11269073 + timestamp: 1750118493594 - conda: https://prefix.dev/conda-forge/osx-arm64/mypy-1.16.0-py310h078409c_0.conda sha256: e8c4d987cc13ced8bca9128ac325fe09f14ffbcedf14e295d25b1328b9d829dc md5: 1ad2eb08347327aa09a0331526089939 @@ -8485,9 +8537,9 @@ packages: - pkg:pypi/mypy?source=hash-mapping size: 9302149 timestamp: 1748547424712 -- conda: https://prefix.dev/conda-forge/osx-arm64/mypy-1.16.0-py313h90d716c_0.conda - sha256: a37d65ad2e837bc86eff91a0bf15ea86d6d64d7bb52dbf2720334314563cae50 - md5: 4946c89919f258c1aad6000225b729a6 +- conda: https://prefix.dev/conda-forge/osx-arm64/mypy-1.16.1-py313h90d716c_0.conda + sha256: 71805207ebe9def6100809c0a8ff5a5b2f88a1b32851b9a3ae339823db308762 + md5: 25298ce104edf05af28ed4f172c7e334 depends: - __osx >=11.0 - mypy_extensions >=1.0.0 @@ -8500,9 +8552,9 @@ packages: license: MIT license_family: MIT purls: - - pkg:pypi/mypy?source=compressed-mapping - size: 10453519 - timestamp: 1748547483049 + - pkg:pypi/mypy?source=hash-mapping + size: 10423256 + timestamp: 1750118390866 - conda: https://prefix.dev/conda-forge/win-64/mypy-1.16.0-py310ha8f682b_0.conda sha256: 240740e6e003851f79494e3d53abda4799854b874dfeaee01318def43fe16a12 md5: 4c1035de8a27a8684c821547428bd614 @@ -8523,9 +8575,9 @@ packages: - pkg:pypi/mypy?source=hash-mapping size: 9675989 timestamp: 1748547662848 -- conda: https://prefix.dev/conda-forge/win-64/mypy-1.16.0-py313ha7868ed_0.conda - sha256: 96f238306b14960b379570a209a864f448869b0c19a52b7ac5ac37d07c8ae797 - md5: ae82bb456e3d670e21deac8b067fdc45 +- conda: https://prefix.dev/conda-forge/win-64/mypy-1.16.1-py313ha7868ed_0.conda + sha256: d915755801ee459c174dcd7d40ddc6b1a4b0e96fa161c686582223a3b51077f2 + md5: 7c94601304b4e66c082e9c86ad219cea depends: - mypy_extensions >=1.0.0 - pathspec >=0.9.0 @@ -8540,8 +8592,8 @@ packages: license_family: MIT purls: - pkg:pypi/mypy?source=hash-mapping - size: 8496738 - timestamp: 1748547465206 + size: 8494415 + timestamp: 1750118712013 - conda: https://prefix.dev/conda-forge/noarch/mypy_extensions-1.1.0-pyha770c72_0.conda sha256: 6ed158e4e5dd8f6a10ad9e525631e35cee8557718f83de7a4e3966b1f772c4b1 md5: e9c622e0d00fa24a6292279af3ab6d06 @@ -8709,6 +8761,18 @@ packages: - pkg:pypi/nodejs-wheel-binaries?source=hash-mapping size: 12413 timestamp: 1747937965724 +- conda: https://prefix.dev/conda-forge/noarch/nodejs-wheel-22.17.0-pyhd8ed1ab_0.conda + sha256: 459fe173ba9a087d42602cbd399b6074f9641dcb0053be8edebe618e9020bfed + md5: 6f863bd4d3bbdf6fcd741aa004529bb9 + depends: + - nodejs + - python >=3.9 + license: MIT + license_family: MIT + purls: + - pkg:pypi/nodejs-wheel-binaries?source=hash-mapping + size: 12351 + timestamp: 1751232576536 - conda: https://prefix.dev/conda-forge/noarch/nomkl-1.0-h5ca1d4c_0.tar.bz2 sha256: d38542a151a90417065c1a234866f97fd1ea82a81de75ecb725955ab78f88b4b md5: 9a66894dfd07c4510beb6b3f9672ccc0 @@ -9149,6 +9213,19 @@ packages: - pkg:pypi/numpydoc?source=hash-mapping size: 58041 timestamp: 1733650959971 +- conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.9.0-pyhe01879c_1.conda + sha256: 9e1f3dda737ac9aeec3c245c5d856d0268c4f64a5293c094298d74bb55e2b165 + md5: 66f9ba52d846feffa1c5d62522324b4f + depends: + - python >=3.9 + - sphinx >=6 + - tomli >=1.1.0 + - python + license: BSD-3-Clause + purls: + - pkg:pypi/numpydoc?source=hash-mapping + size: 60220 + timestamp: 1750861325361 - conda: https://prefix.dev/conda-forge/linux-64/openssl-3.5.0-h7b32b05_1.conda sha256: b4491077c494dbf0b5eaa6d87738c22f2154e9277e5293175ec187634bd808a0 md5: de356753cfdbffcde5bb1e86e3aa6cd0 @@ -11066,6 +11143,17 @@ packages: - pkg:pypi/tomlkit?source=hash-mapping size: 37372 timestamp: 1733230836889 +- conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.3-pyha770c72_0.conda + sha256: f8d3b49c084831a20923f66826f30ecfc55a4cd951e544b7213c692887343222 + md5: 146402bf0f11cbeb8f781fa4309a95d3 + depends: + - python >=3.9 + license: MIT + license_family: MIT + purls: + - pkg:pypi/tomlkit?source=compressed-mapping + size: 38777 + timestamp: 1749127286558 - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda sha256: eda38f423c33c2eaeca49ed946a8d3bf466cc3364970e083a65eb2fd85258d87 md5: 40d0ed782a8aaa16ef248e68c06c168d @@ -11313,6 +11401,18 @@ packages: purls: [] size: 750733 timestamp: 1743195092905 +- conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_26.conda + sha256: 7bad6e25a7c836d99011aee59dcf600b7f849a6fa5caa05a406255527e80a703 + md5: 14d65350d3f5c8ff163dc4f76d6e2830 + depends: + - ucrt >=10.0.20348.0 + constrains: + - vs2015_runtime 14.44.35208.* *_26 + license: LicenseRef-MicrosoftVisualCpp2015-2022Runtime + license_family: Proprietary + purls: [] + size: 756109 + timestamp: 1750371459116 - conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.31.2-pyhd8ed1ab_0.conda sha256: 763dc774200b2eebdf5437b112834c5455a1dd1c9b605340696950277ff36729 md5: c0600c1b374efa7a1ff444befee108ca @@ -11337,6 +11437,16 @@ packages: purls: [] size: 17873 timestamp: 1743195097269 +- conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.44.35208-h38c0c73_26.conda + sha256: d18d77c8edfbad37fa0e0bb0f543ad80feb85e8fe5ced0f686b8be463742ec0b + md5: 312f3a0a6b3c5908e79ce24002411e32 + depends: + - vc14_runtime >=14.44.35208 + license: BSD-3-Clause + license_family: BSD + purls: [] + size: 17888 + timestamp: 1750371463202 - conda: https://prefix.dev/conda-forge/noarch/wcwidth-0.2.13-pyhd8ed1ab_1.conda sha256: f21e63e8f7346f9074fd00ca3b079bd3d2fa4d71f1f89d5b6934bf31446dc2a5 md5: b68980f2495d096e71c7fd9d7ccf63e6 diff --git a/pyproject.toml b/pyproject.toml index e7a2bf8b..fff7b9a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,17 +65,17 @@ array-api-strict = ">=2.3.1" numpy = ">=2.1.3" pytest = ">=8.4.0" hypothesis = ">=6.131.28" -dask-core = ">=2025.5.1" # No distributed, tornado, etc. +dask-core = ">=2025.5.1" # No distributed, tornado, etc. # NOTE: don't add cupy, jax, pytorch, or sparse here, # as they slow down mypy and are not portable across target OSs [tool.pixi.feature.lint.tasks] -pre-commit-install = { cmd = "pre-commit install", description = "Install pre-commit"} -pre-commit = { cmd = "pre-commit run --all-files", description = "Run pre-commit"} -mypy = { cmd = "mypy", description="Type check with mypy"} -pylint = { cmd = "pylint array_api_extra", cwd = "src" , description = "Lint using pylint"} -pyright = { cmd = "basedpyright", description = "Type check with basedpyright"} -lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] , description = "Run pre-commit, pylint, mypy, and pyright"} +pre-commit-install = { cmd = "pre-commit install", description = "Install pre-commit" } +pre-commit = { cmd = "pre-commit run --all-files", description = "Run pre-commit" } +mypy = { cmd = "mypy", description = "Type check with mypy" } +pylint = { cmd = "pylint array_api_extra", cwd = "src", description = "Lint using pylint" } +pyright = { cmd = "basedpyright", description = "Type check with basedpyright" } +lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"], description = "Run pre-commit, pylint, mypy, and pyright" } [tool.pixi.feature.tests.dependencies] pytest = ">=8.4.0" @@ -85,18 +85,18 @@ array-api-strict = ">=2.3.1" numpy = ">=1.22.0" [tool.pixi.feature.tests.tasks] -tests = { cmd = "pytest -v", description = "Run tests"} -tests-cov = { cmd="pytest -v -ra --cov --cov-report=xml --cov-report=term --durations=20", description = "Run tests with coverage"} +tests = { cmd = "pytest -v", description = "Run tests" } +tests-cov = { cmd = "pytest -v -ra --cov --cov-report=xml --cov-report=term --durations=20", description = "Run tests with coverage" } -clean-vendor-compat = { cmd = "rm -rf vendor_tests/array_api_compat", description = "Delete the existing vendored version of array-api-compat"} -clean-vendor-extra = { cmd = "rm -rf vendor_tests/array_api_extra" , description = "Delete the existing vendored version of array-api-extra"} -copy-vendor-compat = { cmd = "cp -r $(python -c 'import site; print(site.getsitepackages()[0])')/array_api_compat vendor_tests/", depends-on = ["clean-vendor-compat"] , description = "Vendor a clean copy of array-api-compat"} -copy-vendor-extra = { cmd = "cp -r src/array_api_extra vendor_tests/", depends-on = ["clean-vendor-extra"] , description = "Vendor a clean copy of array-api-extra"} -tests-vendor = { cmd = "pytest -v vendor_tests", depends-on = ["copy-vendor-compat", "copy-vendor-extra"] , description = "Check that array-api-extra and array-api-compat can be vendored together" } +clean-vendor-compat = { cmd = "rm -rf vendor_tests/array_api_compat", description = "Delete the existing vendored version of array-api-compat" } +clean-vendor-extra = { cmd = "rm -rf vendor_tests/array_api_extra", description = "Delete the existing vendored version of array-api-extra" } +copy-vendor-compat = { cmd = "cp -r $(python -c 'import site; print(site.getsitepackages()[0])')/array_api_compat vendor_tests/", depends-on = ["clean-vendor-compat"], description = "Vendor a clean copy of array-api-compat" } +copy-vendor-extra = { cmd = "cp -r src/array_api_extra vendor_tests/", depends-on = ["clean-vendor-extra"], description = "Vendor a clean copy of array-api-extra" } +tests-vendor = { cmd = "pytest -v vendor_tests", depends-on = ["copy-vendor-compat", "copy-vendor-extra"], description = "Check that array-api-extra and array-api-compat can be vendored together" } -tests-ci = { depends-on = ["tests-cov", "tests-vendor"] , description = "Run tests with coverage and vendor tests"} -coverage = { cmd = "coverage html", depends-on = ["tests-cov"], description = "Generate test coverage html report"} -open-coverage = { cmd = "open htmlcov/index.html", depends-on = ["coverage"] , description = "Open test coverage report"} +tests-ci = { depends-on = ["tests-cov", "tests-vendor"], description = "Run tests with coverage and vendor tests" } +coverage = { cmd = "coverage html", depends-on = ["tests-cov"], description = "Generate test coverage html report" } +open-coverage = { cmd = "open htmlcov/index.html", depends-on = ["coverage"], description = "Open test coverage report" } [tool.pixi.feature.docs.dependencies] sphinx = ">=7.4.7" @@ -105,20 +105,20 @@ myst-parser = ">=4.0.1" sphinx-copybutton = ">=0.5.2" sphinx-autodoc-typehints = ">=1.25.3" # Needed to import parsed modules with autodoc -dask-core = ">=2025.5.1" # No distributed, tornado, etc. +dask-core = ">=2025.5.1" # No distributed, tornado, etc. pytest = ">=8.4.0" typing-extensions = ">=4.14.0" numpy = ">=2.1.3" [tool.pixi.feature.docs.tasks] -docs = { cmd = "sphinx-build -E -W . build/", cwd = "docs" , description = "Build docs"} -open-docs = { cmd = "open build/index.html", cwd = "docs", depends-on = ["docs"] , description = "Open the generated docs"} +docs = { cmd = "sphinx-build -E -W . build/", cwd = "docs", description = "Build docs" } +open-docs = { cmd = "open build/index.html", cwd = "docs", depends-on = ["docs"], description = "Open the generated docs" } [tool.pixi.feature.dev.dependencies] ipython = ">=7.33.0" [tool.pixi.feature.dev.tasks] -ipython = { cmd = "ipython" , description = "Launch ipython"} +ipython = { cmd = "ipython", description = "Launch ipython" } [tool.pixi.feature.py310.dependencies] python = "~=3.10.0" @@ -135,7 +135,7 @@ numpy = "=1.22.0" # Note: JAX and PyTorch will install CPU variants. [tool.pixi.feature.backends.dependencies] pytorch = ">=2.7.0" -dask-core = ">=2025.5.1" # No distributed, tornado, etc. +dask-core = ">=2025.5.1" # No distributed, tornado, etc. sparse = ">=0.17.0" [tool.pixi.feature.backends.target.linux-64.dependencies] @@ -184,7 +184,7 @@ python-freethreading = "~=3.13.0" pytest-run-parallel = ">=0.4.4" numpy = ">=2.3.0" # pytorch = "*" # Not available on Python 3.13t yet -dask-core = ">=2025.5.1" # No distributed, tornado, etc. +dask-core = ">=2025.5.1" # No distributed, tornado, etc. # sparse = "*" # numba not available on Python 3.13t yet # jax = "*" # ml_dtypes not available on Python 3.13t yet @@ -245,7 +245,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = ["tests/*"] -disable_error_code = ["no-untyped-def"] # test(...) without -> None +disable_error_code = ["no-untyped-def"] # test(...) without -> None # pyright @@ -322,6 +322,7 @@ ignore = [ "N801", # Class name should use CapWords convention "N802", # Function name should be lowercase "N806", # Variable in function should be lowercase + "PLC0415", # `import` should be at the top-level of a file ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 1b552398..b2f57a6e 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -325,13 +325,15 @@ def quantile( xp = array_namespace(x, q) if xp is None else xp try: - import scipy + import scipy # type: ignore[import-untyped] from packaging import version # The quantile function in scipy 1.16 supports array API directly, no need # to delegate - if version.parse(scipy.__version__) >= version.parse("1.16"): - from scipy.stats import quantile as scipy_quantile + if version.parse(scipy.__version__) >= version.parse("1.16"): # pyright: ignore[reportUnknownArgumentType] + from scipy.stats import ( # type: ignore[import-untyped] + quantile as scipy_quantile, + ) return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method) except (ImportError, AttributeError): diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index cf1a894a..05db6251 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -268,7 +268,7 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ... for axis in range(-ndim, 0): sizes = {shape[axis] for shape in shapes if axis >= -len(shape)} # Dask uses NaN for unknown shape, which predates the Array API spec for None - none_size = None in sizes or math.nan in sizes + none_size = None in sizes or math.nan in sizes # noqa: PLW0177 sizes -= {1, None, math.nan} if len(sizes) > 1: msg = ( diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 15b59f40..3bcf1dda 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -1,6 +1,7 @@ """Quantile implementation.""" from types import ModuleType +from typing import cast from ._at import at from ._utils import _compat @@ -14,7 +15,7 @@ def quantile( /, *, axis: int | None = None, - keepdims: bool = None, # noqa: RUF013 + keepdims: bool | None = None, method: str = "linear", xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 @@ -25,26 +26,27 @@ def quantile( q_is_scalar = isinstance(q, int | float) if q_is_scalar: q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x)) + q_arr = cast(Array, q) if not xp.isdtype(x.dtype, ("integral", "real floating")): raise ValueError("`x` must have real dtype.") # noqa: EM101 - if not xp.isdtype(q.dtype, "real floating"): + if not xp.isdtype(q_arr.dtype, "real floating"): raise ValueError("`q` must have real floating dtype.") # noqa: EM101 # Promote to common dtype x = xp.astype(x, xp.float64) - q = xp.astype(q, xp.float64) - q = xp.asarray(q, device=_compat.device(x)) + q_arr = xp.astype(q_arr, xp.float64) + q_arr = xp.asarray(q_arr, device=_compat.device(x)) dtype = x.dtype axis_none = axis is None - ndim = max(x.ndim, q.ndim) + ndim = max(x.ndim, q_arr.ndim) if axis_none: x = xp.reshape(x, (-1,)) - q = xp.reshape(q, (-1,)) + q_arr = xp.reshape(q_arr, (-1,)) axis = 0 - elif not isinstance(axis, int): + elif not isinstance(axis, int): # pyright: ignore[reportUnnecessaryIsInstance] raise ValueError("`axis` must be an integer or None.") # noqa: EM101 elif axis >= ndim or axis < -ndim: raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101 @@ -63,15 +65,15 @@ def quantile( # Move axis to the end for easier processing y = xp.moveaxis(y, axis, -1) - if not (q_is_scalar or q.ndim == 0): - q = xp.moveaxis(q, axis, -1) + if not (q_is_scalar or q_arr.ndim == 0): + q_arr = xp.moveaxis(q_arr, axis, -1) n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) - res = _quantile_hf(y, q, n, method, xp) + res = _quantile_hf(y, q_arr, n, method, xp) # Handle NaN output for invalid q values - p_mask = (q > 1) | (q < 0) | xp.isnan(q) + p_mask = (q_arr > 1) | (q_arr < 0) | xp.isnan(q_arr) if xp.any(p_mask): res = xp.asarray(res, copy=True) res = at(res, p_mask).set(xp.nan) @@ -100,7 +102,7 @@ def _quantile_hf( y: Array, p: Array, n: Array, method: str, xp: ModuleType ) -> Array: # numpydoc ignore=PR01,RT01 """Helper function for Hyndman-Fan quantile method.""" - ms = { + ms: dict[str, Array | int | float] = { "inverted_cdf": 0, "averaged_inverted_cdf": 0, "closest_observation": -0.5, diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index 47511c75..d40fea1a 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -297,9 +297,9 @@ def temp_setattr(mod: ModuleType, name: str, func: object) -> None: # Enable using patch_lazy_xp_function not as a context manager temp_setattr = monkeypatch.setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType] - def iter_tagged() -> ( - Iterator[tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]] - ): + def iter_tagged() -> Iterator[ + tuple[ModuleType, str, Callable[..., Any], dict[str, Any]] + ]: for mod in mods: for name, func in mod.__dict__.items(): tags: dict[str, Any] | None = None diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 43ad5dc7..3d5119cc 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -237,7 +237,7 @@ def test_device(self, xp: ModuleType, library: Backend, device: Device): actual = capabilities(xp, device=device.type) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue] -class Wrapper(Generic[T]): +class Wrapper(Generic[T]): # noqa: PLW1641 """Trivial opaque wrapper. Must be pickleable.""" x: T @@ -263,7 +263,7 @@ def __reduce__(self) -> tuple[object, ...]: # Note: NotHashable() instances can be reduced to an # unserializable local class - class NotHashable: + class NotHashable: # noqa: PLW1641 @override def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and other.__dict__ == self.__dict__ diff --git a/tests/test_testing.py b/tests/test_testing.py index 7eda3fb6..054d951b 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -459,7 +459,7 @@ def test_patch_lazy_xp_functions_deprecated_monkeypatch( y = non_materializable5(x) xp_assert_equal(y, x) - with pytest.warns(DeprecationWarning): + with pytest.warns(DeprecationWarning, match="`monkeypatch` parameter"): _ = patch_lazy_xp_functions(request, monkeypatch, xp=xp) with pytest.raises(AssertionError, match=r"dask\.compute.* 1 times"): From 1ef7d5ede21a04e9c097dc95ccf3a54596957a4b Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Mon, 30 Jun 2025 18:47:33 +0100 Subject: [PATCH 13/20] improve style --- src/array_api_extra/_delegation.py | 3 ++- src/array_api_extra/_lib/_quantile.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b2f57a6e..0782de75 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -320,7 +320,8 @@ def quantile( "weibull", } if method not in methods: - raise ValueError(f"`method` must be one of {methods}") # noqa: EM102 + msg = f"`method` must be one of {methods}" + raise ValueError(msg) xp = array_namespace(x, q) if xp is None else xp diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 3bcf1dda..4eb84bd4 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -29,9 +29,11 @@ def quantile( q_arr = cast(Array, q) if not xp.isdtype(x.dtype, ("integral", "real floating")): - raise ValueError("`x` must have real dtype.") # noqa: EM101 + msg = "`x` must have real dtype." + raise ValueError(msg) if not xp.isdtype(q_arr.dtype, "real floating"): - raise ValueError("`q` must have real floating dtype.") # noqa: EM101 + msg = "`q` must have real floating dtype." + raise ValueError(msg) # Promote to common dtype x = xp.astype(x, xp.float64) @@ -47,14 +49,17 @@ def quantile( q_arr = xp.reshape(q_arr, (-1,)) axis = 0 elif not isinstance(axis, int): # pyright: ignore[reportUnnecessaryIsInstance] - raise ValueError("`axis` must be an integer or None.") # noqa: EM101 + msg = "`axis` must be an integer or None." + raise ValueError(msg) elif axis >= ndim or axis < -ndim: - raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101 + msg = "`axis` is not compatible with the shapes of the inputs." + raise ValueError(msg) else: axis = int(axis) if keepdims not in {None, True, False}: - raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101 + msg = "If specified, `keepdims` must be True or False." + raise ValueError(msg) if x.shape[axis] == 0: shape = list(x.shape) @@ -135,7 +140,7 @@ def _quantile_hf( # Broadcast indices to match y shape except for the last axis if y.ndim > 1: # Create broadcast shape for indices - broadcast_shape = list(y.shape[:-1]) + [1] # noqa: RUF005 + broadcast_shape = list(y.shape[:-1]).append(1) j = xp.broadcast_to(j, broadcast_shape) jp1 = xp.broadcast_to(jp1, broadcast_shape) g = xp.broadcast_to(g, broadcast_shape) From 440106f3c9eb65c72007dae185b4c4249d47a950 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Mon, 30 Jun 2025 18:54:25 +0100 Subject: [PATCH 14/20] fix list --- src/array_api_extra/_lib/_quantile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 4eb84bd4..2d46ec57 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -140,7 +140,7 @@ def _quantile_hf( # Broadcast indices to match y shape except for the last axis if y.ndim > 1: # Create broadcast shape for indices - broadcast_shape = list(y.shape[:-1]).append(1) + broadcast_shape = [*y.shape[:-1], 1] j = xp.broadcast_to(j, broadcast_shape) jp1 = xp.broadcast_to(jp1, broadcast_shape) g = xp.broadcast_to(g, broadcast_shape) From 3ef67271d4d2384430efefd7daf765454b058aeb Mon Sep 17 00:00:00 2001 From: Tim Head Date: Tue, 1 Jul 2025 14:39:34 +0200 Subject: [PATCH 15/20] Raise exception for invalid q values --- src/array_api_extra/_delegation.py | 2 +- src/array_api_extra/_lib/_quantile.py | 23 +++++++++++++---------- tests/test_funcs.py | 15 +++++++++------ 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 0782de75..78dafa8b 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -331,7 +331,7 @@ def quantile( # The quantile function in scipy 1.16 supports array API directly, no need # to delegate - if version.parse(scipy.__version__) >= version.parse("1.16"): # pyright: ignore[reportUnknownArgumentType] + if version.parse(scipy.__version__) >= version.parse("1.17"): # pyright: ignore[reportUnknownArgumentType] from scipy.stats import ( # type: ignore[import-untyped] quantile as scipy_quantile, ) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 2d46ec57..3d03c8d6 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -75,13 +75,12 @@ def quantile( n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y)) - res = _quantile_hf(y, q_arr, n, method, xp) + # Validate that q values are in the range [0, 1] + if xp.any((q_arr < 0) | (q_arr > 1)): + msg = "`q` must contain values between 0 and 1 inclusive." + raise ValueError(msg) - # Handle NaN output for invalid q values - p_mask = (q_arr > 1) | (q_arr < 0) | xp.isnan(q_arr) - if xp.any(p_mask): - res = xp.asarray(res, copy=True) - res = at(res, p_mask).set(xp.nan) + res = _quantile_hf(y, q_arr, n, method, xp) # Reshape per axis/keepdims if axis_none and keepdims: @@ -97,9 +96,10 @@ def quantile( res = xp.squeeze(res, axis=axis) # For scalar q, ensure we return a scalar result - if q_is_scalar and hasattr(res, "shape") and res.shape != (): - res = res[()] - + # if q_is_scalar and hasattr(res, "shape") and res.shape != (): + # res = res[()] + if res.ndim == 0: + return res[()] return res @@ -121,7 +121,10 @@ def _quantile_hf( m = ms[method] jg = p * n + m - 1 - j = xp.astype(jg // 1, xp.int64) # Convert to integer + # Convert both to integers, the type of j and n must be the same + # for us to be able to `xp.clip` them. + j = xp.astype(jg // 1, xp.int64) + n = xp.astype(n, xp.int64) g = jg % 1 if method == "inverted_cdf": diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 84ebbaf7..de4e8a89 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1213,13 +1213,16 @@ def test_edge_cases(self, xp: ModuleType): def test_invalid_q(self, xp: ModuleType): x = xp.asarray([1, 2, 3, 4, 5]) - # q > 1 should return NaN - actual = quantile(x, 1.5) - assert xp.isnan(actual) + # q > 1 should raise + with pytest.raises( + ValueError, match="`q` must contain values between 0 and 1 inclusive" + ): + quantile(x, 1.5) - # q < 0 should return NaN - actual = quantile(x, -0.5) - assert xp.isnan(actual) + with pytest.raises( + ValueError, match="`q` must contain values between 0 and 1 inclusive" + ): + quantile(x, -0.5) def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3, 4, 5], device=device) From 007a61f25e0c638d0f3efecccfc52d9cb35eaec3 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Tue, 1 Jul 2025 14:51:47 +0200 Subject: [PATCH 16/20] Tweak --- src/array_api_extra/_lib/_quantile.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 3d03c8d6..adf5200d 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -91,13 +91,9 @@ def quantile( # Move axis back to original position res = xp.moveaxis(res, -1, axis) - # Handle keepdims if not keepdims and res.shape[axis] == 1: res = xp.squeeze(res, axis=axis) - # For scalar q, ensure we return a scalar result - # if q_is_scalar and hasattr(res, "shape") and res.shape != (): - # res = res[()] if res.ndim == 0: return res[()] return res @@ -148,6 +144,7 @@ def _quantile_hf( jp1 = xp.broadcast_to(jp1, broadcast_shape) g = xp.broadcast_to(g, broadcast_shape) - return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis( + res = (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis( y, jp1, axis=-1 ) + return res From 1ccdac4d61b2784c009801aa8598f263bf6c0786 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Tue, 1 Jul 2025 15:03:53 +0200 Subject: [PATCH 17/20] noqa --- src/array_api_extra/_lib/_quantile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index adf5200d..1b2953a5 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -147,4 +147,4 @@ def _quantile_hf( res = (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis( y, jp1, axis=-1 ) - return res + return res # noqa: RET504 From ebcec0ef5ba9e6c82cf281046a26be0afd2c43ca Mon Sep 17 00:00:00 2001 From: Tim Head Date: Tue, 1 Jul 2025 15:05:08 +0200 Subject: [PATCH 18/20] More lint pleasure --- tests/test_funcs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 3847df29..d7d6bc56 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1217,12 +1217,12 @@ def test_invalid_q(self, xp: ModuleType): with pytest.raises( ValueError, match="`q` must contain values between 0 and 1 inclusive" ): - quantile(x, 1.5) + _ = quantile(x, 1.5) with pytest.raises( ValueError, match="`q` must contain values between 0 and 1 inclusive" ): - quantile(x, -0.5) + _ = quantile(x, -0.5) def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3, 4, 5], device=device) From 5c974a4fc7f64c5a2e30307b9136be6243e387e6 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Tue, 1 Jul 2025 15:30:41 +0200 Subject: [PATCH 19/20] Delegate to dask directly --- src/array_api_extra/_delegation.py | 3 +++ src/array_api_extra/_lib/_quantile.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 78dafa8b..9bca87ae 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -325,6 +325,9 @@ def quantile( xp = array_namespace(x, q) if xp is None else xp + if is_dask_namespace(xp): + return xp.quantile(x, q, axis=axis, keepdims=keepdims, method=method) + try: import scipy # type: ignore[import-untyped] from packaging import version diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 1b2953a5..4cf4f5ff 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -37,8 +37,7 @@ def quantile( # Promote to common dtype x = xp.astype(x, xp.float64) - q_arr = xp.astype(q_arr, xp.float64) - q_arr = xp.asarray(q_arr, device=_compat.device(x)) + q_arr = xp.asarray(q_arr, xp.float64, device=_compat.device(x)) dtype = x.dtype axis_none = axis is None From 477c916f47120ebda3208f64dc5a30628eb15ad4 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Tue, 1 Jul 2025 15:45:07 +0200 Subject: [PATCH 20/20] Fix --- src/array_api_extra/_lib/_quantile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 4cf4f5ff..9670d4dd 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -37,7 +37,7 @@ def quantile( # Promote to common dtype x = xp.astype(x, xp.float64) - q_arr = xp.asarray(q_arr, xp.float64, device=_compat.device(x)) + q_arr = xp.asarray(q_arr, dtype=xp.float64, device=_compat.device(x)) dtype = x.dtype axis_none = axis is None