diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 140f2b88914..b63270ef7c1 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -3,11 +3,11 @@ from ssl import SSLContext from types import TracebackType from typing import Any, Dict, List, Optional, Type, Union +from uuid import uuid4 -from httpx import Timeout +from httpx import Timeout, codes from firebolt.async_db.cursor import Cursor, CursorV1, CursorV2 -from firebolt.async_db.util import _get_system_engine_url_and_params from firebolt.client import DEFAULT_API_URL from firebolt.client.auth import Auth from firebolt.client.auth.base import FireboltAuthVersion @@ -20,21 +20,30 @@ AsyncQueryInfo, BaseConnection, _parse_async_query_info_results, + get_cached_system_engine_info, + get_user_agent_for_connection, + set_cached_system_engine_info, ) -from firebolt.common.cache import _firebolt_system_engine_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( + AccountNotFoundOrNoAccessError, ConfigurationError, ConnectionClosedError, FireboltError, + InterfaceError, ) from firebolt.utils.firebolt_core import ( get_core_certificate_context, parse_firebolt_core_url, validate_firebolt_core_parameters, ) -from firebolt.utils.usage_tracker import get_user_agent_header -from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME +from firebolt.utils.util import ( + fix_url_schema, + parse_url_and_params, + validate_engine_name_and_url_v1, +) class Connection(BaseConnection): @@ -71,6 +80,7 @@ class Connection(BaseConnection): "_is_closed", "client_class", "cursor_type", + "id", ) def __init__( @@ -81,12 +91,14 @@ def __init__( cursor_type: Type[Cursor], api_endpoint: str, init_parameters: Optional[Dict[str, Any]] = None, + id: str = uuid4().hex, ): super().__init__(cursor_type) self.api_endpoint = api_endpoint self.engine_url = engine_url self._cursors: List[Cursor] = [] self._client = client + self.id = id self.init_parameters = init_parameters or {} if database: self.init_parameters["database"] = database @@ -225,13 +237,13 @@ async def connect( if not auth: raise ConfigurationError("auth is required to connect.") + api_endpoint = fix_url_schema(api_endpoint) # Type checks assert auth is not None - user_drivers = additional_parameters.get("user_drivers", []) - user_clients = additional_parameters.get("user_clients", []) - user_agent_header = get_user_agent_header(user_drivers, user_clients) - if disable_cache: - _firebolt_system_engine_cache.disable() + connection_id = uuid4().hex + user_agent_header = get_user_agent_for_connection( + auth, connection_id, account_name, additional_parameters, disable_cache + ) # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials # Use V1 if auth is ServiceAccount or UsernamePassword @@ -254,6 +266,8 @@ async def connect( database=database, engine_name=engine_name, api_endpoint=api_endpoint, + connection_id=connection_id, + disable_cache=disable_cache, ) elif auth_version == FireboltAuthVersion.V1: return await connect_v1( @@ -264,6 +278,7 @@ async def connect( engine_name=engine_name, engine_url=engine_url, api_endpoint=api_endpoint, + connection_id=connection_id, ) else: raise ConfigurationError(f"Unsupported auth type: {type(auth)}") @@ -272,10 +287,12 @@ async def connect( async def connect_v2( auth: Auth, user_agent_header: str, + connection_id: str, account_name: Optional[str] = None, database: Optional[str] = None, engine_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, + disable_cache: bool = False, ) -> Connection: """Connect to Firebolt. @@ -301,10 +318,8 @@ async def connect_v2( assert auth is not None assert account_name is not None - api_endpoint = fix_url_schema(api_endpoint) - - system_engine_url, system_engine_params = await _get_system_engine_url_and_params( - auth, account_name, api_endpoint + system_engine_info = await _get_system_engine_url_and_params( + auth, account_name, api_endpoint, connection_id, disable_cache ) client = AsyncClientV2( @@ -316,19 +331,21 @@ async def connect_v2( ) async with Connection( - system_engine_url, + system_engine_info.url, None, client, CursorV2, api_endpoint, - system_engine_params, + system_engine_info.params, + connection_id, ) as system_engine_connection: cursor = system_engine_connection.cursor() + if database: - await cursor.execute(f'USE DATABASE "{database}"') + await cursor.use_database(database, cache=not disable_cache) if engine_name: - await cursor.execute(f'USE ENGINE "{engine_name}"') + await cursor.use_engine(engine_name, cache=not disable_cache) # Ensure cursors created from this connection are using the same starting # database and engine return Connection( @@ -338,12 +355,14 @@ async def connect_v2( CursorV2, api_endpoint, cursor.parameters, + connection_id, ) async def connect_v1( auth: Auth, user_agent_header: str, + connection_id: str, database: Optional[str] = None, account_name: Optional[str] = None, engine_name: Optional[str] = None, @@ -358,8 +377,6 @@ async def connect_v1( validate_engine_name_and_url_v1(engine_name, engine_url) - api_endpoint = fix_url_schema(api_endpoint) - no_engine_client = AsyncClientV1( auth=auth, base_url=api_endpoint, @@ -397,11 +414,7 @@ async def connect_v1( headers={"User-Agent": user_agent_header}, ) return Connection( - engine_url, - database, - client, - CursorV1, - api_endpoint, + engine_url, database, client, CursorV1, api_endpoint, id=connection_id ) @@ -448,3 +461,39 @@ def connect_core( cursor_type=CursorV2, api_endpoint=verified_url, ) + + +async def _get_system_engine_url_and_params( + auth: Auth, + account_name: str, + api_endpoint: str, + connection_id: str, + disable_cache: bool = False, +) -> EngineInfo: + cache_key, cached_result = get_cached_system_engine_info( + auth, account_name, disable_cache + ) + if cached_result: + return cached_result + + async with AsyncClientV2( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) + response = await client.get(url=url) + if response.status_code == codes.NOT_FOUND: + raise AccountNotFoundOrNoAccessError(account_name) + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve system engine endpoint {url}: " + f"{response.status_code} {response.content.decode()}" + ) + url, params = parse_url_and_params(response.json()["engineUrl"]) + + return set_cached_system_engine_info( + cache_key, connection_id, url, params, disable_cache + ) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 20fa2c1a7e8..dafc9a239fe 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -58,6 +58,7 @@ from firebolt.async_db.connection import Connection from firebolt.utils.async_util import anext, async_islice +from firebolt.utils.cache import ConnectionInfo, DatabaseInfo, EngineInfo from firebolt.utils.util import Timer, raise_error_from_response logger = logging.getLogger(__name__) @@ -85,8 +86,8 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - self._client = client self.connection = connection + self._client: AsyncClient = client self.engine_url = connection.engine_url self._row_set: Optional[BaseAsyncRowSet] = None if connection.init_parameters: @@ -332,6 +333,41 @@ async def _handle_query_execution( await self._parse_response_headers(resp.headers) await self._append_row_set_from_response(resp) + async def use_database(self, database: str, cache: bool = True) -> None: + """Switch the current database context with caching.""" + if cache: + cache_record = self.get_cache_record() + cache_record = ( + cache_record if cache_record else ConnectionInfo(id=self.connection.id) + ) + if cache_record.databases.get(database): + # If database is cached, use it + self.database = database + else: + await self.execute(f'USE DATABASE "{database}"') + cache_record.databases[database] = DatabaseInfo(database) + self.set_cache_record(cache_record) + else: + await self.execute(f'USE DATABASE "{database}"') + + async def use_engine(self, engine: str, cache: bool = True) -> None: + """Switch the current engine context with caching.""" + if cache: + cache_obj = self.get_cache_record() + cache_obj = ( + cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) + ) + if cache_obj.engines.get(engine): + # If engine is cached, use it + self.engine_url = cache_obj.engines[engine].url + self._update_set_parameters(cache_obj.engines[engine].params) + else: + await self.execute(f'USE ENGINE "{engine}"') + cache_obj.engines[engine] = EngineInfo(self.engine_url, self.parameters) + self.set_cache_record(cache_obj) + else: + await self.execute(f'USE ENGINE "{engine}"') + @check_not_closed async def execute( self, diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py deleted file mode 100644 index fb39c85a89d..00000000000 --- a/src/firebolt/async_db/util.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Tuple - -from httpx import Timeout, codes - -from firebolt.client.auth import Auth -from firebolt.client.client import AsyncClientV2 -from firebolt.common.cache import _firebolt_system_engine_cache -from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.exception import ( - AccountNotFoundOrNoAccessError, - InterfaceError, -) -from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import parse_url_and_params - - -async def _get_system_engine_url_and_params( - auth: Auth, - account_name: str, - api_endpoint: str, -) -> Tuple[str, Dict[str, str]]: - if result := _firebolt_system_engine_cache.get([account_name, api_endpoint]): - return result - async with AsyncClientV2( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) - response = await client.get(url=url) - if response.status_code == codes.NOT_FOUND: - raise AccountNotFoundOrNoAccessError(account_name) - if response.status_code != codes.OK: - raise InterfaceError( - f"Unable to retrieve system engine endpoint {url}: " - f"{response.status_code} {response.content.decode()}" - ) - result = parse_url_and_params(response.json()["engineUrl"]) - _firebolt_system_engine_cache.set( - key=[account_name, api_endpoint], value=result - ) - return result diff --git a/src/firebolt/client/auth/base.py b/src/firebolt/client/auth/base.py index 74df42202db..56af8026eb7 100644 --- a/src/firebolt/client/auth/base.py +++ b/src/firebolt/client/auth/base.py @@ -65,6 +65,24 @@ def token(self) -> Optional[str]: """ return self._token + @property + @abstractmethod + def principal(self) -> str: + """Get the principal (username or id) associated with the token. + + Returns: + str: Principal string + """ + + @property + @abstractmethod + def secret(self) -> str: + """Get the secret (password or secret key) associated with the token. + + Returns: + str: Secret string + """ + @abstractmethod def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/client/auth/client_credentials.py b/src/firebolt/client/auth/client_credentials.py index 92181fb9188..d729a581fb3 100644 --- a/src/firebolt/client/auth/client_credentials.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -53,6 +53,24 @@ def copy(self) -> "ClientCredentials": self.client_id, self.client_secret, self._use_token_cache ) + @property + def principal(self) -> str: + """Get the principal (client id) associated with this auth. + + Returns: + str: Principal client id + """ + return self.client_id + + @property + def secret(self) -> str: + """Get the secret (secret key) associated with this auth. + + Returns: + str: Secret + """ + return self.client_secret + def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/client/auth/firebolt_core.py b/src/firebolt/client/auth/firebolt_core.py index 8cffe64aa4f..039082caef0 100644 --- a/src/firebolt/client/auth/firebolt_core.py +++ b/src/firebolt/client/auth/firebolt_core.py @@ -27,6 +27,28 @@ def __init__(self) -> None: self._token = "" self._expires = None + @property + def principal(self) -> str: + """Get the principal associated with the auth. + + For FireboltCore, this returns an empty string since no auth is needed. + + Returns: + str: Placeholder string for principal (no auth needed) + """ + return "core" + + @property + def secret(self) -> str: + """Get the secret associated with the auth. + + For FireboltCore, this returns an empty string since no auth is needed. + + Returns: + str: Placeholder string for secret (no auth needed) + """ + return "core" + def copy(self) -> "FireboltCore": """Make another auth object with same URL. diff --git a/src/firebolt/client/auth/service_account.py b/src/firebolt/client/auth/service_account.py index 6d670c00d0a..7e5b91c4a4b 100644 --- a/src/firebolt/client/auth/service_account.py +++ b/src/firebolt/client/auth/service_account.py @@ -43,6 +43,24 @@ def __init__( self.client_secret = client_secret super().__init__(use_token_cache) + @property + def principal(self) -> str: + """Get the principal (client id) associated with the auth. + + Returns: + str: client id + """ + return self.client_id + + @property + def secret(self) -> str: + """Get the secret (client secret) associated with the auth. + + Returns: + str: client secret + """ + return self.client_secret + def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/client/auth/token.py b/src/firebolt/client/auth/token.py index 4ad7bd13ce3..223d4ec37ed 100644 --- a/src/firebolt/client/auth/token.py +++ b/src/firebolt/client/auth/token.py @@ -24,6 +24,24 @@ def __init__(self, token: str): super().__init__(use_token_cache=False) self._token = token + @property + def principal(self) -> str: + """Get the principal (placeholder) associated with the auth. + + Returns: + str: Principal (placeholder) + """ + return "token_principal" + + @property + def secret(self) -> str: + """Get the secret (token) associated with the auth. + + Returns: + str: Secret, which is the token itself + """ + return self._token or "token" + def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/client/auth/username_password.py b/src/firebolt/client/auth/username_password.py index 79f1548294a..29050641d16 100644 --- a/src/firebolt/client/auth/username_password.py +++ b/src/firebolt/client/auth/username_password.py @@ -43,6 +43,24 @@ def __init__( self.password = password super().__init__(use_token_cache) + @property + def principal(self) -> str: + """Get the principal (username) associated with the auth. + + Returns: + str: Principal username + """ + return self.username + + @property + def secret(self) -> str: + """Get the secret (password) associated with the auth. + + Returns: + str: Secret password + """ + return self.password + def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/common/base_connection.py b/src/firebolt/common/base_connection.py index 24880c60bd5..a39459be238 100644 --- a/src/firebolt/common/base_connection.py +++ b/src/firebolt/common/base_connection.py @@ -1,8 +1,19 @@ from collections import namedtuple -from typing import Any, List, Type +from typing import Any, Dict, List, Optional, Tuple, Type +from firebolt.client.auth.base import Auth from firebolt.common._types import ColType +from firebolt.utils.cache import ( + ConnectionInfo, + EngineInfo, + SecureCacheKey, + _firebolt_cache, +) from firebolt.utils.exception import ConnectionClosedError, FireboltError +from firebolt.utils.usage_tracker import ( + get_cache_tracking_params, + get_user_agent_header, +) ASYNC_QUERY_STATUS_RUNNING = "RUNNING" ASYNC_QUERY_STATUS_SUCCESSFUL = "ENDED_SUCCESSFULLY" @@ -75,3 +86,76 @@ def commit(self) -> None: if self.closed: raise ConnectionClosedError("Unable to commit: Connection closed.") + + +def get_cached_system_engine_info( + auth: Auth, + account_name: str, + disable_cache: bool = False, +) -> Tuple[SecureCacheKey, Optional[EngineInfo]]: + """ + Common cache retrieval logic for system engine info. + + Returns: + tuple: (cache_key, cached_engine_info_or_none) + """ + cache_key = SecureCacheKey([auth.principal, auth.secret, account_name], auth.secret) + + if disable_cache: + return cache_key, None + + cache = _firebolt_cache.get(cache_key) + cached_result = cache.system_engine if cache else None + + return cache_key, cached_result + + +def set_cached_system_engine_info( + cache_key: SecureCacheKey, + connection_id: str, + url: str, + params: dict, + disable_cache: bool = False, +) -> EngineInfo: + """ + Common cache setting logic for system engine info. + + Returns: + EngineInfo: The engine info that was cached (or created) + """ + + engine_info = EngineInfo(url=url, params=params) + + if not disable_cache: + cache = _firebolt_cache.get(cache_key) + if not cache: + cache = ConnectionInfo(id=connection_id) + cache.system_engine = engine_info + _firebolt_cache.set(cache_key, cache) + + return engine_info + + +def get_user_agent_for_connection( + auth: Auth, + connection_id: str, + account_name: Optional[str] = None, + additional_parameters: Dict[str, Any] = {}, + disable_cache: bool = False, +) -> str: + """ + Get the user agent string for the Firebolt connection. + + Returns: + str: The user agent string. + """ + user_drivers = additional_parameters.get("user_drivers", []) + user_clients = additional_parameters.get("user_clients", []) + ua_parameters = [] + if not disable_cache: + cache_key = SecureCacheKey( + [auth.principal, auth.secret, account_name], auth.secret + ) + ua_parameters = get_cache_tracking_params(cache_key, connection_id) + user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) + return user_agent_header diff --git a/src/firebolt/common/cache.py b/src/firebolt/common/cache.py deleted file mode 100644 index 66a13df0ede..00000000000 --- a/src/firebolt/common/cache.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -from typing import ( - Any, - Callable, - Dict, - Generic, - Optional, - Protocol, - Tuple, - TypeVar, -) - -T = TypeVar("T") - - -class ReprCacheable(Protocol): - def __repr__(self) -> str: - ... - - -def noop_if_disabled(func: Callable) -> Callable: - """Decorator to make function do nothing if the cache is disabled.""" - - def wrapper(self: "UtilCache", *args: Any, **kwargs: Any) -> Any: - if not self.disabled: - return func(self, *args, **kwargs) - - return wrapper - - -class UtilCache(Generic[T]): - """ - Generic cache implementation to store key-value pairs. - Created to abstract the cache implementation in case we find a better - solution in the future. - """ - - def __init__(self, cache_name: str = "") -> None: - self._cache: Dict[str, T] = {} - # Allow disabling cache if we have no direct access to the constructor - self.disabled = os.getenv("FIREBOLT_SDK_DISABLE_CACHE", False) or os.getenv( - f"FIREBOLT_SDK_DISABLE_CACHE_${cache_name}", False - ) - - def disable(self) -> None: - self.disabled = True - - def enable(self) -> None: - self.disabled = False - - def get(self, key: ReprCacheable) -> Optional[T]: - if self.disabled: - return None - s_key = self.create_key(key) - return self._cache.get(s_key) - - @noop_if_disabled - def set(self, key: ReprCacheable, value: T) -> None: - if not self.disabled: - s_key = self.create_key(key) - self._cache[s_key] = value - - @noop_if_disabled - def delete(self, key: ReprCacheable) -> None: - s_key = self.create_key(key) - if s_key in self._cache: - del self._cache[s_key] - - @noop_if_disabled - def clear(self) -> None: - self._cache.clear() - - def create_key(self, obj: ReprCacheable) -> str: - return repr(obj) - - def __contains__(self, key: str) -> bool: - """Support for 'in' operator to check if key is present in cache.""" - if self.disabled: - return False - return key in self._cache - - -_firebolt_system_engine_cache = UtilCache[Tuple[str, Dict[str, str]]]( - cache_name="system_engine" -) diff --git a/src/firebolt/common/cursor/base_cursor.py b/src/firebolt/common/cursor/base_cursor.py index da89fc0f404..381e43a90f6 100644 --- a/src/firebolt/common/cursor/base_cursor.py +++ b/src/firebolt/common/cursor/base_cursor.py @@ -8,6 +8,8 @@ from httpx import URL, Response +from firebolt.client.auth.base import Auth +from firebolt.client.client import AsyncClient, Client from firebolt.common._types import ParameterType, RawColType, SetParameter from firebolt.common.constants import ( DISALLOWED_PARAMETER_LIST, @@ -21,6 +23,11 @@ from firebolt.common.row_set.base import BaseRowSet from firebolt.common.row_set.types import AsyncResponse, Column, Statistics from firebolt.common.statement_formatter import StatementFormatter +from firebolt.utils.cache import ( + ConnectionInfo, + SecureCacheKey, + _firebolt_cache, +) from firebolt.utils.exception import ConfigurationError, FireboltError from firebolt.utils.util import fix_url_schema @@ -89,7 +96,10 @@ class BaseCursor: streaming_row_set_type: Type = BaseRowSet def __init__( - self, *args: Any, formatter: StatementFormatter, **kwargs: Any + self, + *args: Any, + formatter: StatementFormatter, + **kwargs: Any, ) -> None: self._arraysize = self.default_arraysize # These fields initialized here for type annotations purpose @@ -101,6 +111,7 @@ def __init__( self.engine_url = "" self._query_id = "" # not used self._query_token = "" + self._client: Optional[Union[Client, AsyncClient]] = None self._row_set: Optional[BaseRowSet] = None self._reset() @@ -314,3 +325,32 @@ def _get_output_format(is_streaming: bool) -> str: if is_streaming: return JSON_LINES_OUTPUT_FORMAT return JSON_OUTPUT_FORMAT + + def get_cache_record(self) -> Optional[ConnectionInfo]: + if not self._client or not self._client.auth: + return None + assert isinstance(self._client.auth, Auth) # Type check + cache_key = SecureCacheKey( + [ + self._client.auth.principal, + self._client.auth.secret, + self._client.account_name, + ], + self._client.auth.secret, + ) + record = _firebolt_cache.get(cache_key) + return record + + def set_cache_record(self, record: ConnectionInfo) -> None: + if not self._client or not self._client.auth: + return + assert isinstance(self._client.auth, Auth) # Type check + cache_key = SecureCacheKey( + [ + self._client.auth.principal, + self._client.auth.secret, + self._client.account_name, + ], + self._client.auth.secret, + ) + _firebolt_cache.set(cache_key, record) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index d678cacd2fa..72021ce07ff 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -4,9 +4,10 @@ from ssl import SSLContext from types import TracebackType from typing import Any, Dict, List, Optional, Type, Union +from uuid import uuid4 from warnings import warn -from httpx import Timeout +from httpx import Timeout, codes from firebolt.client import DEFAULT_API_URL, Client, ClientV1, ClientV2 from firebolt.client.auth import Auth @@ -19,23 +20,31 @@ AsyncQueryInfo, BaseConnection, _parse_async_query_info_results, + get_cached_system_engine_info, + get_user_agent_for_connection, + set_cached_system_engine_info, ) -from firebolt.common.cache import _firebolt_system_engine_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.db.cursor import Cursor, CursorV1, CursorV2 -from firebolt.db.util import _get_system_engine_url_and_params +from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( + AccountNotFoundOrNoAccessError, ConfigurationError, ConnectionClosedError, FireboltError, + InterfaceError, ) from firebolt.utils.firebolt_core import ( get_core_certificate_context, parse_firebolt_core_url, validate_firebolt_core_parameters, ) -from firebolt.utils.usage_tracker import get_user_agent_header -from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME +from firebolt.utils.util import ( + fix_url_schema, + parse_url_and_params, + validate_engine_name_and_url_v1, +) logger = logging.getLogger(__name__) @@ -57,14 +66,14 @@ def connect( if not auth: raise ConfigurationError("auth is required to connect.") + api_endpoint = fix_url_schema(api_endpoint) # Type checks assert auth is not None - user_drivers = additional_parameters.get("user_drivers", []) - user_clients = additional_parameters.get("user_clients", []) - user_agent_header = get_user_agent_header(user_drivers, user_clients) + connection_id = uuid4().hex + user_agent_header = get_user_agent_for_connection( + auth, connection_id, account_name, additional_parameters, disable_cache + ) auth_version = auth.get_firebolt_version() - if disable_cache: - _firebolt_system_engine_cache.disable() # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials # Use V1 if auth is ServiceAccount or UsernamePassword @@ -87,6 +96,8 @@ def connect( database=database, engine_name=engine_name, api_endpoint=api_endpoint, + connection_id=connection_id, + disable_cache=disable_cache, ) elif auth_version == FireboltAuthVersion.V1: return connect_v1( @@ -97,6 +108,7 @@ def connect( engine_name=engine_name, engine_url=engine_url, api_endpoint=api_endpoint, + connection_id=connection_id, ) else: raise ConfigurationError(f"Unsupported auth type: {type(auth)}") @@ -105,10 +117,12 @@ def connect( def connect_v2( auth: Auth, user_agent_header: str, + connection_id: str, account_name: Optional[str] = None, database: Optional[str] = None, engine_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, + disable_cache: bool = False, ) -> Connection: """Connect to Firebolt. @@ -134,10 +148,8 @@ def connect_v2( assert auth is not None assert account_name is not None - api_endpoint = fix_url_schema(api_endpoint) - - system_engine_url, system_engine_params = _get_system_engine_url_and_params( - auth, account_name, api_endpoint + system_engine_info = _get_system_engine_url_and_params( + auth, account_name, api_endpoint, connection_id, disable_cache ) client = ClientV2( @@ -149,19 +161,20 @@ def connect_v2( ) with Connection( - system_engine_url, + system_engine_info.url, None, client, CursorV2, api_endpoint, - system_engine_params, + system_engine_info.params, + connection_id, ) as system_engine_connection: cursor = system_engine_connection.cursor() if database: - cursor.execute(f'USE DATABASE "{database}"') + cursor.use_database(database, cache=not disable_cache) if engine_name: - cursor.execute(f'USE ENGINE "{engine_name}"') + cursor.use_engine(engine_name, cache=not disable_cache) # Ensure cursors created from this connection are using the same starting # database and engine return Connection( @@ -171,6 +184,7 @@ def connect_v2( CursorV2, api_endpoint, cursor.parameters, + connection_id, ) @@ -201,6 +215,7 @@ class Connection(BaseConnection): "_is_closed", "client_class", "cursor_type", + "id", ) def __init__( @@ -211,12 +226,14 @@ def __init__( cursor_type: Type[Cursor], api_endpoint: str = DEFAULT_API_URL, init_parameters: Optional[Dict[str, Any]] = None, + id: str = uuid4().hex, ): super().__init__(cursor_type) self.api_endpoint = api_endpoint self.engine_url = engine_url self._cursors: List[Cursor] = [] self._client = client + self.id = id self.init_parameters = init_parameters or {} if database: self.init_parameters["database"] = database @@ -348,6 +365,7 @@ def __del__(self) -> None: def connect_v1( auth: Auth, user_agent_header: str, + connection_id: str, database: Optional[str] = None, account_name: Optional[str] = None, engine_name: Optional[str] = None, @@ -362,8 +380,6 @@ def connect_v1( validate_engine_name_and_url_v1(engine_name, engine_url) - api_endpoint = fix_url_schema(api_endpoint) - # Override tcp keepalive settings for connection no_engine_client = ClientV1( auth=auth, @@ -401,7 +417,9 @@ def connect_v1( timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), headers={"User-Agent": user_agent_header}, ) - return Connection(engine_url, database, client, CursorV1, api_endpoint) + return Connection( + engine_url, database, client, CursorV1, api_endpoint, id=connection_id + ) def connect_core( @@ -448,3 +466,39 @@ def connect_core( cursor_type=CursorV2, api_endpoint=verified_url, ) + + +def _get_system_engine_url_and_params( + auth: Auth, + account_name: str, + api_endpoint: str, + connection_id: str, + disable_cache: bool = False, +) -> EngineInfo: + cache_key, cached_result = get_cached_system_engine_info( + auth, account_name, disable_cache + ) + if cached_result: + return cached_result + + with ClientV2( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) + response = client.get(url=url) + if response.status_code == codes.NOT_FOUND: + raise AccountNotFoundOrNoAccessError(account_name) + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve system engine endpoint {url}: " + f"{response.status_code} {response.content.decode()}" + ) + url, params = parse_url_and_params(response.json()["engineUrl"]) + + return set_cached_system_engine_info( + cache_key, connection_id, url, params, disable_cache + ) diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index d8d913894e7..b8e1be97840 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -50,6 +50,7 @@ from firebolt.common.row_set.synchronous.in_memory import InMemoryRowSet from firebolt.common.row_set.synchronous.streaming import StreamingRowSet from firebolt.common.statement_formatter import create_statement_formatter +from firebolt.utils.cache import ConnectionInfo, DatabaseInfo, EngineInfo from firebolt.utils.exception import ( EngineNotRunningError, FireboltDatabaseError, @@ -91,7 +92,7 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - self._client = client + self._client: Client = client self.connection = connection self.engine_url = connection.engine_url self._row_set: Optional[BaseSyncRowSet] = None @@ -338,6 +339,41 @@ def _handle_query_execution( self._parse_response_headers(resp.headers) self._append_row_set_from_response(resp) + def use_database(self, database: str, cache: bool = True) -> None: + """Switch the current database context with caching.""" + if cache: + cache_record = self.get_cache_record() + cache_record = ( + cache_record if cache_record else ConnectionInfo(id=self.connection.id) + ) + if cache_record.databases.get(database): + # If database is cached, use it + self.database = database + else: + self.execute(f'USE DATABASE "{database}"') + cache_record.databases[database] = DatabaseInfo(database) + self.set_cache_record(cache_record) + else: + self.execute(f'USE DATABASE "{database}"') + + def use_engine(self, engine: str, cache: bool = True) -> None: + """Switch the current engine context with caching.""" + if cache: + cache_obj = self.get_cache_record() + cache_obj = ( + cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) + ) + if cache_obj.engines.get(engine): + # If engine is cached, use it + self.engine_url = cache_obj.engines[engine].url + self._update_set_parameters(cache_obj.engines[engine].params) + else: + self.execute(f'USE ENGINE "{engine}"') + cache_obj.engines[engine] = EngineInfo(self.engine_url, self.parameters) + self.set_cache_record(cache_obj) + else: + self.execute(f'USE ENGINE "{engine}"') + @check_not_closed def execute( self, diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py deleted file mode 100644 index 8cc12be6b2d..00000000000 --- a/src/firebolt/db/util.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Tuple - -from httpx import Timeout, codes - -from firebolt.client import ClientV2 -from firebolt.client.auth import Auth -from firebolt.common.cache import _firebolt_system_engine_cache -from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.exception import ( - AccountNotFoundOrNoAccessError, - InterfaceError, -) -from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import parse_url_and_params - - -def _get_system_engine_url_and_params( - auth: Auth, - account_name: str, - api_endpoint: str, -) -> Tuple[str, Dict[str, str]]: - if result := _firebolt_system_engine_cache.get([account_name, api_endpoint]): - return result - with ClientV2( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) - response = client.get(url=url) - if response.status_code == codes.NOT_FOUND: - raise AccountNotFoundOrNoAccessError(account_name) - if response.status_code != codes.OK: - raise InterfaceError( - f"Unable to retrieve system engine endpoint {url}: " - f"{response.status_code} {response.content.decode()}" - ) - result = parse_url_and_params(response.json()["engineUrl"]) - _firebolt_system_engine_cache.set( - key=[account_name, api_endpoint], value=result - ) - return result diff --git a/src/firebolt/utils/cache.py b/src/firebolt/utils/cache.py new file mode 100644 index 00000000000..e8f0d8bd9d1 --- /dev/null +++ b/src/firebolt/utils/cache.py @@ -0,0 +1,153 @@ +import os +import time +from dataclasses import dataclass, field +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Protocol, + TypeVar, +) + +T = TypeVar("T") + +# Cache expiry configuration +CACHE_EXPIRY_SECONDS = 3600 # 1 hour + + +class ReprCacheable(Protocol): + def __repr__(self) -> str: + ... + + +@dataclass +class EngineInfo: + """Class to hold engine information for caching.""" + + url: str + params: Dict[str, str] + + +@dataclass +class DatabaseInfo: + """Class to hold database information for caching.""" + + name: str + + +@dataclass +class ConnectionInfo: + """Class to hold connection information for caching.""" + + id: str + expiry_time: Optional[int] = None + system_engine: Optional[EngineInfo] = None + databases: Dict[str, DatabaseInfo] = field(default_factory=dict) + engines: Dict[str, EngineInfo] = field(default_factory=dict) + + +def noop_if_disabled(func: Callable) -> Callable: + """Decorator to make function do nothing if the cache is disabled.""" + + def wrapper(self: "UtilCache", *args: Any, **kwargs: Any) -> Any: + if not self.disabled: + return func(self, *args, **kwargs) + + return wrapper + + +class UtilCache(Generic[T]): + """ + Generic cache implementation to store key-value pairs. + Created to abstract the cache implementation in case we find a better + solution in the future. + """ + + def __init__(self, cache_name: str = "") -> None: + self._cache: Dict[str, T] = {} + # Allow disabling cache if we have no direct access to the constructor + self.disabled = os.getenv("FIREBOLT_SDK_DISABLE_CACHE", False) or os.getenv( + f"FIREBOLT_SDK_DISABLE_CACHE_${cache_name}", False + ) + + def disable(self) -> None: + self.disabled = True + + def enable(self) -> None: + self.disabled = False + + def get(self, key: ReprCacheable) -> Optional[T]: + if self.disabled: + return None + s_key = self.create_key(key) + value = self._cache.get(s_key) + + if value is not None and self._is_expired(value): + # Cache miss due to expiry - delete the expired item + del self._cache[s_key] + return None + + return value + + def _is_expired(self, value: T) -> bool: + """Check if a cached value has expired.""" + # Only check expiry for ConnectionInfo objects that have expiry_time + if hasattr(value, "expiry_time") and value.expiry_time is not None: + current_time = int(time.time()) + return current_time >= value.expiry_time + return False + + @noop_if_disabled + def set(self, key: ReprCacheable, value: T) -> None: + if not self.disabled: + # Set expiry_time for ConnectionInfo objects + if hasattr(value, "expiry_time"): + current_time = int(time.time()) + value.expiry_time = current_time + CACHE_EXPIRY_SECONDS + + s_key = self.create_key(key) + self._cache[s_key] = value + + @noop_if_disabled + def delete(self, key: ReprCacheable) -> None: + s_key = self.create_key(key) + if s_key in self._cache: + del self._cache[s_key] + + @noop_if_disabled + def clear(self) -> None: + self._cache.clear() + + def create_key(self, obj: ReprCacheable) -> str: + return repr(obj) + + def __contains__(self, key: str) -> bool: + """Support for 'in' operator to check if key is present in cache.""" + if self.disabled: + return False + return key in self._cache + + +class SecureCacheKey(ReprCacheable): + """A secure cache key that can be used for caching sensitive information.""" + + def __init__(self, key_elements: List[Optional[str]], encryption_key: str): + self.key = "#".join(str(e) for e in key_elements) + self.encryption_key = encryption_key + + def __repr__(self) -> str: + return f"SecureCacheKey({self.key})" + + def __eq__(self, other: object) -> bool: + if isinstance(other, SecureCacheKey): + return self.key == other.key + return False + + def __hash__(self) -> int: + return hash(self.key) + + +_firebolt_cache = UtilCache[ConnectionInfo](cache_name="connection_info") diff --git a/src/firebolt/utils/usage_tracker.py b/src/firebolt/utils/usage_tracker.py index d56a7c071a2..83b30b6e9cf 100644 --- a/src/firebolt/utils/usage_tracker.py +++ b/src/firebolt/utils/usage_tracker.py @@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Tuple from firebolt import __version__ +from firebolt.utils.cache import ConnectionInfo, ReprCacheable, _firebolt_cache @dataclass @@ -161,7 +162,11 @@ def detect_connectors( return connectors -def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> str: +def format_as_user_agent( + drivers: Dict[str, str], + clients: Dict[str, str], + additional_properties: List[Tuple[str, str]], +) -> str: """ Return a representation of a stored tracking data as a user-agent header. @@ -172,7 +177,12 @@ def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> st String of the current detected connector stack. """ py, sdk, os, ciso = get_sdk_properties() - sdk_format = f"PythonSDK/{sdk} (Python {py}; {os}; {ciso})" + formatted_properties = "; ".join( + [f"{key}:{value}" for key, value in (additional_properties)] + ) + if formatted_properties: + formatted_properties = f"; {formatted_properties}" + sdk_format = f"PythonSDK/{sdk} (Python {py}; {os}; {ciso}{formatted_properties})" driver_format = "".join( [f" {connector}/{version}" for connector, version in drivers.items()] ) @@ -185,6 +195,7 @@ def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> st def get_user_agent_header( user_drivers: Optional[List[Tuple[str, str]]] = None, user_clients: Optional[List[Tuple[str, str]]] = None, + additional_properties: Optional[List[Tuple[str, str]]] = None, ) -> str: """ Return a user agent header with connector stack and system information. @@ -213,4 +224,18 @@ def get_user_agent_header( clients[name] = version for name, version in versions.drivers: drivers[name] = version - return format_as_user_agent(drivers, clients) + return format_as_user_agent(drivers, clients, additional_properties or []) + + +def get_cache_tracking_params( + cache_key: ReprCacheable, conn_id: str +) -> List[Tuple[str, str]]: + ua_parameters = [] + ua_parameters.append(("connId", conn_id)) + cache = _firebolt_cache.get(cache_key) + if cache: + ua_parameters.append(("cachedConnId", cache.id + "-memory")) + else: + _firebolt_cache.set(cache_key, ConnectionInfo(id=conn_id)) + + return ua_parameters diff --git a/tests/integration/dbapi/async/V2/conftest.py b/tests/integration/dbapi/async/V2/conftest.py index 81ac291233a..2167b8edc22 100644 --- a/tests/integration/dbapi/async/V2/conftest.py +++ b/tests/integration/dbapi/async/V2/conftest.py @@ -24,12 +24,9 @@ async def connection( ) -> Connection: if request.param == "core": kwargs = { - "engine_name": None, "database": "firebolt", "auth": core_auth, "url": core_url, - "account_name": None, - "api_endpoint": None, } else: kwargs = { diff --git a/tests/integration/dbapi/async/V2/test_queries_async.py b/tests/integration/dbapi/async/V2/test_queries_async.py index e16c7808f02..e2042e41834 100644 --- a/tests/integration/dbapi/async/V2/test_queries_async.py +++ b/tests/integration/dbapi/async/V2/test_queries_async.py @@ -600,6 +600,80 @@ async def test_fb_numeric_paramstyle_incorrect_params( ) +async def test_engine_switch( + database_name: str, + connection_system_engine: Connection, + auth: Auth, + account_name: str, + api_endpoint: str, + engine_name: str, +) -> None: + system_cursor = connection_system_engine.cursor() + await system_cursor.execute("SELECT current_engine()") + result = await system_cursor.fetchone() + assert ( + result[0] == "system" + ), f"Incorrect setup - system engine cursor points at {result[0]}" + async with await connect( + engine_name=engine_name, + database=database_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: + cursor = connection.cursor() + await cursor.execute("SELECT current_engine()") + result = await cursor.fetchone() + assert result[0] == engine_name, "Engine switch failed" + # Test switching back to system engine + await cursor.execute("USE ENGINE system") + await cursor.execute("SELECT current_engine()") + result = await cursor.fetchone() + assert result[0] == "system", "Switching back to system engine failed" + + +async def test_database_switch( + database_name: str, + connection_system_engine_no_db: Connection, + auth: Auth, + account_name: str, + api_endpoint: str, + engine_name: str, +) -> None: + system_cursor = connection_system_engine_no_db.cursor() + await system_cursor.execute("SELECT current_database()") + result = await system_cursor.fetchone() + assert ( + result[0] == "account_db" + ), f"Incorrect setup - system engine cursor points at {result[0]}" + async with await connect( + engine_name=engine_name, + database=database_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: + cursor = connection.cursor() + await cursor.execute("SELECT current_database()") + result = await cursor.fetchone() + assert result[0] == database_name, "Database switch failed" + try: + # Test switching back to system database + await system_cursor.execute( + f"CREATE DATABASE IF NOT EXISTS {database_name}_switch" + ) + await cursor.execute(f"USE DATABASE {database_name}_switch") + await cursor.execute("SELECT current_database()") + result = await cursor.fetchone() + assert ( + result[0] == f"{database_name}_switch" + ), "Switching back to switch database failed" + finally: + await system_cursor.execute( + f"DROP DATABASE IF EXISTS {database_name}_switch" + ) + + async def test_select_quoted_decimal( connection: Connection, long_decimal_value: str, long_value_decimal_sql: str ): diff --git a/tests/integration/dbapi/sync/V2/conftest.py b/tests/integration/dbapi/sync/V2/conftest.py index 63ec51d5e9d..6c9addb21d3 100644 --- a/tests/integration/dbapi/sync/V2/conftest.py +++ b/tests/integration/dbapi/sync/V2/conftest.py @@ -24,12 +24,9 @@ def connection( ) -> Connection: if request.param == "core": kwargs = { - "engine_name": None, "database": "firebolt", "auth": core_auth, "url": core_url, - "account_name": None, - "api_endpoint": None, } else: kwargs = { diff --git a/tests/unit/V1/async_db/test_connection.py b/tests/unit/V1/async_db/test_connection.py index 4e1bd4da938..a3deac21994 100644 --- a/tests/unit/V1/async_db/test_connection.py +++ b/tests/unit/V1/async_db/test_connection.py @@ -1,6 +1,7 @@ from asyncio import run from re import Pattern from typing import Callable, List +from unittest.mock import ANY as AnyValue from unittest.mock import patch from httpx import codes @@ -401,7 +402,7 @@ async def test_connect_with_user_agent( query_url: str, access_token: str, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" httpx_mock.add_callback( query_callback, @@ -420,7 +421,9 @@ async def test_connect_with_user_agent( }, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) + ut.assert_called_once_with( + [("DriverA", "1.1")], [("MyConnector", "1.0")], AnyValue + ) async def test_connect_no_user_agent( @@ -432,7 +435,7 @@ async def test_connect_no_user_agent( query_url: str, access_token: str, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" httpx_mock.add_callback( query_callback, url=query_url, match_headers={"User-Agent": "Python/3.0"} @@ -445,7 +448,7 @@ async def test_connect_no_user_agent( api_endpoint=api_endpoint, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([], []) + ut.assert_called_once_with([], [], AnyValue) def test_from_asyncio( diff --git a/tests/unit/async_db/test_caching.py b/tests/unit/async_db/test_caching.py new file mode 100644 index 00000000000..cd035a9e456 --- /dev/null +++ b/tests/unit/async_db/test_caching.py @@ -0,0 +1,499 @@ +import time +from typing import Callable, Generator +from unittest.mock import patch + +from httpx import URL +from pytest import fixture, mark +from pytest_httpx import HTTPXMock + +from firebolt.async_db import connect +from firebolt.client.auth import Auth +from firebolt.utils.cache import CACHE_EXPIRY_SECONDS, _firebolt_cache + + +@fixture(autouse=True) +def use_cache(enable_cache) -> Generator[None, None, None]: + _firebolt_cache.clear() + yield # This fixture is used to ensure cache is enabled for all tests by default + _firebolt_cache.clear() + + +@fixture +async def connection_test( + api_endpoint: str, + auth: Auth, + account_name: str, +): + """Fixture to create a connection factory for testing.""" + + async def factory(db_name: str, engine_name: str, caching: bool) -> Callable: + + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not caching, + ) as connection: + await connection.cursor().execute("select*") + + return factory + + +@mark.parametrize("cache_enabled", [True, False]) +async def test_connect_caching( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, + connection_test: Callable, +): + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + for _ in range(3): + await connection_test(db_name, engine_name, cache_enabled) + + if cache_enabled: + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_database_call_counter == 1, "Use database URL was not cached" + assert use_engine_call_counter == 1, "Use engine URL was not cached" + else: + assert system_engine_call_counter != 1, "System engine URL was cached" + assert use_database_call_counter != 1, "Use database URL was cached" + assert use_engine_call_counter != 1, "Use engine URL was cached" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +async def test_connect_db_switching_caching( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, + connection_test: Callable, +): + """Test caching when switching between different databases.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + second_db_name = f"{db_name}_second" + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + + # First database + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + # Second database + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{second_db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # Connect to first database + await connection_test(db_name, engine_name, cache_enabled) + + first_db_calls = use_database_call_counter + + # Connect to second database + await connection_test(second_db_name, engine_name, cache_enabled) + + second_db_calls = use_database_call_counter - first_db_calls + + # Connect to first database again + await connection_test(db_name, engine_name, cache_enabled) + + third_db_calls = use_database_call_counter - first_db_calls - second_db_calls + + if cache_enabled: + assert second_db_calls == 1, "Second database call was not made" + assert third_db_calls == 0, "First database was not cached" + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_engine_call_counter == 1, "Use engine URL was not cached" + else: + assert second_db_calls == 1, "Second database call was not made" + assert third_db_calls == 1, "First database was cached when cache disabled" + assert ( + system_engine_call_counter == 3 + ), "System engine URL was cached when cache disabled" + assert ( + use_engine_call_counter == 3 + ), "Use engine URL was cached when cache disabled" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +async def test_connect_engine_switching_caching( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, + connection_test: Callable, +): + """Test caching when switching between different engines.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + second_engine_name = f"{engine_name}_second" + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + # First engine + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + + # Second engine + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{second_engine_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback(query_callback, url=query_url) + + # Connect to first engine + await connection_test(db_name, engine_name, cache_enabled) + + first_engine_calls = use_engine_call_counter + + # Connect to second engine + await connection_test(db_name, second_engine_name, cache_enabled) + + second_engine_calls = use_engine_call_counter - first_engine_calls + + # Connect to first engine again + await connection_test(db_name, engine_name, cache_enabled) + + third_engine_calls = ( + use_engine_call_counter - first_engine_calls - second_engine_calls + ) + + if cache_enabled: + assert second_engine_calls == 1, "Second engine call was not made" + assert third_engine_calls == 0, "First engine was not cached" + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_database_call_counter == 1, "Use database URL was not cached" + else: + assert second_engine_calls == 1, "Second engine call was not made" + assert third_engine_calls == 1, "First engine was cached when cache disabled" + assert ( + system_engine_call_counter == 3 + ), "System engine URL was cached when cache disabled" + assert ( + use_database_call_counter == 3 + ), "Use database URL was cached when cache disabled" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +async def test_connect_db_different_accounts( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: URL, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + api_endpoint: str, + auth: Auth, + account_name: str, + cache_enabled: bool, +): + """Test caching when switching between different databases.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + get_system_engine_url_new_account = str(get_system_engine_url).replace( + account_name, account_name + "_second" + ) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + system_engine_callback_counter, url=get_system_engine_url_new_account + ) + + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # First connection + + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + assert system_engine_call_counter == 1, "System engine URL was not called" + assert use_engine_call_counter == 1, "Use engine URL was not called" + assert use_database_call_counter == 1, "Use database URL was not called" + + # Second connection against different account + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name + "_second", + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + # This should trigger additional calls to the system engine URL and engine/database + assert ( + system_engine_call_counter == 2 + ), "System engine URL was not called for second account" + assert ( + use_engine_call_counter == 2 + ), "Use engine URL was not called for second account" + assert ( + use_database_call_counter == 2 + ), "Use database URL was not called for second account" + + +async def test_calls_when_cache_expired( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + connection_test: Callable, +): + """Test that expired cache entries trigger new backend requests.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # First connection - should populate cache + await connection_test(db_name, engine_name, True) # cache_enabled=True + + # Verify initial calls were made + assert system_engine_call_counter == 1, "System engine URL was not called initially" + assert use_database_call_counter == 1, "Use database URL was not called initially" + assert use_engine_call_counter == 1, "Use engine URL was not called initially" + + # Second connection immediately - should use cache + await connection_test(db_name, engine_name, True) + + # Verify no additional calls were made (cache hit) + assert ( + system_engine_call_counter == 1 + ), "System engine URL was called when cache should hit" + assert ( + use_database_call_counter == 1 + ), "Use database URL was called when cache should hit" + assert ( + use_engine_call_counter == 1 + ), "Use engine URL was called when cache should hit" + + # Mock time to simulate cache expiry (1 hour + 1 second past current time) + current_time = int(time.time()) + expired_time = current_time + CACHE_EXPIRY_SECONDS + 1 + + with patch("firebolt.utils.cache.time.time", return_value=expired_time): + # Third connection after cache expiry - should trigger new backend calls + await connection_test(db_name, engine_name, True) + + # Verify additional calls were made due to cache expiry + assert ( + system_engine_call_counter == 2 + ), "System engine URL was not called after cache expiry" + assert ( + use_database_call_counter == 2 + ), "Use database URL was not called after cache expiry" + assert ( + use_engine_call_counter == 2 + ), "Use engine URL was not called after cache expiry" diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index f1d357672b3..e3cc6915fb2 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -1,5 +1,6 @@ -from typing import Callable, List, Optional, Tuple -from unittest.mock import patch +from typing import Callable, Generator, List, Optional, Tuple +from unittest.mock import ANY as AnyValue +from unittest.mock import MagicMock, patch from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises @@ -8,7 +9,7 @@ from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType -from firebolt.common.cache import _firebolt_system_engine_cache +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, ConfigurationError, @@ -142,51 +143,8 @@ async def test_connect_database_failed( httpx_mock.reset(False) -async def test_connect_engine_failed( - db_name: str, - account_name: str, - engine_name: str, - auth: Auth, - api_endpoint: str, - python_query_data: List[List[ColType]], - httpx_mock: HTTPXMock, - system_engine_no_db_query_url: str, - use_database_callback: Callable, - system_engine_query_url: str, - use_engine_failed_callback: Callable, - mock_system_engine_connection_flow: Callable, - mock_query: Callable, -): - """connect properly handles use engine errors""" - mock_system_engine_connection_flow() - - httpx_mock.add_callback( - use_database_callback, - url=system_engine_no_db_query_url, - match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), - ) - - httpx_mock.add_callback( - use_engine_failed_callback, - url=system_engine_query_url, - match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), - ) - with raises(FireboltError): - async with await connect( - database=db_name, - auth=auth, - engine_name=engine_name, - account_name=account_name, - api_endpoint=api_endpoint, - ): - pass - - # Account id endpoint was not used since we didn't get to that point - httpx_mock.reset(False) - - @mark.parametrize("cache_enabled", [True, False]) -async def test_connect_caching( +async def test_connect_system_engine_caching( db_name: str, engine_name: str, auth_url: str, @@ -203,6 +161,7 @@ async def test_connect_caching( use_database_callback: Callable, use_engine_callback: Callable, query_callback: Callable, + enable_cache: Generator, cache_enabled: bool, ): system_engine_call_counter = 0 @@ -243,8 +202,48 @@ def system_engine_callback_counter(request, **kwargs): else: assert system_engine_call_counter != 1, "System engine URL was cached" - # Reset caches for the next test iteration - _firebolt_system_engine_cache.enable() + +async def test_connect_engine_failed( + db_name: str, + account_name: str, + engine_name: str, + auth: Auth, + api_endpoint: str, + python_query_data: List[List[ColType]], + httpx_mock: HTTPXMock, + system_engine_no_db_query_url: str, + use_database_callback: Callable, + system_engine_query_url: str, + use_engine_failed_callback: Callable, + mock_system_engine_connection_flow: Callable, + mock_query: Callable, +): + """connect properly handles use engine errors""" + mock_system_engine_connection_flow() + + httpx_mock.add_callback( + use_database_callback, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_failed_callback, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + with raises(FireboltError): + async with await connect( + database=db_name, + auth=auth, + engine_name=engine_name, + account_name=account_name, + api_endpoint=api_endpoint, + ): + pass + + # Account id endpoint was not used since we didn't get to that point + httpx_mock.reset(False) async def test_connect_system_engine_404( @@ -354,7 +353,7 @@ async def test_connect_with_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" mock_connection_flow() httpx_mock.add_callback( @@ -375,7 +374,7 @@ async def test_connect_with_user_agent( }, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) + ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")], AnyValue) async def test_connect_no_user_agent( @@ -390,7 +389,7 @@ async def test_connect_no_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" mock_connection_flow() httpx_mock.add_callback( @@ -405,7 +404,81 @@ async def test_connect_no_user_agent( api_endpoint=api_endpoint, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_with([], []) + ut.assert_called_with([], [], AnyValue) + + +async def test_connect_caching_headers( + engine_name: str, + account_name: str, + api_endpoint: str, + db_name: str, + auth: Auth, + httpx_mock: HTTPXMock, + query_callback: Callable, + query_url: str, + mock_connection_flow: Callable, + enable_cache: Generator, +) -> None: + async def do_connect(): + async with await connect( + auth=auth, + database=db_name, + engine_name=engine_name, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: + await connection.cursor().execute("select*") + + _firebolt_cache.clear() + mock_id = "12345" + mock_id2 = "67890" + mock_id3 = "54321" + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: + ut.side_effect = [ + f"connId:{mock_id}", + f"connId:{mock_id2}; cachedConnId:{mock_id}-memory", + f"connId:{mock_id3}", + ] + with patch("firebolt.async_db.connection.uuid4") as uuid4: + uuid4.side_effect = [ + MagicMock(hex=mock_id), + MagicMock(hex=mock_id2), + MagicMock(hex=mock_id3), + ] + mock_connection_flow() + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": f"connId:{mock_id}"}, + ) + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={ + "User-Agent": f"connId:{mock_id2}; cachedConnId:{mock_id}-memory" + }, + ) + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": f"connId:{mock_id3}"}, + ) + + await do_connect() + ut.assert_called_with(AnyValue, AnyValue, [("connId", mock_id)]) + + # Second call should use cached connection info + await do_connect() + ut.assert_called_with( + AnyValue, + AnyValue, + [("connId", mock_id2), ("cachedConnId", f"{mock_id}-memory")], + ) + _firebolt_cache.clear() + + # Third call should have a new connection id + await do_connect() + ut.assert_called_with(AnyValue, AnyValue, [("connId", mock_id3)]) @mark.parametrize( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7d77ca906a9..bf2713a5b58 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,5 +1,5 @@ import functools -from typing import Callable +from typing import Callable, Generator import httpx from httpx import Request @@ -8,8 +8,8 @@ from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.client import ClientV2 -from firebolt.common.cache import _firebolt_system_engine_cache from firebolt.common.settings import Settings +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( DatabaseError, DataError, @@ -43,7 +43,20 @@ def global_fake_fs(request) -> None: @fixture(autouse=True) def clear_cache() -> None: - _firebolt_system_engine_cache.clear() + _firebolt_cache.clear() + + +@fixture(autouse=True) +def disable_cache() -> None: + _firebolt_cache.disable() + + +@fixture +def enable_cache() -> Generator[None, None, None]: + """Fixture to enable cache for tests that require it.""" + _firebolt_cache.enable() + yield + _firebolt_cache.disable() @fixture diff --git a/tests/unit/db/test_caching.py b/tests/unit/db/test_caching.py new file mode 100644 index 00000000000..d4b57da91ce --- /dev/null +++ b/tests/unit/db/test_caching.py @@ -0,0 +1,499 @@ +import time +from typing import Callable, Generator +from unittest.mock import patch + +from httpx import URL +from pytest import fixture, mark +from pytest_httpx import HTTPXMock + +from firebolt.client.auth import Auth +from firebolt.db import connect +from firebolt.utils.cache import CACHE_EXPIRY_SECONDS, _firebolt_cache + + +@fixture(autouse=True) +def use_cache(enable_cache) -> Generator[None, None, None]: + _firebolt_cache.clear() + yield # This fixture is used to ensure cache is enabled for all tests by default + _firebolt_cache.clear() + + +@fixture +def connection_test( + api_endpoint: str, + auth: Auth, + account_name: str, +): + """Fixture to create a connection factory for testing.""" + + def factory(db_name: str, engine_name: str, caching: bool) -> Callable: + + with connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not caching, + ) as connection: + connection.cursor().execute("select*") + + return factory + + +@mark.parametrize("cache_enabled", [True, False]) +def test_connect_caching( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, + connection_test: Callable, +): + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + for _ in range(3): + connection_test(db_name, engine_name, cache_enabled) + + if cache_enabled: + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_database_call_counter == 1, "Use database URL was not cached" + assert use_engine_call_counter == 1, "Use engine URL was not cached" + else: + assert system_engine_call_counter != 1, "System engine URL was cached" + assert use_database_call_counter != 1, "Use database URL was cached" + assert use_engine_call_counter != 1, "Use engine URL was cached" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +def test_connect_db_switching_caching( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, + connection_test: Callable, +): + """Test caching when switching between different databases.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + second_db_name = f"{db_name}_second" + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + + # First database + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + # Second database + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{second_db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # Connect to first database + connection_test(db_name, engine_name, cache_enabled) + + first_db_calls = use_database_call_counter + + # Connect to second database + connection_test(second_db_name, engine_name, cache_enabled) + + second_db_calls = use_database_call_counter - first_db_calls + + # Connect to first database again + connection_test(db_name, engine_name, cache_enabled) + + third_db_calls = use_database_call_counter - first_db_calls - second_db_calls + + if cache_enabled: + assert second_db_calls == 1, "Second database call was not made" + assert third_db_calls == 0, "First database was not cached" + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_engine_call_counter == 1, "Use engine URL was not cached" + else: + assert second_db_calls == 1, "Second database call was not made" + assert third_db_calls == 1, "First database was cached when cache disabled" + assert ( + system_engine_call_counter == 3 + ), "System engine URL was cached when cache disabled" + assert ( + use_engine_call_counter == 3 + ), "Use engine URL was cached when cache disabled" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +def test_connect_engine_switching_caching( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, + connection_test: Callable, +): + """Test caching when switching between different engines.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + second_engine_name = f"{engine_name}_second" + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + # First engine + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + + # Second engine + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{second_engine_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback(query_callback, url=query_url) + + # Connect to first engine + connection_test(db_name, engine_name, cache_enabled) + + first_engine_calls = use_engine_call_counter + + # Connect to second engine + connection_test(db_name, second_engine_name, cache_enabled) + + second_engine_calls = use_engine_call_counter - first_engine_calls + + # Connect to first engine again + connection_test(db_name, engine_name, cache_enabled) + + third_engine_calls = ( + use_engine_call_counter - first_engine_calls - second_engine_calls + ) + + if cache_enabled: + assert second_engine_calls == 1, "Second engine call was not made" + assert third_engine_calls == 0, "First engine was not cached" + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_database_call_counter == 1, "Use database URL was not cached" + else: + assert second_engine_calls == 1, "Second engine call was not made" + assert third_engine_calls == 1, "First engine was cached when cache disabled" + assert ( + system_engine_call_counter == 3 + ), "System engine URL was cached when cache disabled" + assert ( + use_database_call_counter == 3 + ), "Use database URL was cached when cache disabled" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +def test_connect_db_different_accounts( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: URL, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + api_endpoint: str, + auth: Auth, + account_name: str, + cache_enabled: bool, +): + """Test caching when switching between different databases.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + get_system_engine_url_new_account = str(get_system_engine_url).replace( + account_name, account_name + "_second" + ) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + system_engine_callback_counter, url=get_system_engine_url_new_account + ) + + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # First connection + + with connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + connection.cursor().execute("select*") + + assert system_engine_call_counter == 1, "System engine URL was not called" + assert use_engine_call_counter == 1, "Use engine URL was not called" + assert use_database_call_counter == 1, "Use database URL was not called" + + # Second connection against different account + with connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name + "_second", + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + connection.cursor().execute("select*") + + # This should trigger additional calls to the system engine URL and engine/database + assert ( + system_engine_call_counter == 2 + ), "System engine URL was not called for second account" + assert ( + use_engine_call_counter == 2 + ), "Use engine URL was not called for second account" + assert ( + use_database_call_counter == 2 + ), "Use database URL was not called for second account" + + +def test_calls_when_cache_expired( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + connection_test: Callable, +): + """Test that expired cache entries trigger new backend requests.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # First connection - should populate cache + connection_test(db_name, engine_name, True) # cache_enabled=True + + # Verify initial calls were made + assert system_engine_call_counter == 1, "System engine URL was not called initially" + assert use_database_call_counter == 1, "Use database URL was not called initially" + assert use_engine_call_counter == 1, "Use engine URL was not called initially" + + # Second connection immediately - should use cache + connection_test(db_name, engine_name, True) + + # Verify no additional calls were made (cache hit) + assert ( + system_engine_call_counter == 1 + ), "System engine URL was called when cache should hit" + assert ( + use_database_call_counter == 1 + ), "Use database URL was called when cache should hit" + assert ( + use_engine_call_counter == 1 + ), "Use engine URL was called when cache should hit" + + # Mock time to simulate cache expiry (1 hour + 1 second past current time) + current_time = int(time.time()) + expired_time = current_time + CACHE_EXPIRY_SECONDS + 1 + + with patch("firebolt.utils.cache.time.time", return_value=expired_time): + # Third connection after cache expiry - should trigger new backend calls + connection_test(db_name, engine_name, True) + + # Verify additional calls were made due to cache expiry + assert ( + system_engine_call_counter == 2 + ), "System engine URL was not called after cache expiry" + assert ( + use_database_call_counter == 2 + ), "Use database URL was not called after cache expiry" + assert ( + use_engine_call_counter == 2 + ), "Use engine URL was not called after cache expiry" diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index c6e94ea1231..060a6d764fc 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -1,7 +1,8 @@ import gc import warnings -from typing import Callable, List, Optional, Tuple -from unittest.mock import patch +from typing import Callable, Generator, List, Optional, Tuple +from unittest.mock import ANY as AnyValue +from unittest.mock import MagicMock, patch from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises, warns @@ -10,9 +11,9 @@ from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.client import ClientV2 from firebolt.common._types import ColType -from firebolt.common.cache import _firebolt_system_engine_cache from firebolt.db import Connection, connect from firebolt.db.cursor import CursorV2 +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, ConfigurationError, @@ -190,7 +191,7 @@ def test_connect_engine_failed( @mark.parametrize("cache_enabled", [True, False]) -def test_connect_caching( +def test_connect_system_engine_caching( db_name: str, engine_name: str, auth_url: str, @@ -207,6 +208,7 @@ def test_connect_caching( use_database_callback: Callable, use_engine_callback: Callable, query_callback: Callable, + enable_cache: Generator, cache_enabled: bool, ): system_engine_call_counter = 0 @@ -247,9 +249,6 @@ def system_engine_callback_counter(request, **kwargs): else: assert system_engine_call_counter != 1, "System engine URL was cached" - # Reset caches for the next test iteration - _firebolt_system_engine_cache.enable() - def test_connect_system_engine_404( db_name: str, @@ -372,7 +371,7 @@ def test_connect_with_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" mock_connection_flow() httpx_mock.add_callback( @@ -393,7 +392,7 @@ def test_connect_with_user_agent( }, ) as connection: connection.cursor().execute("select*") - ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) + ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")], AnyValue) def test_connect_no_user_agent( @@ -408,7 +407,7 @@ def test_connect_no_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" mock_connection_flow() httpx_mock.add_callback( @@ -423,7 +422,81 @@ def test_connect_no_user_agent( api_endpoint=api_endpoint, ) as connection: connection.cursor().execute("select*") - ut.assert_called_with([], []) + ut.assert_called_with([], [], AnyValue) + + +def test_connect_caching_headers( + engine_name: str, + account_name: str, + api_endpoint: str, + db_name: str, + auth: Auth, + httpx_mock: HTTPXMock, + query_callback: Callable, + query_url: str, + mock_connection_flow: Callable, + enable_cache: Generator, +) -> None: + def do_connect(): + with connect( + auth=auth, + database=db_name, + engine_name=engine_name, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: + connection.cursor().execute("select*") + + _firebolt_cache.clear() + mock_id = "12345" + mock_id2 = "67890" + mock_id3 = "54321" + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: + ut.side_effect = [ + f"connId:{mock_id}", + f"connId:{mock_id2}; cachedConnId:{mock_id}-memory", + f"connId:{mock_id3}", + ] + with patch("firebolt.db.connection.uuid4") as uuid4: + uuid4.side_effect = [ + MagicMock(hex=mock_id), + MagicMock(hex=mock_id2), + MagicMock(hex=mock_id3), + ] + mock_connection_flow() + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": f"connId:{mock_id}"}, + ) + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={ + "User-Agent": f"connId:{mock_id2}; cachedConnId:{mock_id}-memory" + }, + ) + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": f"connId:{mock_id3}"}, + ) + + do_connect() + ut.assert_called_with(AnyValue, AnyValue, [("connId", mock_id)]) + + # Second call should use cached connection info + do_connect() + ut.assert_called_with( + AnyValue, + AnyValue, + [("connId", mock_id2), ("cachedConnId", f"{mock_id}-memory")], + ) + _firebolt_cache.clear() + + # Third call should have a new connection id + do_connect() + ut.assert_called_with(AnyValue, AnyValue, [("connId", mock_id3)]) @mark.parametrize( diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py new file mode 100644 index 00000000000..6fce309fbe3 --- /dev/null +++ b/tests/unit/utils/test_cache.py @@ -0,0 +1,353 @@ +import time +from typing import Generator +from unittest.mock import patch + +from pytest import fixture, mark + +from firebolt.utils.cache import ( + CACHE_EXPIRY_SECONDS, + ConnectionInfo, + SecureCacheKey, + UtilCache, +) + + +@fixture +def cache() -> Generator[UtilCache[ConnectionInfo], None, None]: + """Create a fresh cache instance for testing.""" + cache = UtilCache[ConnectionInfo](cache_name="test_cache") + cache.enable() # Ensure cache is enabled for tests + yield cache + cache.clear() + + +@fixture +def string_cache() -> Generator[UtilCache[str], None, None]: + """Create a fresh string cache instance for testing.""" + cache = UtilCache[str](cache_name="string_test_cache") + cache.enable() # Ensure cache is enabled for tests + yield cache + cache.clear() + + +@fixture +def disabled_cache() -> UtilCache[ConnectionInfo]: + """Create a disabled cache instance for testing.""" + cache = UtilCache[ConnectionInfo](cache_name="test_disabled_cache") + cache.disable() + return cache + + +@fixture +def sample_connection_info() -> ConnectionInfo: + """Create a sample ConnectionInfo for testing.""" + return ConnectionInfo(id="test_connection") + + +@fixture +def sample_connection_info_with_expiry() -> ConnectionInfo: + """Create a sample ConnectionInfo with explicit None expiry_time.""" + return ConnectionInfo(id="test_connection_with_expiry", expiry_time=None) + + +@fixture +def sample_cache_key() -> SecureCacheKey: + """Create a sample cache key for testing.""" + return SecureCacheKey(["test", "key"], "secret") + + +@fixture +def additional_cache_keys(): + """Create additional cache keys for multi-key tests.""" + return { + "key1": SecureCacheKey(["key1"], "secret"), + "key2": SecureCacheKey(["key2"], "secret"), + "key3": SecureCacheKey(["user", "other"], "secret"), + } + + +@fixture +def fixed_time(): + """Provide a fixed timestamp for consistent testing.""" + return 1000000 + + +@fixture +def test_string(): + """Provide a test string for non-ConnectionInfo cache tests.""" + return "test_value" + + +def test_cache_set_and_get(cache, sample_cache_key, sample_connection_info): + """Test basic cache set and get operations.""" + # Test cache miss initially + assert cache.get(sample_cache_key) is None + + # Set value and verify it's cached + cache.set(sample_cache_key, sample_connection_info) + cached_value = cache.get(sample_cache_key) + + assert cached_value is not None + assert cached_value.id == "test_connection" + assert cached_value.expiry_time is not None + + # Verify expiry_time is set to current time + CACHE_EXPIRY_SECONDS + current_time = int(time.time()) + expected_expiry = current_time + CACHE_EXPIRY_SECONDS + # Allow for small time difference due to test execution time + assert abs(cached_value.expiry_time - expected_expiry) <= 2 + + +def test_cache_expiry(cache, sample_cache_key, sample_connection_info, fixed_time): + """Test that cache entries expire after the specified time.""" + with patch("time.time", return_value=fixed_time): + # Set a value in the cache + cache.set(sample_cache_key, sample_connection_info) + cached_value = cache.get(sample_cache_key) + + # Verify it's cached and expiry_time is set + assert cached_value is not None + assert cached_value.expiry_time == fixed_time + CACHE_EXPIRY_SECONDS + + # Simulate time passing but not enough to expire (59 minutes) + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS - 60): + cached_value = cache.get(sample_cache_key) + assert cached_value is not None # Should still be cached + + # Simulate time passing to exactly the expiry time + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS): + cached_value = cache.get(sample_cache_key) + assert cached_value is None # Should be expired and removed + + # Verify the item is actually deleted from cache + assert cache.create_key(sample_cache_key) not in cache._cache + + +def test_cache_expiry_past_expiry_time( + cache, sample_cache_key, sample_connection_info, fixed_time +): + """Test that cache entries are removed when accessed after expiry time.""" + with patch("time.time", return_value=fixed_time): + cache.set(sample_cache_key, sample_connection_info) + + # Simulate time passing beyond expiry (2 hours) + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS + 3600): + cached_value = cache.get(sample_cache_key) + assert cached_value is None # Should be expired + + # Verify the item is removed from the internal cache + cache_key_str = cache.create_key(sample_cache_key) + assert cache_key_str not in cache._cache + + +def test_cache_disabled_behavior( + disabled_cache, sample_cache_key, sample_connection_info +): + """Test that disabled cache doesn't store or retrieve values.""" + # Try to set a value + disabled_cache.set(sample_cache_key, sample_connection_info) + + # Should return None even though we set a value + assert disabled_cache.get(sample_cache_key) is None + + # Internal cache should be empty + assert len(disabled_cache._cache) == 0 + + +def test_cache_clear( + cache, sample_cache_key, sample_connection_info, additional_cache_keys +): + """Test that cache.clear() removes all entries.""" + # Add some entries + cache.set(additional_cache_keys["key1"], sample_connection_info) + cache.set(additional_cache_keys["key2"], sample_connection_info) + + # Verify entries exist + assert cache.get(additional_cache_keys["key1"]) is not None + assert cache.get(additional_cache_keys["key2"]) is not None + + # Clear cache + cache.clear() + + # Verify all entries are gone + assert cache.get(additional_cache_keys["key1"]) is None + assert cache.get(additional_cache_keys["key2"]) is None + assert len(cache._cache) == 0 + + +def test_cache_delete(cache, sample_cache_key, sample_connection_info): + """Test manual deletion of cache entries.""" + cache.set(sample_cache_key, sample_connection_info) + assert cache.get(sample_cache_key) is not None + + cache.delete(sample_cache_key) + assert cache.get(sample_cache_key) is None + + +def test_cache_contains_operator(cache, sample_cache_key, sample_connection_info): + """Test the 'in' operator for cache.""" + cache_key_str = cache.create_key(sample_cache_key) + + # Initially not in cache + assert cache_key_str not in cache + + # Add to cache + cache.set(sample_cache_key, sample_connection_info) + assert cache_key_str in cache + + # Test with disabled cache + cache.disable() + assert cache_key_str not in cache # Should return False when disabled + + +def test_non_connection_info_objects(string_cache, sample_cache_key, test_string): + """Test that non-ConnectionInfo objects don't get expiry_time set.""" + string_cache.set(sample_cache_key, test_string) + + # Should retrieve the string without expiry logic + cached_value = string_cache.get(sample_cache_key) + assert cached_value == test_string + + +def test_expiry_time_none_handling( + cache, sample_cache_key, sample_connection_info_with_expiry, fixed_time +): + """Test handling of ConnectionInfo with expiry_time set to None.""" + with patch("time.time", return_value=fixed_time): + cache.set(sample_cache_key, sample_connection_info_with_expiry) + + # Should set expiry_time during set operation + cached_value = cache.get(sample_cache_key) + assert cached_value is not None + assert cached_value.expiry_time is not None + + +def test_secure_cache_key_creation(): + """Test SecureCacheKey creation and repr.""" + key = SecureCacheKey(["user", "db", "engine"], "secret_key") + assert key.key == "user#db#engine" + assert key.encryption_key == "secret_key" + assert repr(key) == "SecureCacheKey(user#db#engine)" + + +def test_secure_cache_key_equality(): + """Test SecureCacheKey equality comparison.""" + key1 = SecureCacheKey(["user", "db"], "secret1") + key2 = SecureCacheKey(["user", "db"], "secret2") + key3 = SecureCacheKey(["user", "other"], "secret1") + + assert key1 == key2 # Same key content, different encryption key + assert key1 != key3 # Different key content + assert key1 != "not_a_key" # Different type + + +def test_secure_cache_key_hash(): + """Test SecureCacheKey hash functionality.""" + key1 = SecureCacheKey(["user", "db"], "secret1") + key2 = SecureCacheKey(["user", "db"], "secret2") + + # Same key content should have same hash + assert hash(key1) == hash(key2) + + # Should be usable in sets and dicts + key_set = {key1, key2} + assert len(key_set) == 1 # Should be treated as same key + + +def test_secure_cache_key_with_none_elements(): + """Test SecureCacheKey handling of None elements.""" + key = SecureCacheKey(["user", None, "engine"], "secret") + assert key.key == "user#None#engine" + + +@mark.parametrize("cache_expiry_time_offset", [-60, 0, 3600]) +def test_cache_expiry_parametrized( + cache, + sample_cache_key, + sample_connection_info, + fixed_time, + cache_expiry_time_offset, +): + """Test cache expiry behavior with different time offsets.""" + with patch("time.time", return_value=fixed_time): + cache.set(sample_cache_key, sample_connection_info) + + # Test at different time offsets relative to expiry time + check_time = fixed_time + CACHE_EXPIRY_SECONDS + cache_expiry_time_offset + + with patch("time.time", return_value=check_time): + cached_value = cache.get(sample_cache_key) + + if cache_expiry_time_offset < 0: + # Before expiry - should be cached + assert cached_value is not None + else: + # At or after expiry - should be None + assert cached_value is None + + +def test_cache_expiry_multiple_entries(cache, additional_cache_keys, fixed_time): + """Test that expiry works correctly with multiple cache entries.""" + # Create separate ConnectionInfo objects to avoid shared state + connection_info_1 = ConnectionInfo(id="test_connection_1") + connection_info_2 = ConnectionInfo(id="test_connection_2") + + # Set multiple entries at different times + with patch("time.time", return_value=fixed_time): + cache.set(additional_cache_keys["key1"], connection_info_1) + + with patch("time.time", return_value=fixed_time + 1800): # 30 minutes later + cache.set(additional_cache_keys["key2"], connection_info_2) + + # Check expiry of first entry while second is still valid + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS): + assert cache.get(additional_cache_keys["key1"]) is None # Expired + assert cache.get(additional_cache_keys["key2"]) is not None # Still valid + + # Check both are expired after sufficient time + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS + 1800): + assert cache.get(additional_cache_keys["key1"]) is None + assert cache.get(additional_cache_keys["key2"]) is None + + +def test_cache_set_updates_expiry_time( + cache, sample_cache_key, sample_connection_info, fixed_time +): + """Test that setting a value again updates the expiry time.""" + # Set initial value + with patch("time.time", return_value=fixed_time): + cache.set(sample_cache_key, sample_connection_info) + initial_cached = cache.get(sample_cache_key) + initial_expiry = initial_cached.expiry_time + + # Set same key again later + with patch("time.time", return_value=fixed_time + 1800): # 30 minutes later + cache.set(sample_cache_key, sample_connection_info) + updated_cached = cache.get(sample_cache_key) + updated_expiry = updated_cached.expiry_time + + # Expiry time should be updated + assert updated_expiry > initial_expiry + assert updated_expiry == fixed_time + 1800 + CACHE_EXPIRY_SECONDS + + +@mark.parametrize("disable_cache_during_operation", [True, False]) +def test_cache_disable_enable_behavior( + cache, sample_cache_key, sample_connection_info, disable_cache_during_operation +): + """Test cache behavior when disabled and re-enabled.""" + # Set initial value + cache.set(sample_cache_key, sample_connection_info) + assert cache.get(sample_cache_key) is not None + + if disable_cache_during_operation: + # Disable cache - should return None even if value exists + cache.disable() + assert cache.get(sample_cache_key) is None + + # Re-enable cache - should work again + cache.enable() + assert cache.get(sample_cache_key) is not None + else: + # Keep cache enabled - should continue working + assert cache.get(sample_cache_key) is not None diff --git a/tests/unit/utils/test_usage_tracker.py b/tests/unit/utils/test_usage_tracker.py index 2949b49757e..363fd1ecf72 100644 --- a/tests/unit/utils/test_usage_tracker.py +++ b/tests/unit/utils/test_usage_tracker.py @@ -146,37 +146,62 @@ def test_detect_connectors(stack, map, expected): @mark.parametrize( - "drivers,clients,expected_string", + "drivers,clients,additional_parameters,expected_string", [ - ([], [], "PythonSDK/2 (Python 1; Win; ciso)"), + ([], [], None, "PythonSDK/2 (Python 1; Win; ciso)"), ( [("ConnectorA", "0.1.1")], [], + None, "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1", ), ( (("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")), (), + None, "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", ), ( [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], [], + None, "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", ), ( [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], [("ClientA", "1.0.1")], + None, "ClientA/1.0.1 PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", ), + ( + [], + [], + [("connId", "12345"), ("cachedConnId", "67890-memory")], + "PythonSDK/2 (Python 1; Win; ciso; connId:12345; cachedConnId:67890-memory)", + ), + ( + [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], + [("ClientA", "1.0.1")], + [("connId", "12345")], + "ClientA/1.0.1 PythonSDK/2 (Python 1; Win; ciso; connId:12345) ConnectorA/0.1.1 ConnectorB/0.2.0", + ), + ( + [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], + [("ClientA", "1.0.1")], + [("connId", "12345"), ("cachedConnId", "67890-memory")], + "ClientA/1.0.1 PythonSDK/2 (Python 1; Win; ciso; connId:12345; cachedConnId:67890-memory) ConnectorA/0.1.1 ConnectorB/0.2.0", + ), ], ) @patch( "firebolt.utils.usage_tracker.get_sdk_properties", MagicMock(return_value=("1", "2", "Win", "ciso")), ) -def test_user_agent(drivers, clients, expected_string): - assert get_user_agent_header(drivers, clients) == expected_string +def test_user_agent(drivers, clients, additional_parameters, expected_string): + assert ( + get_user_agent_header(drivers, clients, additional_parameters) + == expected_string + ) @mark.parametrize(