|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | import time |
5 | | -from functools import lru_cache |
| 5 | +from functools import cached_property, lru_cache |
6 | 6 | from pathlib import Path |
7 | 7 |
|
8 | 8 | from anthropic import Anthropic |
|
15 | 15 | ChatCompletionToolParam, |
16 | 16 | completion_create_params, |
17 | 17 | ) |
18 | | -from openai.types.chat.chat_completion import Choice, CompletionUsage |
| 18 | +from openai.types.chat.chat_completion import Choice |
19 | 19 | from openai.types.chat.chat_completion_message_tool_call import ( |
20 | 20 | ChatCompletionMessageToolCall, |
21 | 21 | Function, |
22 | 22 | ) |
23 | 23 | from openai.types.completion_usage import CompletionUsage |
24 | | -from typing_extensions import Dict, Iterable, List, Optional, Union |
| 24 | +from pydantic_ai.messages import ModelMessage, ModelResponse |
| 25 | +from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse |
| 26 | +from pydantic_ai.models.anthropic import AnthropicModel |
| 27 | +from pydantic_ai.settings import ModelSettings |
| 28 | +from pydantic_ai.usage import Usage |
| 29 | +from typing_extensions import AsyncIterator, Dict, Iterable, List, Optional, Union |
25 | 30 |
|
26 | 31 | from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven |
27 | 32 |
|
@@ -74,7 +79,46 @@ class AnthropicLlmClient(LlmClient): |
74 | 79 | __100k_models = {"claude-2.0", "claude-instant-1.2"} |
75 | 80 |
|
76 | 81 | def __init__(self, api_key: str): |
77 | | - self.client = Anthropic(api_key=api_key) |
| 82 | + self.__api_key = api_key |
| 83 | + |
| 84 | + @cached_property |
| 85 | + def __client(self): |
| 86 | + return Anthropic(api_key=self.__api_key) |
| 87 | + |
| 88 | + def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model: |
| 89 | + if model_settings is None: |
| 90 | + raise ValueError("Model settings cannot be None") |
| 91 | + model_name = model_settings.get("model") |
| 92 | + if model_name is None: |
| 93 | + raise ValueError("Model must be set cannot be None") |
| 94 | + |
| 95 | + return AnthropicModel(model_name, api_key=self.__api_key) |
| 96 | + |
| 97 | + async def request( |
| 98 | + self, |
| 99 | + messages: list[ModelMessage], |
| 100 | + model_settings: ModelSettings | None, |
| 101 | + model_request_parameters: ModelRequestParameters, |
| 102 | + ) -> tuple[ModelResponse, Usage]: |
| 103 | + model = self.__get_pydantic_model(model_settings) |
| 104 | + return await model.request(messages, model_settings, model_request_parameters) |
| 105 | + |
| 106 | + async def request_stream( |
| 107 | + self, |
| 108 | + messages: list[ModelMessage], |
| 109 | + model_settings: ModelSettings | None, |
| 110 | + model_request_parameters: ModelRequestParameters, |
| 111 | + ) -> AsyncIterator[StreamedResponse]: |
| 112 | + model = self.__get_pydantic_model(model_settings) |
| 113 | + yield model.request_stream(messages, model_settings, model_request_parameters) |
| 114 | + |
| 115 | + @property |
| 116 | + def model_name(self) -> str: |
| 117 | + return "Undetermined" |
| 118 | + |
| 119 | + @property |
| 120 | + def system(self) -> str: |
| 121 | + return "anthropic" |
78 | 122 |
|
79 | 123 | def __get_model_limit(self, model: str) -> int: |
80 | 124 | # it is observed that the count tokens is not accurate, so we are using a safety margin |
@@ -250,7 +294,7 @@ def is_prompt_supported( |
250 | 294 | for k, v in input_kwargs.items() |
251 | 295 | if k in {"messages", "model", "system", "tool_choice", "tools", "beta"} |
252 | 296 | } |
253 | | - message_token_count = self.client.beta.messages.count_tokens(**count_token_input_kwargs) |
| 297 | + message_token_count = self.__client.beta.messages.count_tokens(**count_token_input_kwargs) |
254 | 298 | return model_limit - message_token_count.input_tokens |
255 | 299 |
|
256 | 300 | def truncate_messages( |
@@ -295,5 +339,5 @@ def chat_completion( |
295 | 339 | top_p=top_p, |
296 | 340 | ) |
297 | 341 |
|
298 | | - response = self.client.messages.create(**input_kwargs) |
| 342 | + response = self.__client.messages.create(**input_kwargs) |
299 | 343 | return _anthropic_to_openai_response(model, response) |
0 commit comments