2525
2626from google import genai
2727from google .genai import errors
28+ from google .genai .types import HttpOptions
2829from google .genai .types import Part
2930from PIL .Image import Image
3031
@@ -108,6 +109,7 @@ def __init__(
108109 api_key : Optional [str ] = None ,
109110 project : Optional [str ] = None ,
110111 location : Optional [str ] = None ,
112+ use_vertex_flex_api : Optional [bool ] = False ,
111113 * ,
112114 min_batch_size : Optional [int ] = None ,
113115 max_batch_size : Optional [int ] = None ,
@@ -137,6 +139,7 @@ def __init__(
137139 location: the GCP project to use for Vertex AI requests. Setting this
138140 parameter routes requests to Vertex AI. If this paramter is provided,
139141 project must also be provided and api_key should not be set.
142+ use_vertex_flex_api: if true, use the Vertex Flex API.
140143 min_batch_size: optional. the minimum batch size to use when batching
141144 inputs.
142145 max_batch_size: optional. the maximum batch size to use when batching
@@ -169,6 +172,8 @@ def __init__(
169172 self .location = location
170173 self .use_vertex = True
171174
175+ self .use_vertex_flex_api = use_vertex_flex_api
176+
172177 super ().__init__ (
173178 namespace = 'GeminiModelHandler' ,
174179 retry_filter = _retry_on_appropriate_service_error ,
@@ -180,8 +185,19 @@ def create_client(self) -> genai.Client:
180185 provided when the GeminiModelHandler class is instantiated.
181186 """
182187 if self .use_vertex :
183- return genai .Client (
184- vertexai = True , project = self .project , location = self .location )
188+ if self .use_vertex_flex_api :
189+ return genai .Client (
190+ vertexai = True ,
191+ project = self .project ,
192+ location = self .location ,
193+ http_options = HttpOptions (
194+ api_version = "v1" ,
195+ headers = {"X-Vertex-AI-LLM-Request-Type" : "flex" },
196+ # Set timeout in the unit of millisecond.
197+ timeout = 600000 ))
198+ else :
199+ return genai .Client (
200+ vertexai = True , project = self .project , location = self .location )
185201 return genai .Client (api_key = self .api_key )
186202
187203 def request (
0 commit comments