Skip to content

Commit 8c66ba8

Browse files
committed
fix: only transform non-errors
1 parent f0ec9a5 commit 8c66ba8

File tree

5 files changed

+71
-46
lines changed

5 files changed

+71
-46
lines changed

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from starlette.datastructures import Headers
1111
from starlette.requests import Request
12-
from starlette.types import ASGIApp
12+
from starlette.types import ASGIApp, Scope
1313

1414
from ..config import EndpointMethods
1515
from ..utils.middleware import JsonResponseMiddleware
@@ -38,24 +38,27 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
3838

3939
state_key: str = "oidc_metadata"
4040

41-
def should_transform_response(
42-
self, request: Request, response_headers: Headers
43-
) -> bool:
41+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
4442
"""Determine if the response should be transformed."""
4543
# Match STAC catalog, collection, or item URLs with a single regex
46-
return all(
47-
re.match(expr, val)
48-
for expr, val in [
44+
return (
45+
all(
4946
(
50-
# catalog, collections, collection, items, item, search
51-
r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
52-
request.url.path,
47+
re.match(expr, val)
48+
for expr, val in [
49+
(
50+
# catalog, collections, collection, items, item, search
51+
r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
52+
request.url.path,
53+
),
54+
(
55+
self.json_content_type_expr,
56+
Headers(scope=scope).get("content-type", ""),
57+
),
58+
]
5359
),
54-
(
55-
self.json_content_type_expr,
56-
response_headers.get("content-type", ""),
57-
),
58-
]
60+
)
61+
and 200 >= scope["status"] < 300
5962
)
6063

6164
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from starlette.datastructures import Headers
88
from starlette.requests import Request
9-
from starlette.types import ASGIApp
9+
from starlette.types import ASGIApp, Scope
1010

1111
from ..config import EndpointMethods
1212
from ..utils.middleware import JsonResponseMiddleware
@@ -27,19 +27,20 @@ class OpenApiMiddleware(JsonResponseMiddleware):
2727

2828
json_content_type_expr: str = r"application/(vnd\.oai\.openapi\+json?|json)"
2929

30-
def should_transform_response(
31-
self, request: Request, response_headers: Headers
32-
) -> bool:
30+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
3331
"""Only transform responses for the OpenAPI spec path."""
34-
return all(
35-
re.match(expr, val)
36-
for expr, val in [
37-
(self.openapi_spec_path, request.url.path),
38-
(
39-
self.json_content_type_expr,
40-
response_headers.get("content-type", ""),
41-
),
42-
]
32+
return (
33+
all(
34+
re.match(expr, val)
35+
for expr, val in [
36+
(self.openapi_spec_path, request.url.path),
37+
(
38+
self.json_content_type_expr,
39+
Headers(scope=scope).get("content-type", ""),
40+
),
41+
]
42+
)
43+
and 200 >= scope["status"] < 300
4344
)
4445

4546
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:

src/stac_auth_proxy/utils/middleware.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import ABC, abstractmethod
55
from typing import Any, Optional
66

7-
from starlette.datastructures import Headers, MutableHeaders
7+
from starlette.datastructures import MutableHeaders
88
from starlette.requests import Request
99
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1010

@@ -16,7 +16,7 @@ class JsonResponseMiddleware(ABC):
1616

1717
@abstractmethod
1818
def should_transform_response(
19-
self, request: Request, response_headers: Headers
19+
self, request: Request, scope: Scope
2020
) -> bool: # mypy: ignore
2121
"""
2222
Determine if this response should be transformed. At a minimum, this
@@ -60,7 +60,7 @@ async def transform_response(message: Message) -> None:
6060

6161
if not self.should_transform_response(
6262
request=request,
63-
response_headers=headers,
63+
scope=start_message,
6464
):
6565
# For non-JSON responses, send the start message immediately
6666
await send(message)

tests/test_auth_extension.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Tests for AuthenticationExtensionMiddleware."""
22

33
import pytest
4-
from starlette.datastructures import Headers
54
from starlette.requests import Request
65

76
from stac_auth_proxy.config import EndpointMethods
@@ -34,10 +33,20 @@ def request_scope():
3433
}
3534

3635

37-
@pytest.fixture(params=["application/json", "application/geo+json"])
38-
def json_headers(request):
36+
@pytest.fixture(params=[b"application/json", b"application/geo+json"])
37+
def initial_message(request):
3938
"""Create headers with JSON content type."""
40-
return Headers({"content-type": request.param})
39+
return {
40+
"type": "http.response.start",
41+
"status": 200,
42+
"headers": [
43+
(b"date", b"Mon, 07 Apr 2025 06:55:37 GMT"),
44+
(b"server", b"uvicorn"),
45+
(b"content-length", b"27642"),
46+
(b"content-type", request.param),
47+
(b"x-upstream-time", b"0.063"),
48+
],
49+
}
4150

4251

4352
@pytest.fixture
@@ -50,7 +59,9 @@ def oidc_metadata():
5059
}
5160

5261

53-
def test_should_transform_response_valid_paths(middleware, request_scope, json_headers):
62+
def test_should_transform_response_valid_paths(
63+
middleware, request_scope, initial_message
64+
):
5465
"""Test that valid STAC paths are transformed."""
5566
valid_paths = [
5667
"/",
@@ -64,11 +75,11 @@ def test_should_transform_response_valid_paths(middleware, request_scope, json_h
6475
for path in valid_paths:
6576
request_scope["path"] = path
6677
request = Request(request_scope)
67-
assert middleware.should_transform_response(request, json_headers)
78+
assert middleware.should_transform_response(request, initial_message)
6879

6980

7081
def test_should_transform_response_invalid_paths(
71-
middleware, request_scope, json_headers
82+
middleware, request_scope, initial_message
7283
):
7384
"""Test that invalid paths are not transformed."""
7485
invalid_paths = [
@@ -80,14 +91,26 @@ def test_should_transform_response_invalid_paths(
8091
for path in invalid_paths:
8192
request_scope["path"] = path
8293
request = Request(request_scope)
83-
assert not middleware.should_transform_response(request, json_headers)
94+
assert not middleware.should_transform_response(request, initial_message)
8495

8596

8697
def test_should_transform_response_invalid_content_type(middleware, request_scope):
8798
"""Test that non-JSON content types are not transformed."""
8899
request = Request(request_scope)
89-
headers = Headers({"content-type": "text/html"})
90-
assert not middleware.should_transform_response(request, headers)
100+
assert not middleware.should_transform_response(
101+
request,
102+
{
103+
"type": "http.response.start",
104+
"status": 200,
105+
"headers": [
106+
(b"date", b"Mon, 07 Apr 2025 06:55:37 GMT"),
107+
(b"server", b"uvicorn"),
108+
(b"content-length", b"27642"),
109+
(b"content-type", b"text/html"),
110+
(b"x-upstream-time", b"0.063"),
111+
],
112+
},
113+
)
91114

92115

93116
def test_transform_json_catalog(middleware, request_scope, oidc_metadata):

tests/test_middleware.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from starlette.datastructures import Headers
77
from starlette.requests import Request
88
from starlette.testclient import TestClient
9-
from starlette.types import ASGIApp
9+
from starlette.types import ASGIApp, Scope
1010

1111
from stac_auth_proxy.utils.middleware import JsonResponseMiddleware
1212

@@ -18,11 +18,9 @@ def __init__(self, app: ASGIApp):
1818
"""Initialize the middleware."""
1919
self.app = app
2020

21-
def should_transform_response(
22-
self, request: Request, response_headers: Headers
23-
) -> bool:
21+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
2422
"""Transform JSON responses based on content type."""
25-
return response_headers.get("content-type", "") == "application/json"
23+
return Headers(scope=scope).get("content-type", "") == "application/json"
2624

2725
def transform_json(self, data: Any, request: Request) -> Any:
2826
"""Add a test field to the response."""

0 commit comments

Comments
 (0)