Skip to content

Commit 72e84ef

Browse files
authored
Support Vertex Flex API in GeminiModelHandler (#36982)
* Support Vertex Flex API in GeminiModelHandler * Fix lint errors. * Fix lint. * Import the correct HttpOptions * Fix the unit test so it runs successfully. * Fix lint again.
1 parent b65ec28 commit 72e84ef

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

sdks/python/apache_beam/ml/inference/gemini_inference.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from google import genai
2727
from google.genai import errors
28+
from google.genai.types import HttpOptions
2829
from google.genai.types import Part
2930
from PIL.Image import Image
3031

@@ -108,6 +109,7 @@ def __init__(
108109
api_key: Optional[str] = None,
109110
project: Optional[str] = None,
110111
location: Optional[str] = None,
112+
use_vertex_flex_api: Optional[bool] = False,
111113
*,
112114
min_batch_size: Optional[int] = None,
113115
max_batch_size: Optional[int] = None,
@@ -137,6 +139,7 @@ def __init__(
137139
location: the GCP project to use for Vertex AI requests. Setting this
138140
parameter routes requests to Vertex AI. If this paramter is provided,
139141
project must also be provided and api_key should not be set.
142+
use_vertex_flex_api: if true, use the Vertex Flex API.
140143
min_batch_size: optional. the minimum batch size to use when batching
141144
inputs.
142145
max_batch_size: optional. the maximum batch size to use when batching
@@ -169,6 +172,8 @@ def __init__(
169172
self.location = location
170173
self.use_vertex = True
171174

175+
self.use_vertex_flex_api = use_vertex_flex_api
176+
172177
super().__init__(
173178
namespace='GeminiModelHandler',
174179
retry_filter=_retry_on_appropriate_service_error,
@@ -180,8 +185,19 @@ def create_client(self) -> genai.Client:
180185
provided when the GeminiModelHandler class is instantiated.
181186
"""
182187
if self.use_vertex:
183-
return genai.Client(
184-
vertexai=True, project=self.project, location=self.location)
188+
if self.use_vertex_flex_api:
189+
return genai.Client(
190+
vertexai=True,
191+
project=self.project,
192+
location=self.location,
193+
http_options=HttpOptions(
194+
api_version="v1",
195+
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
196+
# Set timeout in the unit of millisecond.
197+
timeout=600000))
198+
else:
199+
return genai.Client(
200+
vertexai=True, project=self.project, location=self.location)
185201
return genai.Client(api_key=self.api_key)
186202

187203
def request(

sdks/python/apache_beam/ml/inference/gemini_inference_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pytype: skip-file
1818

1919
import unittest
20+
from unittest import mock
2021

2122
try:
2223
from google.genai import errors
@@ -81,5 +82,29 @@ def test_missing_all_params(self):
8182
)
8283

8384

85+
@mock.patch('apache_beam.ml.inference.gemini_inference.genai.Client')
86+
@mock.patch('apache_beam.ml.inference.gemini_inference.HttpOptions')
87+
class TestGeminiModelHandler(unittest.TestCase):
88+
def test_create_client_with_flex_api(
89+
self, mock_http_options, mock_genai_client):
90+
handler = GeminiModelHandler(
91+
model_name="gemini-pro",
92+
request_fn=generate_from_string,
93+
project="test-project",
94+
location="us-central1",
95+
use_vertex_flex_api=True)
96+
handler.create_client()
97+
mock_http_options.assert_called_with(
98+
api_version="v1",
99+
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
100+
timeout=600000,
101+
)
102+
mock_genai_client.assert_called_with(
103+
vertexai=True,
104+
project="test-project",
105+
location="us-central1",
106+
http_options=mock_http_options.return_value)
107+
108+
84109
if __name__ == '__main__':
85110
unittest.main()

0 commit comments

Comments
 (0)