1
1
import os
2
2
from asyncio import sleep as asleep
3
3
from time import sleep
4
- from typing import List , Optional , Union
4
+ from typing import Any , Callable , Dict , List , Optional , Union
5
5
6
+ import httpx
6
7
import openai
7
8
from openai import OpenAIError
8
9
from openai ._types import NotGiven
@@ -24,100 +25,135 @@ class AzureOpenAIEncoder(DenseEncoder):
24
25
async_client : Optional [openai .AsyncAzureOpenAI ] = None
25
26
dimensions : Union [int , NotGiven ] = NotGiven ()
26
27
type : str = "azure"
27
- api_key : Optional [str ] = None
28
- deployment_name : Optional [str ] = None
29
- azure_endpoint : Optional [str ] = None
30
- api_version : Optional [str ] = None
31
- model : Optional [str ] = None
28
+ deployment_name : str | None = None
32
29
max_retries : int = 3
33
30
34
31
def __init__ (
35
32
self ,
36
- api_key : Optional [str ] = None ,
37
- deployment_name : Optional [str ] = None ,
38
- azure_endpoint : Optional [str ] = None ,
39
- api_version : Optional [str ] = None ,
40
- model : Optional [str ] = None , # TODO we should change to `name` JB
33
+ name : Optional [str ] = None ,
34
+ azure_endpoint : str | None = None ,
35
+ api_version : str | None = None ,
36
+ api_key : str | None = None ,
37
+ azure_ad_token : str | None = None ,
38
+ azure_ad_token_provider : Callable [[], str ] | None = None ,
39
+ http_client_options : Optional [Dict [str , Any ]] = None ,
40
+ deployment_name : str = EncoderDefault .AZURE .value ["deployment_name" ],
41
41
score_threshold : float = 0.82 ,
42
42
dimensions : Union [int , NotGiven ] = NotGiven (),
43
43
max_retries : int = 3 ,
44
44
):
45
45
"""Initialize the AzureOpenAIEncoder.
46
46
47
- :param api_key: The API key for the Azure OpenAI API.
48
- :type api_key: str
49
- :param deployment_name: The name of the deployment to use.
50
- :type deployment_name: str
51
47
:param azure_endpoint: The endpoint for the Azure OpenAI API.
52
- :type azure_endpoint: str
48
+ Example: ``https://accountname.openai.azure.com``
49
+ :type azure_endpoint: str, optional
50
+
53
51
:param api_version: The version of the API to use.
54
- :type api_version: str
55
- :param model: The model to use.
56
- :type model: str
57
- :param score_threshold: The score threshold for the embeddings.
58
- :type score_threshold: float
59
- :param dimensions: The dimensions of the embeddings.
60
- :type dimensions: int
61
- :param max_retries: The maximum number of retries for the API call.
62
- :type max_retries: int
52
+ Example: ``"2025-02-01-preview"``
53
+ :type api_version: str, optional
54
+
55
+ :param api_key: The API key for the Azure OpenAI API.
56
+ :type api_key: str, optional
57
+
58
+ :param azure_ad_token: The Azure AD/Entra ID token for authentication.
59
+ https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
60
+ :type azure_ad_token: str, optional
61
+
62
+ :param azure_ad_token_provider: A callable function that returns an Azure AD/Entra ID token.
63
+ :type azure_ad_token_provider: Callable[[], str], optional
64
+
65
+ :param http_client_options: Dictionary of options to configure httpx client
66
+ Example:
67
+ {
68
+ "proxies": "http://proxy.server:8080",
69
+ "timeout": 20.0,
70
+ "headers": {"Authorization": "Bearer xyz"}
71
+ }
72
+ :type http_client_options: Dict[str, Any], optional
73
+
74
+ :param deployment_name: The name of the model deployment to use.
75
+ :type deployment_name: str, optional
76
+
77
+ :param score_threshold: The score threshold for filtering embeddings.
78
+ Default is ``0.82``.
79
+ :type score_threshold: float, optional
80
+
81
+ :param dimensions: The number of dimensions for the embeddings. If not given, it defaults to the model's default setting.
82
+ :type dimensions: int, optional
83
+
84
+ :param max_retries: The maximum number of retries for API calls in case of failures.
85
+ Default is ``3``.
86
+ :type max_retries: int, optional
63
87
"""
64
- name = deployment_name
65
88
if name is None :
66
- name = EncoderDefault .AZURE .value ["embedding_model" ]
89
+ name = deployment_name
90
+ if name is None :
91
+ name = EncoderDefault .AZURE .value ["embedding_model" ]
67
92
super ().__init__ (name = name , score_threshold = score_threshold )
68
- self .api_key = api_key
93
+
94
+ azure_endpoint = azure_endpoint or os .getenv ("AZURE_OPENAI_ENDPOINT" )
95
+ if not azure_endpoint :
96
+ raise ValueError ("No Azure OpenAI endpoint provided." )
97
+
98
+ api_version = api_version or os .getenv ("AZURE_OPENAI_API_VERSION" )
99
+ if not api_version :
100
+ raise ValueError ("No Azure OpenAI API version provided." )
101
+
102
+ if not (
103
+ azure_ad_token
104
+ or azure_ad_token_provider
105
+ or api_key
106
+ or os .getenv ("AZURE_OPENAI_API_KEY" )
107
+ ):
108
+ raise ValueError (
109
+ "No authentication method provided. Please provide either `azure_ad_token`, "
110
+ "`azure_ad_token_provider`, or `api_key`."
111
+ )
112
+
113
+ # Only check API Key if no AD token or provider is used
114
+ if not azure_ad_token and not azure_ad_token_provider :
115
+ api_key = api_key or os .getenv ("AZURE_OPENAI_API_KEY" )
116
+ if not api_key :
117
+ raise ValueError ("No Azure OpenAI API key provided." )
118
+
69
119
self .deployment_name = deployment_name
70
- self .azure_endpoint = azure_endpoint
71
- self .api_version = api_version
72
- self .model = model
120
+
73
121
# set dimensions to support openai embed 3 dimensions param
74
122
self .dimensions = dimensions
75
- if self .api_key is None :
76
- self .api_key = os .getenv ("AZURE_OPENAI_API_KEY" )
77
- if self .api_key is None :
78
- raise ValueError ("No Azure OpenAI API key provided." )
123
+
79
124
if max_retries is not None :
80
125
self .max_retries = max_retries
81
- if self .deployment_name is None :
82
- self .deployment_name = EncoderDefault .AZURE .value ["deployment_name" ]
83
- # deployment_name may still be None, but it is optional in the API
84
- if self .azure_endpoint is None :
85
- self .azure_endpoint = os .getenv ("AZURE_OPENAI_ENDPOINT" )
86
- if self .azure_endpoint is None :
87
- raise ValueError ("No Azure OpenAI endpoint provided." )
88
- if self .api_version is None :
89
- self .api_version = os .getenv ("AZURE_OPENAI_API_VERSION" )
90
- if self .api_version is None :
91
- raise ValueError ("No Azure OpenAI API version provided." )
92
- if self .model is None :
93
- self .model = os .getenv ("AZURE_OPENAI_MODEL" )
94
- if self .model is None :
95
- raise ValueError ("No Azure OpenAI model provided." )
96
- assert (
97
- self .api_key is not None
98
- and self .azure_endpoint is not None
99
- and self .api_version is not None
100
- and self .model is not None
126
+
127
+ # Only create HTTP clients if options are provided
128
+ sync_http_client = (
129
+ httpx .Client (** http_client_options ) if http_client_options else None
130
+ )
131
+ async_http_client = (
132
+ httpx .AsyncClient (** http_client_options ) if http_client_options else None
101
133
)
102
134
135
+ assert azure_endpoint is not None and self .deployment_name is not None
136
+
103
137
try :
104
138
self .client = openai .AzureOpenAI (
105
- azure_deployment = (
106
- str ( self . deployment_name ) if self . deployment_name else None
107
- ) ,
108
- api_key = str ( self . api_key ) ,
109
- azure_endpoint = str ( self . azure_endpoint ) ,
110
- api_version = str ( self . api_version ) ,
139
+ azure_endpoint = azure_endpoint ,
140
+ api_version = api_version ,
141
+ api_key = api_key ,
142
+ azure_ad_token = azure_ad_token ,
143
+ azure_ad_token_provider = azure_ad_token_provider ,
144
+ http_client = sync_http_client ,
111
145
)
112
146
self .async_client = openai .AsyncAzureOpenAI (
113
- azure_deployment = (
114
- str ( self . deployment_name ) if self . deployment_name else None
115
- ) ,
116
- api_key = str ( self . api_key ) ,
117
- azure_endpoint = str ( self . azure_endpoint ) ,
118
- api_version = str ( self . api_version ) ,
147
+ azure_endpoint = azure_endpoint ,
148
+ api_version = api_version ,
149
+ api_key = api_key ,
150
+ azure_ad_token = azure_ad_token ,
151
+ azure_ad_token_provider = azure_ad_token_provider ,
152
+ http_client = async_http_client ,
119
153
)
154
+
120
155
except Exception as e :
156
+ logger .error ("OpenAI API client failed to initialize. Error: %s" , e )
121
157
raise ValueError (
122
158
f"OpenAI API client failed to initialize. Error: { e } "
123
159
) from e
@@ -139,7 +175,7 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
139
175
try :
140
176
embeds = self .client .embeddings .create (
141
177
input = docs ,
142
- model = str (self .model ),
178
+ model = str (self .deployment_name ),
143
179
dimensions = self .dimensions ,
144
180
)
145
181
if embeds .data :
@@ -149,12 +185,12 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
149
185
if self .max_retries != 0 and j < self .max_retries :
150
186
sleep (2 ** j )
151
187
logger .warning (
152
- f "Retrying in { 2 ** j } seconds due to OpenAIError: { e } "
188
+ "Retrying in %d seconds due to OpenAIError: %s" , 2 ** j , e
153
189
)
154
190
else :
155
191
raise
156
192
except Exception as e :
157
- logger .error (f "Azure OpenAI API call failed. Error: { e } " )
193
+ logger .error ("Azure OpenAI API call failed. Error: %s" , e )
158
194
raise ValueError (f"Azure OpenAI API call failed. Error: { e } " ) from e
159
195
160
196
if (
@@ -183,23 +219,22 @@ async def acall(self, docs: List[str]) -> List[List[float]]:
183
219
try :
184
220
embeds = await self .async_client .embeddings .create (
185
221
input = docs ,
186
- model = str (self .model ),
222
+ model = str (self .deployment_name ),
187
223
dimensions = self .dimensions ,
188
224
)
189
225
if embeds .data :
190
226
break
191
-
192
227
except OpenAIError as e :
193
228
logger .error ("Exception occurred" , exc_info = True )
194
229
if self .max_retries != 0 and j < self .max_retries :
195
230
await asleep (2 ** j )
196
231
logger .warning (
197
- f "Retrying in { 2 ** j } seconds due to OpenAIError: { e } "
232
+ "Retrying in %d seconds due to OpenAIError: %s" , 2 ** j , e
198
233
)
199
234
else :
200
235
raise
201
236
except Exception as e :
202
- logger .error (f "Azure OpenAI API call failed. Error: { e } " )
237
+ logger .error ("Azure OpenAI API call failed. Error: %s" , e )
203
238
raise ValueError (f"Azure OpenAI API call failed. Error: { e } " ) from e
204
239
205
240
if (
0 commit comments