Skip to content

Commit 2639c9e

Browse files
vishalbolludeliahu
authored andcommitted
Fix cors (#942)
(cherry picked from commit a012c68)
1 parent 4716bef commit 2639c9e

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

pkg/workloads/cortex/serve/serve.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from fastapi import Body, FastAPI
2727
from fastapi.exceptions import RequestValidationError
28+
from fastapi.middleware.cors import CORSMiddleware
2829
from starlette.requests import Request
2930
from starlette.responses import Response
3031
from starlette.background import BackgroundTasks
@@ -55,6 +56,15 @@
5556
)
5657

5758
app = FastAPI()
59+
60+
app.add_middleware(
61+
CORSMiddleware,
62+
allow_origins=["*"],
63+
allow_credentials=True,
64+
allow_methods=["*"],
65+
allow_headers=["*"],
66+
)
67+
5868
local_cache = {"api": None, "predictor_impl": None, "client": None, "class_set": set()}
5969

6070

@@ -90,21 +100,18 @@ def is_prediction_request(request):
90100
@app.exception_handler(StarletteHTTPException)
91101
async def http_exception_handler(request, e):
92102
response = Response(content=str(e.detail), status_code=e.status_code)
93-
apply_cors_headers(request, response)
94103
return response
95104

96105

97106
@app.exception_handler(RequestValidationError)
98107
async def validation_exception_handler(request, e):
99108
response = Response(content=str(e), status_code=400)
100-
apply_cors_headers(request, response)
101109
return response
102110

103111

104112
@app.exception_handler(Exception)
105113
async def uncaught_exception_handler(request, e):
106114
response = Response(content="internal server error", status_code=500)
107-
apply_cors_headers(request, response)
108115
return response
109116

110117

@@ -132,20 +139,12 @@ async def register_request(request: Request, call_next):
132139
status_code = 500
133140
if response is not None:
134141
status_code = response.status_code
135-
apply_cors_headers(request, response)
136142
api = local_cache["api"]
137143
api.post_request_metrics(status_code, time.time() - request.state.start_time)
138144

139145
return response
140146

141147

142-
def apply_cors_headers(request: Request, response: Response):
143-
response.headers["Access-Control-Allow-Origin"] = "*"
144-
response.headers["Access-Control-Allow-Headers"] = request.headers.get(
145-
"Access-Control-Request-Headers", "*"
146-
)
147-
148-
149148
@app.post("/predict")
150149
def predict(request: Any = Body(..., media_type="application/json"), debug=False):
151150
api = local_cache["api"]

0 commit comments

Comments
 (0)