Skip to content

Commit cfb6b5e

Browse files
committed
Faster unique arrays
1 parent d8cbc06 commit cfb6b5e

File tree

4 files changed

+54
-25
lines changed

4 files changed

+54
-25
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
RELEASE_TYPE: patch
2+
3+
This patch makes unique :func:`~hypothesis.extra.numpy.arrays` much more
4+
efficient, especially when there are only a few valid elements - such as
5+
for eight-bit integers (:issue:`3066`).

hypothesis-python/src/hypothesis/extra/numpy.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,7 @@ def __init__(self, element_strategy, shape, dtype, fill, unique):
186186
self.unique = unique
187187
self._check_elements = dtype.kind not in ("O", "V")
188188

189-
def set_element(self, data, result, idx, strategy=None):
190-
strategy = strategy or self.element_strategy
191-
val = data.draw(strategy)
189+
def set_element(self, val, result, idx, *, fill=False):
192190
try:
193191
result[idx] = val
194192
except TypeError as err:
@@ -197,6 +195,7 @@ def set_element(self, data, result, idx, strategy=None):
197195
f"{result.dtype!r} - possible mismatch of time units in dtypes?"
198196
) from err
199197
if self._check_elements and val != result[idx] and val == val:
198+
strategy = self.fill if fill else self.element_strategy
200199
raise InvalidArgument(
201200
"Generated array element %r from %r cannot be represented as "
202201
"dtype %r - instead it becomes %r (type %r). Consider using a more "
@@ -229,28 +228,17 @@ def do_draw(self, data):
229228
# generate a fully dense array with a freshly drawn value for each
230229
# entry.
231230
if self.unique:
232-
seen = set()
233-
elements = cu.many(
234-
data,
231+
elems = st.lists(
232+
self.element_strategy,
235233
min_size=self.array_size,
236234
max_size=self.array_size,
237-
average_size=self.array_size,
235+
unique=True,
238236
)
239-
i = 0
240-
while elements.more():
241-
# We assign first because this means we check for
242-
# uniqueness after numpy has converted it to the relevant
243-
# type for us. Because we don't increment the counter on
244-
# a duplicate we will overwrite it on the next draw.
245-
self.set_element(data, result, i)
246-
if result[i] not in seen:
247-
seen.add(result[i])
248-
i += 1
249-
else:
250-
elements.reject()
237+
for i, v in enumerate(data.draw(elems)):
238+
self.set_element(v, result, i)
251239
else:
252240
for i in range(len(result)):
253-
self.set_element(data, result, i)
241+
self.set_element(data.draw(self.element_strategy), result, i)
254242
else:
255243
# We draw numpy arrays as "sparse with an offset". We draw a
256244
# collection of index assignments within the array and assign
@@ -277,7 +265,7 @@ def do_draw(self, data):
277265
if not needs_fill[i]:
278266
elements.reject()
279267
continue
280-
self.set_element(data, result, i)
268+
self.set_element(data.draw(self.element_strategy), result, i)
281269
if self.unique:
282270
if result[i] in seen:
283271
elements.reject()
@@ -300,7 +288,7 @@ def do_draw(self, data):
300288
one_element = np.zeros(
301289
shape=1, dtype=object if unsized_string_dtype else self.dtype
302290
)
303-
self.set_element(data, one_element, 0, self.fill)
291+
self.set_element(data.draw(self.fill), one_element, 0, fill=True)
304292
if unsized_string_dtype:
305293
one_element = one_element.astype(self.dtype)
306294
fill_value = one_element[0]

hypothesis-python/src/hypothesis/strategies/_internal/core.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,12 @@
9595
from hypothesis.strategies._internal.functions import FunctionStrategy
9696
from hypothesis.strategies._internal.lazy import LazyStrategy
9797
from hypothesis.strategies._internal.misc import just, none, nothing
98-
from hypothesis.strategies._internal.numbers import Real, floats, integers
98+
from hypothesis.strategies._internal.numbers import (
99+
IntegersStrategy,
100+
Real,
101+
floats,
102+
integers,
103+
)
99104
from hypothesis.strategies._internal.recursive import RecursiveStrategy
100105
from hypothesis.strategies._internal.shared import SharedStrategy
101106
from hypothesis.strategies._internal.strategies import (
@@ -283,6 +288,20 @@ def lists(
283288
tuple_suffixes = TupleStrategy(elements.element_strategies[1:])
284289
elements = elements.element_strategies[0]
285290

291+
# UniqueSampledListStrategy offers a substantial performance improvement for
292+
# unique arrays with few possible elements, e.g. of eight-bit integer types.
293+
if (
294+
isinstance(elements, IntegersStrategy)
295+
and None not in (elements.start, elements.end)
296+
and (elements.end - elements.start) <= 255
297+
):
298+
elements = SampledFromStrategy(
299+
sorted(range(elements.start, elements.end + 1), key=abs)
300+
if elements.end < 0 or elements.start > 0
301+
else list(range(0, elements.end + 1))
302+
+ list(range(-1, elements.start - 1, -1))
303+
)
304+
286305
if isinstance(elements, SampledFromStrategy):
287306
element_count = len(elements.elements)
288307
if min_size > element_count:

hypothesis-python/tests/numpy/test_gen_data.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytest
2222

2323
from hypothesis import HealthCheck, assume, given, note, settings, strategies as st
24-
from hypothesis.errors import InvalidArgument, Unsatisfiable
24+
from hypothesis.errors import InvalidArgument
2525
from hypothesis.extra import numpy as nps
2626

2727
from tests.common.debug import find_any, minimal
@@ -251,7 +251,7 @@ def test_array_values_are_unique(arr):
251251

252252
def test_cannot_generate_unique_array_of_too_many_elements():
253253
strat = nps.arrays(dtype=int, elements=st.integers(0, 5), shape=10, unique=True)
254-
with pytest.raises(Unsatisfiable):
254+
with pytest.raises(InvalidArgument):
255255
strat.example()
256256

257257

@@ -274,6 +274,23 @@ def test_generates_all_values_for_unique_array(arr):
274274
assert len(set(arr)) == len(arr)
275275

276276

277+
@given(nps.arrays(dtype="int8", shape=255, unique=True))
278+
def test_efficiently_generates_all_unique_array(arr):
279+
# Avoids the birthday paradox with UniqueSampledListStrategy
280+
assert len(set(arr)) == len(arr)
281+
282+
283+
@given(st.data(), st.integers(-100, 100), st.integers(1, 100))
284+
def test_array_element_rewriting(data, start, size):
285+
arr = nps.arrays(
286+
dtype=np.dtype("int64"),
287+
shape=size,
288+
elements=st.integers(start, start + size - 1),
289+
unique=True,
290+
)
291+
assert set(data.draw(arr)) == set(range(start, start + size))
292+
293+
277294
def test_may_fill_with_nan_when_unique_is_set():
278295
find_any(
279296
nps.arrays(

0 commit comments

Comments
 (0)