|
1 | 1 | import re
|
2 |
| -from typing import Any, List, NamedTuple, Optional, Type, Union |
| 2 | +import sys |
| 3 | +from typing import Any, List, NamedTuple, Optional, Tuple, Type, Union |
3 | 4 |
|
4 | 5 | import pytest
|
5 | 6 | from flask import jsonify
|
|
15 | 16 | from ..util import assert_matches
|
16 | 17 |
|
17 | 18 |
|
| 19 | +class EmptyModel(BaseModel): |
| 20 | + pass |
| 21 | + |
| 22 | + |
18 | 23 | class ValidateParams(NamedTuple):
|
19 |
| - body_model: Optional[Type[BaseModel]] = None |
20 |
| - query_model: Optional[Type[BaseModel]] = None |
21 |
| - form_model: Optional[Type[BaseModel]] = None |
22 |
| - response_model: Type[BaseModel] = None |
| 24 | + body_model: Type[BaseModel] = EmptyModel |
| 25 | + query_model: Type[BaseModel] = EmptyModel |
| 26 | + form_model: Type[BaseModel] = EmptyModel |
| 27 | + response_model: Type[BaseModel] = EmptyModel |
23 | 28 | on_success_status: int = 200
|
24 | 29 | request_query: ImmutableMultiDict = ImmutableMultiDict({})
|
| 30 | + flat_request_query: bool = True |
25 | 31 | request_body: Union[dict, List[dict]] = {}
|
26 | 32 | request_form: ImmutableMultiDict = ImmutableMultiDict({})
|
27 | 33 | expected_response_body: Optional[dict] = None
|
@@ -50,7 +56,25 @@ class RequestBodyModel(BaseModel):
|
50 | 56 |
|
51 | 57 | class FormModel(BaseModel):
|
52 | 58 | f1: int
|
53 |
| - f2: str = None |
| 59 | + f2: Optional[str] = None |
| 60 | + |
| 61 | + |
| 62 | +class RequestWithIterableModel(BaseModel): |
| 63 | + b1: List |
| 64 | + b2: List[str] |
| 65 | + b3: Tuple[str, int] |
| 66 | + b4: Optional[List[int]] = None |
| 67 | + b5: Union[Tuple[str, int], None] = None |
| 68 | + |
| 69 | + |
| 70 | +if sys.version_info >= (3, 10): |
| 71 | + # New Python(>=3.10) syntax tests |
| 72 | + class RequestWithIterableModelPy310(BaseModel): |
| 73 | + b1: list |
| 74 | + b2: list[str] |
| 75 | + b3: tuple[str, int] |
| 76 | + b4: list[int] | None = None |
| 77 | + b5: tuple[str, int] | None = None |
54 | 78 |
|
55 | 79 |
|
56 | 80 | class RequestBodyModelRoot(RootModel):
|
@@ -195,8 +219,76 @@ class RequestBodyModelRoot(RootModel):
|
195 | 219 | ),
|
196 | 220 | id="invalid form param",
|
197 | 221 | ),
|
| 222 | + pytest.param( |
| 223 | + ValidateParams( |
| 224 | + request_query=ImmutableMultiDict( |
| 225 | + [ |
| 226 | + ("b1", "str1"), |
| 227 | + ("b1", "str2"), |
| 228 | + ("b2", "str1"), |
| 229 | + ("b2", "str2"), |
| 230 | + ("b3", "str"), |
| 231 | + ("b3", 123), |
| 232 | + ("b4", 1), |
| 233 | + ("b4", 2), |
| 234 | + ("b4", 3), |
| 235 | + ("b5", "str"), |
| 236 | + ("b5", 321), |
| 237 | + ] |
| 238 | + ), |
| 239 | + flat_request_query=False, |
| 240 | + expected_response_body={ |
| 241 | + "b1": ["str1", "str2"], |
| 242 | + "b2": ["str1", "str2"], |
| 243 | + "b3": ("str", 123), |
| 244 | + "b4": [1, 2, 3], |
| 245 | + "b5": ("str", 321), |
| 246 | + }, |
| 247 | + query_model=RequestWithIterableModel, |
| 248 | + response_model=RequestWithIterableModel, |
| 249 | + expected_status_code=200, |
| 250 | + ), |
| 251 | + id="iterable and Optional[Iterable] fields in pydantic model in query", |
| 252 | + ), |
198 | 253 | ]
|
199 | 254 |
|
| 255 | +if sys.version_info >= (3, 10): |
| 256 | + validate_test_cases.extend( |
| 257 | + [ |
| 258 | + pytest.param( |
| 259 | + ValidateParams( |
| 260 | + request_query=ImmutableMultiDict( |
| 261 | + [ |
| 262 | + ("b1", "str1"), |
| 263 | + ("b1", "str2"), |
| 264 | + ("b2", "str1"), |
| 265 | + ("b2", "str2"), |
| 266 | + ("b3", "str"), |
| 267 | + ("b3", 123), |
| 268 | + ("b4", 1), |
| 269 | + ("b4", 2), |
| 270 | + ("b4", 3), |
| 271 | + ("b5", "str"), |
| 272 | + ("b5", 321), |
| 273 | + ] |
| 274 | + ), |
| 275 | + flat_request_query=False, |
| 276 | + expected_response_body={ |
| 277 | + "b1": ["str1", "str2"], |
| 278 | + "b2": ["str1", "str2"], |
| 279 | + "b3": ("str", 123), |
| 280 | + "b4": [1, 2, 3], |
| 281 | + "b5": ("str", 321), |
| 282 | + }, |
| 283 | + query_model=RequestWithIterableModelPy310, |
| 284 | + response_model=RequestWithIterableModelPy310, |
| 285 | + expected_status_code=200, |
| 286 | + ), |
| 287 | + id="iterable and Iterable | None fields in pydantic model in query (Python 3.10+)", |
| 288 | + ), |
| 289 | + ] |
| 290 | + ) |
| 291 | + |
200 | 292 |
|
201 | 293 | class TestValidate:
|
202 | 294 | @pytest.mark.parametrize("parameters", validate_test_cases)
|
@@ -230,17 +322,17 @@ def f():
|
230 | 322 | assert response.status_code == parameters.expected_status_code
|
231 | 323 | assert_matches(parameters.expected_response_body, response.json)
|
232 | 324 | if 200 <= response.status_code < 300:
|
233 |
| - assert ( |
| 325 | + assert_matches( |
| 326 | + parameters.request_body, |
234 | 327 | mock_request.body_params.model_dump(
|
235 | 328 | exclude_none=True, exclude_defaults=True
|
236 |
| - ) |
237 |
| - == parameters.request_body |
| 329 | + ), |
238 | 330 | )
|
239 |
| - assert ( |
| 331 | + assert_matches( |
| 332 | + parameters.request_query.to_dict(flat=parameters.flat_request_query), |
240 | 333 | mock_request.query_params.model_dump(
|
241 | 334 | exclude_none=True, exclude_defaults=True
|
242 |
| - ) |
243 |
| - == parameters.request_query.to_dict() |
| 335 | + ), |
244 | 336 | )
|
245 | 337 |
|
246 | 338 | @pytest.mark.parametrize("parameters", validate_test_cases)
|
@@ -269,17 +361,17 @@ def f(
|
269 | 361 | assert_matches(parameters.expected_response_body, response.json)
|
270 | 362 | assert response.status_code == parameters.expected_status_code
|
271 | 363 | if 200 <= response.status_code < 300:
|
272 |
| - assert ( |
| 364 | + assert_matches( |
| 365 | + parameters.request_body, |
273 | 366 | mock_request.body_params.model_dump(
|
274 | 367 | exclude_none=True, exclude_defaults=True
|
275 |
| - ) |
276 |
| - == parameters.request_body |
| 368 | + ), |
277 | 369 | )
|
278 |
| - assert ( |
| 370 | + assert_matches( |
| 371 | + parameters.request_query.to_dict(flat=parameters.flat_request_query), |
279 | 372 | mock_request.query_params.model_dump(
|
280 | 373 | exclude_none=True, exclude_defaults=True
|
281 |
| - ) |
282 |
| - == parameters.request_query.to_dict() |
| 374 | + ), |
283 | 375 | )
|
284 | 376 |
|
285 | 377 | @pytest.mark.usefixtures("request_ctx")
|
@@ -468,17 +560,17 @@ def f() -> Any:
|
468 | 560 | assert response.status_code == parameters.expected_status_code
|
469 | 561 | assert_matches(parameters.expected_response_body, response.json)
|
470 | 562 | if 200 <= response.status_code < 300:
|
471 |
| - assert ( |
| 563 | + assert_matches( |
| 564 | + parameters.request_body, |
472 | 565 | mock_request.body_params.model_dump(
|
473 | 566 | exclude_none=True, exclude_defaults=True
|
474 |
| - ) |
475 |
| - == parameters.request_body |
| 567 | + ), |
476 | 568 | )
|
477 |
| - assert ( |
| 569 | + assert_matches( |
| 570 | + parameters.request_query.to_dict(flat=parameters.flat_request_query), |
478 | 571 | mock_request.query_params.model_dump(
|
479 | 572 | exclude_none=True, exclude_defaults=True
|
480 |
| - ) |
481 |
| - == parameters.request_query.to_dict() |
| 573 | + ), |
482 | 574 | )
|
483 | 575 |
|
484 | 576 | def test_fail_validation_custom_status_code(self, app, request_ctx, mocker):
|
|
0 commit comments