Skip to content

Commit e9485cf

Browse files
authored
Add json call with multi-function enums (#1277)
This PR aims at solving #1217
1 parent 5608dd8 commit e9485cf

File tree

5 files changed

+136
-4
lines changed

5 files changed

+136
-4
lines changed

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,33 @@ print(add(**result))
300300

301301
A great advantage of passing functions directly to specify the structure is that the structure of the LLM will change with the function's definition. No need to change the code at several places!
302302

303+
You can also embed various functions into an enum to generate params:
304+
305+
```python
306+
from enum import Enum
307+
from functools import partial
308+
309+
import outlines
310+
311+
312+
def add(a: int, b: int) -> int:
313+
return a + b
314+
315+
def mul(c: float, d: float) -> float:
316+
return c * d
317+
318+
class Operation(Enum):
319+
add = partial(add)
320+
mul = partial(mul)
321+
322+
model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
323+
generator = outlines.generate.json(model, add)
324+
result = generator("Return json with two float named c and d respectively. c is negative and d greater than 1.0.")
325+
326+
print(result)
327+
# {'c': -3.14, 'd': 1.5}
328+
```
329+
303330
## Prompting
304331

305332
Building prompts can get messy. **Outlines** makes it easier to write and manage

outlines/fsm/json_schema.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import re
44
import warnings
5+
from enum import Enum
56
from typing import Callable, Optional, Tuple, Type, Union
67

78
from jsonschema.protocols import Validator
@@ -306,6 +307,8 @@ def to_regex(
306307
for choice in instance["enum"]:
307308
if type(choice) in [int, float, bool, type(None), str]:
308309
choices.append(re.escape(json.dumps(choice)))
310+
elif isinstance(choice, dict):
311+
choices.append(to_regex(resolver, choice, whitespace_pattern))
309312
else:
310313
raise TypeError(f"Unsupported data type in enum: {type(choice)}")
311314
return f"({'|'.join(choices)})"
@@ -524,7 +527,7 @@ def to_regex(
524527
)
525528

526529

527-
def get_schema_from_signature(fn: Callable) -> str:
530+
def get_schema_from_signature(fn: Callable) -> dict:
528531
"""Turn a function signature into a JSON schema.
529532
530533
Every JSON object valid to the output JSON Schema can be passed
@@ -550,3 +553,16 @@ def get_schema_from_signature(fn: Callable) -> str:
550553
model = create_model(fn_name, **arguments)
551554

552555
return model.model_json_schema()
556+
557+
558+
def get_schema_from_enum(myenum: type[Enum]) -> dict:
559+
if len(myenum) == 0:
560+
raise ValueError(
561+
f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)"
562+
)
563+
choices = [
564+
get_schema_from_signature(elt.value.func) if callable(elt.value) else elt.value
565+
for elt in myenum
566+
]
567+
schema = {"title": myenum.__name__, "enum": choices}
568+
return schema

outlines/generate/json.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import json as pyjson
2+
from enum import Enum
23
from functools import singledispatch
34
from typing import Callable, Optional, Union
45

56
from pydantic import BaseModel
67

7-
from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
8+
from outlines.fsm.json_schema import (
9+
build_regex_from_schema,
10+
get_schema_from_enum,
11+
get_schema_from_signature,
12+
)
813
from outlines.generate.api import SequenceGeneratorAdapter
914
from outlines.models import OpenAI
1015
from outlines.samplers import Sampler, multinomial
@@ -48,6 +53,11 @@ def json(
4853
regex_str = build_regex_from_schema(schema, whitespace_pattern)
4954
generator = regex(model, regex_str, sampler)
5055
generator.format_sequence = lambda x: schema_object.parse_raw(x)
56+
elif isinstance(schema_object, type(Enum)):
57+
schema = pyjson.dumps(get_schema_from_enum(schema_object))
58+
regex_str = build_regex_from_schema(schema, whitespace_pattern)
59+
generator = regex(model, regex_str, sampler)
60+
generator.format_sequence = lambda x: pyjson.loads(x)
5161
elif callable(schema_object):
5262
schema = pyjson.dumps(get_schema_from_signature(schema_object))
5363
regex_str = build_regex_from_schema(schema, whitespace_pattern)

tests/fsm/test_json_schema.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import json
22
import re
3+
from contextlib import nullcontext
4+
from enum import Enum
5+
from functools import partial
36
from typing import List, Literal, Union
47

58
import interegular
@@ -19,6 +22,7 @@
1922
UUID,
2023
WHITESPACE,
2124
build_regex_from_schema,
25+
get_schema_from_enum,
2226
get_schema_from_signature,
2327
to_regex,
2428
)
@@ -237,8 +241,26 @@ def test_match_number(pattern, does_match):
237241
),
238242
# Enum mix of types
239243
(
240-
{"title": "Foo", "enum": [6, 5.3, "potato", True, None]},
241-
r'(6|5\.3|"potato"|true|null)',
244+
{
245+
"title": "Foo",
246+
"enum": [
247+
6,
248+
5.3,
249+
"potato",
250+
True,
251+
None,
252+
{
253+
"properties": {
254+
"a": {"title": "A", "type": "number"},
255+
"b": {"title": "B", "type": "number"},
256+
},
257+
"required": ["a", "b"],
258+
"title": "add",
259+
"type": "object",
260+
},
261+
],
262+
},
263+
r'(6|5\.3|"potato"|true|null|\{[ ]?"a"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[ ]?,[ ]?"b"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\})',
242264
[
243265
("6", True),
244266
("5.3", True),
@@ -248,6 +270,8 @@ def test_match_number(pattern, does_match):
248270
("523", False),
249271
("True", False),
250272
("None", False),
273+
('{"a": -1.0, "b": 1.1}', True),
274+
('{"a": "a", "b": 1.1}', False),
251275
],
252276
),
253277
# integer
@@ -1039,3 +1063,34 @@ class Model(BaseModel):
10391063

10401064
# check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm()
10411065
interegular.parse_pattern(pattern).to_fsm()
1066+
1067+
1068+
def add(a: float, b: float) -> float:
1069+
return a + b
1070+
1071+
1072+
class MyEnum(Enum):
1073+
add = partial(add)
1074+
a = "a"
1075+
b = 2
1076+
1077+
1078+
# if you don't register your function as callable, you will get an empty enum
1079+
class EmptyEnum(Enum):
1080+
add = add
1081+
1082+
1083+
@pytest.mark.parametrize(
1084+
"enum,expectation",
1085+
[
1086+
(MyEnum, nullcontext()),
1087+
(EmptyEnum, pytest.raises(ValueError)),
1088+
],
1089+
)
1090+
def test_enum_schema(enum, expectation):
1091+
with expectation:
1092+
result = get_schema_from_enum(enum)
1093+
assert result["title"] == enum.__name__
1094+
assert len(result["enum"]) == len(enum)
1095+
for elt in result["enum"]:
1096+
assert type(elt) in [int, float, bool, type(None), str, dict]

tests/generate/test_integration_transformers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import re
33
from enum import Enum
4+
from functools import partial
45
from typing import List, Union
56

67
import pytest
@@ -354,6 +355,29 @@ class User(BaseModel):
354355
assert result.user_id in [1, 2]
355356

356357

358+
def add(a: int, b: int) -> int:
359+
return a + b
360+
361+
362+
def mul(c: float, d: float) -> float:
363+
return c * d
364+
365+
366+
def test_transformers_json_function_enum(model):
367+
prompt = "Output some JSON "
368+
369+
class Operation(Enum):
370+
add = partial(add)
371+
mul = partial(mul)
372+
373+
result = generate.json(model, Operation)(prompt, seed=0)
374+
assert isinstance(result, dict)
375+
assert len(result) == 2
376+
for k, v in result.items():
377+
assert k in ["a", "b", "c", "d"]
378+
assert isinstance(v, (int, float))
379+
380+
357381
def test_transformers_json_array(model):
358382
prompt = "Output some JSON "
359383

0 commit comments

Comments
 (0)