2
2
3
3
import json
4
4
import re
5
- from dataclasses import dataclass , field
6
- from functools import partial
5
+ from dataclasses import dataclass
7
6
from logging import getLogger
8
- from typing import Callable , Optional
7
+ from typing import Optional
9
8
10
9
from cql2 import Expr
11
- from starlette .datastructures import MutableHeaders , State
10
+ from starlette .datastructures import MutableHeaders
12
11
from starlette .requests import Request
13
12
from starlette .types import ASGIApp , Message , Receive , Scope , Send
14
13
@@ -39,32 +38,25 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
39
38
40
39
request = Request (scope )
41
40
42
- get_cql2_filter : Callable [[], Optional [Expr ]] = partial (
43
- getattr , request .state , self .state_key , None
44
- )
41
+ cql2_filter : Optional [Expr ] = getattr (request .state , self .state_key , None )
42
+
43
+ if not cql2_filter :
44
+ return await self .app (scope , receive , send )
45
45
46
46
# Handle POST, PUT, PATCH
47
47
if request .method in ["POST" , "PUT" , "PATCH" ]:
48
- return await self .app (
49
- scope ,
50
- Cql2RequestBodyAugmentor (
51
- receive = receive ,
52
- state = request .state ,
53
- get_cql2_filter = get_cql2_filter ,
54
- ),
55
- send ,
48
+ req_body_handler = Cql2RequestBodyAugmentor (
49
+ app = self .app ,
50
+ cql2_filter = cql2_filter ,
56
51
)
57
-
58
- cql2_filter = get_cql2_filter ()
59
- if not cql2_filter :
60
- return await self .app (scope , receive , send )
52
+ return await req_body_handler (scope , receive , send )
61
53
62
54
if re .match (r"^/collections/([^/]+)/items/([^/]+)$" , request .url .path ):
63
- return await self .app (
64
- scope ,
65
- receive ,
66
- Cql2ResponseBodyValidator (cql2_filter = cql2_filter , send = send ),
55
+ res_body_validator = Cql2ResponseBodyValidator (
56
+ app = self .app ,
57
+ cql2_filter = cql2_filter ,
67
58
)
59
+ return await res_body_validator (scope , send , receive )
68
60
69
61
scope ["query_string" ] = filters .append_qs_filter (request .url .query , cql2_filter )
70
62
return await self .app (scope , receive , send )
@@ -74,88 +66,115 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
74
66
class Cql2RequestBodyAugmentor :
75
67
"""Handler to augment the request body with a CQL2 filter."""
76
68
77
- receive : Receive
78
- state : State
79
- get_cql2_filter : Callable [[], Optional [Expr ]]
80
-
81
- async def __call__ (self ) -> Message :
82
- """Process a request body and augment with a CQL2 filter if available."""
83
- message = await self .receive ()
84
- if message ["type" ] != "http.request" :
85
- return message
86
-
87
- # NOTE: Can only get cql2 filter _after_ calling self.receive()
88
- cql2_filter = self .get_cql2_filter ()
89
- if not cql2_filter :
90
- return message
69
+ app : ASGIApp
70
+ cql2_filter : Expr
91
71
72
+ async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
73
+ """Augment the request body with a CQL2 filter."""
74
+ body = b""
75
+ more_body = True
76
+
77
+ # Read the body
78
+ while more_body :
79
+ message = await receive ()
80
+ if message ["type" ] == "http.request" :
81
+ body += message .get ("body" , b"" )
82
+ more_body = message .get ("more_body" , False )
83
+
84
+ # Modify body
92
85
try :
93
- body = json .loads (message . get ( " body" , b"{}" ) )
86
+ body = json .loads (body )
94
87
except json .JSONDecodeError as e :
95
88
logger .warning ("Failed to parse request body as JSON" )
96
89
# TODO: Return a 400 error
97
90
raise e
98
91
99
- new_body = filters .append_body_filter (body , cql2_filter )
100
- message ["body" ] = json .dumps (new_body ).encode ("utf-8" )
101
- return message
92
+ # Augment the body
93
+ assert isinstance (body , dict ), "Request body must be a JSON object"
94
+ new_body = json .dumps (
95
+ filters .append_body_filter (body , self .cql2_filter )
96
+ ).encode ("utf-8" )
97
+
98
+ # Patch content-length in the headers
99
+ headers = dict (scope ["headers" ])
100
+ headers [b"content-length" ] = str (len (new_body )).encode ("latin1" )
101
+ scope ["headers" ] = list (headers .items ())
102
+
103
+ async def new_receive ():
104
+ return {
105
+ "type" : "http.request" ,
106
+ "body" : new_body ,
107
+ "more_body" : False ,
108
+ }
109
+
110
+ await self .app (scope , new_receive , send )
102
111
103
112
104
113
@dataclass
105
114
class Cql2ResponseBodyValidator :
106
115
"""Handler to validate response body with CQL2."""
107
116
108
- send : Send
117
+ app : ASGIApp
109
118
cql2_filter : Expr
110
- initial_message : Optional [Message ] = field (init = False )
111
- body : bytes = field (init = False , default_factory = bytes )
112
119
113
- async def __call__ (self , message : Message ) -> None :
120
+ async def __call__ (self , scope : Scope , send : Send , receive : Receive ) -> None :
114
121
"""Process a response message and apply filtering if needed."""
115
- if message ["type" ] == "http.response.start" :
116
- self .initial_message = message
117
- return
122
+ if scope ["type" ] != "http" :
123
+ return await self .app (scope , send , receive )
124
+
125
+ body = b""
126
+ initial_message : Optional [Message ] = None
127
+
128
+ async def _send_error_response (status : int , message : str ) -> None :
129
+ """Send an error response with the given status and message."""
130
+ assert initial_message , "Initial message not set"
131
+ error_body = json .dumps ({"message" : message }).encode ("utf-8" )
132
+ headers = MutableHeaders (scope = initial_message )
133
+ headers ["content-length" ] = str (len (error_body ))
134
+ initial_message ["status" ] = status
135
+ await send (initial_message )
136
+ await send (
137
+ {
138
+ "type" : "http.response.body" ,
139
+ "body" : error_body ,
140
+ "more_body" : False ,
141
+ }
142
+ )
118
143
119
- if message ["type" ] == "http.response.body" :
120
- assert self .initial_message , "Initial message not set"
144
+ async def buffered_send (message : Message ) -> None :
145
+ """Process a response message and apply filtering if needed."""
146
+ nonlocal body
147
+ nonlocal initial_message
121
148
122
- self .body += message ["body" ]
149
+ if message ["type" ] == "http.response.start" :
150
+ initial_message = message
151
+ return
152
+
153
+ assert initial_message , "Initial message not set"
154
+
155
+ body += message ["body" ]
123
156
if message .get ("more_body" ):
124
157
return
125
158
126
159
try :
127
- body_json = json .loads (self . body )
160
+ body_json = json .loads (body )
128
161
except json .JSONDecodeError :
129
162
logger .warning ("Failed to parse response body as JSON" )
130
- await self . _send_error_response (502 , "Not found" )
163
+ await _send_error_response (502 , "Not found" )
131
164
return
132
165
133
166
logger .debug (
134
167
"Applying %s filter to %s" , self .cql2_filter .to_text (), body_json
135
168
)
136
169
if self .cql2_filter .matches (body_json ):
137
- await self . send (self . initial_message )
138
- return await self . send (
170
+ await send (initial_message )
171
+ return await send (
139
172
{
140
173
"type" : "http.response.body" ,
141
174
"body" : json .dumps (body_json ).encode ("utf-8" ),
142
175
"more_body" : False ,
143
176
}
144
177
)
145
- return await self ._send_error_response (404 , "Not found" )
146
-
147
- async def _send_error_response (self , status : int , message : str ) -> None :
148
- """Send an error response with the given status and message."""
149
- assert self .initial_message , "Initial message not set"
150
- error_body = json .dumps ({"message" : message }).encode ("utf-8" )
151
- headers = MutableHeaders (scope = self .initial_message )
152
- headers ["content-length" ] = str (len (error_body ))
153
- self .initial_message ["status" ] = status
154
- await self .send (self .initial_message )
155
- await self .send (
156
- {
157
- "type" : "http.response.body" ,
158
- "body" : error_body ,
159
- "more_body" : False ,
160
- }
161
- )
178
+ return await _send_error_response (404 , "Not found" )
179
+
180
+ return await self .app (scope , receive , buffered_send )
0 commit comments