Skip to content

Commit 999ce7a

Browse files
Daverballdavidism
authored andcommitted
Improve generic typing further
1 parent 52890d7 commit 999ce7a

File tree

3 files changed

+92
-21
lines changed

3 files changed

+92
-21
lines changed

src/itsdangerous/_json.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ class _CompactJSON:
88
"""Wrapper around json module that strips whitespace."""
99

1010
@staticmethod
11-
def loads(s: str | bytes) -> t.Any:
12-
return _json.loads(s)
11+
def loads(payload: str | bytes) -> t.Any:
12+
return _json.loads(payload)
1313

1414
@staticmethod
15-
def dumps(obj: t.Any, *args: t.Any, **kwargs: t.Any) -> str:
15+
def dumps(obj: t.Any, **kwargs: t.Any) -> str:
1616
kwargs.setdefault("ensure_ascii", False)
1717
kwargs.setdefault("separators", (",", ":"))
1818
return _json.dumps(obj, **kwargs)

src/itsdangerous/serializer.py

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,36 @@
1010
from .signer import _make_keys_list
1111
from .signer import Signer
1212

13-
14-
class _PDataSerializer(t.Protocol[t.AnyStr]):
15-
def loads(self, s: t.AnyStr) -> t.Any: ...
16-
def dumps(self, obj: t.Any, *args: t.Any, **kwargs: t.Any) -> t.AnyStr: ...
17-
18-
19-
def is_text_serializer(serializer: _PDataSerializer[t.Any]) -> bool:
13+
if t.TYPE_CHECKING:
14+
import typing_extensions as te
15+
16+
# This should be either be str or bytes. To avoid having to specify the
17+
# bound type, it falls back to a union if structural matching fails.
18+
_TSerialized = te.TypeVar(
19+
"_TSerialized", bound=t.Union[str, bytes], default=t.Union[str, bytes]
20+
)
21+
else:
22+
# Still available at runtime on Python < 3.13, but without the default.
23+
_TSerialized = t.TypeVar("_TSerialized", bound=t.Union[str, bytes])
24+
25+
26+
class _PDataSerializer(t.Protocol[_TSerialized]):
27+
def loads(self, payload: _TSerialized, /) -> t.Any: ...
28+
# A signature with additional arguments is not handled correctly by type
29+
# checkers right now, so an overload is used below for serializers that
30+
# don't match this strict protocol.
31+
def dumps(self, obj: t.Any, /) -> _TSerialized: ...
32+
33+
34+
# Use TypeIs once it's available in typing_extensions or 3.13.
35+
def is_text_serializer(
36+
serializer: _PDataSerializer[t.Any],
37+
) -> te.TypeGuard[_PDataSerializer[str]]:
2038
"""Checks whether a serializer generates text or binary."""
2139
return isinstance(serializer.dumps({}), str)
2240

2341

24-
class Serializer(t.Generic[t.AnyStr]):
42+
class Serializer(t.Generic[_TSerialized]):
2543
"""A serializer wraps a :class:`~itsdangerous.signer.Signer` to
2644
enable serializing and securely signing data other than bytes. It
2745
can unsign to verify that the data hasn't been changed.
@@ -76,7 +94,7 @@ class Serializer(t.Generic[t.AnyStr]):
7694
#: The default serialization module to use to serialize data to a
7795
#: string internally. The default is :mod:`json`, but can be changed
7896
#: to any object that provides ``dumps`` and ``loads`` methods.
79-
default_serializer: _PDataSerializer[t.Any] = json # pyright: ignore
97+
default_serializer: _PDataSerializer[t.Any] = json
8098

8199
#: The default ``Signer`` class to instantiate when signing data.
82100
#: The default is :class:`itsdangerous.signer.Signer`.
@@ -87,14 +105,64 @@ class Serializer(t.Generic[t.AnyStr]):
87105
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
88106
] = []
89107

90-
# Tell type checkers that the default type is Serializer[str] if no
91-
# data serializer is provided.
108+
# Serializer[str] if no data serializer is provided, or if it returns str.
92109
@t.overload
93110
def __init__(
94111
self: Serializer[str],
95112
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
96113
salt: str | bytes | None = b"itsdangerous",
97-
serializer: None = None,
114+
serializer: None | _PDataSerializer[str] = None,
115+
serializer_kwargs: dict[str, t.Any] | None = None,
116+
signer: type[Signer] | None = None,
117+
signer_kwargs: dict[str, t.Any] | None = None,
118+
fallback_signers: list[
119+
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
120+
]
121+
| None = None,
122+
): ...
123+
124+
# Serializer[bytes] with a bytes data serializer positional argument.
125+
@t.overload
126+
def __init__(
127+
self: Serializer[bytes],
128+
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
129+
salt: str | bytes | None,
130+
serializer: _PDataSerializer[bytes],
131+
serializer_kwargs: dict[str, t.Any] | None = None,
132+
signer: type[Signer] | None = None,
133+
signer_kwargs: dict[str, t.Any] | None = None,
134+
fallback_signers: list[
135+
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
136+
]
137+
| None = None,
138+
): ...
139+
140+
# Serializer[bytes] with a bytes data serializer keyword argument.
141+
@t.overload
142+
def __init__(
143+
self: Serializer[bytes],
144+
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
145+
salt: str | bytes | None = b"itsdangerous",
146+
*,
147+
serializer: _PDataSerializer[bytes],
148+
serializer_kwargs: dict[str, t.Any] | None = None,
149+
signer: type[Signer] | None = None,
150+
signer_kwargs: dict[str, t.Any] | None = None,
151+
fallback_signers: list[
152+
dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer]
153+
]
154+
| None = None,
155+
): ...
156+
157+
# Fall back with a positional argument. If the strict signature of
158+
# _PDataSerializer doesn't match, fall back to a union, requiring the user
159+
# to specify the type.
160+
@t.overload
161+
def __init__(
162+
self,
163+
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
164+
salt: str | bytes | None,
165+
serializer: t.Any,
98166
serializer_kwargs: dict[str, t.Any] | None = None,
99167
signer: type[Signer] | None = None,
100168
signer_kwargs: dict[str, t.Any] | None = None,
@@ -104,12 +172,14 @@ def __init__(
104172
| None = None,
105173
): ...
106174

175+
# Fall back with a keyword argument.
107176
@t.overload
108177
def __init__(
109-
self: Serializer[t.AnyStr],
178+
self,
110179
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
111180
salt: str | bytes | None = b"itsdangerous",
112-
serializer: _PDataSerializer[t.AnyStr] = ...,
181+
*,
182+
serializer: t.Any,
113183
serializer_kwargs: dict[str, t.Any] | None = None,
114184
signer: type[Signer] | None = None,
115185
signer_kwargs: dict[str, t.Any] | None = None,
@@ -123,7 +193,7 @@ def __init__(
123193
self,
124194
secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes],
125195
salt: str | bytes | None = b"itsdangerous",
126-
serializer: _PDataSerializer[t.AnyStr] | None = None,
196+
serializer: t.Any | None = None,
127197
serializer_kwargs: dict[str, t.Any] | None = None,
128198
signer: type[Signer] | None = None,
129199
signer_kwargs: dict[str, t.Any] | None = None,
@@ -148,7 +218,7 @@ def __init__(
148218
if serializer is None:
149219
serializer = self.default_serializer
150220

151-
self.serializer: _PDataSerializer[t.AnyStr] = serializer
221+
self.serializer: _PDataSerializer[_TSerialized] = serializer
152222
self.is_text_serializer: bool = is_text_serializer(serializer)
153223

154224
if signer is None:
@@ -238,7 +308,7 @@ def iter_unsigners(self, salt: str | bytes | None = None) -> cabc.Iterator[Signe
238308
for secret_key in self.secret_keys:
239309
yield fallback(secret_key, salt=salt, **kwargs)
240310

241-
def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> t.AnyStr:
311+
def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> _TSerialized:
242312
"""Returns a signed string serialized with the internal
243313
serializer. The return value can be either a byte or unicode
244314
string depending on the format of the internal serializer.

src/itsdangerous/timed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .exc import BadSignature
1515
from .exc import BadTimeSignature
1616
from .exc import SignatureExpired
17+
from .serializer import _TSerialized
1718
from .serializer import Serializer
1819
from .signer import Signer
1920

@@ -166,7 +167,7 @@ def validate(self, signed_value: str | bytes, max_age: int | None = None) -> boo
166167
return False
167168

168169

169-
class TimedSerializer(Serializer[t.AnyStr]):
170+
class TimedSerializer(Serializer[_TSerialized]):
170171
"""Uses :class:`TimestampSigner` instead of the default
171172
:class:`.Signer`.
172173
"""

0 commit comments

Comments
 (0)