|
4 | 4 | from typing import (
|
5 | 5 | TYPE_CHECKING,
|
6 | 6 | Any,
|
| 7 | + AsyncIterator, |
7 | 8 | Iterator,
|
8 | 9 | Optional,
|
9 | 10 | Union,
|
|
13 | 14 | from pydantic import BaseModel, TypeAdapter
|
14 | 15 |
|
15 | 16 | from outlines.inputs import Chat, Image
|
16 |
| -from outlines.models.base import Model, ModelTypeAdapter |
| 17 | +from outlines.models.base import AsyncModel, Model, ModelTypeAdapter |
17 | 18 | from outlines.models.utils import set_additional_properties_false_json_schema
|
18 | 19 | from outlines.types import JsonSchema, Regex, CFG
|
19 | 20 | from outlines.types.utils import (
|
|
25 | 26 | )
|
26 | 27 |
|
27 | 28 | if TYPE_CHECKING:
|
28 |
| - from openai import OpenAI as OpenAIClient, AzureOpenAI as AzureOpenAIClient |
| 29 | + from openai import ( |
| 30 | + OpenAI as OpenAIClient, |
| 31 | + AsyncOpenAI as AsyncOpenAIClient, |
| 32 | + AzureOpenAI as AzureOpenAIClient, |
| 33 | + AsyncAzureOpenAI as AsyncAzureOpenAIClient, |
| 34 | + ) |
29 | 35 |
|
30 |
| -__all__ = ["OpenAI", "from_openai"] |
| 36 | +__all__ = ["AsyncOpenAI", "OpenAI", "from_openai"] |
31 | 37 |
|
32 | 38 |
|
33 | 39 | class OpenAITypeAdapter(ModelTypeAdapter):
|
@@ -348,36 +354,211 @@ def generate_stream(
|
348 | 354 | if "model" not in inference_kwargs and self.model_name is not None:
|
349 | 355 | inference_kwargs["model"] = self.model_name
|
350 | 356 |
|
351 |
| - stream = self.client.chat.completions.create( |
352 |
| - stream=True, |
353 |
| - messages=messages, |
354 |
| - **response_format, |
355 |
| - **inference_kwargs |
356 |
| - ) |
| 357 | + try: |
| 358 | + stream = self.client.chat.completions.create( |
| 359 | + stream=True, |
| 360 | + messages=messages, |
| 361 | + **response_format, |
| 362 | + **inference_kwargs |
| 363 | + ) |
| 364 | + except openai.BadRequestError as e: |
| 365 | + if e.body["message"].startswith("Invalid schema"): |
| 366 | + raise TypeError( |
| 367 | + f"OpenAI does not support your schema: {e.body['message']}. " |
| 368 | + "Try a local model or dottxt instead." |
| 369 | + ) |
| 370 | + else: |
| 371 | + raise e |
357 | 372 |
|
358 | 373 | for chunk in stream:
|
359 | 374 | if chunk.choices and chunk.choices[0].delta.content is not None:
|
360 | 375 | yield chunk.choices[0].delta.content
|
361 | 376 |
|
362 | 377 |
|
| 378 | +class AsyncOpenAI(AsyncModel): |
| 379 | + """Thin wrapper around the `openai.AsyncOpenAI` client. |
| 380 | +
|
| 381 | + This wrapper is used to convert the input and output types specified by the |
| 382 | + users at a higher level to arguments to the `openai.AsyncOpenAI` client. |
| 383 | +
|
| 384 | + """ |
| 385 | + |
| 386 | + def __init__( |
| 387 | + self, |
| 388 | + client: Union["AsyncOpenAIClient", "AsyncAzureOpenAIClient"], |
| 389 | + model_name: Optional[str] = None, |
| 390 | + ): |
| 391 | + """ |
| 392 | + Parameters |
| 393 | + ---------- |
| 394 | + client |
| 395 | + The `openai.AsyncOpenAI` or `openai.AsyncAzureOpenAI` client. |
| 396 | + model_name |
| 397 | + The name of the model to use. |
| 398 | +
|
| 399 | + """ |
| 400 | + self.client = client |
| 401 | + self.model_name = model_name |
| 402 | + self.type_adapter = OpenAITypeAdapter() |
| 403 | + |
| 404 | + async def generate( |
| 405 | + self, |
| 406 | + model_input: Union[Chat, list, str], |
| 407 | + output_type: Optional[Union[type[BaseModel], str]] = None, |
| 408 | + **inference_kwargs: Any, |
| 409 | + ) -> Union[str, list[str]]: |
| 410 | + """Generate text using OpenAI. |
| 411 | +
|
| 412 | + Parameters |
| 413 | + ---------- |
| 414 | + model_input |
| 415 | + The prompt based on which the model will generate a response. |
| 416 | + output_type |
| 417 | + The desired format of the response generated by the model. The |
| 418 | + output type must be of a type that can be converted to a JSON |
| 419 | + schema or an empty dictionary. |
| 420 | + **inference_kwargs |
| 421 | + Additional keyword arguments to pass to the client. |
| 422 | +
|
| 423 | + Returns |
| 424 | + ------- |
| 425 | + Union[str, list[str]] |
| 426 | + The text generated by the model. |
| 427 | +
|
| 428 | + """ |
| 429 | + import openai |
| 430 | + |
| 431 | + messages = self.type_adapter.format_input(model_input) |
| 432 | + response_format = self.type_adapter.format_output_type(output_type) |
| 433 | + |
| 434 | + if "model" not in inference_kwargs and self.model_name is not None: |
| 435 | + inference_kwargs["model"] = self.model_name |
| 436 | + |
| 437 | + try: |
| 438 | + result = await self.client.chat.completions.create( |
| 439 | + messages=messages, |
| 440 | + **response_format, |
| 441 | + **inference_kwargs, |
| 442 | + ) |
| 443 | + except openai.BadRequestError as e: |
| 444 | + if e.body["message"].startswith("Invalid schema"): |
| 445 | + raise TypeError( |
| 446 | + f"OpenAI does not support your schema: {e.body['message']}. " |
| 447 | + "Try a local model or dottxt instead." |
| 448 | + ) |
| 449 | + else: |
| 450 | + raise e |
| 451 | + |
| 452 | + messages = [choice.message for choice in result.choices] |
| 453 | + for message in messages: |
| 454 | + if message.refusal is not None: |
| 455 | + raise ValueError( |
| 456 | + f"OpenAI refused to answer the request: {message.refusal}" |
| 457 | + ) |
| 458 | + |
| 459 | + if len(messages) == 1: |
| 460 | + return messages[0].content |
| 461 | + else: |
| 462 | + return [message.content for message in messages] |
| 463 | + |
| 464 | + async def generate_batch( |
| 465 | + self, |
| 466 | + model_input, |
| 467 | + output_type = None, |
| 468 | + **inference_kwargs, |
| 469 | + ): |
| 470 | + raise NotImplementedError( |
| 471 | + "The `openai` library does not support batch inference." |
| 472 | + ) |
| 473 | + |
| 474 | + async def generate_stream( # type: ignore |
| 475 | + self, |
| 476 | + model_input: Union[Chat, list, str], |
| 477 | + output_type: Optional[Union[type[BaseModel], str]] = None, |
| 478 | + **inference_kwargs, |
| 479 | + ) -> AsyncIterator[str]: |
| 480 | + """Stream text using OpenAI. |
| 481 | +
|
| 482 | + Parameters |
| 483 | + ---------- |
| 484 | + model_input |
| 485 | + The prompt based on which the model will generate a response. |
| 486 | + output_type |
| 487 | + The desired format of the response generated by the model. The |
| 488 | + output type must be of a type that can be converted to a JSON |
| 489 | + schema or an empty dictionary. |
| 490 | + **inference_kwargs |
| 491 | + Additional keyword arguments to pass to the client. |
| 492 | +
|
| 493 | + Returns |
| 494 | + ------- |
| 495 | + Iterator[str] |
| 496 | + An iterator that yields the text generated by the model. |
| 497 | +
|
| 498 | + """ |
| 499 | + import openai |
| 500 | + |
| 501 | + messages = self.type_adapter.format_input(model_input) |
| 502 | + response_format = self.type_adapter.format_output_type(output_type) |
| 503 | + |
| 504 | + if "model" not in inference_kwargs and self.model_name is not None: |
| 505 | + inference_kwargs["model"] = self.model_name |
| 506 | + |
| 507 | + try: |
| 508 | + stream = await self.client.chat.completions.create( |
| 509 | + stream=True, |
| 510 | + messages=messages, |
| 511 | + **response_format, |
| 512 | + **inference_kwargs |
| 513 | + ) |
| 514 | + except openai.BadRequestError as e: |
| 515 | + if e.body["message"].startswith("Invalid schema"): |
| 516 | + raise TypeError( |
| 517 | + f"OpenAI does not support your schema: {e.body['message']}. " |
| 518 | + "Try a local model or dottxt instead." |
| 519 | + ) |
| 520 | + else: |
| 521 | + raise e |
| 522 | + |
| 523 | + async for chunk in stream: |
| 524 | + if chunk.choices and chunk.choices[0].delta.content is not None: |
| 525 | + yield chunk.choices[0].delta.content |
| 526 | + |
| 527 | + |
363 | 528 | def from_openai(
|
364 |
| - client: Union["OpenAIClient", "AzureOpenAIClient"], |
| 529 | + client: Union[ |
| 530 | + "OpenAIClient", |
| 531 | + "AsyncOpenAIClient", |
| 532 | + "AzureOpenAIClient", |
| 533 | + "AsyncAzureOpenAIClient", |
| 534 | + ], |
365 | 535 | model_name: Optional[str] = None,
|
366 |
| -) -> OpenAI: |
367 |
| - """Create an Outlines `OpenAI` model instance from an `openai.OpenAI` |
368 |
| - client. |
| 536 | +) -> Union[OpenAI, AsyncOpenAI]: |
| 537 | + """Create an Outlines `OpenAI` or `AsyncOpenAI` model instance from an |
| 538 | + `openai.OpenAI` or `openai.AsyncOpenAI` client. |
369 | 539 |
|
370 | 540 | Parameters
|
371 | 541 | ----------
|
372 | 542 | client
|
373 |
| - An `openai.OpenAI` client instance. |
| 543 | + An `openai.OpenAI`, `openai.AsyncOpenAI`, `openai.AzureOpenAI` or |
| 544 | + `openai.AsyncAzureOpenAI` client instance. |
374 | 545 | model_name
|
375 | 546 | The name of the model to use.
|
376 | 547 |
|
377 | 548 | Returns
|
378 | 549 | -------
|
379 | 550 | OpenAI
|
380 |
| - An Outlines `OpenAI` model instance. |
| 551 | + An Outlines `OpenAI` or `AsyncOpenAI` model instance. |
381 | 552 |
|
382 | 553 | """
|
383 |
| - return OpenAI(client, model_name) |
| 554 | + import openai |
| 555 | + |
| 556 | + if isinstance(client, openai.OpenAI): |
| 557 | + return OpenAI(client, model_name) |
| 558 | + elif isinstance(client, openai.AsyncOpenAI): |
| 559 | + return AsyncOpenAI(client, model_name) |
| 560 | + else: |
| 561 | + raise ValueError( |
| 562 | + "Invalid client type. The client must be an instance of " |
| 563 | + "+ `openai.OpenAI` or `openai.AsyncOpenAI`." |
| 564 | + ) |
0 commit comments