|
25 | 25 |
|
26 | 26 | from fastapi import Body, FastAPI
|
27 | 27 | from fastapi.exceptions import RequestValidationError
|
| 28 | +from fastapi.middleware.cors import CORSMiddleware |
28 | 29 | from starlette.requests import Request
|
29 | 30 | from starlette.responses import Response
|
30 | 31 | from starlette.background import BackgroundTasks
|
|
55 | 56 | )
|
56 | 57 |
|
57 | 58 | app = FastAPI()
|
| 59 | + |
| 60 | +app.add_middleware( |
| 61 | + CORSMiddleware, |
| 62 | + allow_origins=["*"], |
| 63 | + allow_credentials=True, |
| 64 | + allow_methods=["*"], |
| 65 | + allow_headers=["*"], |
| 66 | +) |
| 67 | + |
58 | 68 | local_cache = {"api": None, "predictor_impl": None, "client": None, "class_set": set()}
|
59 | 69 |
|
60 | 70 |
|
@@ -90,21 +100,18 @@ def is_prediction_request(request):
|
90 | 100 | @app.exception_handler(StarletteHTTPException)
|
91 | 101 | async def http_exception_handler(request, e):
|
92 | 102 | response = Response(content=str(e.detail), status_code=e.status_code)
|
93 |
| - apply_cors_headers(request, response) |
94 | 103 | return response
|
95 | 104 |
|
96 | 105 |
|
97 | 106 | @app.exception_handler(RequestValidationError)
|
98 | 107 | async def validation_exception_handler(request, e):
|
99 | 108 | response = Response(content=str(e), status_code=400)
|
100 |
| - apply_cors_headers(request, response) |
101 | 109 | return response
|
102 | 110 |
|
103 | 111 |
|
104 | 112 | @app.exception_handler(Exception)
|
105 | 113 | async def uncaught_exception_handler(request, e):
|
106 | 114 | response = Response(content="internal server error", status_code=500)
|
107 |
| - apply_cors_headers(request, response) |
108 | 115 | return response
|
109 | 116 |
|
110 | 117 |
|
@@ -132,20 +139,12 @@ async def register_request(request: Request, call_next):
|
132 | 139 | status_code = 500
|
133 | 140 | if response is not None:
|
134 | 141 | status_code = response.status_code
|
135 |
| - apply_cors_headers(request, response) |
136 | 142 | api = local_cache["api"]
|
137 | 143 | api.post_request_metrics(status_code, time.time() - request.state.start_time)
|
138 | 144 |
|
139 | 145 | return response
|
140 | 146 |
|
141 | 147 |
|
142 |
| -def apply_cors_headers(request: Request, response: Response): |
143 |
| - response.headers["Access-Control-Allow-Origin"] = "*" |
144 |
| - response.headers["Access-Control-Allow-Headers"] = request.headers.get( |
145 |
| - "Access-Control-Request-Headers", "*" |
146 |
| - ) |
147 |
| - |
148 |
| - |
149 | 148 | @app.post("/predict")
|
150 | 149 | def predict(request: Any = Body(..., media_type="application/json"), debug=False):
|
151 | 150 | api = local_cache["api"]
|
|
0 commit comments