Skip to content

Commit 12b672b

Browse files
committed
Parse JSON schema into a generation schedule
1 parent c7ece19 commit 12b672b

File tree

8 files changed

+604
-8
lines changed

8 files changed

+604
-8
lines changed

.github/workflows/build_documentation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ jobs:
1414
with:
1515
persist-credentials: false
1616

17-
- name: Set up Python 3.9
17+
- name: Set up Python 3.10
1818
uses: actions/setup-python@v1
1919
with:
20-
python-version: 3.9
20+
python-version: "3.10"
2121

2222
- name: Build the documentation with Sphinx
2323
run: |

.github/workflows/publish_documentation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ jobs:
1515
- name: Checkout the branch
1616
uses: actions/[email protected]
1717

18-
- name: Set up Python 3.9
18+
- name: Set up Python 3.10
1919
uses: actions/setup-python@v1
2020
with:
21-
python-version: 3.9
21+
python-version: "3.10"
2222

2323
- name: Build the documentation with Sphinx
2424
run: |

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- name: Set up Python
1515
uses: actions/setup-python@v2
1616
with:
17-
python-version: 3.9
17+
python-version: "3.10"
1818
- name: Build sdist and wheel
1919
run: |
2020
python -m pip install -U pip

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- uses: actions/checkout@v3
1515
- uses: actions/setup-python@v4
1616
with:
17-
python-version: 3.9
17+
python-version: "3.10"
1818
- uses: pre-commit/[email protected]
1919

2020
tests:
@@ -24,7 +24,7 @@ jobs:
2424
- uses: actions/checkout@v3
2525
- uses: actions/setup-python@v4
2626
with:
27-
python-version: 3.9
27+
python-version: "3.10"
2828
- name: Set up test environment
2929
run: |
3030
python -m pip install --upgrade pip

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ channels:
88
- conda-forge
99
- huggingface
1010
dependencies:
11-
- python<3.11.0
11+
- python==3.10.0
1212
- jinja2
1313
- numpy
1414
- pillow

outlines/text/json_schema.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import itertools
2+
import json
3+
from typing import Dict
4+
5+
STRING = r'".*"'
6+
INTEGER = r"(0|[1-9][0-9]*)"
7+
NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
8+
BOOLEAN = r"(true|false)"
9+
NULL = r"null"
10+
11+
type_to_regex = {
12+
"string": STRING,
13+
"integer": INTEGER,
14+
"number": NUMBER,
15+
"boolean": BOOLEAN,
16+
"null": NULL,
17+
}
18+
19+
20+
def build_regex_from_schema(schema: str):
21+
"""Turn a JSON schema into a regex that matches any JSON object that follows
22+
this schema.
23+
24+
Parameters
25+
----------
26+
schema
27+
A string that contains the JSON schema.
28+
29+
Returns
30+
-------
31+
A string that contains a regular expression that matches any JSON object that
32+
follows the schema.
33+
34+
"""
35+
schedule = build_schedule_from_schema(schema)
36+
37+
regex = ""
38+
for step in schedule:
39+
regex += match_step_to_regex(step)
40+
41+
return regex
42+
43+
44+
def build_schedule_from_schema(schema: str):
45+
"""Turn a JSON schema into a regex that matches any JSON object that follows
46+
this schema.
47+
48+
JSON Schema is a declarative language that allows to annotate JSON documents
49+
with types and descriptions. These schemas can be generated from any Python
50+
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
51+
models. And by ensuring that the generation respects the schema we ensure
52+
that the output can be parsed into these objects.
53+
This function parses the provided schema and builds a generation schedule which
54+
mixes deterministic generation (fixed strings), and sampling with constraints.
55+
56+
Parameters
57+
----------
58+
schema
59+
A string that represents a JSON Schema.
60+
61+
Returns
62+
-------
63+
A generation schedule. A list of strings that represent the JSON
64+
schema's structure and regular expression that define the structure of
65+
the fields.
66+
67+
References
68+
----------
69+
.. [0] JSON Schema. https://json-schema.org/
70+
71+
"""
72+
schema = json.loads(schema)
73+
74+
# Find object definitions in the schema, if any
75+
definitions = {}
76+
if "$defs" in schema:
77+
for definition, annotation in schema["$defs"].items():
78+
definitions[f"#/$defs/{definition}"] = annotation
79+
80+
schema = expand_json_schema(schema, definitions)
81+
schedule = build_schedule_from_instance(schema)
82+
83+
# Concatenate adjacent strings
84+
reduced_schedule = [
85+
x
86+
for cls, grp in itertools.groupby(schedule, type)
87+
for x in (("".join(grp),) if cls is str else grp)
88+
]
89+
90+
return reduced_schedule
91+
92+
93+
def expand_json_schema(raw_schema: Dict, definitions: Dict):
94+
"""Replace references by their value in the JSON Schema.
95+
96+
This recursively follows the references to other schemas in case
97+
of nested models. Other schemas are stored under the "definitions"
98+
key in the schema of the top-level model.
99+
100+
Parameters
101+
---------
102+
raw_schema
103+
The raw JSON schema as a Python dictionary, possibly with definitions
104+
and references.
105+
definitions
106+
The currently known definitions.
107+
108+
Returns
109+
-------
110+
A dictionary that represents the flattened equivalent of the input
111+
JSON schema.
112+
113+
"""
114+
expanded_properties = {}
115+
116+
if "properties" in raw_schema:
117+
for name, value in raw_schema["properties"].items():
118+
if "$ref" in value: # if item is a single element
119+
expanded_properties[name] = expand_json_schema(
120+
definitions[value["$ref"]], definitions
121+
)
122+
elif "type" in value and value["type"] == "array": # if item is a list
123+
expanded_properties[name] = value
124+
if "$ref" in value["items"]:
125+
expanded_properties[name]["items"] = expand_json_schema(
126+
definitions[value["items"]["$ref"]], definitions
127+
)
128+
else:
129+
expanded_properties[name]["items"] = value["items"]
130+
else:
131+
expanded_properties[name] = value
132+
133+
return {
134+
"title": raw_schema["title"],
135+
"type": raw_schema["type"],
136+
"properties": expanded_properties,
137+
}
138+
139+
else:
140+
return raw_schema
141+
142+
143+
def build_schedule_from_instance(instance: Dict, indent: int = 0):
144+
"""Build a generation schedule from a instance.
145+
146+
This recursively follows the references to other instances.
147+
148+
Parameters
149+
----------
150+
instance
151+
An instance, can be the JSON schema itself.
152+
indent
153+
The current indentation level
154+
155+
Returns
156+
-------
157+
A generation schedule for the instance, a list of strings that represent
158+
the structure of the JSON schema and dictionaries that contain the
159+
instance definition.
160+
161+
"""
162+
schedule = []
163+
if "properties" in instance:
164+
schedule.append("{\n")
165+
schedule += build_schedule_from_instance(instance["properties"], indent + 2)
166+
if indent > 0:
167+
schedule.append(" " * indent)
168+
schedule.append("}")
169+
else:
170+
for i, (name, annotation) in enumerate(instance.items()):
171+
schedule.append(" " * indent)
172+
schedule.append(f'"{name}": ')
173+
if "anyOf" in annotation:
174+
schedule.append(annotation)
175+
elif annotation["type"] == "object":
176+
schedule += build_schedule_from_instance(annotation, indent)
177+
else:
178+
schedule.append(annotation)
179+
180+
# We cannot add commas after the last key-value pair in JSON
181+
if i == len(instance) - 1:
182+
schedule.append("\n")
183+
else:
184+
schedule.append(",\n")
185+
186+
return schedule
187+
188+
189+
def match_step_to_regex(step):
190+
"""Translate an element of a JSON schema to a regex that defines its content.
191+
192+
Parameters
193+
----------
194+
step:
195+
A string that represents the schema's structure, or a dictionnary
196+
that represents a field in the schema.
197+
198+
Returns
199+
-------
200+
A string that represents a regular expression that defines the value of the
201+
schedule's step.
202+
203+
"""
204+
match step:
205+
case str() as step:
206+
return step
207+
208+
case {"enum": choices, "type": "string"}:
209+
choices = [f'"{choice}"' for choice in choices]
210+
return f"({'|'.join(choices)})"
211+
case {"enum": choices}:
212+
choices = [str(choice) for choice in choices]
213+
return f"({'|'.join(choices)})"
214+
215+
case {"type": "array", "items": items}:
216+
item_regexes = match_step_to_regex(items)
217+
return rf"\[({item_regexes})(,({item_regexes}))*\]"
218+
219+
case {"type": "object"} as object:
220+
steps = build_schedule_from_schema(json.dumps(object))
221+
regex_str = ""
222+
for step in steps:
223+
regex_str += match_step_to_regex(step)
224+
return regex_str
225+
226+
case {"type": "string", "maxLength": max_length}:
227+
return f'".{{,{max_length}}}"'
228+
case {"type": "string", "minLength": min_length}:
229+
return f'".{{{min_length},}}"'
230+
231+
case {"type": field_type}:
232+
return type_to_regex[field_type]
233+
234+
case {"anyOf": choices}:
235+
regexes = [match_step_to_regex(choice) for choice in choices]
236+
return rf"({'|'.join(regexes)})"
237+
238+
case _:
239+
raise NotImplementedError

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dynamic = ["version"]
4242
test = [
4343
"diffusers",
4444
"pre-commit",
45+
"pydantic>=2.0",
4546
"pytest",
4647
"pytest-cov",
4748
"transformers",

0 commit comments

Comments
 (0)