diff --git a/src/msgraph_core/graph_client_factory.py b/src/msgraph_core/graph_client_factory.py index 6cbbbcf7..3c1ac04f 100644 --- a/src/msgraph_core/graph_client_factory.py +++ b/src/msgraph_core/graph_client_factory.py @@ -8,11 +8,10 @@ import httpx from kiota_http.kiota_client_factory import KiotaClientFactory -from kiota_http.middleware import AsyncKiotaTransport from kiota_http.middleware.middleware import BaseMiddleware from ._enums import APIVersion, NationalClouds -from .middleware import GraphTelemetryHandler +from .middleware import AsyncGraphTransport, GraphTelemetryHandler class GraphClientFactory(KiotaClientFactory): @@ -40,9 +39,10 @@ def create_with_default_middleware( middleware, current_transport ) - client._transport = AsyncKiotaTransport( + client._transport = AsyncGraphTransport( transport=current_transport, pipeline=middleware_pipeline ) + client._transport.pipeline return client @staticmethod @@ -66,7 +66,7 @@ def create_with_custom_middleware( middleware, current_transport ) - client._transport = AsyncKiotaTransport( + client._transport = AsyncGraphTransport( transport=current_transport, pipeline=middleware_pipeline ) return client diff --git a/src/msgraph_core/middleware/__init__.py b/src/msgraph_core/middleware/__init__.py index 508a0819..1ecb083a 100644 --- a/src/msgraph_core/middleware/__init__.py +++ b/src/msgraph_core/middleware/__init__.py @@ -2,5 +2,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .async_graph_transport import AsyncGraphTransport from .request_context import GraphRequestContext from .telemetry import GraphTelemetryHandler diff --git a/src/msgraph_core/middleware/async_graph_transport.py b/src/msgraph_core/middleware/async_graph_transport.py new file mode 100644 index 00000000..9186231b --- /dev/null +++ b/src/msgraph_core/middleware/async_graph_transport.py @@ -0,0 +1,44 @@ +import json + +import httpx +from kiota_http.middleware import MiddlewarePipeline, RedirectHandler, RetryHandler + +from .._enums import FeatureUsageFlag +from .request_context import GraphRequestContext + + +class AsyncGraphTransport(httpx.AsyncBaseTransport): + """A custom transport for requests to the Microsoft Graph API + """ + + def __init__(self, transport: httpx.AsyncBaseTransport, pipeline: MiddlewarePipeline) -> None: + self.transport = transport + self.pipeline = pipeline + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + if self.pipeline: + self.set_request_context_and_feature_usage(request) + response = await self.pipeline.send(request) + return response + + response = await self.transport.handle_async_request(request) + return response + + def set_request_context_and_feature_usage(self, request: httpx.Request) -> httpx.Request: + + request_options = {} + options = request.headers.get('request_options', None) + if options: + request_options = json.loads(options) + + context = GraphRequestContext(request_options, request.headers) + middleware = self.pipeline._first_middleware + while middleware: + if isinstance(middleware, RedirectHandler): + context.feature_usage = FeatureUsageFlag.REDIRECT_HANDLER_ENABLED + if isinstance(middleware, RetryHandler): + context.feature_usage = FeatureUsageFlag.RETRY_HANDLER_ENABLED + + middleware = middleware.next + request.context = context #type: ignore + return request diff --git a/src/msgraph_core/middleware/telemetry.py b/src/msgraph_core/middleware/telemetry.py index 01d66f9a..33ac0b66 100644 --- a/src/msgraph_core/middleware/telemetry.py +++ b/src/msgraph_core/middleware/telemetry.py @@ -3,11 +3,12 @@ import platform import httpx -from kiota_http.middleware import AsyncKiotaTransport, BaseMiddleware, RedirectHandler, RetryHandler +from kiota_http.middleware import BaseMiddleware from urllib3.util import parse_url from .._constants import SDK_VERSION -from .._enums import FeatureUsageFlag, NationalClouds +from .._enums import NationalClouds +from .async_graph_transport import AsyncGraphTransport from .request_context import GraphRequestContext @@ -20,10 +21,9 @@ class GraphTelemetryHandler(BaseMiddleware): the SDK team improve the developer experience. """ - async def send(self, request: GraphRequest, transport: AsyncKiotaTransport): + async def send(self, request: GraphRequest, transport: AsyncGraphTransport): """Adds telemetry headers and sends the http request. """ - self.set_request_context_and_feature_usage(request, transport) if self.is_graph_url(request.url): self._add_client_request_id_header(request) @@ -34,27 +34,6 @@ async def send(self, request: GraphRequest, transport: AsyncKiotaTransport): response = await super().send(request, transport) return response - def set_request_context_and_feature_usage( - self, request: GraphRequest, transport: AsyncKiotaTransport - ) -> GraphRequest: - - request_options = {} - options = request.headers.pop('request_options', None) - if options: - request_options = json.loads(options) - - request.context = GraphRequestContext(request_options, request.headers) - middleware = transport.pipeline._first_middleware - while middleware: - if isinstance(middleware, RedirectHandler): - request.context.feature_usage = FeatureUsageFlag.REDIRECT_HANDLER_ENABLED - if isinstance(middleware, RetryHandler): - request.context.feature_usage = FeatureUsageFlag.RETRY_HANDLER_ENABLED - - middleware = middleware.next - - return request - def is_graph_url(self, url): """Check if the request is made to a graph endpoint. We do not add telemetry headers to non-graph endpoints""" diff --git a/tests/unit/test_async_graph_transport.py b/tests/unit/test_async_graph_transport.py new file mode 100644 index 00000000..619ced98 --- /dev/null +++ b/tests/unit/test_async_graph_transport.py @@ -0,0 +1,18 @@ +import pytest +from kiota_http.kiota_client_factory import KiotaClientFactory + +from msgraph_core._enums import FeatureUsageFlag +from msgraph_core.middleware import AsyncGraphTransport, GraphRequestContext + + +def test_set_request_context_and_feature_usage(mock_request, mock_transport): + middleware = KiotaClientFactory.get_default_middleware() + pipeline = KiotaClientFactory.create_middleware_pipeline(middleware, mock_transport) + transport = AsyncGraphTransport(mock_transport, pipeline) + transport.set_request_context_and_feature_usage(mock_request) + + assert hasattr(mock_request, 'context') + assert isinstance(mock_request.context, GraphRequestContext) + assert mock_request.context.feature_usage == hex( + FeatureUsageFlag.RETRY_HANDLER_ENABLED | FeatureUsageFlag.REDIRECT_HANDLER_ENABLED + ) diff --git a/tests/unit/test_graph_client_factory.py b/tests/unit/test_graph_client_factory.py index ae72531e..621a6ebb 100644 --- a/tests/unit/test_graph_client_factory.py +++ b/tests/unit/test_graph_client_factory.py @@ -4,10 +4,10 @@ # ------------------------------------ import httpx import pytest -from kiota_http.middleware import AsyncKiotaTransport, MiddlewarePipeline, RedirectHandler +from kiota_http.middleware import MiddlewarePipeline, RedirectHandler from msgraph_core import APIVersion, GraphClientFactory, NationalClouds -from msgraph_core.middleware.telemetry import GraphTelemetryHandler +from msgraph_core.middleware import AsyncGraphTransport, GraphTelemetryHandler def test_create_with_default_middleware(): @@ -15,7 +15,7 @@ def test_create_with_default_middleware(): client = GraphClientFactory.create_with_default_middleware() assert isinstance(client, httpx.AsyncClient) - assert isinstance(client._transport, AsyncKiotaTransport) + assert isinstance(client._transport, AsyncGraphTransport) pipeline = client._transport.pipeline assert isinstance(pipeline, MiddlewarePipeline) assert isinstance(pipeline._first_middleware, RedirectHandler) @@ -30,7 +30,7 @@ def test_create_with_custom_middleware(): client = GraphClientFactory.create_with_custom_middleware(middleware=middleware) assert isinstance(client, httpx.AsyncClient) - assert isinstance(client._transport, AsyncKiotaTransport) + assert isinstance(client._transport, AsyncGraphTransport) pipeline = client._transport.pipeline assert isinstance(pipeline._first_middleware, GraphTelemetryHandler) diff --git a/tests/unit/test_graph_telemetry_handler.py b/tests/unit/test_graph_telemetry_handler.py index 33e89260..de9684a6 100644 --- a/tests/unit/test_graph_telemetry_handler.py +++ b/tests/unit/test_graph_telemetry_handler.py @@ -10,22 +10,11 @@ import pytest from msgraph_core import SDK_VERSION, APIVersion, NationalClouds -from msgraph_core._enums import FeatureUsageFlag from msgraph_core.middleware import GraphRequestContext, GraphTelemetryHandler BASE_URL = NationalClouds.Global + '/' + APIVersion.v1 -def test_set_request_context_and_feature_usage(mock_request, mock_transport): - telemetry_handler = GraphTelemetryHandler() - telemetry_handler.set_request_context_and_feature_usage(mock_request, mock_transport) - - assert hasattr(mock_request, 'context') - assert mock_request.context.feature_usage == hex( - FeatureUsageFlag.RETRY_HANDLER_ENABLED | FeatureUsageFlag.REDIRECT_HANDLER_ENABLED - ) - - def test_is_graph_url(mock_graph_request): """ Test method that checks whether a request url is a graph endpoint