@@ -1830,9 +1830,11 @@ async def test_connect_with_provided_credentials(mock_websocket):
1830
1830
credentials = Credentials (token = "provided_fake_token" )
1831
1831
# mock api client
1832
1832
client = mock_api_client (vertexai = True , credentials = credentials )
1833
+ capture = {}
1833
1834
1834
1835
@contextlib .asynccontextmanager
1835
1836
async def mock_connect (uri , additional_headers = None , ** kwargs ):
1837
+ capture ['headers' ] = additional_headers
1836
1838
yield mock_websocket
1837
1839
1838
1840
@patch .object (live , "ws_connect" , new = mock_connect )
@@ -1841,10 +1843,9 @@ async def _test_connect():
1841
1843
async with live_module .connect (model = "test-model" ):
1842
1844
pass
1843
1845
1844
- assert "Authorization" in live_module . _api_client . _http_options . headers
1846
+ assert "Authorization" in capture [ ' headers' ]
1845
1847
assert (
1846
- live_module ._api_client ._http_options .headers ["Authorization" ]
1847
- == "Bearer provided_fake_token"
1848
+ capture ['headers' ]['Authorization' ] == "Bearer provided_fake_token"
1848
1849
)
1849
1850
1850
1851
await _test_connect ()
@@ -1858,9 +1859,11 @@ async def test_connect_with_default_credentials(mock_websocket):
1858
1859
mock_google_auth_default = Mock (return_value = (None , None ))
1859
1860
mock_creds = Mock (token = "default_test_token" )
1860
1861
mock_google_auth_default .return_value = (mock_creds , None )
1862
+ capture = {}
1861
1863
1862
1864
@contextlib .asynccontextmanager
1863
1865
async def mock_connect (uri , additional_headers = None , ** kwargs ):
1866
+ capture ['headers' ] = additional_headers
1864
1867
yield mock_websocket
1865
1868
1866
1869
@patch ("google.auth.default" , new = mock_google_auth_default )
@@ -1870,10 +1873,9 @@ async def _test_connect():
1870
1873
async with live_module .connect (model = "test-model" ):
1871
1874
pass
1872
1875
1873
- assert "Authorization" in live_module . _api_client . _http_options . headers
1876
+ assert "Authorization" in capture [ ' headers' ]
1874
1877
assert (
1875
- live_module ._api_client ._http_options .headers ["Authorization" ]
1876
- == "Bearer default_test_token"
1878
+ capture ['headers' ]['Authorization' ] == "Bearer default_test_token"
1877
1879
)
1878
1880
1879
1881
await _test_connect ()
@@ -1892,11 +1894,12 @@ async def test_bidi_setup_to_api_with_auth_tokens(mock_websocket, vertexai):
1892
1894
mock_ws = AsyncMock ()
1893
1895
mock_ws .send = AsyncMock ()
1894
1896
mock_ws .recv = AsyncMock (return_value = b'some response' )
1895
- uri_capture = {} # Capture the uri here
1897
+ capture = {}
1896
1898
1897
1899
@contextlib .asynccontextmanager
1898
1900
async def mock_connect (uri , additional_headers = None , ** kwargs ):
1899
- uri_capture ['uri' ] = uri # Capture the uri
1901
+ capture ['uri' ] = uri
1902
+ capture ['headers' ] = additional_headers
1900
1903
yield mock_ws
1901
1904
1902
1905
with patch .object (live , 'ws_connect' , new = mock_connect ):
@@ -1906,5 +1909,41 @@ async def mock_connect(uri, additional_headers=None, **kwargs):
1906
1909
):
1907
1910
pass
1908
1911
1909
- assert 'access_token=auth_tokens/TEST_AUTH_TOKEN' in uri_capture ['uri' ]
1910
- assert 'BidiGenerateContentConstrained' in uri_capture ['uri' ]
1912
+ assert 'Authorization' in capture ['headers' ], "Authorization key is missing from headers"
1913
+ assert capture ['headers' ]['Authorization' ] == 'Token auth_tokens/TEST_AUTH_TOKEN'
1914
+ assert 'BidiGenerateContentConstrained' in capture ['uri' ]
1915
+
1916
+
1917
+ @pytest .mark .parametrize ('vertexai' , [False ])
1918
+ @pytest .mark .asyncio
1919
+ async def test_bidi_setup_to_api_with_api_key (mock_websocket , vertexai ):
1920
+ api_client_mock = mock_api_client (vertexai = vertexai )
1921
+ api_client_mock ._http_options = types .HttpOptions .model_validate (
1922
+ {'headers' : {'x-goog-api-key' : 'TEST_API_KEY' }}
1923
+ )
1924
+ result = await get_connect_message (
1925
+ api_client_mock ,
1926
+ model = 'test_model'
1927
+ )
1928
+
1929
+ mock_ws = AsyncMock ()
1930
+ mock_ws .send = AsyncMock ()
1931
+ mock_ws .recv = AsyncMock (return_value = b'some response' )
1932
+ capture = {}
1933
+
1934
+ @contextlib .asynccontextmanager
1935
+ async def mock_connect (uri , additional_headers = None , ** kwargs ):
1936
+ capture ['uri' ] = uri
1937
+ capture ['headers' ] = additional_headers
1938
+ yield mock_ws
1939
+
1940
+ with patch .object (live , 'ws_connect' , new = mock_connect ):
1941
+ live_module = live .AsyncLive (api_client_mock )
1942
+ async with live_module .connect (
1943
+ model = 'test_model' ,
1944
+ ):
1945
+ pass
1946
+
1947
+ assert 'x-goog-api-key' in capture ['headers' ], "x-goog-api-key is missing from headers"
1948
+ assert capture ['headers' ]['x-goog-api-key' ] == 'TEST_API_KEY'
1949
+ assert 'BidiGenerateContent' in capture ['uri' ]
0 commit comments