Skip to content

Commit cfeb812

Browse files
committed
Generate a choice between different strings
1 parent 4a3126b commit cfeb812

File tree

4 files changed

+49
-1
lines changed

4 files changed

+49
-1
lines changed

outlines/text/generate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .continuation import continuation
2-
from .regex import float, integer, regex
2+
from .regex import choice, float, integer, regex

outlines/text/generate/regex.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,9 @@ def float(model, max_tokens: Optional[int] = None):
200200
201201
"""
202202
return Regex(model, r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))", max_tokens)
203+
204+
205+
def choice(model, choices: List[str], max_tokens: Optional[int] = None):
206+
"""Choose between different sequences."""
207+
regex_str = r"(" + r"|".join(choices) + r")"
208+
return Regex(model, regex_str, max_tokens)

tests/text/generate/test_integration_transfomers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,19 @@ def test_transformers_integration_float():
9898
float(generated)
9999

100100

101+
def test_transformers_integration_choice():
102+
rng = torch.Generator()
103+
rng.manual_seed(0)
104+
105+
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
106+
model = models.transformers(model_name, device="cpu")
107+
prompt = "Write a short sentence "
108+
sequence = generate.choice(model, ["test", "choice"])(prompt, rng=rng)
109+
110+
generated = sequence[len(prompt) :]
111+
assert generated == "test" or generated == "choice"
112+
113+
101114
def test_transformers_integration_with_pad_token():
102115
model_name = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"
103116
model = models.transformers(model_name, device="cpu")

tests/text/generate/test_regex.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,35 @@ def test_integer_proposal(input_ids, proposal):
106106
)
107107

108108

109+
def test_choice_proposal():
110+
model = Model()
111+
generator = generate.choice(model, ["1", "431a", "431A-"])
112+
logits = torch.ones(len(model.tokenizer.vocabulary))
113+
result = generator.create_proposal(torch.tensor([[]]), logits)
114+
assert torch.equal(
115+
result,
116+
torch.tensor(
117+
[[-math.inf, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf]]
118+
),
119+
)
120+
121+
result = generator.create_proposal(torch.tensor([[4]]), logits)
122+
assert torch.equal(
123+
result,
124+
torch.tensor(
125+
[[-math.inf, -math.inf, -math.inf, -math.inf, -math.inf, 1.0, 1.0]]
126+
),
127+
)
128+
129+
result = generator.create_proposal(torch.tensor([[4, 6]]), logits)
130+
assert torch.equal(
131+
result,
132+
torch.tensor(
133+
[[-math.inf, 1.0, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf]]
134+
),
135+
)
136+
137+
109138
@pytest.mark.parametrize(
110139
"input_ids, proposal",
111140
[

0 commit comments

Comments
 (0)