Skip to content

Commit dce9265

Browse files
committed
Add integration test for JSON schema
1 parent eb748a1 commit dce9265

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

outlines/generate/api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,5 +248,11 @@ def json(
248248
regex_str = build_regex_from_object(schema)
249249
generator = regex(model, regex_str, max_tokens, sampler)
250250
generator.format_sequence = lambda x: pyjson.loads(x)
251+
else:
252+
raise ValueError(
253+
f"Cannot parse schema {schema_object}. The schema must be either "
254+
+ "a Pydantic object, a function or a string that contains the JSON "
255+
+ "Schema specification"
256+
)
251257

252258
return generator

tests/generate/test_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121

2222
def test_sequence_generator_class():
2323
class MockFSM:
24-
def next_state(self, state, next_token_ids):
24+
def next_state(self, state, next_token_ids, _):
2525
return 4
2626

27-
def allowed_token_ids(self, _):
27+
def allowed_token_ids(self, *_):
2828
return [4]
2929

3030
def is_final_state(self, _, idx=0):

tests/generate/test_integration_transfomers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,30 @@ class Spam(BaseModel):
276276
assert len(result.spam) <= 10
277277

278278

279+
def test_transformers_json_schema():
280+
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
281+
model = models.transformers(model_name, device="cpu")
282+
prompt = "Output some JSON "
283+
284+
schema = """{
285+
"title": "spam",
286+
"type": "object",
287+
"properties": {
288+
"foo" : {"type": "integer"},
289+
"bar": {"type": "string", "maxLength": 4}
290+
}
291+
}
292+
"""
293+
294+
rng = torch.Generator()
295+
rng.manual_seed(0) # make sure that `bar` is not an int
296+
297+
result = generate.json(model, schema, max_tokens=500)(prompt, rng=rng)
298+
assert isinstance(result, dict)
299+
assert isinstance(result["foo"], int)
300+
assert isinstance(result["bar"], str)
301+
302+
279303
def test_transformers_json_batch():
280304
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
281305
model = models.transformers(model_name, device="cpu")

0 commit comments

Comments
 (0)