Skip to content

Commit 62712fa

Browse files
author
Johannes E. M. Mosig
committed
Fix test case
1 parent 953ab0a commit 62712fa

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

penzai/core/selectors.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
T = typing.TypeVar("T")
4040

4141

42-
def shift_negative_indices(indices: Iterable[int], shift: int) -> tuple[int, ...]:
42+
def _shift_negative_indices(indices: Iterable[int], shift: int) -> tuple[int, ...]:
4343
"""Adds `shift` to negative indices and leaves non-negative indices unchanged
4444
4545
Can be used to handle negative indices. For example, if we expect indices in
@@ -49,12 +49,24 @@ def shift_negative_indices(indices: Iterable[int], shift: int) -> tuple[int, ...
4949
shift_negative_indices([0, 3, -2], len(r))
5050
```
5151
52-
to get `(0, 3, 4)`
52+
to get `(0, 3, 4)`. The same can be achieved in more generality with
53+
54+
```py
55+
pz.select((0, 3, -2)) \
56+
.at_instances_of(int) \
57+
.where(lambda i: i < 0) \
58+
.apply(lambda i: i + shift)
59+
```
60+
61+
which is why this method is private to this module
5362
5463
Args:
5564
indices: The integers to shift
5665
shift: The offset to add to negative indices. Usually, this is the largest
5766
index + 1, i.e. the length of the range of indices
67+
68+
Returns:
69+
The indices as a tuple, with negative indices increased by `shift`
5870
"""
5971
maybe_shifted_indices = []
6072
for index in indices:
@@ -1382,7 +1394,7 @@ def pick_nth_selected(self, n: int | Sequence[int]) -> Selection:
13821394
else:
13831395
indices = n
13841396

1385-
indices = shift_negative_indices(indices, len(self.selected_by_path))
1397+
indices = _shift_negative_indices(indices, len(self.selected_by_path))
13861398

13871399
with _wrap_selection_errors(self):
13881400
keep = _InProgressSelectionBoundary

tests/core/selectors_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
@pytest.mark.parametrize(
3333
"input_indices, shift, expected_output",
3434
[
35-
((,), 1, (,)),
35+
((), 1, ()),
3636
([0, 3, -2], len(range(6)), (0, 3, 4)),
3737
]
3838
)
@@ -42,7 +42,7 @@ def test_shift_negative_indices(
4242
expected_output: tuple[int, ...],
4343
):
4444
assert (
45-
penzai.core.selectors.shift_negative_indices(input_indices, shift=shift)
45+
penzai.core.selectors._shift_negative_indices(input_indices, shift=shift)
4646
== expected_output
4747
)
4848

0 commit comments

Comments
 (0)