Skip to content

Commit 2f9f2c0

Browse files
authored
Clarify get_operation_factory (#17)
1 parent 8387369 commit 2f9f2c0

File tree

4 files changed

+42
-48
lines changed

4 files changed

+42
-48
lines changed

src/nexusrpc/_util.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,12 @@ def get_operation(
5151
) -> Optional[nexusrpc.Operation[Any, Any]]:
5252
"""Return the :py:class:`nexusrpc.Operation` for the object, or None
5353
54-
``obj`` should be a decorated operation start method.
54+
``obj`` should be a decorated operation start method, or a method that takes
55+
no arguments and returns an OperationHandler.
5556
"""
57+
if factory := getattr(obj, _NEXUS_OPERATION_FACTORY_ATTR_NAME, None):
58+
# obj was a decorated operation start method
59+
obj = factory
5660
op = getattr(obj, _NEXUS_OPERATION_ATTR_NAME, None)
5761
if op and not isinstance(op, nexusrpc.Operation):
5862
raise ValueError(f"{op} is not a nexusrpc.Operation")
@@ -74,23 +78,15 @@ def set_operation(
7478

7579
def get_operation_factory(
7680
obj: Any,
77-
) -> tuple[
78-
Optional[Callable[[Any], OperationHandler[InputT, OutputT]]],
79-
Optional[nexusrpc.Operation[InputT, OutputT]],
80-
]:
81-
"""Return the :py:class:`Operation` for the object along with the factory function.
82-
83-
``obj`` should be a decorated operation start method.
84-
"""
85-
op = get_operation(obj)
86-
if op:
87-
factory = obj
88-
else:
89-
if factory := getattr(obj, _NEXUS_OPERATION_FACTORY_ATTR_NAME, None):
90-
op = get_operation(factory)
91-
if not isinstance(op, nexusrpc.Operation):
92-
return None, None
93-
return factory, op
81+
) -> Optional[Callable[[Any], OperationHandler[Any, Any]]]:
82+
"""Return the :py:class:`OperationHandler` factory function for the object."""
83+
if factory := getattr(obj, _NEXUS_OPERATION_FACTORY_ATTR_NAME, None):
84+
# obj was a decorated operation start method
85+
return factory
86+
if get_operation(obj):
87+
# obj was the desired factory
88+
return obj
89+
return None
9490

9591

9692
def set_operation_factory(

src/nexusrpc/handler/_operation_handler.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from nexusrpc._common import InputT, OperationInfo, OutputT, ServiceHandlerT
99
from nexusrpc._service import Operation, OperationDefinition, ServiceDefinition
1010
from nexusrpc._util import (
11+
get_operation,
1112
get_operation_factory,
1213
is_async_callable,
1314
is_callable,
@@ -153,31 +154,28 @@ def collect_operation_handler_factories_by_method_name(
153154
)
154155
seen = set()
155156
for _, method in inspect.getmembers(user_service_cls, is_callable):
156-
factory, op = get_operation_factory(method) # type: ignore[var-annotated]
157-
if factory and isinstance(op, Operation):
158-
# This is a method decorated with one of the *operation_handler decorators
159-
if op.name in seen:
160-
raise RuntimeError(
161-
f"Operation '{op.name}' in service '{user_service_cls.__name__}' "
162-
f"is defined multiple times."
157+
if factory := get_operation_factory(method):
158+
if op := get_operation(factory):
159+
if op.name in seen:
160+
raise RuntimeError(
161+
f"Operation '{op.name}' in service '{user_service_cls.__name__}' "
162+
f"is defined multiple times."
163+
)
164+
if service and op.method_name not in service_method_names:
165+
_names = ", ".join(f"'{s}'" for s in sorted(service_method_names))
166+
msg = (
167+
f"Operation method name '{op.method_name}' in service handler {user_service_cls} "
168+
f"does not match an operation method name in the service definition. "
169+
f"Available method names in the service definition: "
170+
)
171+
msg += _names if _names else "[none]"
172+
raise TypeError(msg)
173+
174+
assert op.method_name, (
175+
f"Operation '{op}' method name should not be empty. Please report this as a bug."
163176
)
164-
if service and op.method_name not in service_method_names:
165-
_names = ", ".join(f"'{s}'" for s in sorted(service_method_names))
166-
msg = (
167-
f"Operation method name '{op.method_name}' in service handler {user_service_cls} "
168-
f"does not match an operation method name in the service definition. "
169-
f"Available method names in the service definition: "
170-
)
171-
msg += _names if _names else "[none]"
172-
msg += "."
173-
raise TypeError(msg)
174-
175-
# TODO(preview) op_defn.method name should be non-nullable
176-
assert op.method_name, (
177-
f"Operation '{op}' method name should not be None. This is an SDK bug."
178-
)
179-
factories[op.method_name] = factory
180-
seen.add(op.name)
177+
factories[op.method_name] = factory
178+
seen.add(op.name)
181179
return factories
182180

183181

@@ -212,7 +210,7 @@ def validate_operation_handler_methods(
212210
f"method name '{op_defn.method_name}'. But this operation is in service "
213211
f"definition '{service_definition}'."
214212
)
215-
_, op = get_operation_factory(factory)
213+
op = get_operation(factory)
216214
if not isinstance(op, Operation):
217215
raise ValueError(
218216
f"Method '{factory}' in class '{service_cls.__name__}' "
@@ -278,7 +276,7 @@ def service_definition_from_operation_handler_methods(
278276
"""
279277
op_defns: dict[str, OperationDefinition[Any, Any]] = {}
280278
for name, method in user_methods.items():
281-
_, op = get_operation_factory(method)
279+
op = get_operation(method)
282280
if not isinstance(op, Operation):
283281
raise ValueError(
284282
f"In service '{service_name}', could not locate operation definition for "

tests/handler/test_request_routing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import nexusrpc
77
from nexusrpc import LazyValue
8-
from nexusrpc._util import get_operation_factory, get_service_definition
8+
from nexusrpc._util import get_operation, get_service_definition
99
from nexusrpc.handler import (
1010
Handler,
1111
StartOperationContext,
@@ -33,7 +33,7 @@ class UserServiceHandler:
3333
async def _op_impl(self, ctx: StartOperationContext, input: None) -> bool:
3434
assert (service_defn := get_service_definition(self.__class__))
3535
assert ctx.service == service_defn.name
36-
_, op_handler_op_defn = get_operation_factory(self.op)
36+
op_handler_op_defn = get_operation(self.op)
3737
assert op_handler_op_defn
3838
assert service_defn.operation_definitions.get(ctx.operation)
3939
return True

tests/handler/test_sync_operation_handler_decorator_creates_valid_operation_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async def my_async_def_op(self, ctx: StartOperationContext, input: int) -> int:
3535

3636
def test_def_sync_handler():
3737
user_instance = MyServiceHandler()
38-
op_handler_factory, _ = get_operation_factory(user_instance.my_def_op)
38+
op_handler_factory = get_operation_factory(user_instance.my_def_op)
3939
assert op_handler_factory
4040
op_handler = op_handler_factory(user_instance)
4141
assert not is_async_callable(op_handler.start)
@@ -54,7 +54,7 @@ def test_def_sync_handler():
5454
@pytest.mark.asyncio
5555
async def test_async_def_sync_handler():
5656
user_instance = MyServiceHandler()
57-
op_handler_factory, _ = get_operation_factory(user_instance.my_async_def_op)
57+
op_handler_factory = get_operation_factory(user_instance.my_async_def_op)
5858
assert op_handler_factory
5959
op_handler = op_handler_factory(user_instance)
6060
assert is_async_callable(op_handler.start)

0 commit comments

Comments
 (0)