Skip to content

Commit f0e1510

Browse files
authored
Merge pull request #105 from oda02/convert-query-params-fix
fix: added _is_sequence() to support more sequence types in models and tests for new cases
2 parents 21a4ea5 + 25c3b00 commit f0e1510

File tree

4 files changed

+143
-35
lines changed

4 files changed

+143
-35
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
fail-fast: false
1414
matrix:
15-
python-version: ["3.7", "3.8", "3.9", "3.10"]
15+
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
1616
os: [ubuntu-latest, macOS-latest]
1717
# Python 3.7 is not supported on Apple ARM64,
1818
# or the latest Ubuntu 2404

flask_pydantic/converters.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Type, Union
1+
import types
2+
from collections import deque
3+
from typing import Deque, FrozenSet, List, Sequence, Set, Tuple, Type, Union
24

35
try:
46
from typing import get_args, get_origin
@@ -10,15 +12,29 @@
1012
from werkzeug.datastructures import ImmutableMultiDict
1113

1214
V1OrV2BaseModel = Union[BaseModel, V1BaseModel]
15+
UnionType = getattr(types, "UnionType", Union)
1316

17+
sequence_types = {
18+
Sequence,
19+
List,
20+
list,
21+
Tuple,
22+
tuple,
23+
Set,
24+
set,
25+
FrozenSet,
26+
frozenset,
27+
Deque,
28+
deque,
29+
}
1430

15-
def _is_list(type_: Type) -> bool:
16-
origin = get_origin(type_)
17-
if origin is list:
18-
return True
19-
if origin is Union:
20-
return any(_is_list(t) for t in get_args(type_))
21-
return False
31+
32+
def _is_sequence(type_: Type) -> bool:
33+
origin = get_origin(type_) or type_
34+
if origin is Union or origin is UnionType:
35+
return any(_is_sequence(t) for t in get_args(type_))
36+
37+
return origin in sequence_types and origin not in (str, bytes)
2238

2339

2440
def convert_query_params(
@@ -38,7 +54,7 @@ def convert_query_params(
3854
key: value
3955
for key, value in query_params.to_dict(flat=False).items()
4056
if key in model.model_fields
41-
and _is_list(model.model_fields[key].annotation)
57+
and _is_sequence(model.model_fields[key].annotation)
4258
},
4359
}
4460
else:

tests/unit/test_core.py

Lines changed: 116 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
2-
from typing import Any, List, NamedTuple, Optional, Type, Union
2+
import sys
3+
from typing import Any, List, NamedTuple, Optional, Tuple, Type, Union
34

45
import pytest
56
from flask import jsonify
@@ -15,13 +16,18 @@
1516
from ..util import assert_matches
1617

1718

19+
class EmptyModel(BaseModel):
20+
pass
21+
22+
1823
class ValidateParams(NamedTuple):
19-
body_model: Optional[Type[BaseModel]] = None
20-
query_model: Optional[Type[BaseModel]] = None
21-
form_model: Optional[Type[BaseModel]] = None
22-
response_model: Type[BaseModel] = None
24+
body_model: Type[BaseModel] = EmptyModel
25+
query_model: Type[BaseModel] = EmptyModel
26+
form_model: Type[BaseModel] = EmptyModel
27+
response_model: Type[BaseModel] = EmptyModel
2328
on_success_status: int = 200
2429
request_query: ImmutableMultiDict = ImmutableMultiDict({})
30+
flat_request_query: bool = True
2531
request_body: Union[dict, List[dict]] = {}
2632
request_form: ImmutableMultiDict = ImmutableMultiDict({})
2733
expected_response_body: Optional[dict] = None
@@ -50,7 +56,25 @@ class RequestBodyModel(BaseModel):
5056

5157
class FormModel(BaseModel):
5258
f1: int
53-
f2: str = None
59+
f2: Optional[str] = None
60+
61+
62+
class RequestWithIterableModel(BaseModel):
63+
b1: List
64+
b2: List[str]
65+
b3: Tuple[str, int]
66+
b4: Optional[List[int]] = None
67+
b5: Union[Tuple[str, int], None] = None
68+
69+
70+
if sys.version_info >= (3, 10):
71+
# New Python(>=3.10) syntax tests
72+
class RequestWithIterableModelPy310(BaseModel):
73+
b1: list
74+
b2: list[str]
75+
b3: tuple[str, int]
76+
b4: list[int] | None = None
77+
b5: tuple[str, int] | None = None
5478

5579

5680
class RequestBodyModelRoot(RootModel):
@@ -195,8 +219,76 @@ class RequestBodyModelRoot(RootModel):
195219
),
196220
id="invalid form param",
197221
),
222+
pytest.param(
223+
ValidateParams(
224+
request_query=ImmutableMultiDict(
225+
[
226+
("b1", "str1"),
227+
("b1", "str2"),
228+
("b2", "str1"),
229+
("b2", "str2"),
230+
("b3", "str"),
231+
("b3", 123),
232+
("b4", 1),
233+
("b4", 2),
234+
("b4", 3),
235+
("b5", "str"),
236+
("b5", 321),
237+
]
238+
),
239+
flat_request_query=False,
240+
expected_response_body={
241+
"b1": ["str1", "str2"],
242+
"b2": ["str1", "str2"],
243+
"b3": ("str", 123),
244+
"b4": [1, 2, 3],
245+
"b5": ("str", 321),
246+
},
247+
query_model=RequestWithIterableModel,
248+
response_model=RequestWithIterableModel,
249+
expected_status_code=200,
250+
),
251+
id="iterable and Optional[Iterable] fields in pydantic model in query",
252+
),
198253
]
199254

255+
if sys.version_info >= (3, 10):
256+
validate_test_cases.extend(
257+
[
258+
pytest.param(
259+
ValidateParams(
260+
request_query=ImmutableMultiDict(
261+
[
262+
("b1", "str1"),
263+
("b1", "str2"),
264+
("b2", "str1"),
265+
("b2", "str2"),
266+
("b3", "str"),
267+
("b3", 123),
268+
("b4", 1),
269+
("b4", 2),
270+
("b4", 3),
271+
("b5", "str"),
272+
("b5", 321),
273+
]
274+
),
275+
flat_request_query=False,
276+
expected_response_body={
277+
"b1": ["str1", "str2"],
278+
"b2": ["str1", "str2"],
279+
"b3": ("str", 123),
280+
"b4": [1, 2, 3],
281+
"b5": ("str", 321),
282+
},
283+
query_model=RequestWithIterableModelPy310,
284+
response_model=RequestWithIterableModelPy310,
285+
expected_status_code=200,
286+
),
287+
id="iterable and Iterable | None fields in pydantic model in query (Python 3.10+)",
288+
),
289+
]
290+
)
291+
200292

201293
class TestValidate:
202294
@pytest.mark.parametrize("parameters", validate_test_cases)
@@ -230,17 +322,17 @@ def f():
230322
assert response.status_code == parameters.expected_status_code
231323
assert_matches(parameters.expected_response_body, response.json)
232324
if 200 <= response.status_code < 300:
233-
assert (
325+
assert_matches(
326+
parameters.request_body,
234327
mock_request.body_params.model_dump(
235328
exclude_none=True, exclude_defaults=True
236-
)
237-
== parameters.request_body
329+
),
238330
)
239-
assert (
331+
assert_matches(
332+
parameters.request_query.to_dict(flat=parameters.flat_request_query),
240333
mock_request.query_params.model_dump(
241334
exclude_none=True, exclude_defaults=True
242-
)
243-
== parameters.request_query.to_dict()
335+
),
244336
)
245337

246338
@pytest.mark.parametrize("parameters", validate_test_cases)
@@ -269,17 +361,17 @@ def f(
269361
assert_matches(parameters.expected_response_body, response.json)
270362
assert response.status_code == parameters.expected_status_code
271363
if 200 <= response.status_code < 300:
272-
assert (
364+
assert_matches(
365+
parameters.request_body,
273366
mock_request.body_params.model_dump(
274367
exclude_none=True, exclude_defaults=True
275-
)
276-
== parameters.request_body
368+
),
277369
)
278-
assert (
370+
assert_matches(
371+
parameters.request_query.to_dict(flat=parameters.flat_request_query),
279372
mock_request.query_params.model_dump(
280373
exclude_none=True, exclude_defaults=True
281-
)
282-
== parameters.request_query.to_dict()
374+
),
283375
)
284376

285377
@pytest.mark.usefixtures("request_ctx")
@@ -468,17 +560,17 @@ def f() -> Any:
468560
assert response.status_code == parameters.expected_status_code
469561
assert_matches(parameters.expected_response_body, response.json)
470562
if 200 <= response.status_code < 300:
471-
assert (
563+
assert_matches(
564+
parameters.request_body,
472565
mock_request.body_params.model_dump(
473566
exclude_none=True, exclude_defaults=True
474-
)
475-
== parameters.request_body
567+
),
476568
)
477-
assert (
569+
assert_matches(
570+
parameters.request_query.to_dict(flat=parameters.flat_request_query),
478571
mock_request.query_params.model_dump(
479572
exclude_none=True, exclude_defaults=True
480-
)
481-
== parameters.request_query.to_dict()
573+
),
482574
)
483575

484576
def test_fail_validation_custom_status_code(self, app, request_ctx, mocker):

tests/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def assert_matches(expected: ExpectedType, actual: ActualType):
2121
assert set(expected.keys()) == set(actual.keys())
2222
for key, value in expected.items():
2323
assert_matches(value, actual[key])
24-
elif isinstance(expected, list):
24+
elif isinstance(expected, (list, tuple)):
2525
assert len(expected) == len(actual)
2626
for a, b in zip(expected, actual):
2727
assert_matches(a, b)

0 commit comments

Comments
 (0)