Skip to content

Commit 9c74d7c

Browse files
Andrew Lapprlouf
authored andcommitted
Allow users to pass custom whitespace pattern for JSON-structured generation
1 parent 7c71199 commit 9c74d7c

File tree

6 files changed

+148
-35
lines changed

6 files changed

+148
-35
lines changed

docs/reference/json.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,21 @@ class User(BaseModel):
2626
id: int
2727

2828

29-
model = models.transformers("mistralai/Mistral-7B")
29+
model = models.transformers("mistralai/Mistral-7B-v0.1")
3030
generator = text.generate.json(model, User)
3131
result = generator("Create a user profile with the fields name, last_name and id")
3232
print(result)
3333
# User(name="John", last_name="Doe", id=11)
3434
```
3535

36+
!!! warning "JSON and whitespaces"
37+
38+
By default Outlines lets model choose the number of linebreaks and white spaces used to structure the JSON. Small models tend to struggle with this, in which case we recommend to set the value of the parameter `whitespace_pattern` to the empty string:
39+
40+
```python
41+
generator = text.generate.json(model, User, whitespace_pattern="")
42+
```
43+
3644
## From a function's signature
3745

3846
Outlines can infer the structure of the output from the signature of a function. The result is a dictionary, and can be passed directly to the function using the usual dictionary expansion syntax `**`:
@@ -44,7 +52,7 @@ from outlines import text
4452
def add(a: int, b: int):
4553
return a + b
4654

47-
model = models.transformers("mistralai/Mistral-7B")
55+
model = models.transformers("mistralai/Mistral-7B-v0.1")
4856
generator = text.generate.json(model, add)
4957
result = generator("Return two integers named a and b respectively. a is odd and b even.")
5058

outlines/fsm/json_schema.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import json
33
import re
4-
from typing import Callable, Union
4+
from typing import Callable, Optional, Union
55

66
from jsonschema.protocols import Validator
77
from pydantic import BaseModel, create_model
@@ -38,7 +38,9 @@
3838
}
3939

4040

41-
def build_regex_from_object(object: Union[str, Callable, BaseModel]):
41+
def build_regex_from_object(
42+
object: Union[str, Callable, BaseModel], whitespace_pattern: Optional[str] = None
43+
):
4244
"""Turn a JSON schema into a regex that matches any JSON object that follows
4345
this schema.
4446
@@ -54,6 +56,9 @@ def build_regex_from_object(object: Union[str, Callable, BaseModel]):
5456
----------
5557
schema
5658
A string that represents a JSON Schema.
59+
whitespace_pattern
60+
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
61+
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
5762
5863
Returns
5964
-------
@@ -83,10 +88,12 @@ def build_regex_from_object(object: Union[str, Callable, BaseModel]):
8388
resolver = registry.resolver()
8489

8590
content = schema.contents
86-
return to_regex(resolver, content)
91+
return to_regex(resolver, content, whitespace_pattern)
8792

8893

89-
def to_regex(resolver: Resolver, instance: dict):
94+
def to_regex(
95+
resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None
96+
):
9097
"""Translate a JSON Schema instance into a regex that validates the schema.
9198
9299
Note
@@ -105,8 +112,15 @@ def to_regex(resolver: Resolver, instance: dict):
105112
An object that resolves references to other instances within a schema
106113
instance
107114
The instance to translate
115+
whitespace_pattern
116+
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
117+
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
108118
"""
109119

120+
# set whitespace pattern
121+
if whitespace_pattern is None:
122+
whitespace_pattern = WHITESPACE
123+
110124
if "properties" in instance:
111125
regex = ""
112126
regex += r"\{"
@@ -120,12 +134,12 @@ def to_regex(resolver: Resolver, instance: dict):
120134
if any(is_required):
121135
last_required_pos = max([i for i, value in enumerate(is_required) if value])
122136
for i, (name, value) in enumerate(properties.items()):
123-
subregex = f'{WHITESPACE}"{name}"{WHITESPACE}:{WHITESPACE}'
124-
subregex += to_regex(resolver, value)
137+
subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}'
138+
subregex += to_regex(resolver, value, whitespace_pattern)
125139
if i < last_required_pos:
126-
subregex = f"{subregex}{WHITESPACE},"
140+
subregex = f"{subregex}{whitespace_pattern},"
127141
elif i > last_required_pos:
128-
subregex = f"{WHITESPACE},{subregex}"
142+
subregex = f"{whitespace_pattern},{subregex}"
129143
regex += subregex if is_required[i] else f"({subregex})?"
130144
# If no property is required, we have to create a possible pattern for each property in which
131145
# it's the last one necessarilly present. Then, we add the others as optional before and after
@@ -134,41 +148,47 @@ def to_regex(resolver: Resolver, instance: dict):
134148
else:
135149
property_subregexes = []
136150
for i, (name, value) in enumerate(properties.items()):
137-
subregex = f'{WHITESPACE}"{name}"{WHITESPACE}:{WHITESPACE}'
138-
subregex += to_regex(resolver, value)
151+
subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}'
152+
subregex += to_regex(resolver, value, whitespace_pattern)
139153
property_subregexes.append(subregex)
140154
possible_patterns = []
141155
for i in range(len(property_subregexes)):
142156
pattern = ""
143157
for subregex in property_subregexes[:i]:
144-
pattern += f"({subregex}{WHITESPACE},)?"
158+
pattern += f"({subregex}{whitespace_pattern},)?"
145159
pattern += property_subregexes[i]
146160
for subregex in property_subregexes[i + 1 :]:
147-
pattern += f"({WHITESPACE},{subregex})?"
161+
pattern += f"({whitespace_pattern},{subregex})?"
148162
possible_patterns.append(pattern)
149163
regex += f"({'|'.join(possible_patterns)})?"
150164

151-
regex += f"{WHITESPACE}" + r"\}"
165+
regex += f"{whitespace_pattern}" + r"\}"
152166

153167
return regex
154168

155169
# To validate against allOf, the given data must be valid against all of the
156170
# given subschemas.
157171
elif "allOf" in instance:
158-
subregexes = [to_regex(resolver, t) for t in instance["allOf"]]
172+
subregexes = [
173+
to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"]
174+
]
159175
subregexes_str = [f"{subregex}" for subregex in subregexes]
160176
return rf"({''.join(subregexes_str)})"
161177

162178
# To validate against `anyOf`, the given data must be valid against
163179
# any (one or more) of the given subschemas.
164180
elif "anyOf" in instance:
165-
subregexes = [to_regex(resolver, t) for t in instance["anyOf"]]
181+
subregexes = [
182+
to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"]
183+
]
166184
return rf"({'|'.join(subregexes)})"
167185

168186
# To validate against oneOf, the given data must be valid against exactly
169187
# one of the given subschemas.
170188
elif "oneOf" in instance:
171-
subregexes = [to_regex(resolver, t) for t in instance["oneOf"]]
189+
subregexes = [
190+
to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"]
191+
]
172192

173193
xor_patterns = []
174194
# json schema validation ensured there is no overlapping schemas in oneOf
@@ -195,7 +215,7 @@ def to_regex(resolver: Resolver, instance: dict):
195215
elif "$ref" in instance:
196216
path = f"{instance['$ref']}"
197217
instance = resolver.lookup(path).contents
198-
return to_regex(resolver, instance)
218+
return to_regex(resolver, instance, whitespace_pattern)
199219

200220
# The type keyword may either be a string or an array:
201221
# - If it's a string, it is the name of one of the basic types.
@@ -254,14 +274,14 @@ def to_regex(resolver: Resolver, instance: dict):
254274
num_repeats = rf"{{{max(min_items - 1, 0)},}}"
255275
else:
256276
if max_items < 1:
257-
return rf"\[{WHITESPACE}\]"
277+
return rf"\[{whitespace_pattern}\]"
258278
num_repeats = rf"{{{max(min_items - 1, 0)},{max_items - 1}}}"
259279

260280
allow_empty = "?" if min_items == 0 else ""
261281

262282
if "items" in instance:
263-
items_regex = to_regex(resolver, instance["items"])
264-
return rf"\[{WHITESPACE}(({items_regex})(,{WHITESPACE}({items_regex})){num_repeats}){allow_empty}{WHITESPACE}\]"
283+
items_regex = to_regex(resolver, instance["items"], whitespace_pattern)
284+
return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]"
265285
else:
266286
# Here we need to make the choice to exclude generating list of objects
267287
# if the specification of the object is not given, even though a JSON
@@ -273,8 +293,8 @@ def to_regex(resolver: Resolver, instance: dict):
273293
{"type": "integer"},
274294
{"type": "string"},
275295
]
276-
regexes = [to_regex(resolver, t) for t in types]
277-
return rf"\[{WHITESPACE}({'|'.join(regexes)})(,{WHITESPACE}({'|'.join(regexes)})){num_repeats}){allow_empty}{WHITESPACE}\]"
296+
regexes = [to_regex(resolver, t, whitespace_pattern) for t in types]
297+
return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}){allow_empty}{whitespace_pattern}\]"
278298

279299
elif instance_type == "boolean":
280300
return type_to_regex["boolean"]
@@ -287,7 +307,9 @@ def to_regex(resolver: Resolver, instance: dict):
287307
# if the specification of the object is not give, even though a JSON
288308
# object that contains an object here would be valid under the specification.
289309
regexes = [
290-
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
310+
to_regex(resolver, {"type": t}, whitespace_pattern)
311+
for t in instance_type
312+
if t != "object"
291313
]
292314
return rf"({'|'.join(regexes)})"
293315

outlines/generate/json.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json as pyjson
22
from functools import singledispatch
3-
from typing import Callable, Union
3+
from typing import Callable, Optional, Union
44

55
from pydantic import BaseModel
66

@@ -14,21 +14,49 @@
1414

1515
@singledispatch
1616
def json(
17-
model, schema_object: Union[str, object, Callable], sampler: Sampler = multinomial()
17+
model,
18+
schema_object: Union[str, object, Callable],
19+
sampler: Sampler = multinomial(),
20+
whitespace_pattern: Optional[str] = None,
1821
) -> SequenceGenerator:
22+
"""
23+
Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
24+
25+
Parameters
26+
----------
27+
model:
28+
An instance of `Transformer` that represents a model from the
29+
`transformers` library.
30+
schema_object:
31+
The JSON Schema to generate data for. Can be a JSON string, a Pydantic model, or a callable
32+
that returns a JSON schema.
33+
max_tokens:
34+
The maximum number of tokens to generate.
35+
sampler:
36+
The sampling algorithm to use to generate token ids from the logits
37+
distribution.
38+
whitespace_pattern
39+
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
40+
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
41+
42+
Returns
43+
-------
44+
A `SequenceGenerator` instance that generates text constrained by the schema_object and
45+
transforms the result if BaseModel is used.
46+
"""
1947
if isinstance(schema_object, type(BaseModel)):
2048
schema = pyjson.dumps(schema_object.model_json_schema())
21-
regex_str = build_regex_from_object(schema)
49+
regex_str = build_regex_from_object(schema, whitespace_pattern)
2250
generator = regex(model, regex_str, sampler)
2351
generator.format_sequence = lambda x: schema_object.parse_raw(x)
2452
elif callable(schema_object):
2553
schema = pyjson.dumps(get_schema_from_signature(schema_object))
26-
regex_str = build_regex_from_object(schema)
54+
regex_str = build_regex_from_object(schema, whitespace_pattern)
2755
generator = regex(model, regex_str, sampler)
2856
generator.format_sequence = lambda x: pyjson.loads(x)
2957
elif isinstance(schema_object, str):
3058
schema = schema_object
31-
regex_str = build_regex_from_object(schema)
59+
regex_str = build_regex_from_object(schema, whitespace_pattern)
3260
generator = regex(model, regex_str, sampler)
3361
generator.format_sequence = lambda x: pyjson.loads(x)
3462
else:

outlines/serve/vllm.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import math
44
from collections import defaultdict
5-
from typing import DefaultDict, List
5+
from typing import DefaultDict, List, Optional
66

77
import torch
88

@@ -105,18 +105,20 @@ def convert_token_to_string(token: str) -> str:
105105

106106

107107
class JSONLogitsProcessor(RegexLogitsProcessor):
108-
def __init__(self, schema, llm):
109-
"""Compile the FSM that drives the JSON-structured generation.
108+
def __init__(self, schema, llm, whitespace_pattern: Optional[str] = None):
109+
"""Compile the FSM that drives the JSON-guided generation.
110110
111111
Parameters
112112
----------
113113
schema
114114
A JSON schema that encodes the structure we want the model to generate
115115
llm
116116
An instance of `vllm.LLM`
117-
117+
whitespace_pattern
118+
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
119+
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
118120
"""
119121
if isinstance(schema, dict):
120122
schema = json.dumps(schema)
121-
regex_string = build_regex_from_object(schema)
123+
regex_string = build_regex_from_object(schema, whitespace_pattern)
122124
super().__init__(regex_string, llm)

tests/fsm/test_json_schema.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,3 +540,32 @@ def test_format(schema, regex, examples):
540540
assert match.span() == (0, len(string))
541541
else:
542542
assert match is None
543+
544+
545+
@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]?", "abc"])
546+
def test_json_schema_custom_whitespace_pattern(whitespace_pattern):
547+
"""assert whitespace_pattern setting respected"""
548+
549+
class MockModel(BaseModel):
550+
foo: int
551+
bar: str
552+
553+
# assert any ws pattern can be used
554+
if whitespace_pattern == "abc":
555+
build_regex_from_object(MockModel, whitespace_pattern)
556+
return
557+
558+
pattern = build_regex_from_object(MockModel, whitespace_pattern)
559+
560+
mock_result_mult_ws = (
561+
"""{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}"""
562+
)
563+
mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}"""
564+
565+
match_default_ws = re.fullmatch(pattern, mock_result_mult_ws)
566+
if whitespace_pattern is None:
567+
assert match_default_ws
568+
else:
569+
assert match_default_ws is None
570+
571+
assert re.fullmatch(pattern, mock_result_maybe_ws)

tests/generate/test_integration_transfomers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,30 @@ def test_transformers_logits_vocab_size():
543543
assert sequence == "False"
544544

545545

546+
def test_transformers_json_custom_ws():
547+
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
548+
model = models.transformers(model_name, device="cpu")
549+
prompt = "Output some JSON with newlines" # try to force model to use newlines
550+
551+
schema = """{
552+
"title": "spam",
553+
"type": "object",
554+
"properties": {
555+
"foo" : {"type": "integer"},
556+
"bar": {"type": "integer"}
557+
},
558+
"required": ["foo", "bar"]
559+
}
560+
"""
561+
562+
rng = torch.Generator()
563+
rng.manual_seed(0)
564+
565+
generator = generate.json(model, schema, whitespace_pattern=r"[ ]?")
566+
generator.format_sequence = lambda x: x # patch to return raw text
567+
assert "\n" not in generator(prompt, max_tokens=500, rng=rng)
568+
569+
546570
def test_transformers_reduced_vocabulary_caching():
547571
tokenizer = TransformerTokenizer("gpt2")
548572
tokenizer2 = TransformerTokenizer("gpt2")

0 commit comments

Comments
 (0)