Skip to content

Commit 1bef4d6

Browse files
authored
Support VertexAI (#1482)
* initial * update pydantic ai and pin boto3 * update * update numpy and pandas * pin scipy * update * fix typing validation * bump version
1 parent 126f2c5 commit 1bef4d6

File tree

16 files changed

+752
-301
lines changed

16 files changed

+752
-301
lines changed

patchwork/common/client/llm/aio.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing_extensions import AsyncIterator, Dict, Iterable, List, Optional, Union
1818

1919
from patchwork.common.client.llm.anthropic import AnthropicLlmClient
20-
from patchwork.common.client.llm.google import GoogleLlmClient
20+
from patchwork.common.client.llm.google_ import GoogleLlmClient
2121
from patchwork.common.client.llm.openai_ import OpenAiLlmClient
2222
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
2323
from patchwork.common.constants import DEFAULT_PATCH_URL
@@ -31,10 +31,10 @@ def __init__(self, *clients: LlmClient):
3131
self.__supported_models = set()
3232
for client in clients:
3333
try:
34-
self.__supported_models.update(client.get_models())
34+
client.test()
3535
self.__clients.append(client)
36-
except Exception:
37-
pass
36+
except Exception as e:
37+
logger.error(f"{client.__class__.__name__} Failed with exception: {e}")
3838

3939
def __get_model(self, model_settings: ModelSettings | None) -> Optional[str]:
4040
if model_settings is None:
@@ -45,6 +45,9 @@ def __get_model(self, model_settings: ModelSettings | None) -> Optional[str]:
4545

4646
return model_name
4747

48+
def test(self) -> None:
49+
pass
50+
4851
async def request(
4952
self,
5053
messages: list[ModelMessage],
@@ -94,9 +97,6 @@ def model_name(self) -> str:
9497
def system(self) -> str:
9598
return next(iter(self.__clients)).system
9699

97-
def get_models(self) -> set[str]:
98-
return self.__supported_models
99-
100100
def is_model_supported(self, model: str) -> bool:
101101
return any(client.is_model_supported(model) for client in self.__clients)
102102

@@ -216,8 +216,9 @@ def create_aio_client(inputs) -> "AioLlmClient" | None:
216216
clients.append(client)
217217

218218
google_key = inputs.get("google_api_key")
219-
if google_key is not None:
220-
client = GoogleLlmClient(google_key, **client_args)
219+
is_gcp = bool(client_args.get("is_gcp") or os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") or False)
220+
if google_key is not None or is_gcp:
221+
client = GoogleLlmClient(api_key=google_key, is_gcp=is_gcp)
221222
clients.append(client)
222223

223224
anthropic_key = inputs.get("anthropic_api_key")

patchwork/common/client/llm/anthropic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import time
5-
from functools import cached_property, lru_cache
5+
from functools import cached_property
66
from pathlib import Path
77

88
from anthropic import Anthropic
@@ -245,9 +245,8 @@ def __adapt_chat_completion_request(
245245

246246
return NotGiven.remove_not_given(input_kwargs)
247247

248-
@lru_cache(maxsize=None)
249-
def get_models(self) -> set[str]:
250-
return self.__definitely_allowed_models.union(set(f"{self.__allowed_model_prefix}*"))
248+
def test(self):
249+
return
251250

252251
def is_model_supported(self, model: str) -> bool:
253252
return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix)
Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

3+
import os
34
import time
4-
from functools import lru_cache
5+
from functools import lru_cache, partial
56
from pathlib import Path
67

78
import magic
9+
import vertexai
810
from google import genai
11+
from google.auth.exceptions import GoogleAuthError
912
from google.genai import types
13+
from google.genai.errors import APIError
1014
from google.genai.types import (
1115
CountTokensConfig,
1216
File,
@@ -26,7 +30,8 @@
2630
from openai.types.chat.chat_completion import ChatCompletion, Choice
2731
from pydantic import BaseModel
2832
from pydantic_ai.messages import ModelMessage, ModelResponse
29-
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
33+
from pydantic_ai.models import Model as PydanticAiModel
34+
from pydantic_ai.models import ModelRequestParameters, StreamedResponse
3035
from pydantic_ai.models.gemini import GeminiModel
3136
from pydantic_ai.settings import ModelSettings
3237
from pydantic_ai.usage import Usage
@@ -40,9 +45,11 @@
4045
Type,
4146
Union,
4247
)
48+
from vertexai.generative_models import GenerativeModel, SafetySetting
4349

4450
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
4551
from patchwork.common.client.llm.utils import json_schema_to_model
52+
from patchwork.logger import logger
4653

4754

4855
class GoogleLlmClient(LlmClient):
@@ -51,30 +58,63 @@ class GoogleLlmClient(LlmClient):
5158
dict(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
5259
dict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
5360
dict(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
61+
dict(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
62+
]
63+
__VERTEX_SAFETY_SETTINGS = [
64+
SafetySetting(
65+
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
66+
threshold=SafetySetting.HarmBlockThreshold.OFF,
67+
),
68+
SafetySetting(
69+
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
70+
threshold=SafetySetting.HarmBlockThreshold.OFF,
71+
),
72+
SafetySetting(
73+
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
74+
threshold=SafetySetting.HarmBlockThreshold.OFF,
75+
),
76+
SafetySetting(
77+
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=SafetySetting.HarmBlockThreshold.OFF
78+
),
79+
SafetySetting(
80+
category=SafetySetting.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
81+
threshold=SafetySetting.HarmBlockThreshold.OFF,
82+
),
5483
]
5584
__MODEL_PREFIX = "models/"
5685

57-
def __init__(self, api_key: str, location: Optional[str] = None):
86+
def __init__(self, api_key: Optional[str] = None, is_gcp: bool = False):
5887
self.__api_key = api_key
59-
self.__location = location
60-
self.client = genai.Client(api_key=api_key, location=location)
88+
self.__is_gcp = is_gcp
89+
if not self.__is_gcp:
90+
self.client = genai.Client(api_key=api_key)
91+
else:
92+
self.client = genai.Client(api_key=api_key, vertexai=True)
93+
location = os.environ.get("GOOGLE_CLOUD_LOCATION", "global")
94+
vertexai.init(
95+
project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
96+
location=location,
97+
api_endpoint=f"{location}-aiplatform.googleapis.com",
98+
)
6199

62100
@lru_cache(maxsize=1)
63101
def __get_models_info(self) -> list[Model]:
64-
return list(self.client.models.list())
102+
if not self.__is_gcp:
103+
return list(self.client.models.list())
104+
else:
105+
return list()
65106

66-
def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model:
107+
def __get_pydantic_model(self, model_settings: ModelSettings | None) -> PydanticAiModel:
67108
if model_settings is None:
68109
raise ValueError("Model settings cannot be None")
69110
model_name = model_settings.get("model")
70111
if model_name is None:
71112
raise ValueError("Model must be set cannot be None")
72113

73-
if self.__location is None:
114+
if not self.__is_gcp:
74115
return GeminiModel(model_name, api_key=self.__api_key)
75-
76-
url_template = f"https://{self.__location}-generativelanguage.googleapis.com/v1beta/models/{{model}}:"
77-
return GeminiModel(model_name, api_key=self.__api_key, url_template=url_template)
116+
else:
117+
return GeminiModel(model_name, provider="google-vertex")
78118

79119
async def request(
80120
self,
@@ -108,12 +148,15 @@ def __get_model_limits(self, model: str) -> int:
108148
return model_info.input_token_limit
109149
return 1_000_000
110150

111-
@lru_cache
112-
def get_models(self) -> set[str]:
113-
return {model_info.name.removeprefix(self.__MODEL_PREFIX) for model_info in self.__get_models_info()}
151+
def test(self):
152+
return
114153

115154
def is_model_supported(self, model: str) -> bool:
116-
return model in self.get_models()
155+
if not self.__is_gcp:
156+
model_names = {model_info.name.removeprefix(self.__MODEL_PREFIX) for model_info in self.__get_models_info()}
157+
return model in model_names
158+
else:
159+
return True
117160

118161
def __upload(self, file: Path | NotGiven) -> Part | File | None:
119162
if isinstance(file, NotGiven):
@@ -163,6 +206,8 @@ def is_prompt_supported(
163206
top_p: Optional[float] | NotGiven = NOT_GIVEN,
164207
file: Path | NotGiven = NOT_GIVEN,
165208
) -> int:
209+
if self.__is_gcp:
210+
return 1
166211
system, contents = self.__openai_messages_to_google_messages(messages)
167212

168213
file_ref = self.__upload(file)
@@ -178,7 +223,12 @@ def is_prompt_supported(
178223
),
179224
)
180225
token_count = token_response.total_tokens
226+
except GoogleAuthError:
227+
raise
228+
except APIError:
229+
raise
181230
except Exception as e:
231+
logger.debug(f"Error during token count at GoogleLlmClient: {e}")
182232
return -1
183233
model_limit = self.__get_model_limits(model)
184234
return model_limit - token_count
@@ -245,15 +295,25 @@ def chat_completion(
245295
if file_ref is not None:
246296
contents.append(file_ref)
247297

248-
response = self.client.models.generate_content(
249-
model=model,
250-
contents=contents,
251-
config=GenerateContentConfig(
252-
system_instruction=system_content,
253-
safety_settings=self.__SAFETY_SETTINGS,
254-
**NotGiven.remove_not_given(generation_dict),
255-
),
256-
)
298+
if not self.__is_gcp:
299+
generate_content_func = partial(
300+
self.client.models.generate_content,
301+
model=model,
302+
config=GenerateContentConfig(
303+
system_instruction=system_content,
304+
safety_settings=self.__SAFETY_SETTINGS,
305+
**NotGiven.remove_not_given(generation_dict),
306+
),
307+
)
308+
else:
309+
vertexai_model = GenerativeModel(model, system_instruction=system_content)
310+
generate_content_func = partial(
311+
vertexai_model.generate_content,
312+
safety_settings=self.__VERTEX_SAFETY_SETTINGS,
313+
generation_config=NotGiven.remove_not_given(generation_dict),
314+
)
315+
316+
response = generate_content_func(contents=contents)
257317
return self.__google_response_to_openai_response(response, model)
258318

259319
@staticmethod

patchwork/common/client/llm/openai_.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,18 @@ def __is_not_openai_url(self):
9696
# We mainly use this to skip using the model endpoints.
9797
return self.__base_url is not None and self.__base_url != "https://api.openai.com/v1"
9898

99-
def get_models(self) -> set[str]:
99+
def test(self):
100100
if self.__is_not_openai_url():
101-
return set()
101+
return
102102

103-
return _cached_list_models_from_openai(self.__api_key)
103+
_cached_list_models_from_openai(self.__api_key)
104+
return
104105

105106
def is_model_supported(self, model: str) -> bool:
106107
# might not implement model endpoint
107108
if self.__is_not_openai_url():
108109
return True
109-
return model in self.get_models()
110+
return model in _cached_list_models_from_openai(self.__api_key)
110111

111112
def __get_model_limits(self, model: str) -> int:
112113
return self.__MODEL_LIMITS.get(model, 128_000)

patchwork/common/client/llm/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def remove_not_given(obj: Any) -> Union[None, dict[Any, Any], list[Any], Any]:
3333

3434
class LlmClient(Model):
3535
@abstractmethod
36-
def get_models(self) -> set[str]:
36+
def test(self) -> None:
3737
...
3838

3939
@abstractmethod

patchwork/common/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from patchwork.common.client.llm.aio import AioLlmClient
77
from patchwork.common.client.llm.anthropic import AnthropicLlmClient
8-
from patchwork.common.client.llm.google import GoogleLlmClient
8+
from patchwork.common.client.llm.google_ import GoogleLlmClient
99
from patchwork.common.client.llm.openai_ import OpenAiLlmClient
1010

1111
app = FastAPI()

patchwork/steps/AgenticLLM/typed.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ class AgenticLLMInputs(TypedDict, total=False):
1111
user_prompt: str
1212
max_llm_calls: Annotated[int, StepTypeConfig(is_config=True)]
1313
openai_api_key: Annotated[
14-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "anthropic_api_key"])
14+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
1515
]
1616
anthropic_api_key: Annotated[
17-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "openai_api_key"])
17+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
1818
]
1919
patched_api_key: Annotated[
2020
str,
2121
StepTypeConfig(
2222
is_config=True,
23-
or_op=["openai_api_key", "google_api_key", "anthropic_api_key"],
23+
or_op=["openai_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"],
2424
msg=f"""\
2525
Model API key not found.
2626
Please login at: "{TOKEN_URL}"
@@ -31,7 +31,10 @@ class AgenticLLMInputs(TypedDict, total=False):
3131
),
3232
]
3333
google_api_key: Annotated[
34-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
34+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
35+
]
36+
client_is_gcp: Annotated[
37+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
3538
]
3639

3740

patchwork/steps/CallLLM/typed.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@ class CallLLMInputs(TypedDict, total=False):
1313
model_args: Annotated[str, StepTypeConfig(is_config=True)]
1414
client_args: Annotated[str, StepTypeConfig(is_config=True)]
1515
openai_api_key: Annotated[
16-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "anthropic_api_key"])
16+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
1717
]
1818
anthropic_api_key: Annotated[
19-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "openai_api_key"])
19+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
2020
]
2121
patched_api_key: Annotated[
2222
str,
2323
StepTypeConfig(
2424
is_config=True,
25-
or_op=["openai_api_key", "google_api_key", "anthropic_api_key"],
25+
or_op=["openai_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"],
2626
msg=f"""\
2727
Model API key not found.
2828
Please login at: "{TOKEN_URL}"
@@ -33,7 +33,10 @@ class CallLLMInputs(TypedDict, total=False):
3333
),
3434
]
3535
google_api_key: Annotated[
36-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
36+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
37+
]
38+
client_is_gcp: Annotated[
39+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
3740
]
3841
file: Annotated[str, StepTypeConfig(is_path=True)]
3942

patchwork/steps/FileAgent/typed.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,7 @@ class FileAgentInputs(__ReconcilationAgentRequiredInputs, total=False):
1111
base_path: str
1212
prompt_value: Dict[str, Any]
1313
max_llm_calls: Annotated[int, StepTypeConfig(is_config=True)]
14-
openai_api_key: Annotated[
15-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "anthropic_api_key"])
16-
]
17-
anthropic_api_key: Annotated[
18-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "openai_api_key"])
19-
]
20-
google_api_key: Annotated[
21-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
22-
]
14+
anthropic_api_key: Annotated[str, StepTypeConfig(is_config=True)]
2315

2416

2517
class FileAgentOutputs(TypedDict):

0 commit comments

Comments
 (0)