diff --git a/Directory.Packages.props b/Directory.Packages.props index 78133ed9..78ec76dd 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -19,6 +19,7 @@ + @@ -26,6 +27,8 @@ + + diff --git a/ModelContextProtocol.slnx b/ModelContextProtocol.slnx index e4fd42fe..5233cb46 100644 --- a/ModelContextProtocol.slnx +++ b/ModelContextProtocol.slnx @@ -12,6 +12,8 @@ + + diff --git a/samples/ProtectedMCPClient/Program.cs b/samples/ProtectedMCPClient/Program.cs new file mode 100644 index 00000000..9343a999 --- /dev/null +++ b/samples/ProtectedMCPClient/Program.cs @@ -0,0 +1,173 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Authentication; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using System.Diagnostics; +using System.Net; +using System.Text; +using System.Web; + +Console.WriteLine("Protected MCP Weather Server"); +Console.WriteLine(); + +var serverUrl = "http://localhost:7071/"; +var clientId = Environment.GetEnvironmentVariable("CLIENT_ID") ?? throw new Exception("The CLIENT_ID environment variable is not set."); + +// We can customize a shared HttpClient with a custom handler if desired +var sharedHandler = new SocketsHttpHandler +{ + PooledConnectionLifetime = TimeSpan.FromMinutes(2), + PooledConnectionIdleTimeout = TimeSpan.FromMinutes(1) +}; + +var consoleLoggerFactory = LoggerFactory.Create(builder => +{ + builder.AddConsole(); +}); + +var httpClient = new HttpClient(sharedHandler); +// Create the token provider with our custom HttpClient and authorization URL handler +var tokenProvider = new GenericOAuthProvider( + new Uri(serverUrl), + httpClient, + clientId: clientId, + redirectUri: new Uri("http://localhost:1179/callback"), + authorizationRedirectDelegate: HandleAuthorizationUrlAsync, + loggerFactory: consoleLoggerFactory); + +Console.WriteLine(); +Console.WriteLine($"Connecting to weather server at {serverUrl}..."); + +try +{ + var transport = new SseClientTransport(new() + { + Endpoint = new Uri(serverUrl), + Name = "Secure Weather Client", + CredentialProvider = tokenProvider, + }, httpClient, consoleLoggerFactory); + + var client = await McpClientFactory.CreateAsync(transport, loggerFactory: consoleLoggerFactory); + + var tools = await client.ListToolsAsync(); + if (tools.Count == 0) + { + Console.WriteLine("No tools available on the server."); + return; + } + + Console.WriteLine($"Found {tools.Count} tools on the server."); + Console.WriteLine(); + + if (tools.Any(t => t.Name == "GetAlerts")) + { + Console.WriteLine("Calling GetAlerts tool..."); + + var result = await client.CallToolAsync( + "GetAlerts", + new Dictionary { { "state", "WA" } } + ); + + Console.WriteLine("Result: " + ((TextContentBlock)result.Content[0]).Text); + Console.WriteLine(); + } +} +catch (Exception ex) +{ + Console.WriteLine($"Error: {ex.Message}"); + if (ex.InnerException != null) + { + Console.WriteLine($"Inner error: {ex.InnerException.Message}"); + } + +#if DEBUG + Console.WriteLine($"Stack trace: {ex.StackTrace}"); +#endif +} +Console.WriteLine("Press any key to exit..."); +Console.ReadKey(); + +/// +/// Handles the OAuth authorization URL by starting a local HTTP server and opening a browser. +/// This implementation demonstrates how SDK consumers can provide their own authorization flow. +/// +/// The authorization URL to open in the browser. +/// The redirect URI where the authorization code will be sent. +/// The cancellation token. +/// The authorization code extracted from the callback, or null if the operation failed. +static async Task HandleAuthorizationUrlAsync(Uri authorizationUrl, Uri redirectUri, CancellationToken cancellationToken) +{ + Console.WriteLine("Starting OAuth authorization flow..."); + Console.WriteLine($"Opening browser to: {authorizationUrl}"); + + var listenerPrefix = redirectUri.GetLeftPart(UriPartial.Authority); + if (!listenerPrefix.EndsWith("/")) listenerPrefix += "/"; + + using var listener = new HttpListener(); + listener.Prefixes.Add(listenerPrefix); + + try + { + listener.Start(); + Console.WriteLine($"Listening for OAuth callback on: {listenerPrefix}"); + + OpenBrowser(authorizationUrl); + + var context = await listener.GetContextAsync(); + var query = HttpUtility.ParseQueryString(context.Request.Url?.Query ?? string.Empty); + var code = query["code"]; + var error = query["error"]; + + string responseHtml = "

Authentication complete

You can close this window now.

"; + byte[] buffer = Encoding.UTF8.GetBytes(responseHtml); + context.Response.ContentLength64 = buffer.Length; + context.Response.ContentType = "text/html"; + context.Response.OutputStream.Write(buffer, 0, buffer.Length); + context.Response.Close(); + + if (!string.IsNullOrEmpty(error)) + { + Console.WriteLine($"Auth error: {error}"); + return null; + } + + if (string.IsNullOrEmpty(code)) + { + Console.WriteLine("No authorization code received"); + return null; + } + + Console.WriteLine("Authorization code received successfully."); + return code; + } + catch (Exception ex) + { + Console.WriteLine($"Error getting auth code: {ex.Message}"); + return null; + } + finally + { + if (listener.IsListening) listener.Stop(); + } +} + +/// +/// Opens the specified URL in the default browser. +/// +/// The URL to open. +static void OpenBrowser(Uri url) +{ + try + { + var psi = new ProcessStartInfo + { + FileName = url.ToString(), + UseShellExecute = true + }; + Process.Start(psi); + } + catch (Exception ex) + { + Console.WriteLine($"Error opening browser. {ex.Message}"); + } +} \ No newline at end of file diff --git a/samples/ProtectedMCPClient/ProtectedMCPClient.csproj b/samples/ProtectedMCPClient/ProtectedMCPClient.csproj new file mode 100644 index 00000000..d1d47637 --- /dev/null +++ b/samples/ProtectedMCPClient/ProtectedMCPClient.csproj @@ -0,0 +1,18 @@ + + + + Exe + net9.0 + enable + enable + + + + + + + + + + + \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Program.cs b/samples/ProtectedMCPServer/Program.cs new file mode 100644 index 00000000..6f8bc74e --- /dev/null +++ b/samples/ProtectedMCPServer/Program.cs @@ -0,0 +1,105 @@ +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.IdentityModel.Tokens; +using ModelContextProtocol.AspNetCore.Authentication; +using ModelContextProtocol.Authentication; +using ProtectedMCPServer.Tools; +using System.Net.Http.Headers; +using System.Security.Claims; + +var builder = WebApplication.CreateBuilder(args); + +var serverUrl = "http://localhost:7071/"; +var tenantId = builder.Configuration["TenantId"]; +var clientId = builder.Configuration["ClientId"]; +var instance = "https://login.microsoftonline.com/"; + +builder.Services.AddAuthentication(options => +{ + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; +}) +.AddJwtBearer(options => +{ + options.Authority = $"{instance}{tenantId}/v2.0"; + options.TokenValidationParameters = new TokenValidationParameters + { + ValidateIssuer = true, + ValidateAudience = true, + ValidateLifetime = true, + ValidateIssuerSigningKey = true, + ValidAudience = clientId, + ValidIssuer = $"{instance}{tenantId}/v2.0", + NameClaimType = "name", + RoleClaimType = "roles" + }; + + options.MetadataAddress = $"{instance}{tenantId}/v2.0/.well-known/openid-configuration"; + + options.Events = new JwtBearerEvents + { + OnTokenValidated = context => + { + var name = context.Principal?.Identity?.Name ?? "unknown"; + var email = context.Principal?.FindFirstValue("preferred_username") ?? "unknown"; + Console.WriteLine($"Token validated for: {name} ({email})"); + return Task.CompletedTask; + }, + OnAuthenticationFailed = context => + { + Console.WriteLine($"Authentication failed: {context.Exception.Message}"); + return Task.CompletedTask; + }, + OnChallenge = context => + { + Console.WriteLine($"Challenging client to authenticate with Entra ID"); + return Task.CompletedTask; + } + }; +}) +.AddMcp(options => +{ + options.ProtectedResourceMetadataProvider = context => + { + var metadata = new ProtectedResourceMetadata + { + Resource = new Uri("http://localhost:7071/"), + BearerMethodsSupported = { "header" }, + ResourceDocumentation = new Uri("https://docs.example.com/api/weather"), + AuthorizationServers = { new Uri($"{instance}{tenantId}/v2.0") } + }; + + metadata.ScopesSupported.AddRange([ + $"api://{clientId}/weather.read" + ]); + + return metadata; + }; +}); + +builder.Services.AddAuthorization(); + +builder.Services.AddHttpContextAccessor(); +builder.Services.AddMcpServer() + .WithTools() + .WithHttpTransport(); + +// Configure HttpClientFactory for weather.gov API +builder.Services.AddHttpClient("WeatherApi", client => +{ + client.BaseAddress = new Uri("https://api.weather.gov"); + client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("weather-tool", "1.0")); +}); + +var app = builder.Build(); + +app.UseAuthentication(); +app.UseAuthorization(); + +// Use the default MCP policy name that we've configured +app.MapMcp().RequireAuthorization(); + +Console.WriteLine($"Starting MCP server with authorization at {serverUrl}"); +Console.WriteLine($"PRM Document URL: {serverUrl}.well-known/oauth-protected-resource"); +Console.WriteLine("Press Ctrl+C to stop the server"); + +app.Run(serverUrl); diff --git a/samples/ProtectedMCPServer/Properties/launchSettings.json b/samples/ProtectedMCPServer/Properties/launchSettings.json new file mode 100644 index 00000000..03646532 --- /dev/null +++ b/samples/ProtectedMCPServer/Properties/launchSettings.json @@ -0,0 +1,12 @@ +{ + "profiles": { + "ProtectedMCPServer": { + "commandName": "Project", + "launchBrowser": true, + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + }, + "applicationUrl": "https://localhost:55598;http://localhost:55599" + } + } +} \ No newline at end of file diff --git a/samples/ProtectedMCPServer/ProtectedMCPServer.csproj b/samples/ProtectedMCPServer/ProtectedMCPServer.csproj new file mode 100644 index 00000000..b4c35c77 --- /dev/null +++ b/samples/ProtectedMCPServer/ProtectedMCPServer.csproj @@ -0,0 +1,15 @@ + + + + net9.0 + enable + enable + 783daef3-9c45-408d-a1d3-7caf44724f39 + + + + + + + + \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Tools/HttpClientExt.cs b/samples/ProtectedMCPServer/Tools/HttpClientExt.cs new file mode 100644 index 00000000..f7b2b549 --- /dev/null +++ b/samples/ProtectedMCPServer/Tools/HttpClientExt.cs @@ -0,0 +1,13 @@ +using System.Text.Json; + +namespace ModelContextProtocol; + +internal static class HttpClientExt +{ + public static async Task ReadJsonDocumentAsync(this HttpClient client, string requestUri) + { + using var response = await client.GetAsync(requestUri); + response.EnsureSuccessStatusCode(); + return await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); + } +} \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Tools/WeatherTools.cs b/samples/ProtectedMCPServer/Tools/WeatherTools.cs new file mode 100644 index 00000000..7c8c0851 --- /dev/null +++ b/samples/ProtectedMCPServer/Tools/WeatherTools.cs @@ -0,0 +1,67 @@ +using ModelContextProtocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Globalization; +using System.Text.Json; + +namespace ProtectedMCPServer.Tools; + +[McpServerToolType] +public sealed class WeatherTools +{ + private readonly IHttpClientFactory _httpClientFactory; + + public WeatherTools(IHttpClientFactory httpClientFactory) + { + _httpClientFactory = httpClientFactory; + } + + [McpServerTool, Description("Get weather alerts for a US state.")] + public async Task GetAlerts( + [Description("The US state to get alerts for. Use the 2 letter abbreviation for the state (e.g. NY).")] string state) + { + var client = _httpClientFactory.CreateClient("WeatherApi"); + using var jsonDocument = await client.ReadJsonDocumentAsync($"/alerts/active/area/{state}"); + var jsonElement = jsonDocument.RootElement; + var alerts = jsonElement.GetProperty("features").EnumerateArray(); + + if (!alerts.Any()) + { + return "No active alerts for this state."; + } + + return string.Join("\n--\n", alerts.Select(alert => + { + JsonElement properties = alert.GetProperty("properties"); + return $""" + Event: {properties.GetProperty("event").GetString()} + Area: {properties.GetProperty("areaDesc").GetString()} + Severity: {properties.GetProperty("severity").GetString()} + Description: {properties.GetProperty("description").GetString()} + Instruction: {properties.GetProperty("instruction").GetString()} + """; + })); + } + + [McpServerTool, Description("Get weather forecast for a location.")] + public async Task GetForecast( + [Description("Latitude of the location.")] double latitude, + [Description("Longitude of the location.")] double longitude) + { + var client = _httpClientFactory.CreateClient("WeatherApi"); + var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}"); + using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl); + var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString() + ?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); + + using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl); + var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray(); + + return string.Join("\n---\n", periods.Select(period => $""" + {period.GetProperty("name").GetString()} + Temperature: {period.GetProperty("temperature").GetInt32()}°F + Wind: {period.GetProperty("windSpeed").GetString()} {period.GetProperty("windDirection").GetString()} + Forecast: {period.GetProperty("detailedForecast").GetString()} + """)); + } +} diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationDefaults.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationDefaults.cs new file mode 100644 index 00000000..4c720c65 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationDefaults.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Default values used by MCP authentication. +/// +public static class McpAuthenticationDefaults +{ + /// + /// The default value used for authentication scheme name. + /// + public const string AuthenticationScheme = "McpAuth"; + + /// + /// The default value used for authentication scheme display name. + /// + public const string DisplayName = "MCP Authentication"; +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationEvents.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationEvents.cs new file mode 100644 index 00000000..10762a0f --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationEvents.cs @@ -0,0 +1,8 @@ +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Represents the authentication events for Model Context Protocol. +/// +public class McpAuthenticationEvents +{ +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationExtensions.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationExtensions.cs new file mode 100644 index 00000000..49dd3dfe --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationExtensions.cs @@ -0,0 +1,47 @@ +using Microsoft.AspNetCore.Authentication; +using ModelContextProtocol.AspNetCore.Authentication; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension methods for adding MCP authorization support to ASP.NET Core applications. +/// +public static class McpAuthenticationExtensions +{ + /// + /// Adds MCP authorization support to the application. + /// + /// The authentication builder. + /// An action to configure MCP authentication options. + /// The authentication builder for chaining. + public static AuthenticationBuilder AddMcp( + this AuthenticationBuilder builder, + Action? configureOptions = null) + { + return AddMcp( + builder, + McpAuthenticationDefaults.AuthenticationScheme, + McpAuthenticationDefaults.DisplayName, + configureOptions); + } + + /// + /// Adds MCP authorization support to the application with a custom scheme name. + /// + /// The authentication builder. + /// The authentication scheme name to use. + /// The display name for the authentication scheme. + /// An action to configure MCP authentication options. + /// The authentication builder for chaining. + public static AuthenticationBuilder AddMcp( + this AuthenticationBuilder builder, + string authenticationScheme, + string displayName, + Action? configureOptions = null) + { + return builder.AddScheme( + authenticationScheme, + displayName, + configureOptions); // No-op to avoid overriding + } +} diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs new file mode 100644 index 00000000..63c6db03 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs @@ -0,0 +1,149 @@ +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Authentication; +using System.Text.Encodings.Web; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Authentication handler for MCP protocol that adds resource metadata to challenge responses +/// and handles resource metadata endpoint requests. +/// +public class McpAuthenticationHandler : AuthenticationHandler, IAuthenticationRequestHandler +{ + private readonly IOptionsMonitor _optionsMonitor; + + /// + /// Initializes a new instance of the class. + /// + public McpAuthenticationHandler( + IOptionsMonitor options, + ILoggerFactory logger, + UrlEncoder encoder) + : base(options, logger, encoder) + { + _optionsMonitor = options; + } + + /// + public async Task HandleRequestAsync() + { + // Check if the request is for the resource metadata endpoint + string requestPath = Request.Path.Value ?? string.Empty; + + string expectedMetadataPath = this.Options.ResourceMetadataUri?.ToString() ?? string.Empty; + if (this.Options.ResourceMetadataUri != null && !this.Options.ResourceMetadataUri.IsAbsoluteUri) + { + // For relative URIs, it's just the path component. + expectedMetadataPath = this.Options.ResourceMetadataUri.OriginalString; + } + + // If the path doesn't match, let the request continue through the pipeline + if (!string.Equals(requestPath, expectedMetadataPath, StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + var cancellationToken = Request.HttpContext.RequestAborted; + await HandleResourceMetadataRequestAsync(cancellationToken); + return true; + } + + /// + /// Gets the base URL from the current request, including scheme, host, and path base. + /// + private string GetBaseUrl() => $"{Request.Scheme}://{Request.Host}{Request.PathBase}"; + + /// + /// Gets the absolute URI for the resource metadata endpoint. + /// + private string GetAbsoluteResourceMetadataUri() + { + var options = this.Options; + var resourceMetadataUri = options.ResourceMetadataUri; + + string currentPath = resourceMetadataUri?.ToString() ?? string.Empty; + + if (resourceMetadataUri != null && resourceMetadataUri.IsAbsoluteUri) + { + return currentPath; + } + + // For relative URIs, combine with the base URL + string baseUrl = GetBaseUrl(); + string relativePath = resourceMetadataUri?.OriginalString.TrimStart('/') ?? string.Empty; + + if (!Uri.TryCreate($"{baseUrl.TrimEnd('/')}/{relativePath}", UriKind.Absolute, out var absoluteUri)) + { + throw new InvalidOperationException($"Could not create absolute URI for resource metadata. Base URL: {baseUrl}, Relative Path: {relativePath}"); + } + + return absoluteUri.ToString(); + } + + /// + /// Handles the resource metadata request. + /// + /// A token to cancel the operation. + private Task HandleResourceMetadataRequestAsync(CancellationToken cancellationToken = default) + { + var options = this.Options; + var resourceMetadata = options.GetResourceMetadata(Request.HttpContext); + + // Create a copy to avoid modifying the original + var metadata = new ProtectedResourceMetadata + { + Resource = resourceMetadata.Resource ?? new Uri(GetBaseUrl()), + AuthorizationServers = [.. resourceMetadata.AuthorizationServers], + BearerMethodsSupported = [.. resourceMetadata.BearerMethodsSupported], + ScopesSupported = [.. resourceMetadata.ScopesSupported], + ResourceDocumentation = resourceMetadata.ResourceDocumentation + }; + + Response.StatusCode = StatusCodes.Status200OK; + Response.ContentType = "application/json"; + + var json = JsonSerializer.Serialize( + metadata, + McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))); + + return Response.WriteAsync(json, cancellationToken); + } + + /// + protected override async Task HandleAuthenticateAsync() + { + // If ForwardAuthenticate is set, forward the authentication to the specified scheme + if (!string.IsNullOrEmpty(Options.ForwardAuthenticate) && + Options.ForwardAuthenticate != Scheme.Name) + { + // Simply forward the authentication request to the specified scheme and return its result + // This ensures we don't interfere with the authentication process + return await Context.AuthenticateAsync(Options.ForwardAuthenticate); + } + + // If no forwarding is configured, this handler doesn't perform authentication + return AuthenticateResult.NoResult(); + } + + /// + protected override Task HandleChallengeAsync(AuthenticationProperties properties) + { + // Get the absolute URI for the resource metadata + string rawPrmDocumentUri = GetAbsoluteResourceMetadataUri(); + + properties ??= new AuthenticationProperties(); + + // Store the resource_metadata in properties in case other handlers need it + properties.Items["resource_metadata"] = rawPrmDocumentUri; + + // Add the WWW-Authenticate header with Bearer scheme and resource metadata + string headerValue = $"Bearer realm=\"{Scheme.Name}\", resource_metadata=\"{rawPrmDocumentUri}\""; + Response.Headers.Append("WWW-Authenticate", headerValue); + + return base.HandleChallengeAsync(properties); + } +} diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationOptions.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationOptions.cs new file mode 100644 index 00000000..3a989cf8 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationOptions.cs @@ -0,0 +1,104 @@ +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using ModelContextProtocol.Authentication; + +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Options for the MCP authentication handler. +/// +public class McpAuthenticationOptions : AuthenticationSchemeOptions +{ + private static readonly Uri DefaultResourceMetadataUri = new("/.well-known/oauth-protected-resource", UriKind.Relative); + + private Func? _resourceMetadataProvider; + + private ProtectedResourceMetadata? _resourceMetadata; + + /// + /// Initializes a new instance of the class. + /// + public McpAuthenticationOptions() + { + // "Bearer" is JwtBearerDefaults.AuthenticationScheme, but we don't have a reference to the JwtBearer package here. + ForwardAuthenticate = "Bearer"; + ResourceMetadataUri = DefaultResourceMetadataUri; + Events = new McpAuthenticationEvents(); + } + + /// + /// Gets or sets the events used to handle authentication events. + /// + public new McpAuthenticationEvents Events + { + get { return (McpAuthenticationEvents)base.Events!; } + set { base.Events = value; } + } + + /// + /// The URI to the resource metadata document. + /// + /// + /// This URI will be included in the WWW-Authenticate header when a 401 response is returned. + /// + public Uri ResourceMetadataUri { get; set; } + + /// + /// Gets or sets the static protected resource metadata. + /// + /// + /// This contains the OAuth metadata for the protected resource, including authorization servers, + /// supported scopes, and other information needed for clients to authenticate. + /// Setting this property will automatically update the + /// to return this static instance. + /// + /// Thrown when trying to set a null value. + /// Thrown when the Resource property of the metadata is null. + public ProtectedResourceMetadata ResourceMetadata + { + get => _resourceMetadata ?? throw new InvalidOperationException( + "ResourceMetadata has not been configured."); + set + { + ArgumentNullException.ThrowIfNull(value); + if (value.Resource == null) + { + throw new ArgumentException("The Resource property of the metadata cannot be null. A valid resource URI is required.", nameof(value)); + } + + _resourceMetadata = value; + // When static metadata is set, update the provider to use it + _resourceMetadataProvider = _ => _resourceMetadata; + } + } + + /// + /// Gets or sets a delegate that dynamically provides resource metadata based on the HTTP context. + /// + /// + /// When set, this delegate will be called to generate resource metadata for each request, + /// allowing dynamic customization based on the caller or other contextual information. + /// This takes precedence over the static property. + /// + public Func? ProtectedResourceMetadataProvider + { + get => _resourceMetadataProvider; + set => _resourceMetadataProvider = value; + } + + /// + /// Gets the resource metadata for the current request. + /// + /// The HTTP context for the current request. + /// The resource metadata to use for the current request. + /// Thrown when no resource metadata has been configured. + internal ProtectedResourceMetadata GetResourceMetadata(HttpContext context) + { + var provider = _resourceMetadataProvider; + + return provider != null + ? provider(context) + : _resourceMetadata ?? throw new InvalidOperationException( + "ResourceMetadata has not been configured."); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs b/src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs new file mode 100644 index 00000000..53d08d7e --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs @@ -0,0 +1,118 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using System.Net.Http.Headers; + +namespace ModelContextProtocol.Authentication; + +/// +/// A delegating handler that adds authentication tokens to requests and handles 401 responses. +/// +internal sealed class AuthenticatingMcpHttpClient(HttpClient httpClient, IMcpCredentialProvider credentialProvider) : McpHttpClient(httpClient) +{ + // Select first supported scheme as the default + private string _currentScheme = credentialProvider.SupportedSchemes.FirstOrDefault() ?? + throw new ArgumentException("Authorization provider must support at least one authentication scheme.", nameof(credentialProvider)); + + /// + /// Sends an HTTP request with authentication handling. + /// + internal override async Task SendAsync(HttpRequestMessage request, JsonRpcMessage? message, CancellationToken cancellationToken) + { + if (request.Headers.Authorization == null) + { + await AddAuthorizationHeaderAsync(request, _currentScheme, cancellationToken).ConfigureAwait(false); + } + + var response = await base.SendAsync(request, message, cancellationToken).ConfigureAwait(false); + + if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized) + { + return await HandleUnauthorizedResponseAsync(request, message, response, cancellationToken).ConfigureAwait(false); + } + + return response; + } + + /// + /// Handles a 401 Unauthorized response by attempting to authenticate and retry the request. + /// + private async Task HandleUnauthorizedResponseAsync( + HttpRequestMessage originalRequest, + JsonRpcMessage? originalJsonRpcMessage, + HttpResponseMessage response, + CancellationToken cancellationToken) + { + // Gather the schemes the server wants us to use from WWW-Authenticate headers + var serverSchemes = ExtractServerSupportedSchemes(response); + + if (!serverSchemes.Contains(_currentScheme)) + { + // Find the first server scheme that's in our supported set + var bestSchemeMatch = serverSchemes.Intersect(credentialProvider.SupportedSchemes, StringComparer.OrdinalIgnoreCase).FirstOrDefault(); + + if (bestSchemeMatch is not null) + { + _currentScheme = bestSchemeMatch; + } + else if (serverSchemes.Count > 0) + { + // If no match was found, either throw an exception or use default + throw new McpException( + $"The server does not support any of the provided authentication schemes." + + $"Server supports: [{string.Join(", ", serverSchemes)}], " + + $"Provider supports: [{string.Join(", ", credentialProvider.SupportedSchemes)}]."); + } + } + + // Try to handle the 401 response with the selected scheme + await credentialProvider.HandleUnauthorizedResponseAsync(_currentScheme, response, cancellationToken).ConfigureAwait(false); + + using var retryRequest = new HttpRequestMessage(originalRequest.Method, originalRequest.RequestUri); + + // Copy headers except Authorization which we'll set separately + foreach (var header in originalRequest.Headers) + { + if (!header.Key.Equals("Authorization", StringComparison.OrdinalIgnoreCase)) + { + retryRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + } + + await AddAuthorizationHeaderAsync(retryRequest, _currentScheme, cancellationToken).ConfigureAwait(false); + return await base.SendAsync(retryRequest, originalJsonRpcMessage, cancellationToken).ConfigureAwait(false); + } + + /// + /// Extracts the authentication schemes that the server supports from the WWW-Authenticate headers. + /// + private static HashSet ExtractServerSupportedSchemes(HttpResponseMessage response) + { + var serverSchemes = new HashSet(StringComparer.OrdinalIgnoreCase); + + foreach (var header in response.Headers.WwwAuthenticate) + { + serverSchemes.Add(header.Scheme); + } + + return serverSchemes; + } + + /// + /// Adds an authorization header to the request. + /// + private async Task AddAuthorizationHeaderAsync(HttpRequestMessage request, string scheme, CancellationToken cancellationToken) + { + if (request.RequestUri is null) + { + return; + } + + var token = await credentialProvider.GetCredentialAsync(scheme, request.RequestUri, cancellationToken).ConfigureAwait(false); + if (string.IsNullOrEmpty(token)) + { + return; + } + + request.Headers.Authorization = new AuthenticationHeaderValue(scheme, token); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/AuthorizationRedirectDelegate.cs b/src/ModelContextProtocol.Core/Authentication/AuthorizationRedirectDelegate.cs new file mode 100644 index 00000000..5904130e --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/AuthorizationRedirectDelegate.cs @@ -0,0 +1,28 @@ + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a method that handles the OAuth authorization URL and returns the authorization code. +/// +/// The authorization URL that the user needs to visit. +/// The redirect URI where the authorization code will be sent. +/// The cancellation token. +/// A task that represents the asynchronous operation. The task result contains the authorization code if successful, or null if the operation failed or was cancelled. +/// +/// +/// This delegate provides SDK consumers with full control over how the OAuth authorization flow is handled. +/// Implementers can choose to: +/// +/// +/// Start a local HTTP server and open a browser (default behavior) +/// Display the authorization URL to the user for manual handling +/// Integrate with a custom UI or authentication flow +/// Use a different redirect mechanism altogether +/// +/// +/// The implementation should handle user interaction to visit the authorization URL and extract +/// the authorization code from the callback. The authorization code is typically provided as +/// a query parameter in the redirect URI callback. +/// +/// +public delegate Task AuthorizationRedirectDelegate(Uri authorizationUrl, Uri redirectUri, CancellationToken cancellationToken); \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/AuthorizationServerMetadata.cs b/src/ModelContextProtocol.Core/Authentication/AuthorizationServerMetadata.cs new file mode 100644 index 00000000..e94fce7a --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/AuthorizationServerMetadata.cs @@ -0,0 +1,69 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents the metadata about an OAuth authorization server. +/// +internal sealed class AuthorizationServerMetadata +{ + /// + /// The authorization endpoint URI. + /// + [JsonPropertyName("authorization_endpoint")] + public Uri AuthorizationEndpoint { get; set; } = null!; + + /// + /// The token endpoint URI. + /// + [JsonPropertyName("token_endpoint")] + public Uri TokenEndpoint { get; set; } = null!; + + /// + /// The registration endpoint URI. + /// + [JsonPropertyName("registration_endpoint")] + public Uri? RegistrationEndpoint { get; set; } + + /// + /// The revocation endpoint URI. + /// + [JsonPropertyName("revocation_endpoint")] + public Uri? RevocationEndpoint { get; set; } + + /// + /// The response types supported by the authorization server. + /// + [JsonPropertyName("response_types_supported")] + public List? ResponseTypesSupported { get; set; } + + /// + /// The grant types supported by the authorization server. + /// + [JsonPropertyName("grant_types_supported")] + public List? GrantTypesSupported { get; set; } + + /// + /// The token endpoint authentication methods supported by the authorization server. + /// + [JsonPropertyName("token_endpoint_auth_methods_supported")] + public List? TokenEndpointAuthMethodsSupported { get; set; } + + /// + /// The code challenge methods supported by the authorization server. + /// + [JsonPropertyName("code_challenge_methods_supported")] + public List? CodeChallengeMethodsSupported { get; set; } + + /// + /// The issuer URI of the authorization server. + /// + [JsonPropertyName("issuer")] + public Uri? Issuer { get; set; } + + /// + /// The scopes supported by the authorization server. + /// + [JsonPropertyName("scopes_supported")] + public List? ScopesSupported { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Authentication/GenericOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/GenericOAuthProvider.cs new file mode 100644 index 00000000..b523930a --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/GenericOAuthProvider.cs @@ -0,0 +1,557 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using System.Diagnostics.CodeAnalysis; +using System.Net.Http.Headers; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using System.Web; + +namespace ModelContextProtocol.Authentication; + +/// +/// A generic implementation of an OAuth authorization provider for MCP. This does not do any advanced token +/// protection or caching - it acquires a token and server metadata and holds it in memory. +/// This is suitable for demonstration and development purposes. +/// +public sealed class GenericOAuthProvider : IMcpCredentialProvider +{ + /// + /// The Bearer authentication scheme. + /// + private const string BearerScheme = "Bearer"; + + private readonly Uri _serverUrl; + private readonly Uri _redirectUri; + private readonly List _additionalScopes; + private readonly string _clientId; + private readonly string? _clientSecret; + private readonly HttpClient _httpClient; + private readonly ILogger _logger; + private readonly Func, Uri?> _authServerSelector; + private readonly AuthorizationRedirectDelegate _authorizationRedirectDelegate; + + private TokenContainer? _token; + private AuthorizationServerMetadata? _authServerMetadata; + + /// + /// Initializes a new instance of the class with explicit authorization server selection. + /// + /// The MCP server URL. + /// The HTTP client to use for OAuth requests. If null, a default HttpClient will be used. + /// OAuth client ID. + /// OAuth client secret. + /// OAuth redirect URI. + /// Custom handler for processing the OAuth authorization URL. If null, uses the default HTTP listener approach. + /// Additional OAuth scopes to request beyond those specified in the scopes_supported specified in the .well-known/oauth-protected-resource response. + /// A logger factory to handle diagnostic messages. + /// Function to select which authorization server to use from available servers. If null, uses default selection strategy. + /// Thrown when serverUrl is null. + public GenericOAuthProvider( + Uri serverUrl, + HttpClient? httpClient, + string clientId, + Uri redirectUri, + AuthorizationRedirectDelegate? authorizationRedirectDelegate = null, + string? clientSecret = null, + IEnumerable? additionalScopes = null, + Func, Uri?>? authServerSelector = null, + ILoggerFactory? loggerFactory = null) + { + _serverUrl = serverUrl ?? throw new ArgumentNullException(nameof(serverUrl)); + _httpClient = httpClient ?? new HttpClient(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + _redirectUri = redirectUri; + _additionalScopes = additionalScopes?.ToList() ?? []; + _clientId = clientId; + _clientSecret = clientSecret; + + // Set up authorization server selection strategy + _authServerSelector = authServerSelector ?? DefaultAuthServerSelector; + + // Set up authorization URL handler (use default if not provided) + _authorizationRedirectDelegate = authorizationRedirectDelegate ?? DefaultAuthorizationUrlHandler; + } + + /// + /// Default authorization server selection strategy that selects the first available server. + /// + /// List of available authorization servers. + /// The selected authorization server, or null if none are available. + private static Uri? DefaultAuthServerSelector(IReadOnlyList availableServers) => availableServers.FirstOrDefault(); + + /// + /// Default authorization URL handler that displays the URL to the user for manual input. + /// + /// The authorization URL to handle. + /// The redirect URI where the authorization code will be sent. + /// The cancellation token. + /// The authorization code entered by the user, or null if none was provided. + private static Task DefaultAuthorizationUrlHandler(Uri authorizationUrl, Uri redirectUri, CancellationToken cancellationToken) + { + Console.WriteLine($"Please open the following URL in your browser to authorize the application:"); + Console.WriteLine($"{authorizationUrl}"); + Console.WriteLine(); + Console.Write("Enter the authorization code from the redirect URL: "); + var authorizationCode = Console.ReadLine(); + return Task.FromResult(authorizationCode); + } + + /// + public IEnumerable SupportedSchemes => [BearerScheme]; + + /// + public async Task GetCredentialAsync(string scheme, Uri resourceUri, CancellationToken cancellationToken = default) + { + ThrowIfNotBearerScheme(scheme); + + // REVIEW: Should we be doing anything with the resourceUri? If not, why is it part of the IMcpCredentialProvider interface? + + // Return the token if it's valid + if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + { + return _token.AccessToken; + } + + // Try to refresh the token if we have a refresh token + if (_token?.RefreshToken != null && _authServerMetadata != null) + { + var newToken = await RefreshTokenAsync(_token.RefreshToken, _authServerMetadata, cancellationToken).ConfigureAwait(false); + if (newToken != null) + { + _token = newToken; + return _token.AccessToken; + } + } + + // No valid token - auth handler will trigger the 401 flow + return null; + } + + /// + public async Task HandleUnauthorizedResponseAsync( + string scheme, + HttpResponseMessage response, + CancellationToken cancellationToken = default) + { + // This provider only supports Bearer scheme + if (!string.Equals(scheme, BearerScheme, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException("This credential provider only supports the Bearer scheme"); + } + + await PerformOAuthAuthorizationAsync(response, cancellationToken).ConfigureAwait(false); + } + + /// + /// Performs OAuth authorization by selecting an appropriate authorization server and completing the OAuth flow. + /// + /// The 401 Unauthorized response containing authentication challenge. + /// Cancellation token. + /// Result indicating whether authorization was successful. + private async Task PerformOAuthAuthorizationAsync( + HttpResponseMessage response, + CancellationToken cancellationToken) + { + // Get available authorization servers from the 401 response + var protectedResourceMetadata = await ExtractProtectedResourceMetadata(response, _serverUrl, cancellationToken).ConfigureAwait(false); + var availableAuthorizationServers = protectedResourceMetadata.AuthorizationServers; + + if (availableAuthorizationServers.Count == 0) + { + ThrowFailedToHandleUnauthorizedResponse("No authorization servers found in authentication challenge"); + } + + // Select authorization server using configured strategy + var selectedAuthServer = _authServerSelector(availableAuthorizationServers); + + if (selectedAuthServer is null) + { + ThrowFailedToHandleUnauthorizedResponse($"Authorization server selection returned null. Available servers: {string.Join(", ", availableAuthorizationServers)}"); + } + + if (!availableAuthorizationServers.Contains(selectedAuthServer)) + { + ThrowFailedToHandleUnauthorizedResponse($"Authorization server selector returned a server not in the available list: {selectedAuthServer}. Available servers: {string.Join(", ", availableAuthorizationServers)}"); + } + + _logger.LogInformation("Selected authorization server: {Server} from {Count} available servers", selectedAuthServer, availableAuthorizationServers.Count); + + // Get auth server metadata + var authServerMetadata = await GetAuthServerMetadataAsync(selectedAuthServer, cancellationToken).ConfigureAwait(false); + + if (authServerMetadata is null) + { + ThrowFailedToHandleUnauthorizedResponse($"Failed to retrieve metadata for authorization server: '{selectedAuthServer}'"); + } + + // Store auth server metadata for future refresh operations + _authServerMetadata = authServerMetadata; + + // Perform the OAuth flow + var token = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); + + if (token is null) + { + ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); + } + + _token = token; + _logger.LogInformation("OAuth authorization completed successfully"); + } + + private async Task GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken) + { + if (!authServerUri.OriginalString.EndsWith("/")) + { + authServerUri = new Uri(authServerUri.OriginalString + "/"); + } + + foreach (var path in new[] { ".well-known/openid-configuration", ".well-known/oauth-authorization-server" }) + { + try + { + var response = await _httpClient.GetAsync(new Uri(authServerUri, path), cancellationToken).ConfigureAwait(false); + if (!response.IsSuccessStatusCode) + { + continue; + } + + using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var metadata = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.AuthorizationServerMetadata, cancellationToken).ConfigureAwait(false); + + if (metadata != null) + { + metadata.ResponseTypesSupported ??= ["code"]; + metadata.GrantTypesSupported ??= ["authorization_code", "refresh_token"]; + metadata.TokenEndpointAuthMethodsSupported ??= ["client_secret_basic"]; + metadata.CodeChallengeMethodsSupported ??= ["S256"]; + + return metadata; + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error fetching auth server metadata from {Path}", path); + } + } + + return null; + } + + private async Task RefreshTokenAsync(string refreshToken, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) + { + var requestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "refresh_token", + ["refresh_token"] = refreshToken, + ["client_id"] = _clientId + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint) + { + Content = requestContent + }; + + if (!string.IsNullOrEmpty(_clientSecret)) + { + var authValue = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{_clientId}:{_clientSecret}")); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", authValue); + } + + return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); + } + + private async Task InitiateAuthorizationCodeFlowAsync( + ProtectedResourceMetadata protectedResourceMetadata, + AuthorizationServerMetadata authServerMetadata, + CancellationToken cancellationToken) + { + var codeVerifier = GenerateCodeVerifier(); + var codeChallenge = GenerateCodeChallenge(codeVerifier); + + var authUrl = BuildAuthorizationUrl(protectedResourceMetadata, authServerMetadata, codeChallenge); + var authCode = await _authorizationRedirectDelegate(authUrl, _redirectUri, cancellationToken).ConfigureAwait(false); + + if (string.IsNullOrEmpty(authCode)) + { + return null; + } + + return await ExchangeCodeForTokenAsync(authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false); + } + + private Uri BuildAuthorizationUrl( + ProtectedResourceMetadata protectedResourceMetadata, + AuthorizationServerMetadata authServerMetadata, + string codeChallenge) + { + if (authServerMetadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttp && + authServerMetadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttps) + { + throw new ArgumentException("AuthorizationEndpoint must use HTTP or HTTPS.", nameof(authServerMetadata)); + } + + var queryParams = HttpUtility.ParseQueryString(string.Empty); + queryParams["client_id"] = _clientId; + queryParams["redirect_uri"] = _redirectUri.ToString(); + queryParams["response_type"] = "code"; + queryParams["code_challenge"] = codeChallenge; + queryParams["code_challenge_method"] = "S256"; + + var scopesSupported = protectedResourceMetadata.ScopesSupported; + if (_additionalScopes.Count > 0 || scopesSupported.Count > 0) + { + queryParams["scope"] = string.Join(" ", [.._additionalScopes, ..scopesSupported]); + } + + var uriBuilder = new UriBuilder(authServerMetadata.AuthorizationEndpoint) + { + Query = queryParams.ToString() + }; + + return uriBuilder.Uri; + } + + private async Task ExchangeCodeForTokenAsync( + AuthorizationServerMetadata authServerMetadata, + string authorizationCode, + string codeVerifier, + CancellationToken cancellationToken) + { + var requestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "authorization_code", + ["code"] = authorizationCode, + ["redirect_uri"] = _redirectUri.ToString(), + ["client_id"] = _clientId, + ["code_verifier"] = codeVerifier + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint) + { + Content = requestContent + }; + + if (!string.IsNullOrEmpty(_clientSecret)) + { + var authValue = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{_clientId}:{_clientSecret}")); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", authValue); + } + + return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); + } + + private async Task FetchTokenAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); + httpResponse.EnsureSuccessStatusCode(); + + using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenContainer, cancellationToken).ConfigureAwait(false); + + if (tokenResponse is null) + { + ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response."); + } + + tokenResponse.ObtainedAt = DateTimeOffset.UtcNow; + return tokenResponse; + } + + /// + /// Fetches the protected resource metadata from the provided URL. + /// + /// The URL to fetch the metadata from. + /// A token to cancel the operation. + /// The fetched ProtectedResourceMetadata, or null if it couldn't be fetched. + private async Task FetchProtectedResourceMetadataAsync(Uri metadataUrl, CancellationToken cancellationToken = default) + { + using var httpResponse = await _httpClient.GetAsync(metadataUrl, cancellationToken).ConfigureAwait(false); + httpResponse.EnsureSuccessStatusCode(); + + using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + return await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.ProtectedResourceMetadata, cancellationToken).ConfigureAwait(false); + } + + /// + /// Verifies that the resource URI in the metadata exactly matches the original request URL as required by the RFC. + /// Per RFC: The resource value must be identical to the URL that the client used to make the request to the resource server. + /// + /// The metadata to verify. + /// The original URL the client used to make the request to the resource server. + /// True if the resource URI exactly matches the original request URL, otherwise false. + private static bool VerifyResourceMatch(ProtectedResourceMetadata protectedResourceMetadata, Uri resourceLocation) + { + if (protectedResourceMetadata.Resource == null || resourceLocation == null) + { + return false; + } + + // Per RFC: The resource value must be identical to the URL that the client used + // to make the request to the resource server. Compare entire URIs, not just the host. + + // Normalize the URIs to ensure consistent comparison + string normalizedMetadataResource = NormalizeUri(protectedResourceMetadata.Resource); + string normalizedResourceLocation = NormalizeUri(resourceLocation); + + return string.Equals(normalizedMetadataResource, normalizedResourceLocation, StringComparison.OrdinalIgnoreCase); + } + + /// + /// Normalizes a URI for consistent comparison. + /// + /// The URI to normalize. + /// A normalized string representation of the URI. + private static string NormalizeUri(Uri uri) + { + var builder = new UriBuilder(uri) + { + Port = -1 // Always remove port + }; + + if (builder.Path == "/") + { + builder.Path = string.Empty; + } + else if (builder.Path.Length > 1 && builder.Path.EndsWith("/")) + { + builder.Path = builder.Path.TrimEnd('/'); + } + + return builder.Uri.ToString(); + } + + /// + /// Responds to a 401 challenge by parsing the WWW-Authenticate header, fetching the resource metadata, + /// verifying the resource match, and returning the metadata if valid. + /// + /// The HTTP response containing the WWW-Authenticate header. + /// The server URL to verify against the resource metadata. + /// A token to cancel the operation. + /// The resource metadata if the resource matches the server, otherwise throws an exception. + /// Thrown when the response is not a 401, lacks a WWW-Authenticate header, + /// lacks a resource_metadata parameter, the metadata can't be fetched, or the resource URI doesn't match the server URL. + private async Task ExtractProtectedResourceMetadata(HttpResponseMessage response, Uri serverUrl, CancellationToken cancellationToken = default) + { + if (response.StatusCode != System.Net.HttpStatusCode.Unauthorized) + { + throw new InvalidOperationException($"Expected a 401 Unauthorized response, but received {(int)response.StatusCode} {response.StatusCode}"); + } + + // Extract the WWW-Authenticate header + if (response.Headers.WwwAuthenticate.Count == 0) + { + throw new McpException("The 401 response does not contain a WWW-Authenticate header"); + } + + // Look for the Bearer authentication scheme with resource_metadata parameter + string? resourceMetadataUrl = null; + foreach (var header in response.Headers.WwwAuthenticate) + { + if (string.Equals(header.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter)) + { + resourceMetadataUrl = ParseWwwAuthenticateParameters(header.Parameter, "resource_metadata"); + if (resourceMetadataUrl != null) + { + break; + } + } + } + + if (resourceMetadataUrl == null) + { + throw new McpException("The WWW-Authenticate header does not contain a resource_metadata parameter"); + } + + Uri metadataUri = new(resourceMetadataUrl); + var metadata = await FetchProtectedResourceMetadataAsync(metadataUri, cancellationToken).ConfigureAwait(false) + ?? throw new McpException($"Failed to fetch resource metadata from {resourceMetadataUrl}"); + + // Per RFC: The resource value must be identical to the URL that the client used + // to make the request to the resource server + _logger.LogDebug($"Validating resource metadata against original server URL: {serverUrl}"); + + if (!VerifyResourceMatch(metadata, serverUrl)) + { + throw new McpException($"Resource URI in metadata ({metadata.Resource}) does not match the expected URI ({serverUrl})"); + } + + return metadata; + } + + /// + /// Parses the WWW-Authenticate header parameters to extract a specific parameter. + /// + /// The parameter string from the WWW-Authenticate header. + /// The name of the parameter to extract. + /// The value of the parameter, or null if not found. + private static string? ParseWwwAuthenticateParameters(string parameters, string parameterName) + { + if (parameters.IndexOf(parameterName, StringComparison.OrdinalIgnoreCase) == -1) + { + return null; + } + + foreach (var part in parameters.Split(',')) + { + string trimmedPart = part.Trim(); + int equalsIndex = trimmedPart.IndexOf('='); + + if (equalsIndex <= 0) + { + continue; + } + + string key = trimmedPart.Substring(0, equalsIndex).Trim(); + + if (string.Equals(key, parameterName, StringComparison.OrdinalIgnoreCase)) + { + string value = trimmedPart.Substring(equalsIndex + 1).Trim(); + + if (value.StartsWith("\"") && value.EndsWith("\"")) + { + value = value.Substring(1, value.Length - 2); + } + + return value; + } + } + + return null; + } + + private static string GenerateCodeVerifier() + { + var bytes = new byte[32]; + using var rng = RandomNumberGenerator.Create(); + rng.GetBytes(bytes); + return Convert.ToBase64String(bytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + } + + private static string GenerateCodeChallenge(string codeVerifier) + { + using var sha256 = SHA256.Create(); + var challengeBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier)); + return Convert.ToBase64String(challengeBytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + } + + private static void ThrowIfNotBearerScheme(string scheme) + { + if (!string.Equals(scheme, BearerScheme, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException($"The '{scheme}' is not supported. This credential provider only supports the '{BearerScheme}' scheme"); + } + } + + [DoesNotReturn] + private static void ThrowFailedToHandleUnauthorizedResponse(string message) => + throw new McpException($"Failed to handle unauthorized response with 'Bearer' scheme. {message}"); +} diff --git a/src/ModelContextProtocol.Core/Authentication/IMcpCredentialProvider.cs b/src/ModelContextProtocol.Core/Authentication/IMcpCredentialProvider.cs new file mode 100644 index 00000000..7c3aa722 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/IMcpCredentialProvider.cs @@ -0,0 +1,45 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Defines an interface for providing authentication for requests. +/// This is the main extensibility point for authentication in MCP clients. +/// +public interface IMcpCredentialProvider +{ + /// + /// Gets the collection of authentication schemes supported by this provider. + /// + /// + /// + /// This property returns all authentication schemes that this provider can handle, + /// allowing clients to select the appropriate scheme based on server capabilities. + /// + /// + /// Common values include "Bearer" for JWT tokens, "Basic" for username/password authentication, + /// and "Negotiate" for integrated Windows authentication. + /// + /// + IEnumerable SupportedSchemes { get; } + + /// + /// Gets an authentication token or credential for authenticating requests to a resource + /// using the specified authentication scheme. + /// + /// The authentication scheme to use. + /// The URI of the resource requiring authentication. + /// A token to cancel the operation. + /// An authentication token string or null if no token could be obtained for the specified scheme. + Task GetCredentialAsync(string scheme, Uri resourceUri, CancellationToken cancellationToken = default); + + /// + /// Handles a 401 Unauthorized response from a resource. + /// + /// The authentication scheme that was used when the unauthorized response was received. + /// The HTTP response that contained the 401 status code. + /// A token to cancel the operation. + /// + /// A result object indicating if the provider was able to handle the unauthorized response, + /// and the authentication scheme that should be used for the next attempt, if any. + /// + Task HandleUnauthorizedResponseAsync(string scheme, HttpResponseMessage response, CancellationToken cancellationToken = default); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ProtectedResourceMetadata.cs b/src/ModelContextProtocol.Core/Authentication/ProtectedResourceMetadata.cs new file mode 100644 index 00000000..061b729d --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ProtectedResourceMetadata.cs @@ -0,0 +1,145 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents the resource metadata for OAuth authorization as defined in RFC 9396. +/// Defined by RFC 9728. +/// +public sealed class ProtectedResourceMetadata +{ + /// + /// The resource URI. + /// + /// + /// REQUIRED. The protected resource's resource identifier. + /// + [JsonPropertyName("resource")] + public required Uri Resource { get; init; } + + /// + /// The list of authorization server URIs. + /// + /// + /// OPTIONAL. JSON array containing a list of OAuth authorization server issuer identifiers + /// for authorization servers that can be used with this protected resource. + /// + [JsonPropertyName("authorization_servers")] + public List AuthorizationServers { get; set; } = []; + + /// + /// The supported bearer token methods. + /// + /// + /// OPTIONAL. JSON array containing a list of the supported methods of sending an OAuth 2.0 bearer token + /// to the protected resource. Defined values are ["header", "body", "query"]. + /// + [JsonPropertyName("bearer_methods_supported")] + public List BearerMethodsSupported { get; set; } = []; + + /// + /// The supported scopes. + /// + /// + /// RECOMMENDED. JSON array containing a list of scope values that are used in authorization + /// requests to request access to this protected resource. + /// + [JsonPropertyName("scopes_supported")] + public List ScopesSupported { get; set; } = []; + + /// + /// URL of the protected resource's JSON Web Key (JWK) Set document. + /// + /// + /// OPTIONAL. This contains public keys belonging to the protected resource, such as signing key(s) + /// that the resource server uses to sign resource responses. This URL MUST use the https scheme. + /// + [JsonPropertyName("jwks_uri")] + public Uri? JwksUri { get; set; } + + /// + /// List of the JWS signing algorithms supported by the protected resource for signing resource responses. + /// + /// + /// OPTIONAL. JSON array containing a list of the JWS signing algorithms (alg values) supported by the protected resource + /// for signing resource responses. No default algorithms are implied if this entry is omitted. The value none MUST NOT be used. + /// + [JsonPropertyName("resource_signing_alg_values_supported")] + public List? ResourceSigningAlgValuesSupported { get; set; } + + /// + /// Human-readable name of the protected resource intended for display to the end user. + /// + /// + /// RECOMMENDED. It is recommended that protected resource metadata include this field. + /// The value of this field MAY be internationalized. + /// + [JsonPropertyName("resource_name")] + public string? ResourceName { get; set; } + + /// + /// The URI to the resource documentation. + /// + /// + /// OPTIONAL. URL of a page containing human-readable information that developers might want or need to know + /// when using the protected resource. + /// + [JsonPropertyName("resource_documentation")] + public Uri? ResourceDocumentation { get; set; } + + /// + /// URL of a page containing human-readable information about the protected resource's requirements. + /// + /// + /// OPTIONAL. Information about how the client can use the data provided by the protected resource. + /// + [JsonPropertyName("resource_policy_uri")] + public Uri? ResourcePolicyUri { get; set; } + + /// + /// URL of a page containing human-readable information about the protected resource's terms of service. + /// + /// + /// OPTIONAL. The value of this field MAY be internationalized. + /// + [JsonPropertyName("resource_tos_uri")] + public Uri? ResourceTosUri { get; set; } + + /// + /// Boolean value indicating protected resource support for mutual-TLS client certificate-bound access tokens. + /// + /// + /// OPTIONAL. If omitted, the default value is false. + /// + [JsonPropertyName("tls_client_certificate_bound_access_tokens")] + public bool? TlsClientCertificateBoundAccessTokens { get; set; } + + /// + /// List of the authorization details type values supported by the resource server. + /// + /// + /// OPTIONAL. JSON array containing a list of the authorization details type values supported by the resource server + /// when the authorization_details request parameter is used. + /// + [JsonPropertyName("authorization_details_types_supported")] + public List? AuthorizationDetailsTypesSupported { get; set; } + + /// + /// List of the JWS algorithm values supported by the resource server for validating DPoP proof JWTs. + /// + /// + /// OPTIONAL. JSON array containing a list of the JWS alg values supported by the resource server + /// for validating Demonstrating Proof of Possession (DPoP) proof JWTs. + /// + [JsonPropertyName("dpop_signing_alg_values_supported")] + public List? DpopSigningAlgValuesSupported { get; set; } + + /// + /// Boolean value specifying whether the protected resource always requires the use of DPoP-bound access tokens. + /// + /// + /// OPTIONAL. If omitted, the default value is false. + /// + [JsonPropertyName("dpop_bound_access_tokens_required")] + public bool? DpopBoundAccessTokensRequired { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs new file mode 100644 index 00000000..dc55292b --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -0,0 +1,57 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a token response from the OAuth server. +/// +internal sealed class TokenContainer +{ + /// + /// Gets or sets the access token. + /// + [JsonPropertyName("access_token")] + public string AccessToken { get; set; } = string.Empty; + + /// + /// Gets or sets the refresh token. + /// + [JsonPropertyName("refresh_token")] + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the number of seconds until the access token expires. + /// + [JsonPropertyName("expires_in")] + public int ExpiresIn { get; set; } + + /// + /// Gets or sets the extended expiration time in seconds. + /// + [JsonPropertyName("ext_expires_in")] + public int ExtExpiresIn { get; set; } + + /// + /// Gets or sets the token type (typically "Bearer"). + /// + [JsonPropertyName("token_type")] + public string TokenType { get; set; } = string.Empty; + + /// + /// Gets or sets the scope of the access token. + /// + [JsonPropertyName("scope")] + public string Scope { get; set; } = string.Empty; + + /// + /// Gets or sets the timestamp when the token was obtained. + /// + [JsonIgnore] + public DateTimeOffset ObtainedAt { get; set; } + + /// + /// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn. + /// + [JsonIgnore] + public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn); +} diff --git a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs index 39ae7e81..06f2e0bf 100644 --- a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs @@ -13,13 +13,13 @@ namespace ModelContextProtocol.Client; internal sealed partial class AutoDetectingClientSessionTransport : ITransport { private readonly SseClientTransportOptions _options; - private readonly HttpClient _httpClient; + private readonly McpHttpClient _httpClient; private readonly ILoggerFactory? _loggerFactory; private readonly ILogger _logger; private readonly string _name; private readonly Channel _messageChannel; - public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) + public AutoDetectingClientSessionTransport(string endpointName, SseClientTransportOptions transportOptions, McpHttpClient httpClient, ILoggerFactory? loggerFactory) { Throw.IfNull(transportOptions); Throw.IfNull(httpClient); diff --git a/src/ModelContextProtocol.Core/Client/McpHttpClient.cs b/src/ModelContextProtocol.Core/Client/McpHttpClient.cs new file mode 100644 index 00000000..f5e1b596 --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpHttpClient.cs @@ -0,0 +1,42 @@ +using ModelContextProtocol.Protocol; +using System.Diagnostics; + +#if NET +using System.Net.Http.Json; +#else +using System.Text; +using System.Text.Json; +#endif + +namespace ModelContextProtocol.Client; + +internal class McpHttpClient(HttpClient httpClient) +{ + internal virtual async Task SendAsync(HttpRequestMessage request, JsonRpcMessage? message, CancellationToken cancellationToken) + { + Debug.Assert(request.Content is null, "The request body should only be supplied as a JsonRpcMessage"); + Debug.Assert(message is null || request.Method == HttpMethod.Post, "All messages should be sent in POST requests."); + + using var content = CreatePostBodyContent(message); + request.Content = content; + return await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + } + + private HttpContent? CreatePostBodyContent(JsonRpcMessage? message) + { + if (message is null) + { + return null; + } + +#if NET + return JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); +#else + return new StringContent( + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), + Encoding.UTF8, + "application/json; charset=utf-8" + ); +#endif + } +} diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index 93559b7d..aba7bbcf 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -4,7 +4,6 @@ using System.Diagnostics; using System.Net.Http.Headers; using System.Net.ServerSentEvents; -using System.Text; using System.Text.Json; using System.Threading.Channels; @@ -15,7 +14,7 @@ namespace ModelContextProtocol.Client; /// internal sealed partial class SseClientSessionTransport : TransportBase { - private readonly HttpClient _httpClient; + private readonly McpHttpClient _httpClient; private readonly SseClientTransportOptions _options; private readonly Uri _sseEndpoint; private Uri? _messageEndpoint; @@ -31,7 +30,7 @@ internal sealed partial class SseClientSessionTransport : TransportBase public SseClientSessionTransport( string endpointName, SseClientTransportOptions transportOptions, - HttpClient httpClient, + McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) : base(endpointName, messageChannel, loggerFactory) @@ -74,12 +73,6 @@ public override async Task SendMessageAsync( if (_messageEndpoint == null) throw new InvalidOperationException("Transport not connected"); - using var content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json" - ); - string messageId = "(no id)"; if (message is JsonRpcMessageWithId messageWithId) @@ -87,12 +80,9 @@ public override async Task SendMessageAsync( messageId = messageWithId.Id.ToString(); } - using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) - { - Content = content, - }; + using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint); StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); - var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { @@ -154,11 +144,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); - using var response = await _httpClient.SendAsync( - request, - HttpCompletionOption.ResponseHeadersRead, - cancellationToken - ).ConfigureAwait(false); + using var response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false); response.EnsureSuccessStatusCode(); diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientTransport.cs index 3fba349b..5253c546 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientTransport.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.Logging; +using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Client; @@ -15,9 +16,9 @@ namespace ModelContextProtocol.Client; public sealed class SseClientTransport : IClientTransport, IAsyncDisposable { private readonly SseClientTransportOptions _options; - private readonly HttpClient _httpClient; + private readonly HttpClient? _httpClient; + private readonly McpHttpClient _mcpHttpClient; private readonly ILoggerFactory? _loggerFactory; - private readonly bool _ownsHttpClient; /// /// Initializes a new instance of the class. @@ -45,10 +46,22 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient Throw.IfNull(httpClient); _options = transportOptions; - _httpClient = httpClient; _loggerFactory = loggerFactory; - _ownsHttpClient = ownsHttpClient; Name = transportOptions.Name ?? transportOptions.Endpoint.ToString(); + + if (transportOptions.CredentialProvider is { } credentialProvider) + { + _mcpHttpClient = new AuthenticatingMcpHttpClient(httpClient, credentialProvider); + } + else + { + _mcpHttpClient = new(httpClient); + } + + if (ownsHttpClient) + { + _httpClient = httpClient; + } } /// @@ -59,8 +72,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken = { return _options.TransportMode switch { - HttpTransportMode.AutoDetect => new AutoDetectingClientSessionTransport(_options, _httpClient, _loggerFactory, Name), - HttpTransportMode.StreamableHttp => new StreamableHttpClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory), + HttpTransportMode.AutoDetect => new AutoDetectingClientSessionTransport(Name, _options, _mcpHttpClient, _loggerFactory), + HttpTransportMode.StreamableHttp => new StreamableHttpClientSessionTransport(Name, _options, _mcpHttpClient, messageChannel: null, _loggerFactory), HttpTransportMode.Sse => await ConnectSseTransportAsync(cancellationToken).ConfigureAwait(false), _ => throw new InvalidOperationException($"Unsupported transport mode: {_options.TransportMode}"), }; @@ -68,7 +81,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = private async Task ConnectSseTransportAsync(CancellationToken cancellationToken) { - var sessionTransport = new SseClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory); + var sessionTransport = new SseClientSessionTransport(Name, _options, _mcpHttpClient, messageChannel: null, _loggerFactory); try { @@ -85,11 +98,7 @@ private async Task ConnectSseTransportAsync(CancellationToken cancel /// public ValueTask DisposeAsync() { - if (_ownsHttpClient) - { - _httpClient.Dispose(); - } - + _httpClient?.Dispose(); return default; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs index cd522c42..ba723ecb 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs @@ -1,3 +1,5 @@ +using ModelContextProtocol.Authentication; + namespace ModelContextProtocol.Client; /// @@ -46,7 +48,7 @@ public required Uri Endpoint public HttpTransportMode TransportMode { get; set; } = HttpTransportMode.AutoDetect; /// - /// Gets a transport identifier used for logging purposes. + /// Gets or sets a transport identifier used for logging purposes. /// public string? Name { get; set; } @@ -70,4 +72,9 @@ public required Uri Endpoint /// Use this property to specify custom HTTP headers that should be sent with each request to the server. /// public IDictionary? AdditionalHeaders { get; set; } + + /// + /// Gets sor sets the authorization provider to use for authentication. + /// + public IMcpCredentialProvider? CredentialProvider { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 14df5c35..20ebe453 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -6,12 +6,6 @@ using ModelContextProtocol.Protocol; using System.Threading.Channels; -#if NET -using System.Net.Http.Json; -#else -using System.Text; -#endif - namespace ModelContextProtocol.Client; /// @@ -22,7 +16,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private static readonly MediaTypeWithQualityHeaderValue s_applicationJsonMediaType = new("application/json"); private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream"); - private readonly HttpClient _httpClient; + private readonly McpHttpClient _httpClient; private readonly SseClientTransportOptions _options; private readonly CancellationTokenSource _connectionCts; private readonly ILogger _logger; @@ -36,7 +30,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa public StreamableHttpClientSessionTransport( string endpointName, SseClientTransportOptions transportOptions, - HttpClient httpClient, + McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) : base(endpointName, messageChannel, loggerFactory) @@ -69,19 +63,8 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token); cancellationToken = sendCts.Token; -#if NET - using var content = JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); -#else - using var content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json; charset=utf-8" - ); -#endif - using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _options.Endpoint) { - Content = content, Headers = { Accept = { s_applicationJsonMediaType, s_textEventStreamMediaType }, @@ -90,7 +73,7 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); - var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false); // We'll let the caller decide whether to throw or fall back given an unsuccessful response. if (!response.IsSuccessStatusCode) @@ -192,7 +175,7 @@ private async Task ReceiveUnsolicitedMessagesAsync() request.Headers.Accept.Add(s_textEventStreamMediaType); CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); - using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false); + using var response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { @@ -261,7 +244,7 @@ private async Task SendDeleteRequest() try { // Do not validate we get a successful status code, because server support for the DELETE request is optional - (await _httpClient.SendAsync(deleteRequest, CancellationToken.None).ConfigureAwait(false)).Dispose(); + (await _httpClient.SendAsync(deleteRequest, message: null, CancellationToken.None).ConfigureAwait(false)).Dispose(); } catch (Exception ex) { diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index 696e0ec0..9ae59b20 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.AI; +using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; using System.Diagnostics.CodeAnalysis; using System.Text.Json; @@ -154,6 +155,10 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(IReadOnlyDictionary))] [JsonSerializable(typeof(ProgressToken))] + [JsonSerializable(typeof(ProtectedResourceMetadata))] + [JsonSerializable(typeof(AuthorizationServerMetadata))] + [JsonSerializable(typeof(TokenContainer))] + // Primitive types for use in consuming AIFunctions [JsonSerializable(typeof(string))] [JsonSerializable(typeof(byte))] diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index cc6b4e0a..3ff50430 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -35,7 +35,7 @@ public void Constructor_Throws_For_Null_Options() [Fact] public void Constructor_Throws_For_Null_HttpClient() { - var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, null!, LoggerFactory)); + var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, httpClient: null!, LoggerFactory)); Assert.Equal("httpClient", exception.ParamName); }