Skip to content

Commit dbf411b

Browse files
committed
Support multi content type in request body and responses
1 parent ea1d884 commit dbf411b

File tree

10 files changed

+315
-139
lines changed

10 files changed

+315
-139
lines changed

examples/multi_content_type.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -*- coding: utf-8 -*-
2+
# @Author : llc
3+
# @Time : 2024/12/27 15:30
4+
from pydantic import BaseModel
5+
6+
from flask_openapi3 import OpenAPI
7+
8+
app = OpenAPI(__name__)
9+
10+
11+
class DogBody(BaseModel):
12+
a: int = None
13+
b: str = None
14+
15+
model_config = {
16+
"openapi_extra": {
17+
"content_type": "application/vnd.dog+json"
18+
}
19+
}
20+
21+
22+
class CatBody(BaseModel):
23+
c: int = None
24+
d: str = None
25+
26+
model_config = {
27+
"openapi_extra": {
28+
"content_type": "application/vnd.cat+json"
29+
}
30+
}
31+
32+
33+
class ContentTypeModel(BaseModel):
34+
model_config = {
35+
"openapi_extra": {
36+
"content_type": "text/csv"
37+
}
38+
}
39+
40+
41+
@app.post("/a", responses={200: DogBody | CatBody | ContentTypeModel})
42+
def index_a(body: DogBody | CatBody | ContentTypeModel):
43+
print(body)
44+
return {"hello": "world"}
45+
46+
47+
if __name__ == '__main__':
48+
app.run(debug=True)

flask_openapi3/blueprint.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def _collect_openapi_info(
121121
security: Optional[list[dict[str, list[Any]]]] = None,
122122
servers: Optional[list[Server]] = None,
123123
openapi_extensions: Optional[dict[str, Any]] = None,
124+
request_body_description: Optional[str] = None,
125+
request_body_required: Optional[bool] = True,
124126
doc_ui: bool = True,
125127
method: str = HTTPMethod.GET
126128
) -> ParametersTuple:
@@ -140,6 +142,8 @@ def _collect_openapi_info(
140142
security: A declaration of which security mechanisms can be used for this operation.
141143
servers: An alternative server array to service this operation.
142144
openapi_extensions: Allows extensions to the OpenAPI Schema.
145+
request_body_description: A brief description of the request body.
146+
request_body_required: Determines if the request body is required in the request.
143147
doc_ui: Declares this operation to be shown. Default to True.
144148
"""
145149
if self.doc_ui is True and doc_ui is True:
@@ -193,6 +197,11 @@ def _collect_openapi_info(
193197
parse_method(uri, method, self.paths, operation)
194198

195199
# Parse parameters
196-
return parse_parameters(func, components_schemas=self.components_schemas, operation=operation)
200+
return parse_parameters(
201+
func, components_schemas=self.components_schemas,
202+
operation=operation,
203+
request_body_description=request_body_description,
204+
request_body_required=request_body_required
205+
)
197206
else:
198207
return parse_parameters(func, doc_ui=False)

flask_openapi3/openapi.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ def _collect_openapi_info(
380380
security: Optional[list[dict[str, list[Any]]]] = None,
381381
servers: Optional[list[Server]] = None,
382382
openapi_extensions: Optional[dict[str, Any]] = None,
383+
request_body_description: Optional[str] = None,
384+
request_body_required: Optional[bool] = True,
383385
doc_ui: bool = True,
384386
method: str = HTTPMethod.GET
385387
) -> ParametersTuple:
@@ -399,6 +401,8 @@ def _collect_openapi_info(
399401
security: A declaration of which security mechanisms can be used for this operation.
400402
servers: An alternative server array to service this operation.
401403
openapi_extensions: Allows extensions to the OpenAPI Schema.
404+
request_body_description: A brief description of the request body.
405+
request_body_required: Determines if the request body is required in the request.
402406
doc_ui: Declares this operation to be shown. Default to True.
403407
method: HTTP method for the operation. Defaults to GET.
404408
"""
@@ -450,6 +454,11 @@ def _collect_openapi_info(
450454
parse_method(uri, method, self.paths, operation)
451455

452456
# Parse parameters
453-
return parse_parameters(func, components_schemas=self.components_schemas, operation=operation)
457+
return parse_parameters(
458+
func, components_schemas=self.components_schemas,
459+
operation=operation,
460+
request_body_description=request_body_description,
461+
request_body_required=request_body_required
462+
)
454463
else:
455464
return parse_parameters(func, doc_ui=False)

flask_openapi3/request.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,22 @@
33
# @Time : 2022/4/1 16:54
44
import json
55
from json import JSONDecodeError
6-
from typing import Any, Type, Optional
6+
7+
from typing import Any, Type, Optional, get_origin, get_args, Union
8+
9+
try:
10+
from types import UnionType # type: ignore
11+
except ImportError:
12+
# python < 3.9
13+
UnionType = type(Union) # type: ignore
714

815
from flask import request, current_app, abort
9-
from pydantic import ValidationError, BaseModel
16+
from pydantic import ValidationError, BaseModel, RootModel
1017
from pydantic.fields import FieldInfo
1118
from werkzeug.datastructures.structures import MultiDict
1219

20+
from flask_openapi3.utils import is_application_json
21+
1322

1423
def _get_list_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, model_field_value: FieldInfo):
1524
if model_field_value.alias and model.model_config.get("populate_by_name"):
@@ -138,12 +147,20 @@ def _validate_form(form: Type[BaseModel], func_kwargs: dict):
138147

139148

140149
def _validate_body(body: Type[BaseModel], func_kwargs: dict):
141-
obj = request.get_json(silent=True)
142-
if isinstance(obj, str):
143-
body_model = body.model_validate_json(json_data=obj)
150+
if is_application_json(request.mimetype):
151+
if get_origin(body) == UnionType:
152+
root_model_list = [model for model in get_args(body)]
153+
Body = RootModel[Union[tuple(root_model_list)]] # type: ignore
154+
else:
155+
Body = body # type: ignore
156+
obj = request.get_json(silent=True)
157+
if isinstance(obj, str):
158+
body_model = Body.model_validate_json(json_data=obj)
159+
else:
160+
body_model = Body.model_validate(obj=obj)
161+
func_kwargs["body"] = body_model
144162
else:
145-
body_model = body.model_validate(obj=obj)
146-
func_kwargs["body"] = body_model
163+
func_kwargs["body"] = request
147164

148165

149166
def _validate_request(

flask_openapi3/scaffold.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ def _collect_openapi_info(
3232
security: Optional[list[dict[str, list[Any]]]] = None,
3333
servers: Optional[list[Server]] = None,
3434
openapi_extensions: Optional[dict[str, Any]] = None,
35+
request_body_description: Optional[str] = None,
36+
request_body_required: Optional[bool] = True,
3537
doc_ui: bool = True,
3638
method: str = HTTPMethod.GET
3739
) -> ParametersTuple:
38-
raise NotImplementedError # pragma: no cover
40+
raise NotImplementedError # pragma: no cover
3941

4042
def register_api(self, api) -> None:
41-
raise NotImplementedError # pragma: no cover
43+
raise NotImplementedError # pragma: no cover
4244

4345
def _add_url_rule(
4446
self,
@@ -48,7 +50,7 @@ def _add_url_rule(
4850
provide_automatic_options=None,
4951
**options,
5052
) -> None:
51-
raise NotImplementedError # pragma: no cover
53+
raise NotImplementedError # pragma: no cover
5254

5355
@staticmethod
5456
def create_view_func(
@@ -199,6 +201,8 @@ def post(
199201
security: Optional[list[dict[str, list[Any]]]] = None,
200202
servers: Optional[list[Server]] = None,
201203
openapi_extensions: Optional[dict[str, Any]] = None,
204+
request_body_description: Optional[str] = None,
205+
request_body_required: Optional[bool] = True,
202206
doc_ui: bool = True,
203207
**options: Any
204208
) -> Callable:
@@ -218,6 +222,8 @@ def post(
218222
security: A declaration of which security mechanisms can be used for this operation.
219223
servers: An alternative server array to service this operation.
220224
openapi_extensions: Allows extensions to the OpenAPI Schema.
225+
request_body_description: A brief description of the request body.
226+
request_body_required: Determines if the request body is required in the request.
221227
doc_ui: Declares this operation to be shown. Default to True.
222228
"""
223229

@@ -236,6 +242,8 @@ def decorator(func) -> Callable:
236242
security=security,
237243
servers=servers,
238244
openapi_extensions=openapi_extensions,
245+
request_body_description=request_body_description,
246+
request_body_required=request_body_required,
239247
doc_ui=doc_ui,
240248
method=HTTPMethod.POST
241249
)
@@ -262,6 +270,8 @@ def put(
262270
security: Optional[list[dict[str, list[Any]]]] = None,
263271
servers: Optional[list[Server]] = None,
264272
openapi_extensions: Optional[dict[str, Any]] = None,
273+
request_body_description: Optional[str] = None,
274+
request_body_required: Optional[bool] = True,
265275
doc_ui: bool = True,
266276
**options: Any
267277
) -> Callable:
@@ -281,6 +291,8 @@ def put(
281291
security: A declaration of which security mechanisms can be used for this operation.
282292
servers: An alternative server array to service this operation.
283293
openapi_extensions: Allows extensions to the OpenAPI Schema.
294+
request_body_description: A brief description of the request body.
295+
request_body_required: Determines if the request body is required in the request.
284296
doc_ui: Declares this operation to be shown. Default to True.
285297
"""
286298

@@ -299,6 +311,8 @@ def decorator(func) -> Callable:
299311
security=security,
300312
servers=servers,
301313
openapi_extensions=openapi_extensions,
314+
request_body_description=request_body_description,
315+
request_body_required=request_body_required,
302316
doc_ui=doc_ui,
303317
method=HTTPMethod.PUT
304318
)
@@ -325,6 +339,8 @@ def delete(
325339
security: Optional[list[dict[str, list[Any]]]] = None,
326340
servers: Optional[list[Server]] = None,
327341
openapi_extensions: Optional[dict[str, Any]] = None,
342+
request_body_description: Optional[str] = None,
343+
request_body_required: Optional[bool] = True,
328344
doc_ui: bool = True,
329345
**options: Any
330346
) -> Callable:
@@ -344,6 +360,8 @@ def delete(
344360
security: A declaration of which security mechanisms can be used for this operation.
345361
servers: An alternative server array to service this operation.
346362
openapi_extensions: Allows extensions to the OpenAPI Schema.
363+
request_body_description: A brief description of the request body.
364+
request_body_required: Determines if the request body is required in the request.
347365
doc_ui: Declares this operation to be shown. Default to True.
348366
"""
349367

@@ -362,6 +380,8 @@ def decorator(func) -> Callable:
362380
security=security,
363381
servers=servers,
364382
openapi_extensions=openapi_extensions,
383+
request_body_description=request_body_description,
384+
request_body_required=request_body_required,
365385
doc_ui=doc_ui,
366386
method=HTTPMethod.DELETE
367387
)
@@ -388,6 +408,8 @@ def patch(
388408
security: Optional[list[dict[str, list[Any]]]] = None,
389409
servers: Optional[list[Server]] = None,
390410
openapi_extensions: Optional[dict[str, Any]] = None,
411+
request_body_description: Optional[str] = None,
412+
request_body_required: Optional[bool] = True,
391413
doc_ui: bool = True,
392414
**options: Any
393415
) -> Callable:
@@ -407,6 +429,8 @@ def patch(
407429
security: A declaration of which security mechanisms can be used for this operation.
408430
servers: An alternative server array to service this operation.
409431
openapi_extensions: Allows extensions to the OpenAPI Schema.
432+
request_body_description: A brief description of the request body.
433+
request_body_required: Determines if the request body is required in the request.
410434
doc_ui: Declares this operation to be shown. Default to True.
411435
"""
412436

@@ -425,6 +449,8 @@ def decorator(func) -> Callable:
425449
security=security,
426450
servers=servers,
427451
openapi_extensions=openapi_extensions,
452+
request_body_description=request_body_description,
453+
request_body_required=request_body_required,
428454
doc_ui=doc_ui,
429455
method=HTTPMethod.PATCH
430456
)

flask_openapi3/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
# @Author : llc
33
# @Time : 2023/7/9 15:25
44
from http import HTTPStatus
5-
from typing import Union, Type, Any, Optional
5+
from typing import Union, Type, Any, Optional, TypeVar
66

77
from pydantic import BaseModel
88

99
from .models import RawModel
1010
from .models import SecurityScheme
1111

12-
_ResponseDictValue = Union[Type[BaseModel], dict[Any, Any], None]
12+
_MultiBaseModel = TypeVar("_MultiBaseModel", bound=Type[BaseModel])
13+
14+
_ResponseDictValue = Union[Type[BaseModel], _MultiBaseModel, dict[Any, Any], None]
1315

1416
ResponseDict = dict[Union[str, int, HTTPStatus], _ResponseDictValue]
1517

0 commit comments

Comments
 (0)