Skip to content

Commit 8387369

Browse files
authored
[4/5] Introduce OperationDefinition (#16)
1 parent 79baa15 commit 8387369

14 files changed

+261
-260
lines changed

src/nexusrpc/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@
2828
OutputT,
2929
)
3030
from ._serializer import Content, LazyValue
31-
from ._service import Operation, ServiceDefinition, service
31+
from ._service import Operation, OperationDefinition, ServiceDefinition, service
3232
from ._util import (
33-
get_operation_definition,
33+
get_operation,
3434
get_service_definition,
35-
set_operation_definition,
35+
set_operation,
3636
)
3737

3838
__all__ = [
3939
"Content",
40-
"get_operation_definition",
40+
"get_operation",
4141
"get_service_definition",
4242
"handler",
4343
"HandlerError",
@@ -46,12 +46,13 @@
4646
"LazyValue",
4747
"Link",
4848
"Operation",
49+
"OperationDefinition",
4950
"OperationError",
5051
"OperationErrorState",
5152
"OperationInfo",
5253
"OperationState",
5354
"OutputT",
5455
"service",
5556
"ServiceDefinition",
56-
"set_operation_definition",
57+
"set_operation",
5758
]

src/nexusrpc/_service.py

Lines changed: 76 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,38 @@ class MyNexusService:
5151
input_type: Optional[type[InputT]] = dataclasses.field(default=None)
5252
output_type: Optional[type[OutputT]] = dataclasses.field(default=None)
5353

54-
def __post_init__(self):
55-
if not self.name:
56-
raise ValueError("Operation name cannot be empty")
5754

58-
def _validation_errors(self) -> list[str]:
59-
errors = []
60-
if not self.name:
61-
errors.append(
62-
f"Operation has no name (method_name is '{self.method_name}')"
55+
@dataclass
56+
class OperationDefinition(Generic[InputT, OutputT]):
57+
"""
58+
Internal representation of a user's :py:class:`Operation` definition.
59+
"""
60+
61+
name: str
62+
method_name: str
63+
input_type: type[InputT]
64+
output_type: type[OutputT]
65+
66+
@classmethod
67+
def from_operation(
68+
cls, operation: Operation[InputT, OutputT]
69+
) -> OperationDefinition[InputT, OutputT]:
70+
if not operation.name:
71+
raise ValueError(
72+
f"Operation has no name (method_name is '{operation.method_name}')"
6373
)
64-
if not self.method_name:
65-
errors.append(f"Operation '{self.name}' has no method name")
66-
if not self.input_type:
67-
errors.append(f"Operation '{self.name}' has no input type")
68-
if not self.output_type:
69-
errors.append(f"Operation '{self.name}' has no output type")
70-
return errors
74+
if not operation.method_name:
75+
raise ValueError(f"Operation '{operation.name}' has no method name")
76+
if not operation.input_type:
77+
raise ValueError(f"Operation '{operation.name}' has no input type")
78+
if not operation.output_type:
79+
raise ValueError(f"Operation '{operation.name}' has no output type")
80+
return cls(
81+
name=operation.name,
82+
method_name=operation.method_name,
83+
input_type=operation.input_type,
84+
output_type=operation.output_type,
85+
)
7186

7287

7388
@overload
@@ -124,16 +139,23 @@ def decorator(cls: type[ServiceT]) -> type[ServiceT]:
124139
defn = ServiceDefinition.from_class(cls, name or cls.__name__)
125140
set_service_definition(cls, defn)
126141

127-
# In order for callers to refer to operations at run-time, a decorated user
142+
# In order for callers to refer to operation definitions at run-time, a decorated user
128143
# service class must itself have a class attribute for every operation, even if
129144
# declared only via a type annotation, and whether inherited from a parent class
130145
# or not.
131146
#
132147
# TODO(preview): it is sufficient to do this setattr only for the subset of
133148
# operations that were declared on *this* class. Currently however we are
134149
# setting all inherited operations.
135-
for op_name, op in defn.operations.items():
136-
setattr(cls, op_name, op)
150+
for op_name, op_defn in defn.operation_definitions.items():
151+
if not hasattr(cls, op_name):
152+
op = Operation(
153+
name=op_defn.name,
154+
method_name=op_defn.method_name,
155+
input_type=op_defn.input_type,
156+
output_type=op_defn.output_type,
157+
)
158+
setattr(cls, op_name, op)
137159

138160
return cls
139161

@@ -155,7 +177,7 @@ class ServiceDefinition:
155177
"""
156178

157179
name: str
158-
operations: Mapping[str, Operation[Any, Any]]
180+
operation_definitions: Mapping[str, OperationDefinition[Any, Any]]
159181

160182
def __post_init__(self):
161183
if errors := self._validation_errors():
@@ -173,7 +195,7 @@ def from_class(user_class: type[ServiceT], name: str) -> ServiceDefinition:
173195
If multiple service definitions define an operation with the same name, then the
174196
usual mro() precedence rules apply.
175197
"""
176-
operations = ServiceDefinition._collect_operations(user_class)
198+
operation_definitions = ServiceDefinition._collect_operations(user_class)
177199

178200
# Obtain the set of operations to be inherited from ancestral service
179201
# definitions. Operations are only inherited from classes that are also
@@ -186,47 +208,53 @@ def from_class(user_class: type[ServiceT], name: str) -> ServiceDefinition:
186208
# 2. No inherited operation has the same method name as that of an operation
187209
# defined here. If this were violated, there would be ambiguity in which
188210
# operation handler is dispatched to.
189-
parent_defns = (
190-
defn
191-
for defn in (get_service_definition(cls) for cls in user_class.mro()[1:])
192-
if defn
211+
parent_service_defn = next(
212+
(
213+
defn
214+
for defn in (
215+
get_service_definition(cls) for cls in user_class.mro()[1:]
216+
)
217+
if defn
218+
),
219+
None,
193220
)
194-
method_names = {op.method_name for op in operations.values() if op.method_name}
195-
if parent_defn := next(parent_defns, None):
196-
for op in parent_defn.operations.values():
197-
if op.method_name in method_names:
221+
if parent_service_defn:
222+
method_names = {op.method_name for op in operation_definitions.values()}
223+
for op_defn in parent_service_defn.operation_definitions.values():
224+
if op_defn.method_name in method_names:
198225
raise ValueError(
199-
f"Operation method name '{op.method_name}' in class '{user_class}' "
226+
f"Operation method name '{op_defn.method_name}' in class '{user_class}' "
200227
f"also occurs in a service definition inherited from a parent class: "
201-
f"'{parent_defn.name}'. This is not allowed."
228+
f"'{parent_service_defn.name}'. This is not allowed."
202229
)
203-
if op.name in operations:
230+
if op_defn.name in operation_definitions:
204231
raise ValueError(
205-
f"Operation name '{op.name}' in class '{user_class}' "
232+
f"Operation name '{op_defn.name}' in class '{user_class}' "
206233
f"also occurs in a service definition inherited from a parent class: "
207-
f"'{parent_defn.name}'. This is not allowed."
234+
f"'{parent_service_defn.name}'. This is not allowed."
208235
)
209-
operations[op.name] = op
236+
operation_definitions[op_defn.name] = op_defn
210237

211-
return ServiceDefinition(name=name, operations=operations)
238+
return ServiceDefinition(name=name, operation_definitions=operation_definitions)
212239

213240
def _validation_errors(self) -> list[str]:
214241
errors = []
215242
if not self.name:
216243
errors.append("Service has no name")
217244
seen_method_names = set()
218-
for op in self.operations.values():
219-
if op.method_name in seen_method_names:
220-
errors.append(f"Operation method name '{op.method_name}' is not unique")
221-
seen_method_names.add(op.method_name)
222-
errors.extend(op._validation_errors()) # pyright: ignore[reportPrivateUsage]
245+
for op_defn in self.operation_definitions.values():
246+
if op_defn.method_name in seen_method_names:
247+
errors.append(
248+
f"Operation method name '{op_defn.method_name}' is not unique"
249+
)
250+
seen_method_names.add(op_defn.method_name)
223251
return errors
224252

225253
@staticmethod
226254
def _collect_operations(
227255
user_class: type[ServiceT],
228-
) -> dict[str, Operation[Any, Any]]:
229-
"""Collect operations from a user service definition class.
256+
) -> dict[str, OperationDefinition[Any, Any]]:
257+
"""Collect operation definitions from a user service definition class.
230258
231259
Does not visit parent classes.
232260
"""
@@ -301,11 +329,11 @@ def _collect_operations(
301329
if op.method_name is None:
302330
op.method_name = key
303331

304-
operations_by_name = {}
332+
op_defns = {}
305333
for op in operations.values():
306-
if op.name in operations_by_name:
307-
raise ValueError(
308-
f"Operation '{op.name}' in class '{user_class}' is defined multiple times"
334+
if op.name in op_defns:
335+
raise RuntimeError(
336+
f"Operation '{op.name}' in service '{user_class}' is defined multiple times"
309337
)
310-
operations_by_name[op.name] = op
311-
return operations_by_name
338+
op_defns[op.name] = OperationDefinition.from_operation(op)
339+
return op_defns

src/nexusrpc/_util.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@
1616
from nexusrpc._common import ServiceT
1717
from nexusrpc.handler._operation_handler import OperationHandler
1818

19+
_NEXUS_SERVICE_DEFINITION_ATTR_NAME = "__nexus_service_definition__"
20+
_NEXUS_OPERATION_ATTR_NAME = "__nexus_operation__"
21+
_NEXUS_OPERATION_FACTORY_ATTR_NAME = "__nexus_operation_factory__"
22+
1923

2024
def get_service_definition(
2125
obj: Any,
2226
) -> Optional[nexusrpc.ServiceDefinition]:
2327
"""Return the :py:class:`nexusrpc.ServiceDefinition` for the object, or None"""
24-
# getattr would allow a non-decorated class to act as a service
28+
# Do not use getattr since it would allow a non-decorated class to act as a service
2529
# definition if it inherits from a decorated class.
2630
if isinstance(obj, type):
27-
defn = obj.__dict__.get("__nexus_service__")
31+
defn = obj.__dict__.get(_NEXUS_SERVICE_DEFINITION_ATTR_NAME)
2832
else:
29-
defn = getattr(obj, "__dict__", {}).get("__nexus_service__")
33+
defn = getattr(obj, "__dict__", {}).get(_NEXUS_SERVICE_DEFINITION_ATTR_NAME)
3034
if defn and not isinstance(defn, nexusrpc.ServiceDefinition):
3135
raise ValueError(
32-
f"Service definition {obj.__name__} has a __nexus_service__ attribute that is not a ServiceDefinition."
36+
f"{obj.__name__} has a {_NEXUS_SERVICE_DEFINITION_ATTR_NAME} attribute "
37+
f"that is not a nexusrpc.ServiceDefinition."
3338
)
3439
return defn
3540

@@ -38,28 +43,33 @@ def set_service_definition(
3843
cls: type[ServiceT], service_definition: nexusrpc.ServiceDefinition
3944
) -> None:
4045
"""Set the :py:class:`nexusrpc.ServiceDefinition` for this class."""
41-
setattr(cls, "__nexus_service__", service_definition)
46+
setattr(cls, _NEXUS_SERVICE_DEFINITION_ATTR_NAME, service_definition)
4247

4348

44-
def get_operation_definition(
49+
def get_operation(
4550
obj: Any,
4651
) -> Optional[nexusrpc.Operation[Any, Any]]:
4752
"""Return the :py:class:`nexusrpc.Operation` for the object, or None
4853
4954
``obj`` should be a decorated operation start method.
5055
"""
51-
return getattr(obj, "__nexus_operation__", None)
56+
op = getattr(obj, _NEXUS_OPERATION_ATTR_NAME, None)
57+
if op and not isinstance(op, nexusrpc.Operation):
58+
raise ValueError(f"{op} is not a nexusrpc.Operation")
59+
return op
5260

5361

54-
def set_operation_definition(
62+
def set_operation(
5563
obj: Any,
56-
operation_definition: nexusrpc.Operation[Any, Any],
64+
operation: nexusrpc.Operation[Any, Any],
5765
) -> None:
5866
"""Set the :py:class:`nexusrpc.Operation` for this object.
5967
6068
``obj`` should be an operation start method.
6169
"""
62-
setattr(obj, "__nexus_operation__", operation_definition)
70+
if not isinstance(operation, nexusrpc.Operation): # type: ignore
71+
raise ValueError(f"{operation} is not a nexusrpc.Operation") # type: ignore
72+
setattr(obj, _NEXUS_OPERATION_ATTR_NAME, operation)
6373

6474

6575
def get_operation_factory(
@@ -72,15 +82,15 @@ def get_operation_factory(
7282
7383
``obj`` should be a decorated operation start method.
7484
"""
75-
op_defn = get_operation_definition(obj)
76-
if op_defn:
85+
op = get_operation(obj)
86+
if op:
7787
factory = obj
7888
else:
79-
if factory := getattr(obj, "__nexus_operation_factory__", None):
80-
op_defn = get_operation_definition(factory)
81-
if not isinstance(op_defn, nexusrpc.Operation):
89+
if factory := getattr(obj, _NEXUS_OPERATION_FACTORY_ATTR_NAME, None):
90+
op = get_operation(factory)
91+
if not isinstance(op, nexusrpc.Operation):
8292
return None, None
83-
return factory, op_defn
93+
return factory, op
8494

8595

8696
def set_operation_factory(
@@ -91,7 +101,7 @@ def set_operation_factory(
91101
92102
``obj`` should be an operation start method.
93103
"""
94-
setattr(obj, "__nexus_operation_factory__", operation_factory)
104+
setattr(obj, _NEXUS_OPERATION_FACTORY_ATTR_NAME, operation_factory)
95105

96106

97107
# Copied from https://github.com/modelcontextprotocol/python-sdk

0 commit comments

Comments
 (0)