diff --git a/mypy.ini b/mypy.ini index b472cad..10f236a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -59,6 +59,7 @@ exclude = (?x)( | ^tests/test_sse_connection_manager\.py$ | ^prefab_pb2.*\.pyi?$ | ^examples/ + | ^tests/test_api_client\.py$ ) # Strict typing options diff --git a/prefab_cloud_python/_requests.py b/prefab_cloud_python/_requests.py index e1e32f7..aea8e40 100644 --- a/prefab_cloud_python/_requests.py +++ b/prefab_cloud_python/_requests.py @@ -1,6 +1,8 @@ import importlib -from socket import socket -from typing import Optional +import re +from collections import OrderedDict +from dataclasses import dataclass +import time from ._internal_logging import ( InternalLogger, @@ -72,51 +74,190 @@ def __next__(self): return host +# --- Simple LRU Cache Implementation --- + + +@dataclass +class CacheEntry: + data: bytes + etag: str + expires_at: float + url: str # The full URL from the successful response + + +class LRUCache: + def __init__(self, max_size: int): + self.max_size = max_size + self.cache = OrderedDict() + + def get(self, key): + try: + value = self.cache.pop(key) + self.cache[key] = value # Mark as recently used. + return value + except KeyError: + return None + + def set(self, key, value): + if key in self.cache: + self.cache.pop(key) + elif len(self.cache) >= self.max_size: + self.cache.popitem(last=False) + self.cache[key] = value + + def clear(self): + self.cache.clear() + + def __len__(self): + return len(self.cache) + + class ApiClient: def __init__(self, options): + """ + :param options: An object with attributes such as: + - prefab_api_urls: list of API host URLs (e.g. ["https://a.example.com", "https://b.example.com"]) + - version: version string + """ self.hosts = options.prefab_api_urls self.session = requests.Session() - self.session.mount("https://", NoRetryAdapter()) - self.session.mount("http://", NoRetryAdapter()) - self.session.headers.update({VersionHeader: f"prefab-cloud-python-{Version}"}) + self.session.mount("https://", requests.adapters.HTTPAdapter()) + self.session.mount("http://", requests.adapters.HTTPAdapter()) + self.session.headers.update( + { + "X-PrefabCloud-Client-Version": f"prefab-cloud-python-{getattr(options, 'version', 'development')}" + } + ) + # Initialize a cache (here with a maximum of 2 entries). + self.cache = LRUCache(max_size=2) def get_host(self, attempt_number, host_list): return host_list[attempt_number % len(host_list)] + def _get_attempt_number(self) -> int: + """ + Retrieve the current attempt number from tenacity's statistics if available, + otherwise default to 1. + """ + stats = getattr(self.resilient_request, "statistics", None) + if stats is None: + return 1 + return stats.get("attempt_number", 1) + + def _build_url(self, path, hosts: list[str] = None) -> str: + """ + Build the full URL using host-selection logic. + """ + attempt_number = self._get_attempt_number() + host = self.get_host(attempt_number - 1, hosts or self.hosts) + return f"{host.rstrip('/')}/{path.lstrip('/')}" + + def _get_cached_response(self, url: str) -> Response: + """ + If a valid cache entry exists for the given URL, return a synthetic Response. + """ + now = time.time() + entry = self.cache.get(url) + if entry is not None and entry.expires_at > now: + resp = Response() + resp._content = entry.data + resp.status_code = 200 + resp.headers = {"ETag": entry.etag, "X-Cache": "HIT"} + resp.url = entry.url + return resp + return None + + def _apply_cache_headers(self, url: str, kwargs: dict) -> dict: + """ + If a stale cache entry exists, add its ETag as an 'If-None-Match' header. + """ + entry = self.cache.get(url) + headers = kwargs.get("headers", {}).copy() + if entry is not None and entry.etag: + headers["If-None-Match"] = entry.etag + kwargs["headers"] = headers + return kwargs + + def _update_cache(self, url: str, response: Response) -> None: + """ + If the response is cacheable (status 200, and Cache-Control does not include 'no-store'), + update the cache. If Cache-Control includes 'no-cache', mark the cache entry as immediately expired, + so that subsequent requests always trigger revalidation. + """ + cache_control = response.headers.get("Cache-Control", "") + if "no-store" in cache_control.lower(): + return + + etag = response.headers.get("ETag") + max_age = 0 + m = re.search(r"max-age=(\d+)", cache_control) + if m: + max_age = int(m.group(1)) + + # If 'no-cache' is present, then even though we may store the response, + # we treat it as expired immediately so that every subsequent request is revalidated. + if "no-cache" in cache_control.lower(): + expires_at = time.time() # Immediately expired. + else: + expires_at = time.time() + max_age if max_age > 0 else 0 + + if (etag is not None or max_age > 0) and expires_at > time.time(): + self.cache.set( + url, + CacheEntry( + data=response.content, + etag=etag, + expires_at=expires_at, + url=response.url, + ), + ) + response.headers["X-Cache"] = "MISS" + + def _send_request(self, method: str, url: str, **kwargs) -> Response: + """ + Hook method to perform the actual HTTP request. + """ + return self.session.request(method, url, **kwargs) + @retry( stop=stop_after_delay(8), wait=wait_exponential(multiplier=1, min=0.05, max=2), retry=retry_if_exception_type((RequestException, ConnectionError, OSError)), ) def resilient_request( - self, path, method="GET", hosts: Optional[list[str]] = None, **kwargs + self, + path, + method="GET", + allow_cache: bool = False, + hosts: list[str] = None, + **kwargs, ) -> Response: - # Get the current attempt number from tenacity's context - attempt_number = self.resilient_request.statistics["attempt_number"] - host = self.get_host( - attempt_number - 1, hosts or self.hosts - ) # Subtract 1 because attempt_number starts at 1 - url = f"{host.rstrip('/')}/{path.lstrip('/')}" + """ + Makes a resilient (retrying) request. - try: - logger.info(f"Attempt {attempt_number}: Requesting {url}") - response = self.session.request(method, url, **kwargs) - response.raise_for_status() - logger.info(f"Attempt {attempt_number}: Successful request to {url}") - return response - except (RequestException, ConnectionError) as e: - logger.warning( - f"Attempt {attempt_number}: Request to {url} failed: {str(e)}. Will retry" - ) - raise - except OSError as e: - if isinstance(e, socket.gaierror): - logger.warning( - f"Attempt {attempt_number}: DNS resolution failed for {url}: {str(e)}. Will retry" - ) - raise - else: - logger.error( - f"Attempt {attempt_number}: Non-retryable error occurred: {str(e)}" - ) - raise + If allow_cache is True and the request method is GET, caching logic is applied. + This includes: + - Checking the cache and returning a synthetic response if valid. + - Adding an 'If-None-Match' header when a stale entry exists. + - Handling a 304 (Not Modified) response by returning the cached entry. + - Caching a 200 response if Cache-Control permits. + """ + url = self._build_url(path, hosts) + if method.upper() == "GET" and allow_cache: + cached = self._get_cached_response(url) + if cached: + return cached + kwargs = self._apply_cache_headers(url, kwargs) + response = self._send_request(method, url, **kwargs) + if method.upper() == "GET" and allow_cache: + if response.status_code == 304: + cached = self.cache.get(url) + if cached: + resp = Response() + resp._content = cached.data + resp.status_code = 200 + resp.headers = {"ETag": cached.etag, "X-Cache": "HIT"} + resp.url = cached.url + return resp + self._update_cache(url, response) + return response diff --git a/prefab_cloud_python/config_client.py b/prefab_cloud_python/config_client.py index a7d20d4..856a304 100644 --- a/prefab_cloud_python/config_client.py +++ b/prefab_cloud_python/config_client.py @@ -157,8 +157,12 @@ def load_initial_data(self): def load_checkpoint_from_api_cdn(self): try: + hwm = self.config_loader.highwater_mark response = self.api_client.resilient_request( - "/api/v1/configs/0", auth=("authuser", self.options.api_key), timeout=4 + "/api/v1/configs/" + str(hwm), + auth=("authuser", self.options.api_key), + timeout=4, + allow_cache=True, ) if response.ok: configs = Prefab.Configs.FromString(response.content) diff --git a/tests/test_api_client.py b/tests/test_api_client.py new file mode 100644 index 0000000..8e18a18 --- /dev/null +++ b/tests/test_api_client.py @@ -0,0 +1,149 @@ +import unittest +from unittest.mock import patch +from requests import Response +from prefab_cloud_python._requests import ApiClient, CacheEntry +import time + + +# Dummy options for testing. +class DummyOptions: + prefab_api_urls = ["https://a.example.com", "https://b.example.com"] + version = "1.0" + + +class TestApiClient(unittest.TestCase): + def setUp(self): + self.options = DummyOptions() + self.client = ApiClient(self.options) + # Instead of setting statistics on resilient_request, + # patch _get_attempt_number to always return 1. + self.client._get_attempt_number = lambda: 1 + + def create_response( + self, + status_code=200, + content=b"dummy", + headers=None, + url="https://a.example.com/api/v1/configs/0", + ): + resp = Response() + resp.status_code = status_code + resp._content = content + resp.url = url + resp.headers = headers or {} + return resp + + @patch.object(ApiClient, "_send_request") + def test_no_cache(self, mock_send_request): + # Test that when allow_cache is False, caching is bypassed. + response = self.create_response( + status_code=200, + content=b"response_no_cache", + headers={"Cache-Control": "max-age=60", "ETag": "abc"}, + url="https://a.example.com/api/v1/configs/0", + ) + mock_send_request.return_value = response + + resp = self.client.resilient_request("/api/v1/configs/0", allow_cache=False) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.content, b"response_no_cache") + self.assertNotIn("X-Cache", resp.headers) + + @patch.object(ApiClient, "_send_request") + def test_cache_miss_and_hit(self, mock_send_request): + # First call should cache the response (MISS). + response = self.create_response( + status_code=200, + content=b"cached_response", + headers={"Cache-Control": "max-age=60", "ETag": "abc"}, + url="https://a.example.com/api/v1/configs/0", + ) + mock_send_request.return_value = response + + resp1 = self.client.resilient_request("/api/v1/configs/0", allow_cache=True) + self.assertEqual(resp1.status_code, 200) + self.assertEqual(resp1.content, b"cached_response") + self.assertEqual(resp1.headers.get("X-Cache"), "MISS") + + # Change the mock so that a new network call would return different content. + new_response = self.create_response( + status_code=200, + content=b"new_response", + headers={"Cache-Control": "max-age=60", "ETag": "def"}, + url="https://a.example.com/api/v1/configs/0", + ) + mock_send_request.return_value = new_response + + # Second call should return the cached response. + resp2 = self.client.resilient_request("/api/v1/configs/0", allow_cache=True) + self.assertEqual(resp2.status_code, 200) + self.assertEqual(resp2.content, b"cached_response") + self.assertEqual(resp2.headers.get("X-Cache"), "HIT") + + @patch.object(ApiClient, "_send_request") + def test_304_returns_cached_response(self, mock_send_request): + # First, cache a 200 response. + response = self.create_response( + status_code=200, + content=b"cached_response", + headers={"Cache-Control": "max-age=60", "ETag": "abc"}, + url="https://a.example.com/api/v1/configs/0", + ) + mock_send_request.return_value = response + resp1 = self.client.resilient_request("/api/v1/configs/0", allow_cache=True) + self.assertEqual(resp1.status_code, 200) + self.assertEqual(resp1.content, b"cached_response") + self.assertEqual(resp1.headers.get("X-Cache"), "MISS") + + # Now simulate a 304 Not Modified response. + response_304 = self.create_response( + status_code=304, + content=b"", + headers={}, + url="https://a.example.com/api/v1/configs/0", + ) + mock_send_request.return_value = response_304 + resp2 = self.client.resilient_request("/api/v1/configs/0", allow_cache=True) + self.assertEqual(resp2.status_code, 200) + self.assertEqual(resp2.content, b"cached_response") + self.assertEqual(resp2.headers.get("X-Cache"), "HIT") + + @patch.object(ApiClient, "_send_request") + def test_if_none_match_header_added(self, mock_send_request): + # Pre-populate the cache with a stale entry. + # Set the expires_at to a time in the past so that it's considered stale. + stale_time = time.time() - 10 + cache_url = "https://a.example.com/api/v1/configs/0" + self.client.cache.set( + cache_url, + CacheEntry( + data=b"old_response", + etag="abc123", + expires_at=stale_time, + url=cache_url, + ), + ) + + # Prepare a dummy 304 Not Modified response. + response_304 = self.create_response( + status_code=304, content=b"", headers={}, url=cache_url + ) + mock_send_request.return_value = response_304 + + # Call resilient_request with caching enabled. + resp = self.client.resilient_request("/api/v1/configs/0", allow_cache=True) + + # Verify that _send_request was called with an "If-None-Match" header. + args, kwargs = mock_send_request.call_args + headers = kwargs.get("headers", {}) + self.assertIn("If-None-Match", headers) + self.assertEqual(headers["If-None-Match"], "abc123") + + # Also, the response should be synthesized from the cache (a HIT). + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.content, b"old_response") + self.assertEqual(resp.headers.get("X-Cache"), "HIT") + + +if __name__ == "__main__": + unittest.main()