Skip to content

Commit b3fba13

Browse files
RobinPicardrlouf
authored andcommitted
Create AsyncOpenAI model
1 parent 8cecc55 commit b3fba13

File tree

5 files changed

+423
-25
lines changed

5 files changed

+423
-25
lines changed

docs/features/models/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ In alphabetical order:
9393
| Regex ||||||||||||||
9494
| Grammar |||||||| 🟠 ||||||
9595
| **Generation Features** | | | | | | | | | | | | | |
96-
| Async ||||||| |||||||
96+
| Async ||||||| |||||||
9797
| Streaming ||||||||||||||
9898
| Vision ||||||||||||||
9999
| Batching ||||||||||||||
@@ -142,6 +142,7 @@ print(type(model)) # outlines.models.tgi.AsyncTGI
142142
The models that have an async version are the following:
143143

144144
- Ollama
145+
- OpenAI
145146
- SgLang
146147
- TGI
147148
- VLLM

docs/features/models/openai.md

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,32 @@
1010

1111
To create an OpenAI model instance, you can use the `from_openai` function. It takes 2 arguments:
1212

13-
- `client`: an `openai.OpenAI` or `openai.AzureOpenAI` instance
13+
- `client`: an `openai.OpenAI`, `openai.AzureOpenAI`, `openai.AsyncOpenAI` or `openai.AsyncAzureOpenAI` instance
1414
- `model_name`: the name of the model you want to use
1515

16+
Based on whether the inference client instance is synchronous or asynchronous, you will receive an `OpenAI` or an `AsyncOpenAI` model instance.
17+
1618
For instance:
1719

1820
```python
1921
import outlines
2022
import openai
2123

22-
# Create the client
24+
# Create the client or async client
2325
client = openai.OpenAI()
26+
async_client = openai.AsyncOpenAI()
2427

25-
# Create the model
28+
# Create a sync model
2629
model = outlines.from_openai(
2730
client,
2831
"gpt-4o"
2932
)
33+
34+
# Create aa async model
35+
model = outlines.from_openai(
36+
async_client,
37+
"gpt-4o"
38+
)
3039
```
3140

3241
Check the [OpenAI documentation](https://platform.openai.com/docs/models) for an up-to-date list of available models. As shown above, you can use Azure OpenAI in Outlines the same way you would use OpenAI, just provide an `openai.AzureOpenAI` instance to the Outlines model class.
@@ -190,6 +199,47 @@ result = model("Create a character, use the json format.", dict, temperature=0.5
190199
print(result) # '{"first_name": "Henri", "last_name": "Smith", "height": "170"}'
191200
```
192201

202+
## Asynchronous Calls
203+
204+
All features presented above for the sync model are also available for the async model.
205+
206+
For instance:
207+
208+
```python
209+
import asyncio
210+
import openai
211+
import outlines
212+
from pydantic import BaseModel
213+
from typing import List
214+
215+
class Character(BaseModel):
216+
name: str
217+
age: int
218+
skills: List[str]
219+
220+
# Create the model
221+
model = outlines.from_openai(
222+
openai.AsyncOpenAI(),
223+
"gpt-4o"
224+
)
225+
226+
async def text_generation():
227+
# Regular generation
228+
response = await model("What's the capital of Latvia?", max_tokens=20)
229+
print(response) # 'Riga'
230+
231+
# Streaming
232+
async for chunk in model.stream("Tell me a short story about a cat.", max_tokens=50):
233+
print(chunk, end="") # 'Once...'
234+
235+
# Structured generation
236+
result = await model("Create a character, use the json format.", Character, top_p=0.1)
237+
print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}'
238+
print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy']
239+
240+
asyncio.run(text_generation())
241+
```
242+
193243
## Inference arguments
194244

195245
When calling the model, you can provide keyword arguments that will be passed down to the `chat.completions.create` method of the OpenAI client. Some of the most common arguments include `max_tokens`, `temperature`, `stop` and `top_p`.

outlines/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .llamacpp import LlamaCpp, from_llamacpp
1616
from .mlxlm import MLXLM, from_mlxlm
1717
from .ollama import AsyncOllama, Ollama, from_ollama
18-
from .openai import OpenAI, from_openai
18+
from .openai import AsyncOpenAI, OpenAI, from_openai
1919
from .sglang import AsyncSGLang, SGLang, from_sglang
2020
from .tgi import AsyncTGI, TGI, from_tgi
2121
from .transformers import (
@@ -41,6 +41,7 @@
4141
]
4242
AsyncBlackBoxModel = Union[
4343
AsyncOllama,
44+
AsyncOpenAI,
4445
AsyncTGI,
4546
AsyncSGLang,
4647
AsyncVLLM,

outlines/models/openai.py

Lines changed: 197 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import (
55
TYPE_CHECKING,
66
Any,
7+
AsyncIterator,
78
Iterator,
89
Optional,
910
Union,
@@ -13,7 +14,7 @@
1314
from pydantic import BaseModel, TypeAdapter
1415

1516
from outlines.inputs import Chat, Image
16-
from outlines.models.base import Model, ModelTypeAdapter
17+
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
1718
from outlines.models.utils import set_additional_properties_false_json_schema
1819
from outlines.types import JsonSchema, Regex, CFG
1920
from outlines.types.utils import (
@@ -25,9 +26,14 @@
2526
)
2627

2728
if TYPE_CHECKING:
28-
from openai import OpenAI as OpenAIClient, AzureOpenAI as AzureOpenAIClient
29+
from openai import (
30+
OpenAI as OpenAIClient,
31+
AsyncOpenAI as AsyncOpenAIClient,
32+
AzureOpenAI as AzureOpenAIClient,
33+
AsyncAzureOpenAI as AsyncAzureOpenAIClient,
34+
)
2935

30-
__all__ = ["OpenAI", "from_openai"]
36+
__all__ = ["AsyncOpenAI", "OpenAI", "from_openai"]
3137

3238

3339
class OpenAITypeAdapter(ModelTypeAdapter):
@@ -348,36 +354,211 @@ def generate_stream(
348354
if "model" not in inference_kwargs and self.model_name is not None:
349355
inference_kwargs["model"] = self.model_name
350356

351-
stream = self.client.chat.completions.create(
352-
stream=True,
353-
messages=messages,
354-
**response_format,
355-
**inference_kwargs
356-
)
357+
try:
358+
stream = self.client.chat.completions.create(
359+
stream=True,
360+
messages=messages,
361+
**response_format,
362+
**inference_kwargs
363+
)
364+
except openai.BadRequestError as e:
365+
if e.body["message"].startswith("Invalid schema"):
366+
raise TypeError(
367+
f"OpenAI does not support your schema: {e.body['message']}. "
368+
"Try a local model or dottxt instead."
369+
)
370+
else:
371+
raise e
357372

358373
for chunk in stream:
359374
if chunk.choices and chunk.choices[0].delta.content is not None:
360375
yield chunk.choices[0].delta.content
361376

362377

378+
class AsyncOpenAI(AsyncModel):
379+
"""Thin wrapper around the `openai.AsyncOpenAI` client.
380+
381+
This wrapper is used to convert the input and output types specified by the
382+
users at a higher level to arguments to the `openai.AsyncOpenAI` client.
383+
384+
"""
385+
386+
def __init__(
387+
self,
388+
client: Union["AsyncOpenAIClient", "AsyncAzureOpenAIClient"],
389+
model_name: Optional[str] = None,
390+
):
391+
"""
392+
Parameters
393+
----------
394+
client
395+
The `openai.AsyncOpenAI` or `openai.AsyncAzureOpenAI` client.
396+
model_name
397+
The name of the model to use.
398+
399+
"""
400+
self.client = client
401+
self.model_name = model_name
402+
self.type_adapter = OpenAITypeAdapter()
403+
404+
async def generate(
405+
self,
406+
model_input: Union[Chat, list, str],
407+
output_type: Optional[Union[type[BaseModel], str]] = None,
408+
**inference_kwargs: Any,
409+
) -> Union[str, list[str]]:
410+
"""Generate text using OpenAI.
411+
412+
Parameters
413+
----------
414+
model_input
415+
The prompt based on which the model will generate a response.
416+
output_type
417+
The desired format of the response generated by the model. The
418+
output type must be of a type that can be converted to a JSON
419+
schema or an empty dictionary.
420+
**inference_kwargs
421+
Additional keyword arguments to pass to the client.
422+
423+
Returns
424+
-------
425+
Union[str, list[str]]
426+
The text generated by the model.
427+
428+
"""
429+
import openai
430+
431+
messages = self.type_adapter.format_input(model_input)
432+
response_format = self.type_adapter.format_output_type(output_type)
433+
434+
if "model" not in inference_kwargs and self.model_name is not None:
435+
inference_kwargs["model"] = self.model_name
436+
437+
try:
438+
result = await self.client.chat.completions.create(
439+
messages=messages,
440+
**response_format,
441+
**inference_kwargs,
442+
)
443+
except openai.BadRequestError as e:
444+
if e.body["message"].startswith("Invalid schema"):
445+
raise TypeError(
446+
f"OpenAI does not support your schema: {e.body['message']}. "
447+
"Try a local model or dottxt instead."
448+
)
449+
else:
450+
raise e
451+
452+
messages = [choice.message for choice in result.choices]
453+
for message in messages:
454+
if message.refusal is not None:
455+
raise ValueError(
456+
f"OpenAI refused to answer the request: {message.refusal}"
457+
)
458+
459+
if len(messages) == 1:
460+
return messages[0].content
461+
else:
462+
return [message.content for message in messages]
463+
464+
async def generate_batch(
465+
self,
466+
model_input,
467+
output_type = None,
468+
**inference_kwargs,
469+
):
470+
raise NotImplementedError(
471+
"The `openai` library does not support batch inference."
472+
)
473+
474+
async def generate_stream( # type: ignore
475+
self,
476+
model_input: Union[Chat, list, str],
477+
output_type: Optional[Union[type[BaseModel], str]] = None,
478+
**inference_kwargs,
479+
) -> AsyncIterator[str]:
480+
"""Stream text using OpenAI.
481+
482+
Parameters
483+
----------
484+
model_input
485+
The prompt based on which the model will generate a response.
486+
output_type
487+
The desired format of the response generated by the model. The
488+
output type must be of a type that can be converted to a JSON
489+
schema or an empty dictionary.
490+
**inference_kwargs
491+
Additional keyword arguments to pass to the client.
492+
493+
Returns
494+
-------
495+
Iterator[str]
496+
An iterator that yields the text generated by the model.
497+
498+
"""
499+
import openai
500+
501+
messages = self.type_adapter.format_input(model_input)
502+
response_format = self.type_adapter.format_output_type(output_type)
503+
504+
if "model" not in inference_kwargs and self.model_name is not None:
505+
inference_kwargs["model"] = self.model_name
506+
507+
try:
508+
stream = await self.client.chat.completions.create(
509+
stream=True,
510+
messages=messages,
511+
**response_format,
512+
**inference_kwargs
513+
)
514+
except openai.BadRequestError as e:
515+
if e.body["message"].startswith("Invalid schema"):
516+
raise TypeError(
517+
f"OpenAI does not support your schema: {e.body['message']}. "
518+
"Try a local model or dottxt instead."
519+
)
520+
else:
521+
raise e
522+
523+
async for chunk in stream:
524+
if chunk.choices and chunk.choices[0].delta.content is not None:
525+
yield chunk.choices[0].delta.content
526+
527+
363528
def from_openai(
364-
client: Union["OpenAIClient", "AzureOpenAIClient"],
529+
client: Union[
530+
"OpenAIClient",
531+
"AsyncOpenAIClient",
532+
"AzureOpenAIClient",
533+
"AsyncAzureOpenAIClient",
534+
],
365535
model_name: Optional[str] = None,
366-
) -> OpenAI:
367-
"""Create an Outlines `OpenAI` model instance from an `openai.OpenAI`
368-
client.
536+
) -> Union[OpenAI, AsyncOpenAI]:
537+
"""Create an Outlines `OpenAI` or `AsyncOpenAI` model instance from an
538+
`openai.OpenAI` or `openai.AsyncOpenAI` client.
369539
370540
Parameters
371541
----------
372542
client
373-
An `openai.OpenAI` client instance.
543+
An `openai.OpenAI`, `openai.AsyncOpenAI`, `openai.AzureOpenAI` or
544+
`openai.AsyncAzureOpenAI` client instance.
374545
model_name
375546
The name of the model to use.
376547
377548
Returns
378549
-------
379550
OpenAI
380-
An Outlines `OpenAI` model instance.
551+
An Outlines `OpenAI` or `AsyncOpenAI` model instance.
381552
382553
"""
383-
return OpenAI(client, model_name)
554+
import openai
555+
556+
if isinstance(client, openai.OpenAI):
557+
return OpenAI(client, model_name)
558+
elif isinstance(client, openai.AsyncOpenAI):
559+
return AsyncOpenAI(client, model_name)
560+
else:
561+
raise ValueError(
562+
"Invalid client type. The client must be an instance of "
563+
"+ `openai.OpenAI` or `openai.AsyncOpenAI`."
564+
)

0 commit comments

Comments
 (0)