Skip to content

Commit ded2690

Browse files
committed
Generate from JSON schema with JSON class
1 parent 12b672b commit ded2690

File tree

3 files changed

+138
-2
lines changed

3 files changed

+138
-2
lines changed

outlines/text/generate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .continuation import continuation
2-
from .regex import choice, float, integer, regex
2+
from .regex import choice, float, integer, json, regex

outlines/text/generate/regex.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import collections
22
import math
3-
from typing import List, Optional, Tuple
3+
from json import dumps
4+
from typing import List, Optional, Tuple, Union
45

56
import interegular
67
import torch
8+
from pydantic import BaseModel
79

810
from outlines.text.generate.continuation import Continuation
11+
from outlines.text.json_schema import build_regex_from_schema
912
from outlines.text.parsing import find_partial_matches, map_partial_states_to_vocab
1013

1114

@@ -204,3 +207,24 @@ def choice(model, choices: List[str], max_tokens: Optional[int] = None):
204207
"""Choose between different sequences."""
205208
regex_str = r"(" + r"|".join(choices) + r")"
206209
return Regex(model, regex_str, max_tokens)
210+
211+
212+
def json(model, schema: Union[str, BaseModel], max_tokens: Optional[int] = None):
213+
"""Generate a text sequence that follows a JSON schema.
214+
215+
Parameters
216+
---------
217+
model
218+
The model to use to computes the next-token logits.
219+
schema
220+
The JSON schema, or Pydantic model, that guides the generation.
221+
max_tokens
222+
The maximum number of tokens to generate at each step.
223+
224+
"""
225+
if isinstance(schema, type(BaseModel)):
226+
schema = dumps(schema.model_json_schema())
227+
228+
regex_str = build_regex_from_schema(schema)
229+
230+
return Regex(model, regex_str, max_tokens)

tests/text/generate/test_integration_transfomers.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import json
12
import re
3+
from enum import Enum
4+
from typing import List, Union
25

36
import pytest
47
import torch
8+
from pydantic import BaseModel, constr
59

610
import outlines.models as models
711
import outlines.text.generate as generate
@@ -113,3 +117,111 @@ def test_transformers_integration_with_pad_token():
113117
model = models.transformers(model_name, device="cpu")
114118
assert model.tokenizer.pad_token_id == 1
115119
assert model.tokenizer.pad_token == "<pad>"
120+
121+
122+
def test_transformers_json_basic():
123+
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
124+
model = models.transformers(model_name, device="cpu")
125+
prompt = "Output some JSON "
126+
127+
class Spam(BaseModel):
128+
foo: int
129+
bar: float
130+
spam: constr(max_length=10)
131+
fuzz: bool
132+
133+
rng = torch.Generator()
134+
rng.manual_seed(0) # make sure that `bar` is not an int
135+
136+
sequence = generate.json(model, Spam, max_tokens=1000)(prompt, rng=rng)
137+
parsed = json.loads(sequence)
138+
assert isinstance(parsed["foo"], int)
139+
assert isinstance(parsed["bar"], float)
140+
assert isinstance(parsed["spam"], str)
141+
assert isinstance(parsed["fuzz"], bool)
142+
assert len(parsed["spam"]) == 10
143+
144+
145+
def test_transformers_json_str_enum():
146+
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
147+
model = models.transformers(model_name, device="cpu")
148+
prompt = "Output some JSON "
149+
150+
rng = torch.Generator()
151+
rng.manual_seed(0)
152+
153+
class Name(str, Enum):
154+
john = "John"
155+
marc = "Marc"
156+
michel = "Michel"
157+
158+
class User(BaseModel):
159+
user_id: int
160+
name: Name
161+
162+
sequence = generate.json(model, User)(prompt, rng=rng)
163+
parsed = json.loads(sequence)
164+
assert isinstance(parsed["user_id"], int)
165+
assert parsed["name"] in ["John", "Marc", "Michel"]
166+
167+
168+
def test_transformers_json_int_enum():
169+
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
170+
model = models.transformers(model_name, device="cpu")
171+
prompt = "Output some JSON "
172+
173+
rng = torch.Generator()
174+
rng.manual_seed(0)
175+
176+
class Id(int, Enum):
177+
one = 1
178+
two = 2
179+
180+
class User(BaseModel):
181+
user_id: Id
182+
183+
sequence = generate.json(model, User)(prompt, rng=rng)
184+
parsed = json.loads(sequence)
185+
assert isinstance(parsed["user_id"], int)
186+
assert parsed["user_id"] in [1, 2]
187+
188+
189+
def test_transformers_json_array():
190+
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
191+
model = models.transformers(model_name, device="cpu")
192+
prompt = "Output some JSON "
193+
194+
class User(BaseModel):
195+
user_id: int
196+
value: List[float]
197+
198+
rng = torch.Generator()
199+
rng.manual_seed(0)
200+
201+
sequence = generate.json(model, User)(prompt, rng=rng)
202+
parsed = json.loads(sequence)
203+
assert isinstance(parsed["user_id"], int)
204+
assert isinstance(parsed["value"], list)
205+
for value in parsed["value"]:
206+
assert isinstance(value, float) or isinstance(value, int)
207+
208+
209+
def test_transformers_json_union():
210+
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
211+
model = models.transformers(model_name, device="cpu")
212+
prompt = "Output some JSON "
213+
214+
class Spam(BaseModel):
215+
foo: int
216+
bar: Union[constr(max_length=10), float]
217+
218+
rng = torch.Generator()
219+
rng.manual_seed(4)
220+
221+
sequence = generate.json(model, Spam, max_tokens=100)(prompt, rng=rng)
222+
parsed = json.loads(sequence)
223+
assert (
224+
isinstance(parsed["bar"], int)
225+
or isinstance(parsed["bar"], float)
226+
or isinstance(parsed["bar"], str)
227+
)

0 commit comments

Comments
 (0)