Skip to content

Commit f0ec9a5

Browse files
committed
fix: Rework body augmentor to avoid error on empty POSTs
Was seeing `h11._util.LocalProtocolError: Too much data for declared Content-Length.`
1 parent 0b67176 commit f0ec9a5

File tree

3 files changed

+139
-114
lines changed

3 files changed

+139
-114
lines changed

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 91 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
import json
44
import re
5-
from dataclasses import dataclass, field
6-
from functools import partial
5+
from dataclasses import dataclass
76
from logging import getLogger
8-
from typing import Callable, Optional
7+
from typing import Optional
98

109
from cql2 import Expr
11-
from starlette.datastructures import MutableHeaders, State
10+
from starlette.datastructures import MutableHeaders
1211
from starlette.requests import Request
1312
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1413

@@ -39,32 +38,25 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3938

4039
request = Request(scope)
4140

42-
get_cql2_filter: Callable[[], Optional[Expr]] = partial(
43-
getattr, request.state, self.state_key, None
44-
)
41+
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
42+
43+
if not cql2_filter:
44+
return await self.app(scope, receive, send)
4545

4646
# Handle POST, PUT, PATCH
4747
if request.method in ["POST", "PUT", "PATCH"]:
48-
return await self.app(
49-
scope,
50-
Cql2RequestBodyAugmentor(
51-
receive=receive,
52-
state=request.state,
53-
get_cql2_filter=get_cql2_filter,
54-
),
55-
send,
48+
req_body_handler = Cql2RequestBodyAugmentor(
49+
app=self.app,
50+
cql2_filter=cql2_filter,
5651
)
57-
58-
cql2_filter = get_cql2_filter()
59-
if not cql2_filter:
60-
return await self.app(scope, receive, send)
52+
return await req_body_handler(scope, receive, send)
6153

6254
if re.match(r"^/collections/([^/]+)/items/([^/]+)$", request.url.path):
63-
return await self.app(
64-
scope,
65-
receive,
66-
Cql2ResponseBodyValidator(cql2_filter=cql2_filter, send=send),
55+
res_body_validator = Cql2ResponseBodyValidator(
56+
app=self.app,
57+
cql2_filter=cql2_filter,
6758
)
59+
return await res_body_validator(scope, send, receive)
6860

6961
scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter)
7062
return await self.app(scope, receive, send)
@@ -74,88 +66,115 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
7466
class Cql2RequestBodyAugmentor:
7567
"""Handler to augment the request body with a CQL2 filter."""
7668

77-
receive: Receive
78-
state: State
79-
get_cql2_filter: Callable[[], Optional[Expr]]
80-
81-
async def __call__(self) -> Message:
82-
"""Process a request body and augment with a CQL2 filter if available."""
83-
message = await self.receive()
84-
if message["type"] != "http.request":
85-
return message
86-
87-
# NOTE: Can only get cql2 filter _after_ calling self.receive()
88-
cql2_filter = self.get_cql2_filter()
89-
if not cql2_filter:
90-
return message
69+
app: ASGIApp
70+
cql2_filter: Expr
9171

72+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
73+
"""Augment the request body with a CQL2 filter."""
74+
body = b""
75+
more_body = True
76+
77+
# Read the body
78+
while more_body:
79+
message = await receive()
80+
if message["type"] == "http.request":
81+
body += message.get("body", b"")
82+
more_body = message.get("more_body", False)
83+
84+
# Modify body
9285
try:
93-
body = json.loads(message.get("body", b"{}"))
86+
body = json.loads(body)
9487
except json.JSONDecodeError as e:
9588
logger.warning("Failed to parse request body as JSON")
9689
# TODO: Return a 400 error
9790
raise e
9891

99-
new_body = filters.append_body_filter(body, cql2_filter)
100-
message["body"] = json.dumps(new_body).encode("utf-8")
101-
return message
92+
# Augment the body
93+
assert isinstance(body, dict), "Request body must be a JSON object"
94+
new_body = json.dumps(
95+
filters.append_body_filter(body, self.cql2_filter)
96+
).encode("utf-8")
97+
98+
# Patch content-length in the headers
99+
headers = dict(scope["headers"])
100+
headers[b"content-length"] = str(len(new_body)).encode("latin1")
101+
scope["headers"] = list(headers.items())
102+
103+
async def new_receive():
104+
return {
105+
"type": "http.request",
106+
"body": new_body,
107+
"more_body": False,
108+
}
109+
110+
await self.app(scope, new_receive, send)
102111

103112

104113
@dataclass
105114
class Cql2ResponseBodyValidator:
106115
"""Handler to validate response body with CQL2."""
107116

108-
send: Send
117+
app: ASGIApp
109118
cql2_filter: Expr
110-
initial_message: Optional[Message] = field(init=False)
111-
body: bytes = field(init=False, default_factory=bytes)
112119

113-
async def __call__(self, message: Message) -> None:
120+
async def __call__(self, scope: Scope, send: Send, receive: Receive) -> None:
114121
"""Process a response message and apply filtering if needed."""
115-
if message["type"] == "http.response.start":
116-
self.initial_message = message
117-
return
122+
if scope["type"] != "http":
123+
return await self.app(scope, send, receive)
124+
125+
body = b""
126+
initial_message: Optional[Message] = None
127+
128+
async def _send_error_response(status: int, message: str) -> None:
129+
"""Send an error response with the given status and message."""
130+
assert initial_message, "Initial message not set"
131+
error_body = json.dumps({"message": message}).encode("utf-8")
132+
headers = MutableHeaders(scope=initial_message)
133+
headers["content-length"] = str(len(error_body))
134+
initial_message["status"] = status
135+
await send(initial_message)
136+
await send(
137+
{
138+
"type": "http.response.body",
139+
"body": error_body,
140+
"more_body": False,
141+
}
142+
)
118143

119-
if message["type"] == "http.response.body":
120-
assert self.initial_message, "Initial message not set"
144+
async def buffered_send(message: Message) -> None:
145+
"""Process a response message and apply filtering if needed."""
146+
nonlocal body
147+
nonlocal initial_message
121148

122-
self.body += message["body"]
149+
if message["type"] == "http.response.start":
150+
initial_message = message
151+
return
152+
153+
assert initial_message, "Initial message not set"
154+
155+
body += message["body"]
123156
if message.get("more_body"):
124157
return
125158

126159
try:
127-
body_json = json.loads(self.body)
160+
body_json = json.loads(body)
128161
except json.JSONDecodeError:
129162
logger.warning("Failed to parse response body as JSON")
130-
await self._send_error_response(502, "Not found")
163+
await _send_error_response(502, "Not found")
131164
return
132165

133166
logger.debug(
134167
"Applying %s filter to %s", self.cql2_filter.to_text(), body_json
135168
)
136169
if self.cql2_filter.matches(body_json):
137-
await self.send(self.initial_message)
138-
return await self.send(
170+
await send(initial_message)
171+
return await send(
139172
{
140173
"type": "http.response.body",
141174
"body": json.dumps(body_json).encode("utf-8"),
142175
"more_body": False,
143176
}
144177
)
145-
return await self._send_error_response(404, "Not found")
146-
147-
async def _send_error_response(self, status: int, message: str) -> None:
148-
"""Send an error response with the given status and message."""
149-
assert self.initial_message, "Initial message not set"
150-
error_body = json.dumps({"message": message}).encode("utf-8")
151-
headers = MutableHeaders(scope=self.initial_message)
152-
headers["content-length"] = str(len(error_body))
153-
self.initial_message["status"] = status
154-
await self.send(self.initial_message)
155-
await self.send(
156-
{
157-
"type": "http.response.body",
158-
"body": error_body,
159-
"more_body": False,
160-
}
161-
)
178+
return await _send_error_response(404, "Not found")
179+
180+
return await self.app(scope, receive, buffered_send)

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""Middleware to build the Cql2Filter."""
22

3-
import json
3+
import logging
44
import re
55
from dataclasses import dataclass
66
from typing import Any, Awaitable, Callable, Optional
77

8-
from cql2 import Expr
8+
from cql2 import Expr, ValidationError
99
from starlette.requests import Request
10-
from starlette.types import ASGIApp, Message, Receive, Scope, Send
10+
from starlette.responses import Response
11+
from starlette.types import ASGIApp, Receive, Scope, Send
1112

1213
from ..utils import requests
1314

15+
logger = logging.getLogger(__name__)
16+
1417

1518
@dataclass(frozen=True)
1619
class BuildCql2FilterMiddleware:
@@ -35,47 +38,27 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3538
if not filter_builder:
3639
return await self.app(scope, receive, send)
3740

38-
async def set_filter(body: Optional[dict] = None) -> None:
39-
assert filter_builder is not None
40-
filter_expr = await filter_builder(
41-
{
42-
"req": {
43-
"path": request.url.path,
44-
"method": request.method,
45-
"query_params": dict(request.query_params),
46-
"path_params": requests.extract_variables(request.url.path),
47-
"headers": dict(request.headers),
48-
"body": body,
49-
},
50-
**scope["state"],
51-
}
52-
)
53-
cql2_filter = Expr(filter_expr)
41+
filter_expr = await filter_builder(
42+
{
43+
"req": {
44+
"path": request.url.path,
45+
"method": request.method,
46+
"query_params": dict(request.query_params),
47+
"path_params": requests.extract_variables(request.url.path),
48+
"headers": dict(request.headers),
49+
},
50+
**scope["state"],
51+
}
52+
)
53+
cql2_filter = Expr(filter_expr)
54+
try:
5455
cql2_filter.validate()
55-
setattr(request.state, self.state_key, cql2_filter)
56-
57-
# For GET requests, we can build the filter immediately
58-
if request.method == "GET":
59-
await set_filter()
60-
return await self.app(scope, receive, send)
61-
62-
total_body = b""
63-
64-
async def receive_build_filter() -> Message:
65-
"""
66-
Receive the body of the request and build the filter.
67-
NOTE: This is not called for GET requests.
68-
"""
69-
nonlocal total_body
70-
71-
message = await receive()
72-
total_body += message.get("body", b"")
73-
74-
if not message.get("more_body"):
75-
await set_filter(json.loads(total_body) if total_body else None)
76-
return message
56+
except ValidationError:
57+
logger.exception("Invalid CQL2 filter: %s", filter_expr)
58+
return await Response(status_code=502, content="Invalid CQL2 filter")
59+
setattr(request.state, self.state_key, cql2_filter)
7760

78-
return await self.app(scope, receive_build_filter, send)
61+
return await self.app(scope, receive, send)
7962

8063
def _get_filter(
8164
self, path: str

tests/test_filters_jinja2.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,26 @@ def test_item_get(
314314
else:
315315
assert response.status_code == 404
316316
assert response.json() == {"message": "Not found"}
317+
318+
319+
@pytest.mark.parametrize("is_authenticated", [True, False], ids=["auth", "anon"])
320+
async def test_search_post_empty_body(
321+
source_api_server,
322+
is_authenticated,
323+
token_builder,
324+
):
325+
"""Test that POST /search with empty body."""
326+
client = _build_client(
327+
src_api_server=source_api_server,
328+
template_expr="(properties.private = false)",
329+
is_authenticated=is_authenticated,
330+
token_builder=token_builder,
331+
)
332+
333+
# Send request with Content-Length header that doesn't match actual body size
334+
response = client.post(
335+
"/search",
336+
json={},
337+
)
338+
339+
assert response.status_code == 200

0 commit comments

Comments
 (0)