28
28
GetTaskPushNotificationConfigRequest ,
29
29
GetTaskRequest ,
30
30
InternalError ,
31
+ InvalidParamsError ,
31
32
InvalidRequestError ,
32
33
JSONParseError ,
33
34
JSONRPCError ,
34
35
JSONRPCErrorResponse ,
35
36
JSONRPCRequest ,
36
37
JSONRPCResponse ,
37
38
ListTaskPushNotificationConfigRequest ,
39
+ MethodNotFoundError ,
38
40
SendMessageRequest ,
39
41
SendStreamingMessageRequest ,
40
42
SendStreamingMessageResponse ,
89
91
Response = Any
90
92
HTTP_413_REQUEST_ENTITY_TOO_LARGE = Any
91
93
94
+ MAX_CONTENT_LENGTH = 1_000_000
95
+
92
96
93
97
class StarletteUserProxy (A2AUser ):
94
98
"""Adapts the Starlette User class to the A2A user representation."""
@@ -151,6 +155,25 @@ class JSONRPCApplication(ABC):
151
155
(SSE).
152
156
"""
153
157
158
+ # Method-to-model mapping for centralized routing
159
+ A2ARequestModel = (
160
+ SendMessageRequest
161
+ | SendStreamingMessageRequest
162
+ | GetTaskRequest
163
+ | CancelTaskRequest
164
+ | SetTaskPushNotificationConfigRequest
165
+ | GetTaskPushNotificationConfigRequest
166
+ | ListTaskPushNotificationConfigRequest
167
+ | DeleteTaskPushNotificationConfigRequest
168
+ | TaskResubscriptionRequest
169
+ | GetAuthenticatedExtendedCardRequest
170
+ )
171
+
172
+ METHOD_TO_MODEL : dict [str , type [A2ARequestModel ]] = {
173
+ model .model_fields ['method' ].default : model
174
+ for model in A2ARequestModel .__args__
175
+ }
176
+
154
177
def __init__ ( # noqa: PLR0913
155
178
self ,
156
179
agent_card : AgentCard ,
@@ -271,17 +294,60 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
271
294
body = await request .json ()
272
295
if isinstance (body , dict ):
273
296
request_id = body .get ('id' )
297
+ # Ensure request_id is valid for JSON-RPC response (str/int/None only)
298
+ if request_id is not None and not isinstance (
299
+ request_id , str | int
300
+ ):
301
+ request_id = None
302
+ # Treat very large payloads as invalid request (-32600) before routing
303
+ with contextlib .suppress (Exception ):
304
+ content_length = int (request .headers .get ('content-length' , '0' ))
305
+ if content_length and content_length > MAX_CONTENT_LENGTH :
306
+ return self ._generate_error_response (
307
+ request_id ,
308
+ A2AError (
309
+ root = InvalidRequestError (
310
+ message = 'Payload too large'
311
+ )
312
+ ),
313
+ )
314
+ logger .debug ('Request body: %s' , body )
315
+ # 1) Validate base JSON-RPC structure only (-32600 on failure)
316
+ try :
317
+ base_request = JSONRPCRequest .model_validate (body )
318
+ except ValidationError as e :
319
+ logger .exception ('Failed to validate base JSON-RPC request' )
320
+ return self ._generate_error_response (
321
+ request_id ,
322
+ A2AError (
323
+ root = InvalidRequestError (data = json .loads (e .json ()))
324
+ ),
325
+ )
274
326
275
- # First, validate the basic JSON-RPC structure. This is crucial
276
- # because the A2ARequest model is a discriminated union where some
277
- # request types have default values for the 'method' field
278
- JSONRPCRequest .model_validate (body )
327
+ # 2) Route by method name; unknown -> -32601, known -> validate params (-32602 on failure)
328
+ method = base_request .method
279
329
280
- a2a_request = A2ARequest .model_validate (body )
330
+ model_class = self .METHOD_TO_MODEL .get (method )
331
+ if not model_class :
332
+ return self ._generate_error_response (
333
+ request_id , A2AError (root = MethodNotFoundError ())
334
+ )
335
+ try :
336
+ specific_request = model_class .model_validate (body )
337
+ except ValidationError as e :
338
+ logger .exception ('Failed to validate base JSON-RPC request' )
339
+ return self ._generate_error_response (
340
+ request_id ,
341
+ A2AError (
342
+ root = InvalidParamsError (data = json .loads (e .json ()))
343
+ ),
344
+ )
281
345
346
+ # 3) Build call context and wrap the request for downstream handling
282
347
call_context = self ._context_builder .build (request )
283
348
284
- request_id = a2a_request .root .id
349
+ request_id = specific_request .id
350
+ a2a_request = A2ARequest (root = specific_request )
285
351
request_obj = a2a_request .root
286
352
287
353
if isinstance (
@@ -305,12 +371,6 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
305
371
return self ._generate_error_response (
306
372
None , A2AError (root = JSONParseError (message = str (e )))
307
373
)
308
- except ValidationError as e :
309
- traceback .print_exc ()
310
- return self ._generate_error_response (
311
- request_id ,
312
- A2AError (root = InvalidRequestError (data = json .loads (e .json ()))),
313
- )
314
374
except HTTPException as e :
315
375
if e .status_code == HTTP_413_REQUEST_ENTITY_TOO_LARGE :
316
376
return self ._generate_error_response (
0 commit comments