Skip to content

Commit a740aa4

Browse files
feat: propagating headers on requests (#214)
## Problem Currently we are not propagating / accepting headers for StreamableHttp transport ## Solution Implementation of header optional param ## Rationale Will be handled only for StreamableHTTP for now.
1 parent a4d9ae9 commit a740aa4

File tree

4 files changed

+136
-24
lines changed

4 files changed

+136
-24
lines changed

lib/hermes/client/base.ex

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ defmodule Hermes.Client.Base do
1515
alias Hermes.MCP.Response
1616
alias Hermes.Protocol
1717
alias Hermes.Telemetry
18+
alias Hermes.Transport.StreamableHTTP
1819

1920
require Message
2021

@@ -81,7 +82,7 @@ defmodule Hermes.Client.Base do
8182
Hermes.Transport.STDIO
8283
| Hermes.Transport.SSE
8384
| Hermes.Transport.WebSocket
84-
| Hermes.Transport.StreamableHTTP}
85+
| StreamableHTTP}
8586
| {:name, GenServer.server()}
8687
)
8788

@@ -325,6 +326,7 @@ defmodule Hermes.Client.Base do
325326
* `:progress` - Progress tracking options
326327
* `:token` - A unique token to track progress (string or integer)
327328
* `:callback` - A function to call when progress updates are received
329+
* `:headers` - HTTP headers to send with this request (map, only for StreamableHTTP transport)
328330
"""
329331
@spec call_tool(t, String.t(), map() | nil, keyword) ::
330332
{:ok, Response.t()} | {:error, Error.t()}
@@ -337,7 +339,8 @@ defmodule Hermes.Client.Base do
337339
method: "tools/call",
338340
params: params,
339341
progress_opts: Keyword.get(opts, :progress),
340-
timeout: Keyword.get(opts, :timeout)
342+
timeout: Keyword.get(opts, :timeout),
343+
headers: Keyword.get(opts, :headers)
341344
})
342345

343346
buffer_timeout = operation.timeout + to_timeout(second: 1)
@@ -429,7 +432,7 @@ defmodule Hermes.Client.Base do
429432
ref = %{"type" => "ref/prompt", "name" => "code_review"}
430433
argument = %{"name" => "language", "value" => "py"}
431434
{:ok, response} = Hermes.Client.complete(client, ref, argument)
432-
435+
433436
# Access the completion values
434437
values = get_in(Response.unwrap(response), ["completion", "values"])
435438
"""
@@ -706,14 +709,14 @@ defmodule Hermes.Client.Base do
706709
707710
MyClient.register_sampling_callback(fn params ->
708711
messages = params["messages"]
709-
712+
710713
# Show UI for user approval
711714
case MyUI.approve_sampling(messages) do
712715
{:approved, edited_messages} ->
713716
# Call LLM with approved/edited messages
714717
response = MyLLM.generate(edited_messages, params["modelPreferences"])
715718
{:ok, response}
716-
719+
717720
:rejected ->
718721
{:error, "User rejected sampling request"}
719722
end
@@ -799,7 +802,7 @@ defmodule Hermes.Client.Base do
799802
{request_id, updated_state} =
800803
State.add_request_from_operation(state, operation, from),
801804
{:ok, request_data} <- encode_request(method, params_with_token, request_id),
802-
:ok <- send_to_transport(state.transport, request_data) do
805+
:ok <- send_to_transport(state.transport, request_data, operation.headers) do
803806
Telemetry.execute(
804807
Telemetry.event_client_request(),
805808
%{system_time: System.system_time()},
@@ -1466,6 +1469,18 @@ defmodule Hermes.Client.Base do
14661469
send_notification(state, "notifications/cancelled", params)
14671470
end
14681471

1472+
defp send_to_transport(transport, data, headers) when not is_nil(headers) do
1473+
opts = [headers: headers]
1474+
1475+
with {:error, reason} <- transport.layer.send_message(transport.name, data, opts) do
1476+
{:error, Error.transport(:send_failure, %{original_reason: reason})}
1477+
end
1478+
end
1479+
1480+
defp send_to_transport(transport, data, _headers) do
1481+
send_to_transport(transport, data)
1482+
end
1483+
14691484
defp send_to_transport(transport, data) do
14701485
with {:error, reason} <- transport.layer.send_message(transport.name, data) do
14711486
{:error, Error.transport(:send_failure, %{original_reason: reason})}

lib/hermes/client/operation.ex

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ defmodule Hermes.Client.Operation do
2020
method: String.t(),
2121
params: map(),
2222
progress_opts: progress_options() | nil,
23-
timeout: pos_integer()
23+
timeout: pos_integer(),
24+
headers: map() | nil
2425
}
2526

2627
defstruct [
2728
:method,
2829
params: %{},
2930
progress_opts: [],
30-
timeout: @default_timeout
31+
timeout: @default_timeout,
32+
headers: nil
3133
]
3234

3335
@doc """
@@ -40,19 +42,22 @@ defmodule Hermes.Client.Operation do
4042
* `:params` - The parameters to send to the server (required)
4143
* `:progress_opts` - Progress tracking options (optional)
4244
* `:timeout` - The timeout for this operation in milliseconds (optional, defaults to 30s)
45+
* `:headers` - HTTP headers to send with this request (optional, only for StreamableHTTP transport)
4346
"""
4447
@spec new(%{
4548
required(:method) => String.t(),
4649
optional(:params) => map(),
4750
optional(:progress_opts) => progress_options() | nil,
48-
optional(:timeout) => pos_integer()
51+
optional(:timeout) => pos_integer(),
52+
optional(:headers) => map()
4953
}) :: t()
5054
def new(%{method: method} = attrs) do
5155
%__MODULE__{
5256
method: method,
5357
params: Map.get(attrs, :params) || %{},
5458
progress_opts: Map.get(attrs, :progress_opts),
55-
timeout: Map.get(attrs, :timeout) || @default_timeout
59+
timeout: Map.get(attrs, :timeout) || @default_timeout,
60+
headers: Map.get(attrs, :headers)
5661
}
5762
end
5863
end

lib/hermes/transport/streamable_http.ex

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ defmodule Hermes.Transport.StreamableHTTP do
8181
name: {{:custom, &Hermes.genserver_name/1}, {:default, __MODULE__}},
8282
client: {:required, Hermes.get_schema(:process_name)},
8383
base_url: {:required, {:string, {:transform, &URI.new!/1}}},
84-
mcp_path: {:string, {:default, "/mcp"}},
84+
mcp_path: {:string, {:default, "/mcp/"}},
8585
headers: {:map, {:default, %{}}},
8686
transport_opts: {:any, {:default, []}},
8787
http_options: {:any, {:default, []}},
@@ -96,8 +96,8 @@ defmodule Hermes.Transport.StreamableHTTP do
9696
end
9797

9898
@impl Transport
99-
def send_message(pid \\ __MODULE__, message) when is_binary(message) do
100-
GenServer.call(pid, {:send, message})
99+
def send_message(pid \\ __MODULE__, message, opts \\ []) when is_binary(message) do
100+
GenServer.call(pid, {:send, message, opts})
101101
end
102102

103103
@impl Transport
@@ -136,31 +136,36 @@ defmodule Hermes.Transport.StreamableHTTP do
136136
end
137137

138138
@impl GenServer
139-
def handle_call({:send, message}, from, state) do
139+
def handle_call({:send, message, opts}, from, state) do
140140
emit_telemetry(:send, state, %{message_size: byte_size(message)})
141141

142-
Logging.transport_event("sending_http_request", %{
142+
headers = Keyword.get(opts, :headers)
143+
log_event = if headers, do: "sending_http_request_with_headers", else: "sending_http_request"
144+
145+
log_data = %{
143146
url: URI.to_string(state.mcp_url),
144147
size: byte_size(message)
145-
})
148+
}
149+
150+
log_data = if headers, do: Map.put(log_data, :additional_headers, Map.keys(headers)), else: log_data
151+
152+
Logging.transport_event(log_event, log_data)
146153

147154
new_state = %{state | active_request: from}
148155

149-
case send_http_request(new_state, message) do
156+
case send_http_request(new_state, message, headers) do
150157
{:ok, response} ->
151158
Logging.transport_event("got_http_response", %{status: response.status})
152159
handle_response(response, new_state)
153160

154161
{:error, {:http_error, 404, _body}} when not is_nil(state.session_id) ->
155162
Logging.transport_event("session_expired", %{session_id: state.session_id})
156-
GenServer.cast(state.client, :session_expired)
157-
{:reply, {:error, :session_expired}, %{state | session_id: nil}}
163+
{:reply, {:error, :session_expired}, %{new_state | session_id: nil}}
158164

159165
{:error, reason} ->
160166
Logging.transport_event("http_request_error", %{reason: inspect(reason)}, level: :error)
161-
162167
log_error(reason)
163-
{:reply, {:error, reason}, state}
168+
{:reply, {:error, reason}, %{new_state | active_request: nil}}
164169
end
165170
end
166171

@@ -224,13 +229,31 @@ defmodule Hermes.Transport.StreamableHTTP do
224229

225230
# Private functions
226231

227-
defp send_http_request(state, message) do
228-
headers =
232+
defp send_http_request(state, message, additional_headers) do
233+
base_headers =
229234
state.headers
230235
|> Map.put("accept", "application/json, text/event-stream")
231236
|> Map.put("content-type", "application/json")
232237
|> put_session_header(state.session_id)
233238

239+
case additional_headers do
240+
nil ->
241+
do_send_http_request(state, message, base_headers)
242+
243+
headers when is_map(headers) ->
244+
case validate_headers(headers) do
245+
:ok ->
246+
final_headers = Map.merge(base_headers, headers)
247+
do_send_http_request(state, message, final_headers)
248+
249+
{:error, reason} ->
250+
{:error, reason}
251+
end
252+
end
253+
end
254+
255+
# Common HTTP request logic
256+
defp do_send_http_request(state, message, headers) do
234257
options = [transport_opts: state.transport_opts] ++ state.http_options
235258
url = URI.to_string(state.mcp_url)
236259

@@ -253,7 +276,6 @@ defmodule Hermes.Transport.StreamableHTTP do
253276

254277
{:error, reason} = error ->
255278
Logging.transport_event("http_request_failed", %{reason: reason}, level: :error)
256-
257279
error
258280
end
259281
end
@@ -324,6 +346,23 @@ defmodule Hermes.Transport.StreamableHTTP do
324346
Logging.transport_event("unknown_sse_event", event, level: :warning)
325347
end
326348

349+
# Header validation - fail fast with clear error messages
350+
defp validate_headers(headers) when is_map(headers) do
351+
case find_invalid_header(headers) do
352+
nil ->
353+
:ok
354+
355+
{key, value} ->
356+
{:error, "Invalid header value for '#{key}': HTTP headers must be strings, got #{inspect(value)}"}
357+
end
358+
end
359+
360+
defp find_invalid_header(headers) do
361+
Enum.find(headers, fn {_key, value} ->
362+
not is_binary(value) and not is_nil(value)
363+
end)
364+
end
365+
327366
defp put_session_header(headers, nil), do: headers
328367

329368
defp put_session_header(headers, session_id), do: Map.put(headers, "mcp-session-id", session_id)

test/hermes/transport/streamable_http_test.exs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,59 @@ defmodule Hermes.Transport.StreamableHTTPTest do
394394
StreamableHTTP.shutdown(transport)
395395
StubClient.clear_messages()
396396
end
397+
398+
test "sends per-request headers via opts parameter", %{bypass: bypass} do
399+
server_url = "http://localhost:#{bypass.port}"
400+
{:ok, stub_client} = StubClient.start_link()
401+
402+
Bypass.expect(bypass, "POST", "/mcp", fn conn ->
403+
# Verify transport-level header is present
404+
assert "Bearer base-token" ==
405+
conn |> Plug.Conn.get_req_header("authorization") |> List.first()
406+
407+
# Verify per-request header is present
408+
assert "request-123" ==
409+
conn |> Plug.Conn.get_req_header("x-request-id") |> List.first()
410+
411+
# Verify per-request header overrides transport-level header
412+
assert "request-specific" ==
413+
conn |> Plug.Conn.get_req_header("x-client-version") |> List.first()
414+
415+
conn = Plug.Conn.put_resp_header(conn, "content-type", "application/json")
416+
Plug.Conn.resp(conn, 200, ~s|{"jsonrpc":"2.0","id":"1","result":{}}|)
417+
end)
418+
419+
{:ok, transport} =
420+
StreamableHTTP.start_link(
421+
client: stub_client,
422+
base_url: server_url,
423+
mcp_path: "/mcp",
424+
headers: %{
425+
"authorization" => "Bearer base-token",
426+
"x-client-version" => "1.0.0"
427+
},
428+
transport_opts: @test_http_opts
429+
)
430+
431+
Process.sleep(100)
432+
433+
{:ok, ping_message} =
434+
Message.encode_request(%{"method" => "ping", "params" => %{}}, "1")
435+
436+
# Test new opts-based header functionality
437+
per_request_headers = %{
438+
"x-request-id" => "request-123",
439+
# This should override transport header
440+
"x-client-version" => "request-specific"
441+
}
442+
443+
assert :ok = StreamableHTTP.send_message(transport, ping_message, headers: per_request_headers)
444+
445+
Process.sleep(100)
446+
447+
StreamableHTTP.shutdown(transport)
448+
StubClient.clear_messages()
449+
end
397450
end
398451

399452
describe "error handling" do

0 commit comments

Comments
 (0)