Skip to content

Commit f8e96bb

Browse files
puittenbroekPeter Uittenbroek
andauthored
Make validate_request pass on existing kwargs but remove those part of path (#227)
* Make validate_request pass on existing kwargs but remove those part of path * some more tests * cleanup comments * Add example in usage doc --------- Co-authored-by: Peter Uittenbroek <[email protected]>
1 parent 1b6a213 commit f8e96bb

File tree

3 files changed

+175
-20
lines changed

3 files changed

+175
-20
lines changed

docs/Usage/Request.md

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,17 @@ Sometimes you want to delay the verification request parameters, such as after l
111111
from flask_openapi3 import validate_request
112112

113113

114-
def login_required(func):
115-
@wraps(func)
116-
def wrapper(*args, **kwargs):
117-
print("login_required ...")
118-
return func(*args, **kwargs)
114+
def login_required():
115+
def decorator(func):
116+
@wraps(func)
117+
def wrapper(*args, **kwargs):
118+
if not request.headers.get("Authorization"):
119+
return {"error": "Unauthorized"}, 401
120+
return func(*args, **kwargs)
119121

120-
return wrapper
122+
return wrapper
123+
124+
return decorator
121125

122126

123127
@app.get("/book")
@@ -127,6 +131,42 @@ def get_book(query: BookQuery):
127131
...
128132
```
129133

134+
### Custom kwargs are maintained
135+
136+
When your 'auth decorator' injects custom kwargs, these will be passed on to the final function for you to use.
137+
138+
Any kwargs which are part of the 'path' will have been consumed at this point and can only be referenced using the `path`.
139+
140+
So avoid using kwarg-names which overlap with the path.
141+
142+
```python
143+
from flask_openapi3 import validate_request
144+
from functools import wraps
145+
146+
147+
def login_required():
148+
def decorator(func):
149+
@wraps(func)
150+
def wrapper(*args, **kwargs):
151+
if not request.headers.get("Authorization"):
152+
return {"error": "Unauthorized"}, 401
153+
kwargs["client_id"] = "client1234565"
154+
return func(*args, **kwargs)
155+
156+
return wrapper
157+
158+
return decorator
159+
160+
161+
162+
@app.get("/book")
163+
@login_required()
164+
@validate_request()
165+
def get_book(query: BookQuery, client_id:str = None):
166+
print(f"Current user identified as {client_id}")
167+
...
168+
```
169+
130170
## Request model
131171

132172
First, you need to define a [pydantic](https://github.com/pydantic/pydantic) model:

flask_openapi3/request.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import json
66
from functools import wraps
77
from json import JSONDecodeError
8-
from typing import Any, Type, Optional
8+
from typing import Any, Optional, Type
99

10-
from flask import request, current_app, abort
11-
from pydantic import ValidationError, BaseModel
10+
from flask import abort, current_app, request
11+
from pydantic import BaseModel, ValidationError
1212
from pydantic.fields import FieldInfo
1313
from werkzeug.datastructures.structures import MultiDict
1414

@@ -78,7 +78,11 @@ def _validate_cookie(cookie: Type[BaseModel], func_kwargs: dict):
7878

7979

8080
def _validate_path(path: Type[BaseModel], path_kwargs: dict, func_kwargs: dict):
81-
func_kwargs["path"] = path.model_validate(obj=path_kwargs)
81+
path_obj = path.model_validate(obj=path_kwargs)
82+
func_kwargs["path"] = path_obj
83+
# Consume path parameters to prevent from being passed to the function
84+
for field_name, _ in path_obj:
85+
path_kwargs.pop(field_name, None)
8286

8387

8488
def _validate_query(query: Type[BaseModel], func_kwargs: dict):
@@ -151,14 +155,14 @@ def _validate_body(body: Type[BaseModel], func_kwargs: dict):
151155

152156

153157
def _validate_request(
154-
header: Optional[Type[BaseModel]] = None,
155-
cookie: Optional[Type[BaseModel]] = None,
156-
path: Optional[Type[BaseModel]] = None,
157-
query: Optional[Type[BaseModel]] = None,
158-
form: Optional[Type[BaseModel]] = None,
159-
body: Optional[Type[BaseModel]] = None,
160-
raw: Optional[Type[BaseModel]] = None,
161-
path_kwargs: Optional[dict[Any, Any]] = None
158+
header: Optional[Type[BaseModel]] = None,
159+
cookie: Optional[Type[BaseModel]] = None,
160+
path: Optional[Type[BaseModel]] = None,
161+
query: Optional[Type[BaseModel]] = None,
162+
form: Optional[Type[BaseModel]] = None,
163+
body: Optional[Type[BaseModel]] = None,
164+
raw: Optional[Type[BaseModel]] = None,
165+
path_kwargs: Optional[dict[Any, Any]] = None,
162166
) -> dict:
163167
"""
164168
Validate requests and responses.
@@ -212,7 +216,6 @@ def validate_request():
212216
"""
213217

214218
def decorator(func):
215-
216219
setattr(func, "__delay_validate_request__", True)
217220

218221
is_coroutine_function = inspect.iscoroutinefunction(func)
@@ -223,6 +226,8 @@ def decorator(func):
223226
async def wrapper(*args, **kwargs):
224227
header, cookie, path, query, form, body, raw = parse_parameters(func)
225228
func_kwargs = _validate_request(header, cookie, path, query, form, body, raw, path_kwargs=kwargs)
229+
# Update func_kwargs with any additional keyword arguments passed from other decorators or calls.
230+
func_kwargs.update(kwargs)
226231

227232
return await func(*args, **func_kwargs)
228233

@@ -233,7 +238,8 @@ async def wrapper(*args, **kwargs):
233238
def wrapper(*args, **kwargs):
234239
header, cookie, path, query, form, body, raw = parse_parameters(func)
235240
func_kwargs = _validate_request(header, cookie, path, query, form, body, raw, path_kwargs=kwargs)
236-
241+
# Update func_kwargs with any additional keyword arguments passed from other decorators or calls.
242+
func_kwargs.update(kwargs)
237243
return func(*args, **func_kwargs)
238244

239245
return wrapper

tests/test_validate_request.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from functools import wraps
2+
from typing import Optional
3+
4+
import pytest
5+
from flask import request
6+
from pydantic import BaseModel, Field
7+
8+
from flask_openapi3 import APIView, Info, OpenAPI, Tag
9+
from flask_openapi3.request import validate_request
10+
11+
12+
class BookNamePath(BaseModel):
13+
name: str
14+
15+
16+
class BookBody(BaseModel):
17+
age: Optional[int] = Field(..., ge=2, le=4, description="Age")
18+
author: str = Field(None, min_length=2, max_length=4, description="Author")
19+
name: str
20+
21+
22+
def login_required():
23+
def decorator(func):
24+
@wraps(func)
25+
def wrapper(*args, **kwargs):
26+
if not request.headers.get("Authorization"):
27+
return {"error": "Unauthorized"}, 401
28+
kwargs["client_id"] = "client1234565"
29+
return func(*args, **kwargs)
30+
31+
return wrapper
32+
33+
return decorator
34+
35+
36+
@pytest.fixture
37+
def app():
38+
app = OpenAPI(__name__)
39+
app.config["TESTING"] = True
40+
41+
info = Info(title="book API", version="1.0.0")
42+
jwt = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
43+
security_schemes = {"jwt": jwt}
44+
45+
app = OpenAPI(__name__, info=info, security_schemes=security_schemes)
46+
app.config["TESTING"] = True
47+
security = [{"jwt": []}]
48+
49+
api_view = APIView(url_prefix="/v1/books", view_tags=[Tag(name="book")], view_security=security)
50+
51+
@api_view.route("")
52+
class BookListAPIView:
53+
@api_view.doc(summary="get book list", responses={204: None}, doc_ui=False)
54+
@login_required()
55+
@validate_request()
56+
def get(self, client_id: str):
57+
return {"books": ["book1", "book2"], "client_id": client_id}
58+
59+
@api_view.doc(summary="create book")
60+
@login_required()
61+
@validate_request()
62+
def post(self, body: BookBody, client_id):
63+
"""description for a created book"""
64+
return body.model_dump_json()
65+
66+
@api_view.route("/<name>")
67+
class BookNameAPIView:
68+
@api_view.doc(summary="get book by name")
69+
@login_required()
70+
@validate_request()
71+
def get(self, path: BookNamePath, client_id):
72+
return {"name": path.name, "client_id": client_id}
73+
74+
app.register_api_view(api_view)
75+
return app
76+
77+
78+
@pytest.fixture
79+
def client(app):
80+
client = app.test_client()
81+
82+
return client
83+
84+
85+
def test_get_book_list_happy(app, client):
86+
response = client.get("/v1/books", headers={"Authorization": "Bearer sometoken"})
87+
assert response.status_code == 200
88+
assert response.json == {"books": ["book1", "book2"], "client_id": "client1234565"}
89+
90+
91+
def test_get_book_list_not_auth(app, client):
92+
response = client.get("/v1/books", headers={"Nope": "Bearer sometoken"})
93+
assert response.status_code == 401
94+
assert response.json == {"error": "Unauthorized"}
95+
96+
97+
def test_create_book_happy(app, client):
98+
response = client.post(
99+
"/v1/books",
100+
json={"age": 3, "author": "John", "name": "some_book_name"},
101+
headers={"Authorization": "Bearer sometoken"},
102+
)
103+
assert response.status_code == 200
104+
105+
106+
def test_get_book_detail_happy(app, client):
107+
response = client.get("/v1/books/some_book_name", headers={"Authorization": "Bearer sometoken"})
108+
assert response.status_code == 200
109+
assert response.json == {"name": "some_book_name", "client_id": "client1234565"}

0 commit comments

Comments
 (0)