Skip to content

Commit 54bee01

Browse files
authored
[1/5] Typing cleanup (#13)
* Typing cleanup * Ignore errors resulting from lack of dataclass_transform
1 parent 5da147c commit 54bee01

26 files changed

+296
-215
lines changed

pyproject.toml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ dependencies = [
1818

1919
[dependency-groups]
2020
dev = [
21+
"basedpyright>=1.30.1",
2122
"mypy>=1.15.0",
2223
"poethepoet>=0.35.0",
2324
"pydoctor>=25.4.0",
24-
"pyright>=1.1.402",
25+
"pyright>=1.1",
2526
"pytest>=8.3.5",
2627
"pytest-asyncio>=0.26.0",
2728
"pytest-cov>=6.1.1",
@@ -38,6 +39,7 @@ packages = ["src/nexusrpc"]
3839

3940
[tool.poe.tasks]
4041
lint = [
42+
{cmd = "uv run basedpyright src"},
4143
{cmd = "uv run pyright src"},
4244
{cmd = "uv run mypy --check-untyped-defs src"},
4345
{cmd = "uv run ruff check --select I"},
@@ -52,6 +54,21 @@ docs = [
5254
]
5355

5456
[tool.pyright]
57+
# https://docs.basedpyright.com/v1.30.0/configuration/config-files/#diagnostic-settings-defaults
58+
reportAny = "none"
59+
reportDeprecated = "none"
60+
reportExplicitAny = "none"
61+
reportIgnoreCommentWithoutRule = "none"
62+
reportImplicitOverride = "none"
63+
reportImplicitStringConcatenation = "none"
64+
reportImportCycles = "none"
65+
reportUnannotatedClassAttribute = "none"
66+
reportUnknownArgumentType = "none"
67+
reportUnknownMemberType = "none"
68+
reportUnknownVariableType = "none"
69+
reportUnnecessaryTypeIgnoreComment = "none"
70+
enableTypeIgnoreComments = true
71+
failOnWarnings = false # basedpyright setting
5572
include = ["src", "tests"]
5673

5774
[tool.mypy]

src/nexusrpc/_serializer.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncIterable, Awaitable, Mapping
34
from dataclasses import dataclass
45
from typing import (
56
Any,
6-
AsyncIterable,
7-
Awaitable,
8-
Mapping,
97
Optional,
108
Protocol,
11-
Type,
129
Union,
1310
)
1411

@@ -42,7 +39,7 @@ def serialize(self, value: Any) -> Union[Content, Awaitable[Content]]:
4239
...
4340

4441
def deserialize(
45-
self, content: Content, as_type: Optional[Type[Any]] = None
42+
self, content: Content, as_type: Optional[type[Any]] = None
4643
) -> Union[Any, Awaitable[Any]]:
4744
"""Deserialize decodes a Content into a value.
4845
@@ -56,7 +53,7 @@ def deserialize(
5653

5754
class LazyValueT(Protocol):
5855
def consume(
59-
self, as_type: Optional[Type[Any]] = None
56+
self, as_type: Optional[type[Any]] = None
6057
) -> Union[Any, Awaitable[Any]]: ...
6158

6259

@@ -96,22 +93,21 @@ def __init__(
9693
headers: Headers that include information on how to process the stream's content.
9794
Headers constructed by the framework always have lower case keys.
9895
User provided keys are treated case-insensitively.
99-
stream: Iterable that contains request or response data. None means empty data.
96+
stream: AsyncIterable of bytes that contains request or response data.
97+
None means empty data.
10098
"""
10199
self.serializer = serializer
102100
self.headers = headers
103101
self.stream = stream
104102

105-
async def consume(self, as_type: Optional[Type[Any]] = None) -> Any:
103+
async def consume(self, as_type: Optional[type[Any]] = None) -> Any:
106104
"""
107105
Consume the underlying reader stream, deserializing via the embedded serializer.
108106
"""
109107
if self.stream is None:
110108
return await self.serializer.deserialize(
111109
Content(headers=self.headers, data=b""), as_type=as_type
112110
)
113-
elif not isinstance(self.stream, AsyncIterable):
114-
raise ValueError("When using consume, stream must be an AsyncIterable")
115111

116112
return await self.serializer.deserialize(
117113
Content(

src/nexusrpc/_service.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88

99
import dataclasses
1010
import typing
11+
from collections.abc import Mapping
1112
from dataclasses import dataclass
1213
from typing import (
1314
Any,
1415
Callable,
1516
Generic,
16-
Mapping,
1717
Optional,
18-
Type,
1918
Union,
2019
overload,
2120
)
@@ -47,8 +46,8 @@ class MyNexusService:
4746
name: str
4847
# TODO(preview): they should not be able to set method_name in constructor
4948
method_name: Optional[str] = dataclasses.field(default=None)
50-
input_type: Optional[Type[InputT]] = dataclasses.field(default=None)
51-
output_type: Optional[Type[OutputT]] = dataclasses.field(default=None)
49+
input_type: Optional[type[InputT]] = dataclasses.field(default=None)
50+
output_type: Optional[type[OutputT]] = dataclasses.field(default=None)
5251

5352
def __post_init__(self):
5453
if not self.name:
@@ -70,22 +69,22 @@ def _validation_errors(self) -> list[str]:
7069

7170

7271
@overload
73-
def service(cls: Type[ServiceT]) -> Type[ServiceT]: ...
72+
def service(cls: type[ServiceT]) -> type[ServiceT]: ...
7473

7574

7675
@overload
7776
def service(
7877
*, name: Optional[str] = None
79-
) -> Callable[[Type[ServiceT]], Type[ServiceT]]: ...
78+
) -> Callable[[type[ServiceT]], type[ServiceT]]: ...
8079

8180

8281
def service(
83-
cls: Optional[Type[ServiceT]] = None,
82+
cls: Optional[type[ServiceT]] = None,
8483
*,
8584
name: Optional[str] = None,
8685
) -> Union[
87-
Type[ServiceT],
88-
Callable[[Type[ServiceT]], Type[ServiceT]],
86+
type[ServiceT],
87+
Callable[[type[ServiceT]], type[ServiceT]],
8988
]:
9089
"""
9190
Decorator marking a class as a Nexus service definition.
@@ -115,7 +114,7 @@ class AnotherService:
115114
# This will require forming a union of operations disovered via __annotations__
116115
# and __dict__
117116

118-
def decorator(cls: Type[ServiceT]) -> Type[ServiceT]:
117+
def decorator(cls: type[ServiceT]) -> type[ServiceT]:
119118
if name is not None and not name:
120119
raise ValueError("Service name must not be empty.")
121120
defn = ServiceDefinition.from_class(cls, name or cls.__name__)
@@ -161,7 +160,7 @@ def __post_init__(self):
161160
)
162161

163162
@staticmethod
164-
def from_class(user_class: Type[ServiceT], name: str) -> ServiceDefinition:
163+
def from_class(user_class: type[ServiceT], name: str) -> ServiceDefinition:
165164
"""Create a ServiceDefinition from a user service definition class.
166165
167166
The set of service definition operations returned is the union of operations
@@ -216,12 +215,12 @@ def _validation_errors(self) -> list[str]:
216215
if op.method_name in seen_method_names:
217216
errors.append(f"Operation method name '{op.method_name}' is not unique")
218217
seen_method_names.add(op.method_name)
219-
errors.extend(op._validation_errors())
218+
errors.extend(op._validation_errors()) # pyright: ignore[reportPrivateUsage]
220219
return errors
221220

222221
@staticmethod
223222
def _collect_operations(
224-
user_class: Type[ServiceT],
223+
user_class: type[ServiceT],
225224
) -> dict[str, Operation[Any, Any]]:
226225
"""Collect operations from a user service definition class.
227226

src/nexusrpc/_util.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import functools
44
import inspect
55
import typing
6-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Type
6+
from collections.abc import Awaitable
7+
from typing import TYPE_CHECKING, Any, Callable, Optional
78

89
from typing_extensions import TypeGuard
910

@@ -34,17 +35,15 @@ def get_service_definition(
3435

3536

3637
def set_service_definition(
37-
cls: Type[ServiceT], service_definition: nexusrpc.ServiceDefinition
38+
cls: type[ServiceT], service_definition: nexusrpc.ServiceDefinition
3839
) -> None:
39-
"""Set the :py:class:`nexusrpc.ServiceDefinition` for this object."""
40-
if not isinstance(cls, type):
41-
raise TypeError(f"Expected {cls} to be a class, but is {type(cls)}.")
40+
"""Set the :py:class:`nexusrpc.ServiceDefinition` for this class."""
4241
setattr(cls, "__nexus_service__", service_definition)
4342

4443

4544
def get_operation_definition(
4645
obj: Any,
47-
) -> Optional[nexusrpc.Operation]:
46+
) -> Optional[nexusrpc.Operation[Any, Any]]:
4847
"""Return the :py:class:`nexusrpc.Operation` for the object, or None
4948
5049
``obj`` should be a decorated operation start method.
@@ -54,7 +53,7 @@ def get_operation_definition(
5453

5554
def set_operation_definition(
5655
obj: Any,
57-
operation_definition: nexusrpc.Operation,
56+
operation_definition: nexusrpc.Operation[Any, Any],
5857
) -> None:
5958
"""Set the :py:class:`nexusrpc.Operation` for this object.
6059
@@ -137,7 +136,7 @@ def get_callable_name(fn: Callable[..., Any]) -> str:
137136
return method_name
138137

139138

140-
def is_subtype(type1: Type[Any], type2: Type[Any]) -> bool:
139+
def is_subtype(type1: type[Any], type2: type[Any]) -> bool:
141140
# Note that issubclass() argument 2 cannot be a parameterized generic
142141
# TODO(nexus-preview): review desired type compatibility logic
143142
if type1 == type2:
@@ -149,7 +148,9 @@ def is_subtype(type1: Type[Any], type2: Type[Any]) -> bool:
149148
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
150149

151150
try:
152-
from inspect import get_annotations # type: ignore
151+
from inspect import ( # type: ignore
152+
get_annotations as get_annotations, # type: ignore[reportAttributeAccessIssue]
153+
)
153154
except ImportError:
154155
import functools
155156
import sys
@@ -251,10 +252,10 @@ def get_annotations(obj, *, globals=None, locals=None, eval_str=False): # type:
251252
if unwrap is not None:
252253
while True:
253254
if hasattr(unwrap, "__wrapped__"):
254-
unwrap = unwrap.__wrapped__ # type: ignore
255+
unwrap = unwrap.__wrapped__ # type: ignore[reportFunctionMemberAccess,union-attr]
255256
continue
256257
if isinstance(unwrap, functools.partial):
257-
unwrap = unwrap.func # type: ignore
258+
unwrap = unwrap.func # type: ignore[reportGeneralTypeIssues,assignment]
258259
continue
259260
break
260261
if hasattr(unwrap, "__globals__"):

src/nexusrpc/handler/_common.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
from __future__ import annotations
22

33
from abc import ABC
4+
from collections.abc import Mapping, Sequence
45
from dataclasses import dataclass, field
56
from datetime import timedelta
6-
from typing import (
7-
Generic,
8-
Mapping,
9-
Optional,
10-
Sequence,
11-
)
7+
from typing import Any, Generic, Optional
128

139
from nexusrpc._common import Link, OutputT
1410

@@ -19,7 +15,7 @@ class OperationContext(ABC):
1915
2016
Includes information from the request."""
2117

22-
def __new__(cls, *args, **kwargs):
18+
def __new__(cls, *args: Any, **kwargs: Any):
2319
if cls is OperationContext:
2420
raise TypeError(
2521
"OperationContext is an abstract class and cannot be instantiated directly"

src/nexusrpc/handler/_core.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,9 @@ def my_op(...)
9393
import asyncio
9494
import concurrent.futures
9595
from abc import ABC, abstractmethod
96+
from collections.abc import Awaitable, Mapping, Sequence
9697
from dataclasses import dataclass
97-
from typing import (
98-
Any,
99-
Awaitable,
100-
Callable,
101-
Mapping,
102-
Optional,
103-
Sequence,
104-
Union,
105-
)
98+
from typing import Any, Callable, Optional, Union
10699

107100
from typing_extensions import Self, TypeGuard
108101

@@ -179,7 +172,7 @@ def cancel_operation(
179172
...
180173

181174

182-
class BaseServiceCollectionHandler(AbstractHandler):
175+
class BaseServiceCollectionHandler(AbstractHandler, ABC):
183176
"""
184177
A Nexus handler, managing a collection of Nexus service handlers.
185178
@@ -294,7 +287,7 @@ async def start_operation(
294287
input: The input to the operation, as a LazyValue.
295288
"""
296289
service_handler = self._get_service_handler(ctx.service)
297-
op_handler = service_handler._get_operation_handler(ctx.operation)
290+
op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage]
298291
op = service_handler.service.operations[ctx.operation]
299292
deserialized_input = await input.consume(as_type=op.input_type)
300293
# TODO(preview): apply middleware stack
@@ -314,7 +307,7 @@ async def cancel_operation(self, ctx: CancelOperationContext, token: str) -> Non
314307
token: The operation token.
315308
"""
316309
service_handler = self._get_service_handler(ctx.service)
317-
op_handler = service_handler._get_operation_handler(ctx.operation)
310+
op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage]
318311
if is_async_callable(op_handler.cancel):
319312
return await op_handler.cancel(ctx, token)
320313
else:
@@ -325,7 +318,7 @@ async def fetch_operation_info(
325318
self, ctx: FetchOperationInfoContext, token: str
326319
) -> OperationInfo:
327320
service_handler = self._get_service_handler(ctx.service)
328-
op_handler = service_handler._get_operation_handler(ctx.operation)
321+
op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage]
329322
if is_async_callable(op_handler.fetch_info):
330323
return await op_handler.fetch_info(ctx, token)
331324
else:
@@ -341,7 +334,7 @@ async def fetch_operation_result(
341334
"wait parameter or request-timeout header."
342335
)
343336
service_handler = self._get_service_handler(ctx.service)
344-
op_handler = service_handler._get_operation_handler(ctx.operation)
337+
op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage]
345338
if is_async_callable(op_handler.fetch_result):
346339
return await op_handler.fetch_result(ctx, token)
347340
else:
@@ -351,10 +344,10 @@ async def fetch_operation_result(
351344
def _validate_all_operation_handlers_are_async(self) -> None:
352345
for service_handler in self.service_handlers.values():
353346
for op_handler in service_handler.operation_handlers.values():
354-
self._assert_async_callable(op_handler.start)
355-
self._assert_async_callable(op_handler.cancel)
356-
self._assert_async_callable(op_handler.fetch_info)
357-
self._assert_async_callable(op_handler.fetch_result)
347+
_ = self._assert_async_callable(op_handler.start)
348+
_ = self._assert_async_callable(op_handler.cancel)
349+
_ = self._assert_async_callable(op_handler.fetch_info)
350+
_ = self._assert_async_callable(op_handler.fetch_result)
358351

359352
def _assert_async_callable(
360353
self, method: Callable[..., Any]

0 commit comments

Comments
 (0)