Skip to content

Commit df9330e

Browse files
authored
feat: enable collections filtering (#52)
* New env variable: `COLLECTIONS_FILTER` * Pass collections filter to `BuildCql2FilterMiddleware` * Update `BuildCql2FilterMiddleware` with dynamic conformance checks based on which filter generators are provided * Update `ApplyCql2FilterMiddleware` to run record validation on collection details view
1 parent 1ce8ed5 commit df9330e

File tree

9 files changed

+265
-36
lines changed

9 files changed

+265
-36
lines changed

README.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ STAC Auth Proxy is a proxy API that mediates between the client and your interna
1010
## ✨Features✨
1111

1212
- **🔐 Authentication:** Apply [OpenID Connect (OIDC)](https://openid.net/developers/how-connect-works/) token validation and optional scope checks to specified endpoints and methods
13-
- **🛂 Content Filtering:** Use CQL2 filters via the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to tailor API responses based on user context
13+
- **🛂 Content Filtering:** Use CQL2 filters via the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to tailor API responses based on request context (e.g. user role)
1414
- **🤝 External Policy Integration:** Integrate with external systems (e.g. [Open Policy Agent (OPA)](https://www.openpolicyagent.org/)) to generate CQL2 filters dynamically from policy decisions
1515
- **🧩 Authentication Extension:** Add the [Authentication Extension](https://github.com/stac-extensions/authentication) to API responses to expose auth-related metadata
1616
- **📘 OpenAPI Augmentation:** Enhance the [OpenAPI spec](https://swagger.io/specification/) with security details to keep auto-generated docs and UIs (e.g., [Swagger UI](https://swagger.io/tools/swagger-ui/)) accurate
@@ -158,6 +158,18 @@ The application is configurable via environment variables.
158158
- **Type:** Dictionary of keyword arguments used to initialize the class
159159
- **Required:** No, defaults to `{}`
160160
- **Example:** `{"field_name": "properties.organization"}`
161+
- **`COLLECTIONS_FILTER_CLS`**, CQL2 expression generator for collection-level filtering
162+
- **Type:** JSON object with class configuration
163+
- **Required:** No, defaults to `null` (disabled)
164+
- **Example:** `stac_auth_proxy.filters:Opa`, `stac_auth_proxy.filters:Template`, `my_package:OrganizationFilter`
165+
- **`COLLECTIONS_FILTER_ARGS`**, Positional arguments for CQL2 expression generator
166+
- **Type:** List of positional arguments used to initialize the class
167+
- **Required:** No, defaults to `[]`
168+
- **Example:**: `["org1"]`
169+
- **`COLLECTIONS_FILTER_KWARGS`**, Keyword arguments for CQL2 expression generator
170+
- **Type:** Dictionary of keyword arguments used to initialize the class
171+
- **Required:** No, defaults to `{}`
172+
- **Example:** `{"field_name": "properties.organization"}`
161173

162174
### Tips
163175

@@ -227,7 +239,7 @@ The system supports generating CQL2 filters based on request context to provide
227239
228240
#### Filters
229241

230-
If enabled, filters are intended to be applied to the following endpoints:
242+
If enabled, filters are applied to the following endpoints:
231243

232244
- `GET /search`
233245
- **Supported:**
@@ -250,12 +262,12 @@ If enabled, filters are intended to be applied to the following endpoints:
250262
- **Applied Filter:** `ITEMS_FILTER`
251263
- **Strategy:** Validate response against CQL2 query.
252264
- `GET /collections`
253-
- **Supported:** [^23]
265+
- **Supported:**
254266
- **Action:** Read Collection
255267
- **Applied Filter:** `COLLECTIONS_FILTER`
256268
- **Strategy:** Append query params with generated CQL2 query.
257269
- `GET /collections/{collection_id}`
258-
- **Supported:** [^23]
270+
- **Supported:**
259271
- **Action:** Read Collection
260272
- **Applied Filter:** `COLLECTIONS_FILTER`
261273
- **Strategy:** Validate response against CQL2 query.
@@ -411,6 +423,5 @@ class ApprovedCollectionsFilter:
411423
412424
[^21]: https://github.com/developmentseed/stac-auth-proxy/issues/21
413425
[^22]: https://github.com/developmentseed/stac-auth-proxy/issues/22
414-
[^23]: https://github.com/developmentseed/stac-auth-proxy/issues/23
415426
[^30]: https://github.com/developmentseed/stac-auth-proxy/issues/30
416427
[^37]: https://github.com/developmentseed/stac-auth-proxy/issues/37

examples/opa/docker-compose.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ services:
22
proxy:
33
environment:
44
ITEMS_FILTER_CLS: stac_auth_proxy.filters:Opa
5-
ITEMS_FILTER_ARGS: '["http://opa:8181", "stac/cql2"]'
5+
ITEMS_FILTER_ARGS: '["http://opa:8181", "stac/items_cql2"]'
6+
COLLECTIONS_FILTER_CLS: stac_auth_proxy.filters:Opa
7+
COLLECTIONS_FILTER_ARGS: '["http://opa:8181", "stac/collections_cql2"]'
68

79
opa:
810
image: openpolicyagent/opa:latest
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
package stac
22

3-
default cql2 := "\"naip:year\" = 2021"
3+
default items_cql2 := "\"naip:year\" = 2021"
44

5-
cql2 := "1=1" if {
5+
items_cql2 := "1=1" if {
6+
input.payload.sub != null
7+
}
8+
9+
default collections_cql2 := "id = 'naip'"
10+
11+
collections_cql2 := "1=1" if {
612
input.payload.sub != null
713
}

src/stac_auth_proxy/app.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,16 @@ async def lifespan(app: FastAPI):
119119
auth_scheme_override=settings.openapi_auth_scheme_override,
120120
)
121121

122-
if settings.items_filter:
122+
if settings.items_filter or settings.collections_filter:
123123
app.add_middleware(
124124
ApplyCql2FilterMiddleware,
125125
)
126126
app.add_middleware(
127127
BuildCql2FilterMiddleware,
128-
items_filter=settings.items_filter(),
128+
items_filter=settings.items_filter() if settings.items_filter else None,
129+
collections_filter=(
130+
settings.collections_filter() if settings.collections_filter else None
131+
),
129132
)
130133

131134
app.add_middleware(

src/stac_auth_proxy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class Settings(BaseSettings):
7171

7272
# Filters
7373
items_filter: Optional[ClassInput] = None
74+
collections_filter: Optional[ClassInput] = None
7475

7576
model_config = SettingsConfigDict(
7677
env_nested_delimiter="_",

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
2222
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
2323
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
24-
r"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
25-
r"https://api.stacspec.org/v1\.\d+\.\d+(?:-[\w\.]+)?/item-search#filter",
2624
)
2725
@dataclass(frozen=True)
2826
class ApplyCql2FilterMiddleware:
@@ -31,6 +29,11 @@ class ApplyCql2FilterMiddleware:
3129
app: ASGIApp
3230
state_key: str = "cql2_filter"
3331

32+
single_record_endpoints = [
33+
r"^/collections/([^/]+)/items/([^/]+)$",
34+
r"^/collections/([^/]+)$",
35+
]
36+
3437
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3538
"""Add the Cql2Filter to the request."""
3639
if scope["type"] != "http":
@@ -51,7 +54,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5154
)
5255
return await req_body_handler(scope, receive, send)
5356

54-
if re.match(r"^/collections/([^/]+)/items/([^/]+)$", request.url.path):
57+
# Handle single record requests (ie non-filterable endpoints)
58+
if any(
59+
re.match(expr, request.url.path) for expr in self.single_record_endpoints
60+
):
5561
res_body_validator = Cql2ResponseBodyValidator(
5662
app=self.app,
5763
cql2_filter=cql2_filter,
@@ -125,18 +131,22 @@ async def __call__(self, scope: Scope, send: Send, receive: Receive) -> None:
125131
body = b""
126132
initial_message: Optional[Message] = None
127133

128-
async def _send_error_response(status: int, message: str) -> None:
134+
async def _send_error_response(status: int, code: str, message: str) -> None:
129135
"""Send an error response with the given status and message."""
130136
assert initial_message, "Initial message not set"
131-
error_body = json.dumps({"message": message}).encode("utf-8")
137+
response_dict = {
138+
"code": code,
139+
"description": message,
140+
}
141+
response_bytes = json.dumps(response_dict).encode("utf-8")
132142
headers = MutableHeaders(scope=initial_message)
133-
headers["content-length"] = str(len(error_body))
143+
headers["content-length"] = str(len(response_bytes))
134144
initial_message["status"] = status
135145
await send(initial_message)
136146
await send(
137147
{
138148
"type": "http.response.body",
139-
"body": error_body,
149+
"body": response_bytes,
140150
"more_body": False,
141151
}
142152
)
@@ -145,28 +155,37 @@ async def buffered_send(message: Message) -> None:
145155
"""Process a response message and apply filtering if needed."""
146156
nonlocal body
147157
nonlocal initial_message
158+
initial_message = initial_message or message
159+
# NOTE: to avoid data-leak, we process 404s so their responses are the same as rejected 200s
160+
should_process = initial_message["status"] in [200, 404]
161+
162+
if not should_process:
163+
return await send(message)
148164

149165
if message["type"] == "http.response.start":
150-
initial_message = message
166+
# Hold off on sending response headers until we've validated the response body
151167
return
152168

153-
assert initial_message, "Initial message not set"
154-
155169
body += message["body"]
156170
if message.get("more_body"):
157171
return
158172

159173
try:
160174
body_json = json.loads(body)
161175
except json.JSONDecodeError:
162-
logger.warning("Failed to parse response body as JSON")
163-
await _send_error_response(502, "Not found")
176+
msg = "Failed to parse response body as JSON"
177+
logger.warning(msg)
178+
await _send_error_response(status=502, code="ParseError", message=msg)
164179
return
165180

166-
logger.debug(
167-
"Applying %s filter to %s", self.cql2_filter.to_text(), body_json
168-
)
169-
if self.cql2_filter.matches(body_json):
181+
try:
182+
cql2_matches = self.cql2_filter.matches(body_json)
183+
except Exception as e:
184+
cql2_matches = False
185+
logger.warning("Failed to apply filter: %s", e)
186+
187+
if cql2_matches:
188+
logger.debug("Response matches filter, returning record")
170189
await send(initial_message)
171190
return await send(
172191
{
@@ -175,6 +194,9 @@ async def buffered_send(message: Message) -> None:
175194
"more_body": False,
176195
}
177196
)
178-
return await _send_error_response(404, "Not found")
197+
logger.debug("Response did not match filter, returning 404")
198+
return await _send_error_response(
199+
status=404, code="NotFoundError", message="Record not found."
200+
)
179201

180202
return await self.app(scope, receive, buffered_send)

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
from starlette.types import ASGIApp, Receive, Scope, Send
1212

1313
from ..utils import requests
14+
from ..utils.middleware import required_conformance
1415

1516
logger = logging.getLogger(__name__)
1617

1718

19+
@required_conformance(
20+
"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
21+
"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
22+
"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
23+
)
1824
@dataclass(frozen=True)
1925
class BuildCql2FilterMiddleware:
2026
"""Middleware to build the Cql2Filter."""
@@ -25,7 +31,37 @@ class BuildCql2FilterMiddleware:
2531

2632
# Filters
2733
collections_filter: Optional[Callable] = None
34+
collections_filter_path: str = r"^/collections(/[^/]+)?$"
2835
items_filter: Optional[Callable] = None
36+
items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)"
37+
38+
def __post_init__(self):
39+
"""Set required conformances based on the filter functions."""
40+
required_conformances = set()
41+
if self.collections_filter:
42+
logger.debug("Appending required conformance for collections filter")
43+
required_conformances.update(
44+
[
45+
"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/filter",
46+
"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
47+
r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/item-search#filter",
48+
"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
49+
]
50+
)
51+
if self.items_filter:
52+
logger.debug("Appending required conformance for items filter")
53+
required_conformances.update(
54+
[
55+
"https://api.stacspec.org/v1.0.0/core",
56+
r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/collection-search#filter",
57+
"http://www.opengis.net/spec/ogcapi-common-2/1.0/conf/simple-query",
58+
]
59+
)
60+
61+
# Must set required conformances on class
62+
self.__class__.__required_conformances__ = required_conformances.union(
63+
getattr(self.__class__, "__required_conformances__", [])
64+
)
2965

3066
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3167
"""Build the CQL2 filter, place on the request state."""
@@ -65,8 +101,8 @@ def _get_filter(
65101
) -> Optional[Callable[..., Awaitable[str | dict[str, Any]]]]:
66102
"""Get the CQL2 filter builder for the given path."""
67103
endpoint_filters = [
68-
(r"^/collections(/[^/]+)?$", self.collections_filter),
69-
(r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)", self.items_filter),
104+
(self.collections_filter_path, self.collections_filter),
105+
(self.items_filter_path, self.items_filter),
70106
]
71107
for expr, builder in endpoint_filters:
72108
if re.match(expr, path):

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def mock_env():
207207
@pytest.fixture
208208
async def mock_upstream() -> AsyncGenerator[MagicMock, None]:
209209
"""Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API."""
210+
# NOTE: This fixture will interfere with the source_api_responses fixture
210211

211212
async def store_body(request, **kwargs):
212213
"""Exhaust and store the request body."""

0 commit comments

Comments
 (0)