Skip to content

Commit 95a9d97

Browse files
Add another missing walk_fsm condition
The full-match option handling was not correct for scanned/walked strings with valid transitions but not ending in a final state.
1 parent 3cf3f96 commit 95a9d97

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

outlines/text/fsm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ def _walk_fsm(
268268

269269
accepted_states.append(_nonoptional(state))
270270

271+
if full_match and last_final_idx - 1 != i:
272+
return numba.typed.List.empty_list(numba.int64)
273+
271274
return accepted_states
272275

273276

@@ -305,6 +308,9 @@ def walk_fsm(
305308

306309
accepted_states.append(state)
307310

311+
if full_match and last_final_idx - 1 != i:
312+
return []
313+
308314
return accepted_states
309315

310316

@@ -376,7 +382,7 @@ def process_token_string(
376382
res = set()
377383
vocab_string_len = len(token)
378384

379-
for end_idx, state_seq in find_partial_matches(fsm_info, token):
385+
for end_idx, state_seq in find_partial_matches(fsm_info, token, full_match=False):
380386
if end_idx is not None and end_idx < vocab_string_len - 1:
381387
continue
382388

@@ -603,6 +609,7 @@ def state_scan_tokens(
603609
fsm_finals,
604610
token,
605611
start_state,
612+
False,
606613
)
607614

608615
if state_seq is not None and len(state_seq) < len(token):

tests/text/test_fsm.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ def test_walk_fsm(function):
6161
res = tuple(function(regex_fsm, "0", 1, full_match=True))
6262
assert res == tuple()
6363

64+
regex_pattern = interegular.parse_pattern("0|[1-9][2-9]+")
65+
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
66+
67+
res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=True))
68+
assert res == tuple()
69+
70+
res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=False))
71+
assert res == (2,)
72+
73+
res = tuple(function(regex_fsm, "12", regex_fsm.initial, full_match=True))
74+
assert res == (2, 3)
75+
6476
pattern = interegular.parse_pattern(r"(?:[^\W\d]\w*|[\t \x0c]+)")
6577
fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce())
6678

@@ -90,19 +102,19 @@ def to_python(res):
90102

91103
res = to_python(find_partial_matches(def_fsm, "def"))
92104
assert res == {(2, (0, 1, 2, 3))}
93-
res = to_python(find_partial_matches(def_fsm, "de"))
105+
res = to_python(find_partial_matches(def_fsm, "de", full_match=False))
94106
assert res == {(1, (0, 1, 2))}
95-
res = to_python(find_partial_matches(def_fsm, "d"))
107+
res = to_python(find_partial_matches(def_fsm, "d", full_match=False))
96108
assert res == {(0, (0, 1))}
97109
res = to_python(find_partial_matches(def_fsm, ""))
98110
assert res == set()
99111
res = to_python(find_partial_matches(def_fsm, "df"))
100112
assert res == set()
101-
res = to_python(find_partial_matches(def_fsm, "ef"))
113+
res = to_python(find_partial_matches(def_fsm, "ef", full_match=False))
102114
assert res == {(1, (1, 2, 3))}
103-
res = to_python(find_partial_matches(def_fsm, "e"))
115+
res = to_python(find_partial_matches(def_fsm, "e", full_match=False))
104116
assert res == {(0, (1, 2))}
105-
res = to_python(find_partial_matches(def_fsm, "f"))
117+
res = to_python(find_partial_matches(def_fsm, "f", full_match=False))
106118
assert res == {(0, (2, 3))}
107119
res = to_python(find_partial_matches(def_fsm, "ef foo", full_match=False))
108120
assert res == {(1, (1, 2, 3))}
@@ -112,7 +124,7 @@ def to_python(res):
112124
assert res == {(2, (0, 1, 2, 3))}
113125

114126
# `NAME` can have multiple start states for this input
115-
res = to_python(find_partial_matches(name_fsm, "d"))
127+
res = to_python(find_partial_matches(name_fsm, "d", full_match=False))
116128
assert res == {(0, (0, 1)), (0, (1, 1))}
117129
# Not this case
118130
res = to_python(find_partial_matches(name_fsm, "1d"))
@@ -133,7 +145,7 @@ def to_python(res):
133145

134146
float_fsm = float_fsm.fsm_info
135147

136-
res = to_python(find_partial_matches(float_fsm, "."))
148+
res = to_python(find_partial_matches(float_fsm, ".", full_match=False))
137149
assert res == {(0, (3, 5)), (0, (4, 5)), (0, (0, 2))}
138150

139151
joins_fsm, _ = make_deterministic_fsm(

0 commit comments

Comments
 (0)