Skip to content

Commit d78add3

Browse files
hkt74copybara-github
authored andcommitted
fix(live): Enhance security by moving api key from query parameters to header
PiperOrigin-RevId: 783011026
1 parent 72e6859 commit d78add3

File tree

2 files changed

+57
-20
lines changed

2 files changed

+57
-20
lines changed

google/genai/live.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,8 @@ async def connect(
929929
version = self._api_client._http_options.api_version
930930
api_key = self._api_client.api_key
931931
method = 'BidiGenerateContent'
932-
key_name = 'key'
932+
original_headers = self._api_client._http_options.headers
933+
headers = original_headers.copy() if original_headers is not None else {}
933934
if api_key.startswith('auth_tokens/'):
934935
warnings.warn(
935936
message=(
@@ -939,7 +940,7 @@ async def connect(
939940
category=errors.ExperimentalWarning,
940941
)
941942
method = 'BidiGenerateContentConstrained'
942-
key_name = 'access_token'
943+
headers['Authorization'] = f'Token {api_key}'
943944
if version != 'v1alpha':
944945
warnings.warn(
945946
message=(
@@ -950,8 +951,7 @@ async def connect(
950951
),
951952
category=errors.ExperimentalWarning,
952953
)
953-
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}?{key_name}={api_key}'
954-
headers = self._api_client._http_options.headers
954+
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}'
955955

956956
request_dict = _common.convert_to_dict(
957957
live_converters._LiveConnectParameters_to_mldev(
@@ -972,7 +972,7 @@ async def connect(
972972
api_key = self._api_client.api_key
973973
version = self._api_client._http_options.api_version
974974
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
975-
headers = self._api_client._http_options.headers
975+
headers = self._api_client._http_options.headers or {}
976976

977977
request_dict = _common.convert_to_dict(
978978
live_converters._LiveConnectParameters_to_vertex(
@@ -1002,11 +1002,9 @@ async def connect(
10021002
auth_req = google.auth.transport.requests.Request() # type: ignore
10031003
creds.refresh(auth_req)
10041004
bearer_token = creds.token
1005-
headers = self._api_client._http_options.headers
1006-
if headers is not None:
1007-
headers.update({
1008-
'Authorization': 'Bearer {}'.format(bearer_token),
1009-
})
1005+
original_headers = self._api_client._http_options.headers
1006+
headers = original_headers.copy() if original_headers is not None else {}
1007+
headers['Authorization'] = f'Bearer {bearer_token}'
10101008
version = self._api_client._http_options.api_version
10111009
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
10121010
location = self._api_client.location

google/genai/tests/live/test_live.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,9 +1830,11 @@ async def test_connect_with_provided_credentials(mock_websocket):
18301830
credentials = Credentials(token="provided_fake_token")
18311831
# mock api client
18321832
client = mock_api_client(vertexai=True, credentials=credentials)
1833+
capture = {}
18331834

18341835
@contextlib.asynccontextmanager
18351836
async def mock_connect(uri, additional_headers=None, **kwargs):
1837+
capture['headers'] = additional_headers
18361838
yield mock_websocket
18371839

18381840
@patch.object(live, "ws_connect", new=mock_connect)
@@ -1841,10 +1843,9 @@ async def _test_connect():
18411843
async with live_module.connect(model="test-model"):
18421844
pass
18431845

1844-
assert "Authorization" in live_module._api_client._http_options.headers
1846+
assert "Authorization" in capture['headers']
18451847
assert (
1846-
live_module._api_client._http_options.headers["Authorization"]
1847-
== "Bearer provided_fake_token"
1848+
capture['headers']['Authorization'] == "Bearer provided_fake_token"
18481849
)
18491850

18501851
await _test_connect()
@@ -1858,9 +1859,11 @@ async def test_connect_with_default_credentials(mock_websocket):
18581859
mock_google_auth_default = Mock(return_value=(None, None))
18591860
mock_creds = Mock(token="default_test_token")
18601861
mock_google_auth_default.return_value = (mock_creds, None)
1862+
capture = {}
18611863

18621864
@contextlib.asynccontextmanager
18631865
async def mock_connect(uri, additional_headers=None, **kwargs):
1866+
capture['headers'] = additional_headers
18641867
yield mock_websocket
18651868

18661869
@patch("google.auth.default", new=mock_google_auth_default)
@@ -1870,10 +1873,9 @@ async def _test_connect():
18701873
async with live_module.connect(model="test-model"):
18711874
pass
18721875

1873-
assert "Authorization" in live_module._api_client._http_options.headers
1876+
assert "Authorization" in capture['headers']
18741877
assert (
1875-
live_module._api_client._http_options.headers["Authorization"]
1876-
== "Bearer default_test_token"
1878+
capture['headers']['Authorization'] == "Bearer default_test_token"
18771879
)
18781880

18791881
await _test_connect()
@@ -1892,11 +1894,12 @@ async def test_bidi_setup_to_api_with_auth_tokens(mock_websocket, vertexai):
18921894
mock_ws = AsyncMock()
18931895
mock_ws.send = AsyncMock()
18941896
mock_ws.recv = AsyncMock(return_value=b'some response')
1895-
uri_capture = {} # Capture the uri here
1897+
capture = {}
18961898

18971899
@contextlib.asynccontextmanager
18981900
async def mock_connect(uri, additional_headers=None, **kwargs):
1899-
uri_capture['uri'] = uri # Capture the uri
1901+
capture['uri'] = uri
1902+
capture['headers'] = additional_headers
19001903
yield mock_ws
19011904

19021905
with patch.object(live, 'ws_connect', new=mock_connect):
@@ -1906,5 +1909,41 @@ async def mock_connect(uri, additional_headers=None, **kwargs):
19061909
):
19071910
pass
19081911

1909-
assert 'access_token=auth_tokens/TEST_AUTH_TOKEN' in uri_capture['uri']
1910-
assert 'BidiGenerateContentConstrained' in uri_capture['uri']
1912+
assert 'Authorization' in capture['headers'], "Authorization key is missing from headers"
1913+
assert capture['headers']['Authorization'] == 'Token auth_tokens/TEST_AUTH_TOKEN'
1914+
assert 'BidiGenerateContentConstrained' in capture['uri']
1915+
1916+
1917+
@pytest.mark.parametrize('vertexai', [False])
1918+
@pytest.mark.asyncio
1919+
async def test_bidi_setup_to_api_with_api_key(mock_websocket, vertexai):
1920+
api_client_mock = mock_api_client(vertexai=vertexai)
1921+
api_client_mock._http_options = types.HttpOptions.model_validate(
1922+
{'headers': {'x-goog-api-key': 'TEST_API_KEY'}}
1923+
)
1924+
result = await get_connect_message(
1925+
api_client_mock,
1926+
model='test_model'
1927+
)
1928+
1929+
mock_ws = AsyncMock()
1930+
mock_ws.send = AsyncMock()
1931+
mock_ws.recv = AsyncMock(return_value=b'some response')
1932+
capture = {}
1933+
1934+
@contextlib.asynccontextmanager
1935+
async def mock_connect(uri, additional_headers=None, **kwargs):
1936+
capture['uri'] = uri
1937+
capture['headers'] = additional_headers
1938+
yield mock_ws
1939+
1940+
with patch.object(live, 'ws_connect', new=mock_connect):
1941+
live_module = live.AsyncLive(api_client_mock)
1942+
async with live_module.connect(
1943+
model='test_model',
1944+
):
1945+
pass
1946+
1947+
assert 'x-goog-api-key' in capture['headers'], "x-goog-api-key is missing from headers"
1948+
assert capture['headers']['x-goog-api-key'] == 'TEST_API_KEY'
1949+
assert 'BidiGenerateContent' in capture['uri']

0 commit comments

Comments
 (0)