From 0247670e324e3a04b82b99dd9a5f12b154fd468a Mon Sep 17 00:00:00 2001 From: ptiurin Date: Thu, 31 Jul 2025 13:54:07 +0100 Subject: [PATCH 01/21] extend usage tracker --- src/firebolt/utils/usage_tracker.py | 16 ++++++++++--- tests/unit/utils/test_usage_tracker.py | 33 ++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/firebolt/utils/usage_tracker.py b/src/firebolt/utils/usage_tracker.py index d56a7c071a2..5eb656146ca 100644 --- a/src/firebolt/utils/usage_tracker.py +++ b/src/firebolt/utils/usage_tracker.py @@ -161,7 +161,11 @@ def detect_connectors( return connectors -def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> str: +def format_as_user_agent( + drivers: Dict[str, str], + clients: Dict[str, str], + additional_properties: List[Tuple[str, str]], +) -> str: """ Return a representation of a stored tracking data as a user-agent header. @@ -172,7 +176,12 @@ def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> st String of the current detected connector stack. """ py, sdk, os, ciso = get_sdk_properties() - sdk_format = f"PythonSDK/{sdk} (Python {py}; {os}; {ciso})" + formatted_properties = "; ".join( + [f"{key}:{value}" for key, value in (additional_properties)] + ) + if formatted_properties: + formatted_properties = f"; {formatted_properties}" + sdk_format = f"PythonSDK/{sdk} (Python {py}; {os}; {ciso}{formatted_properties})" driver_format = "".join( [f" {connector}/{version}" for connector, version in drivers.items()] ) @@ -185,6 +194,7 @@ def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> st def get_user_agent_header( user_drivers: Optional[List[Tuple[str, str]]] = None, user_clients: Optional[List[Tuple[str, str]]] = None, + additional_properties: Optional[List[Tuple[str, str]]] = None, ) -> str: """ Return a user agent header with connector stack and system information. @@ -213,4 +223,4 @@ def get_user_agent_header( clients[name] = version for name, version in versions.drivers: drivers[name] = version - return format_as_user_agent(drivers, clients) + return format_as_user_agent(drivers, clients, additional_properties or []) diff --git a/tests/unit/utils/test_usage_tracker.py b/tests/unit/utils/test_usage_tracker.py index 2949b49757e..363fd1ecf72 100644 --- a/tests/unit/utils/test_usage_tracker.py +++ b/tests/unit/utils/test_usage_tracker.py @@ -146,37 +146,62 @@ def test_detect_connectors(stack, map, expected): @mark.parametrize( - "drivers,clients,expected_string", + "drivers,clients,additional_parameters,expected_string", [ - ([], [], "PythonSDK/2 (Python 1; Win; ciso)"), + ([], [], None, "PythonSDK/2 (Python 1; Win; ciso)"), ( [("ConnectorA", "0.1.1")], [], + None, "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1", ), ( (("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")), (), + None, "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", ), ( [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], [], + None, "PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", ), ( [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], [("ClientA", "1.0.1")], + None, "ClientA/1.0.1 PythonSDK/2 (Python 1; Win; ciso) ConnectorA/0.1.1 ConnectorB/0.2.0", ), + ( + [], + [], + [("connId", "12345"), ("cachedConnId", "67890-memory")], + "PythonSDK/2 (Python 1; Win; ciso; connId:12345; cachedConnId:67890-memory)", + ), + ( + [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], + [("ClientA", "1.0.1")], + [("connId", "12345")], + "ClientA/1.0.1 PythonSDK/2 (Python 1; Win; ciso; connId:12345) ConnectorA/0.1.1 ConnectorB/0.2.0", + ), + ( + [("ConnectorA", "0.1.1"), ("ConnectorB", "0.2.0")], + [("ClientA", "1.0.1")], + [("connId", "12345"), ("cachedConnId", "67890-memory")], + "ClientA/1.0.1 PythonSDK/2 (Python 1; Win; ciso; connId:12345; cachedConnId:67890-memory) ConnectorA/0.1.1 ConnectorB/0.2.0", + ), ], ) @patch( "firebolt.utils.usage_tracker.get_sdk_properties", MagicMock(return_value=("1", "2", "Win", "ciso")), ) -def test_user_agent(drivers, clients, expected_string): - assert get_user_agent_header(drivers, clients) == expected_string +def test_user_agent(drivers, clients, additional_parameters, expected_string): + assert ( + get_user_agent_header(drivers, clients, additional_parameters) + == expected_string + ) @mark.parametrize( From 1c2fc5a13f32289d0c9ec0c0dc11b2adce974261 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Thu, 31 Jul 2025 14:18:01 +0100 Subject: [PATCH 02/21] WIP: caching extension --- src/firebolt/async_db/connection.py | 4 +- src/firebolt/async_db/util.py | 8 +- src/firebolt/common/cache.py | 110 ++++++++++++++++++++++++- src/firebolt/db/connection.py | 27 ++++-- src/firebolt/db/util.py | 8 +- tests/unit/async_db/test_connection.py | 4 +- tests/unit/conftest.py | 4 +- tests/unit/db/test_connection.py | 4 +- 8 files changed, 145 insertions(+), 24 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 140f2b88914..da7a4d8cf0a 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -21,7 +21,7 @@ BaseConnection, _parse_async_query_info_results, ) -from firebolt.common.cache import _firebolt_system_engine_cache +from firebolt.common.cache import _firebolt_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.utils.exception import ( ConfigurationError, @@ -231,7 +231,7 @@ async def connect( user_clients = additional_parameters.get("user_clients", []) user_agent_header = get_user_agent_header(user_drivers, user_clients) if disable_cache: - _firebolt_system_engine_cache.disable() + _firebolt_cache.disable() # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials # Use V1 if auth is ServiceAccount or UsernamePassword diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index fb39c85a89d..8fe104e65e5 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -6,7 +6,7 @@ from firebolt.client.auth import Auth from firebolt.client.client import AsyncClientV2 -from firebolt.common.cache import _firebolt_system_engine_cache +from firebolt.common.cache import _firebolt_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, @@ -21,7 +21,7 @@ async def _get_system_engine_url_and_params( account_name: str, api_endpoint: str, ) -> Tuple[str, Dict[str, str]]: - if result := _firebolt_system_engine_cache.get([account_name, api_endpoint]): + if result := _firebolt_cache.get_system_engine_url([account_name, api_endpoint]): return result async with AsyncClientV2( auth=auth, @@ -40,7 +40,7 @@ async def _get_system_engine_url_and_params( f"{response.status_code} {response.content.decode()}" ) result = parse_url_and_params(response.json()["engineUrl"]) - _firebolt_system_engine_cache.set( - key=[account_name, api_endpoint], value=result + _firebolt_cache.set_system_engine_url( + key=[account_name, api_endpoint], url=result ) return result diff --git a/src/firebolt/common/cache.py b/src/firebolt/common/cache.py index 66a13df0ede..4c3ed219b62 100644 --- a/src/firebolt/common/cache.py +++ b/src/firebolt/common/cache.py @@ -1,4 +1,5 @@ import os +from dataclasses import dataclass, field from typing import ( Any, Callable, @@ -27,6 +28,26 @@ def wrapper(self: "UtilCache", *args: Any, **kwargs: Any) -> Any: return wrapper +@dataclass +class EngineInfo: + """Class to hold engine information for caching.""" + url: str + params: Dict[str, str] + +@dataclass +class DatabaseInfo: + """Class to hold database information for caching.""" + name: str + +@dataclass +class ConnectionInfo: + """Class to hold connection information for caching.""" + id: Optional[str] = None + expiry_time: Optional[int] = None + system_engine_url: Optional[str] = None + databases: Dict[str, DatabaseInfo] = field(default_factory=dict) + engines: Dict[str, EngineInfo] = field(default_factory=dict) + class UtilCache(Generic[T]): """ @@ -80,6 +101,89 @@ def __contains__(self, key: str) -> bool: return key in self._cache -_firebolt_system_engine_cache = UtilCache[Tuple[str, Dict[str, str]]]( - cache_name="system_engine" -) +class ConnectionInfoCache: + """ + A wrapper around UtilCache to provide granular access to ConnectionInfo. + """ + + def __init__(self, cache_name: str = "") -> None: + self._cache = UtilCache[ConnectionInfo](cache_name) + + def get(self, key: ReprCacheable) -> Optional[ConnectionInfo]: + return self._cache.get(key) + + def set(self, key: ReprCacheable, value: ConnectionInfo) -> None: + self._cache.set(key, value) + + def delete(self, key: ReprCacheable) -> None: + self._cache.delete(key) + + def clear(self) -> None: + self._cache.clear() + + def disable(self) -> None: + self._cache.disable() + + def enable(self) -> None: + self._cache.enable() + + def set_id(self, key: ReprCacheable, id: str) -> None: + conn_info = self.get(key) or ConnectionInfo(id=id) + conn_info.id = id + self.set(key, conn_info) + + def get_id(self, key: ReprCacheable) -> Optional[str]: + conn_info = self.get(key) + return conn_info.id if conn_info else None + + def get_system_engine_url(self, key: ReprCacheable) -> Optional[str]: + conn_info = self.get(key) + return conn_info.system_engine_url if conn_info else None + + def set_system_engine_url(self, key: ReprCacheable, url: str) -> None: + conn_info = self.get(key) or ConnectionInfo() + conn_info.system_engine_url = url + self.set(key, conn_info) + + def get_expiry_time(self, key: ReprCacheable) -> Optional[int]: + conn_info = self.get(key) + return conn_info.expiry_time if conn_info else None + + def get_engines(self, key: ReprCacheable) -> Optional[Dict[str, EngineInfo]]: + conn_info = self.get(key) + return conn_info.engines if conn_info else None + + def get_engine_by_name( + self, key: ReprCacheable, engine_name: str + ) -> Optional[EngineInfo]: + engines = self.get_engines(key) + return engines.get(engine_name) if engines else None + + def add_engine(self, key: ReprCacheable, engine_name: str, engine: EngineInfo) -> None: + conn_info = self.get(key) or ConnectionInfo() + conn_info.engines[engine_name] = engine + self.set(key, conn_info) + + def get_databases(self, key: ReprCacheable) -> Optional[Dict[str, DatabaseInfo]]: + conn_info = self.get(key) + return conn_info.databases if conn_info else None + + def get_database_by_name( + self, key: ReprCacheable, db_name: str + ) -> Optional[DatabaseInfo]: + databases = self.get_databases(key) + return databases.get(db_name) if databases else None + + def add_database( + self, key: ReprCacheable, db_name: str, database: DatabaseInfo + ) -> None: + conn_info = self.get(key) or ConnectionInfo() + conn_info.databases[db_name] = database + self.set(key, conn_info) + + +# _firebolt_system_engine_cache = UtilCache[Tuple[str, Dict[str, str]]]( +# cache_name="system_engine" +# ) + +_firebolt_cache = ConnectionInfoCache(cache_name="connection_info") \ No newline at end of file diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index d678cacd2fa..7011d4a1137 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -3,7 +3,8 @@ import logging from ssl import SSLContext from types import TracebackType -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union +from uuid import uuid4 from warnings import warn from httpx import Timeout @@ -20,7 +21,7 @@ BaseConnection, _parse_async_query_info_results, ) -from firebolt.common.cache import _firebolt_system_engine_cache +from firebolt.common.cache import _firebolt_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.db.cursor import Cursor, CursorV1, CursorV2 from firebolt.db.util import _get_system_engine_url_and_params @@ -40,6 +41,19 @@ logger = logging.getLogger(__name__) +def prepare_ua_parameters(account_name: Optional[str], api_endpoint: str) -> List[Tuple[str, str]]: + ua_parameters = [] + + cached_id = _firebolt_cache.get_id([account_name, api_endpoint]) + conn_uuid = uuid4().hex + ua_parameters.append(("connId", conn_uuid)) + if cached_id: + ua_parameters.append(("cachedConnId", cached_id + "-memory")) + _firebolt_cache.set_id([account_name, api_endpoint], conn_uuid) + + return ua_parameters + + def connect( auth: Optional[Auth] = None, account_name: Optional[str] = None, @@ -61,10 +75,13 @@ def connect( assert auth is not None user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) - user_agent_header = get_user_agent_header(user_drivers, user_clients) - auth_version = auth.get_firebolt_version() + ua_parameters = [] if disable_cache: - _firebolt_system_engine_cache.disable() + _firebolt_cache.disable() + else: + ua_parameters = prepare_ua_parameters(account_name, api_endpoint) + user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) + auth_version = auth.get_firebolt_version() # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials # Use V1 if auth is ServiceAccount or UsernamePassword diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index 8cc12be6b2d..5fc6a2b8152 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -6,7 +6,7 @@ from firebolt.client import ClientV2 from firebolt.client.auth import Auth -from firebolt.common.cache import _firebolt_system_engine_cache +from firebolt.common.cache import _firebolt_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, @@ -21,7 +21,7 @@ def _get_system_engine_url_and_params( account_name: str, api_endpoint: str, ) -> Tuple[str, Dict[str, str]]: - if result := _firebolt_system_engine_cache.get([account_name, api_endpoint]): + if result := _firebolt_cache.get_system_engine_url([account_name, api_endpoint]): return result with ClientV2( auth=auth, @@ -40,7 +40,7 @@ def _get_system_engine_url_and_params( f"{response.status_code} {response.content.decode()}" ) result = parse_url_and_params(response.json()["engineUrl"]) - _firebolt_system_engine_cache.set( - key=[account_name, api_endpoint], value=result + _firebolt_cache.set_system_engine_url( + key=[account_name, api_endpoint], url=result ) return result diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index f1d357672b3..7bb432d21ff 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -8,7 +8,7 @@ from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType -from firebolt.common.cache import _firebolt_system_engine_cache +from firebolt.common.cache import _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, ConfigurationError, @@ -244,7 +244,7 @@ def system_engine_callback_counter(request, **kwargs): assert system_engine_call_counter != 1, "System engine URL was cached" # Reset caches for the next test iteration - _firebolt_system_engine_cache.enable() + _firebolt_cache.enable() async def test_connect_system_engine_404( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7d77ca906a9..c5d69c6de67 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,7 +8,7 @@ from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.client import ClientV2 -from firebolt.common.cache import _firebolt_system_engine_cache +from firebolt.common.cache import _firebolt_cache from firebolt.common.settings import Settings from firebolt.utils.exception import ( DatabaseError, @@ -43,7 +43,7 @@ def global_fake_fs(request) -> None: @fixture(autouse=True) def clear_cache() -> None: - _firebolt_system_engine_cache.clear() + _firebolt_cache.clear() @fixture diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index c6e94ea1231..c934984bd5c 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -10,7 +10,7 @@ from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.client import ClientV2 from firebolt.common._types import ColType -from firebolt.common.cache import _firebolt_system_engine_cache +from firebolt.common.cache import _firebolt_cache from firebolt.db import Connection, connect from firebolt.db.cursor import CursorV2 from firebolt.utils.exception import ( @@ -248,7 +248,7 @@ def system_engine_callback_counter(request, **kwargs): assert system_engine_call_counter != 1, "System engine URL was cached" # Reset caches for the next test iteration - _firebolt_system_engine_cache.enable() + _firebolt_cache.enable() def test_connect_system_engine_404( From f7352e2a97726cd90b78d1196a63782a66291965 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Thu, 31 Jul 2025 15:18:05 +0100 Subject: [PATCH 03/21] cache controller --- src/firebolt/async_db/connection.py | 6 +- src/firebolt/async_db/util.py | 12 +-- src/firebolt/common/cache.py | 149 ++++++++-------------------- src/firebolt/db/connection.py | 18 ++-- src/firebolt/db/util.py | 12 +-- src/firebolt/utils/util.py | 42 +++++--- tests/unit/db/test_connection.py | 5 +- tests/unit/utils/test_utils.py | 6 +- 8 files changed, 98 insertions(+), 152 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index da7a4d8cf0a..443a23f2952 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -303,7 +303,7 @@ async def connect_v2( api_endpoint = fix_url_schema(api_endpoint) - system_engine_url, system_engine_params = await _get_system_engine_url_and_params( + system_engine_info = await _get_system_engine_url_and_params( auth, account_name, api_endpoint ) @@ -316,12 +316,12 @@ async def connect_v2( ) async with Connection( - system_engine_url, + system_engine_info.url, None, client, CursorV2, api_endpoint, - system_engine_params, + system_engine_info.params, ) as system_engine_connection: cursor = system_engine_connection.cursor() diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index 8fe104e65e5..f51c20fee86 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Dict, Tuple - from httpx import Timeout, codes from firebolt.client.auth import Auth @@ -13,15 +11,15 @@ InterfaceError, ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import parse_url_and_params +from firebolt.utils.util import EngineInfo, parse_url_and_params async def _get_system_engine_url_and_params( auth: Auth, account_name: str, api_endpoint: str, -) -> Tuple[str, Dict[str, str]]: - if result := _firebolt_cache.get_system_engine_url([account_name, api_endpoint]): +) -> EngineInfo: + if result := _firebolt_cache.system_engine_cache.get([account_name, api_endpoint]): return result async with AsyncClientV2( auth=auth, @@ -40,7 +38,7 @@ async def _get_system_engine_url_and_params( f"{response.status_code} {response.content.decode()}" ) result = parse_url_and_params(response.json()["engineUrl"]) - _firebolt_cache.set_system_engine_url( - key=[account_name, api_endpoint], url=result + _firebolt_cache.system_engine_cache.set( + key=[account_name, api_endpoint], value=result ) return result diff --git a/src/firebolt/common/cache.py b/src/firebolt/common/cache.py index 4c3ed219b62..08cb455e9f1 100644 --- a/src/firebolt/common/cache.py +++ b/src/firebolt/common/cache.py @@ -1,15 +1,7 @@ import os -from dataclasses import dataclass, field -from typing import ( - Any, - Callable, - Dict, - Generic, - Optional, - Protocol, - Tuple, - TypeVar, -) +from typing import Any, Callable, Dict, Generic, Optional, Protocol, TypeVar + +from firebolt.utils.util import DatabaseInfo, EngineInfo T = TypeVar("T") @@ -28,26 +20,6 @@ def wrapper(self: "UtilCache", *args: Any, **kwargs: Any) -> Any: return wrapper -@dataclass -class EngineInfo: - """Class to hold engine information for caching.""" - url: str - params: Dict[str, str] - -@dataclass -class DatabaseInfo: - """Class to hold database information for caching.""" - name: str - -@dataclass -class ConnectionInfo: - """Class to hold connection information for caching.""" - id: Optional[str] = None - expiry_time: Optional[int] = None - system_engine_url: Optional[str] = None - databases: Dict[str, DatabaseInfo] = field(default_factory=dict) - engines: Dict[str, EngineInfo] = field(default_factory=dict) - class UtilCache(Generic[T]): """ @@ -101,89 +73,46 @@ def __contains__(self, key: str) -> bool: return key in self._cache -class ConnectionInfoCache: - """ - A wrapper around UtilCache to provide granular access to ConnectionInfo. - """ +class CacheController: + def __init__(self) -> None: + self._engine_cache = UtilCache[EngineInfo](cache_name="engine_info") - def __init__(self, cache_name: str = "") -> None: - self._cache = UtilCache[ConnectionInfo](cache_name) + self._system_engine_cache = UtilCache[EngineInfo](cache_name="system_engine") - def get(self, key: ReprCacheable) -> Optional[ConnectionInfo]: - return self._cache.get(key) + self._database_cache = UtilCache[DatabaseInfo](cache_name="database_info") - def set(self, key: ReprCacheable, value: ConnectionInfo) -> None: - self._cache.set(key, value) + @property + def engine_cache(self) -> UtilCache[EngineInfo]: + """Get the engine cache.""" + return self._engine_cache - def delete(self, key: ReprCacheable) -> None: - self._cache.delete(key) + @property + def database_cache(self) -> UtilCache[DatabaseInfo]: + """Get the database cache.""" + return self._database_cache - def clear(self) -> None: - self._cache.clear() + @property + def system_engine_cache(self) -> UtilCache[EngineInfo]: + """Get the system engine cache.""" + return self._system_engine_cache - def disable(self) -> None: - self._cache.disable() - def enable(self) -> None: - self._cache.enable() - - def set_id(self, key: ReprCacheable, id: str) -> None: - conn_info = self.get(key) or ConnectionInfo(id=id) - conn_info.id = id - self.set(key, conn_info) - - def get_id(self, key: ReprCacheable) -> Optional[str]: - conn_info = self.get(key) - return conn_info.id if conn_info else None - - def get_system_engine_url(self, key: ReprCacheable) -> Optional[str]: - conn_info = self.get(key) - return conn_info.system_engine_url if conn_info else None - - def set_system_engine_url(self, key: ReprCacheable, url: str) -> None: - conn_info = self.get(key) or ConnectionInfo() - conn_info.system_engine_url = url - self.set(key, conn_info) - - def get_expiry_time(self, key: ReprCacheable) -> Optional[int]: - conn_info = self.get(key) - return conn_info.expiry_time if conn_info else None - - def get_engines(self, key: ReprCacheable) -> Optional[Dict[str, EngineInfo]]: - conn_info = self.get(key) - return conn_info.engines if conn_info else None - - def get_engine_by_name( - self, key: ReprCacheable, engine_name: str - ) -> Optional[EngineInfo]: - engines = self.get_engines(key) - return engines.get(engine_name) if engines else None - - def add_engine(self, key: ReprCacheable, engine_name: str, engine: EngineInfo) -> None: - conn_info = self.get(key) or ConnectionInfo() - conn_info.engines[engine_name] = engine - self.set(key, conn_info) - - def get_databases(self, key: ReprCacheable) -> Optional[Dict[str, DatabaseInfo]]: - conn_info = self.get(key) - return conn_info.databases if conn_info else None - - def get_database_by_name( - self, key: ReprCacheable, db_name: str - ) -> Optional[DatabaseInfo]: - databases = self.get_databases(key) - return databases.get(db_name) if databases else None - - def add_database( - self, key: ReprCacheable, db_name: str, database: DatabaseInfo - ) -> None: - conn_info = self.get(key) or ConnectionInfo() - conn_info.databases[db_name] = database - self.set(key, conn_info) - - -# _firebolt_system_engine_cache = UtilCache[Tuple[str, Dict[str, str]]]( -# cache_name="system_engine" -# ) - -_firebolt_cache = ConnectionInfoCache(cache_name="connection_info") \ No newline at end of file + """Enable the cache.""" + self._engine_cache.enable() + self._system_engine_cache.enable() + self._database_cache.enable() + + def disable(self) -> None: + """Disable the cache.""" + self._engine_cache.disable() + self._system_engine_cache.disable() + self._database_cache.disable() + + def clear(self) -> None: + """Clear all caches.""" + self._engine_cache.clear() + self._system_engine_cache.clear() + self._database_cache.clear() + + +_firebolt_cache = CacheController() diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 7011d4a1137..f847c7dff86 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -41,15 +41,17 @@ logger = logging.getLogger(__name__) -def prepare_ua_parameters(account_name: Optional[str], api_endpoint: str) -> List[Tuple[str, str]]: +def prepare_ua_parameters( + account_name: Optional[str], api_endpoint: str +) -> List[Tuple[str, str]]: ua_parameters = [] - cached_id = _firebolt_cache.get_id([account_name, api_endpoint]) + # cached_id = _firebolt_cache.get_id([account_name, api_endpoint]) conn_uuid = uuid4().hex ua_parameters.append(("connId", conn_uuid)) - if cached_id: - ua_parameters.append(("cachedConnId", cached_id + "-memory")) - _firebolt_cache.set_id([account_name, api_endpoint], conn_uuid) + # if cached_id: + # ua_parameters.append(("cachedConnId", cached_id + "-memory")) + # _firebolt_cache.set_id([account_name, api_endpoint], conn_uuid) return ua_parameters @@ -153,7 +155,7 @@ def connect_v2( api_endpoint = fix_url_schema(api_endpoint) - system_engine_url, system_engine_params = _get_system_engine_url_and_params( + system_engine_info = _get_system_engine_url_and_params( auth, account_name, api_endpoint ) @@ -166,12 +168,12 @@ def connect_v2( ) with Connection( - system_engine_url, + system_engine_info.url, None, client, CursorV2, api_endpoint, - system_engine_params, + system_engine_info.params, ) as system_engine_connection: cursor = system_engine_connection.cursor() diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index 5fc6a2b8152..dee2ef1b5dd 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Dict, Tuple - from httpx import Timeout, codes from firebolt.client import ClientV2 @@ -13,15 +11,15 @@ InterfaceError, ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import parse_url_and_params +from firebolt.utils.util import EngineInfo, parse_url_and_params def _get_system_engine_url_and_params( auth: Auth, account_name: str, api_endpoint: str, -) -> Tuple[str, Dict[str, str]]: - if result := _firebolt_cache.get_system_engine_url([account_name, api_endpoint]): +) -> EngineInfo: + if result := _firebolt_cache.system_engine_cache.get([account_name, api_endpoint]): return result with ClientV2( auth=auth, @@ -40,7 +38,7 @@ def _get_system_engine_url_and_params( f"{response.status_code} {response.content.decode()}" ) result = parse_url_and_params(response.json()["engineUrl"]) - _firebolt_cache.set_system_engine_url( - key=[account_name, api_endpoint], url=result + _firebolt_cache.system_engine_cache.set( + key=[account_name, api_endpoint], value=result ) return result diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index d296055f0d1..e55d6095f8e 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -1,18 +1,10 @@ import logging +from dataclasses import dataclass, field from functools import lru_cache from os import environ from time import time from types import TracebackType -from typing import ( - TYPE_CHECKING, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - TypeVar, -) +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, TypeVar from urllib.parse import parse_qs, urljoin, urlparse from httpx import URL, Response, codes @@ -26,6 +18,32 @@ logger = logging.getLogger(__name__) +@dataclass +class EngineInfo: + """Class to hold engine information for caching.""" + + url: str + params: Dict[str, str] + + +@dataclass +class DatabaseInfo: + """Class to hold database information for caching.""" + + name: str + + +@dataclass +class ConnectionInfo: + """Class to hold connection information for caching.""" + + id: Optional[str] = None + expiry_time: Optional[int] = None + system_engine_url: Optional[str] = None + databases: Dict[str, DatabaseInfo] = field(default_factory=dict) + engines: Dict[str, EngineInfo] = field(default_factory=dict) + + def cached_property(func: Callable[..., T]) -> T: """cached_property implementation for 3.7 backward compatibility. @@ -212,7 +230,7 @@ def __exit__( logger.debug(log_message) -def parse_url_and_params(url: str) -> Tuple[str, Dict[str, str]]: +def parse_url_and_params(url: str) -> EngineInfo: """Extract URL and query parameters separately from a URL.""" url = fix_url_schema(url) parsed_url = urlparse(url) @@ -228,7 +246,7 @@ def parse_url_and_params(url: str) -> Tuple[str, Dict[str, str]]: if len(values) > 1: raise ValueError(f"Multiple values found for key '{key}'") query_params_dict[key] = values[0] - return result_url, query_params_dict + return EngineInfo(url=result_url, params=query_params_dict) class _ExceptionGroup(Exception): diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index c934984bd5c..d67f51abfa1 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -1,6 +1,7 @@ import gc import warnings from typing import Callable, List, Optional, Tuple +from unittest.mock import ANY as AnyValue from unittest.mock import patch from pyfakefs.fake_filesystem_unittest import Patcher @@ -393,7 +394,7 @@ def test_connect_with_user_agent( }, ) as connection: connection.cursor().execute("select*") - ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) + ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")], AnyValue) def test_connect_no_user_agent( @@ -423,7 +424,7 @@ def test_connect_no_user_agent( api_endpoint=api_endpoint, ) as connection: connection.cursor().execute("select*") - ut.assert_called_with([], []) + ut.assert_called_with([], [], AnyValue) @mark.parametrize( diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 5271a7ae047..08101988f76 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -50,9 +50,9 @@ def test_get_internal_error_code(status_code, content, expected_error_code): ], ) def test_parse_url_and_params(url, expected_url, expected_params): - parsed_url, parsed_params = parse_url_and_params(url) - assert parsed_url == expected_url - assert parsed_params == expected_params + parsed_info = parse_url_and_params(url) + assert parsed_info.url == expected_url + assert parsed_info.params == expected_params @pytest.mark.parametrize( From 57cd62f6d182bb8e42e0d5acf739e9b4c2409dfb Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 1 Aug 2025 15:07:25 +0100 Subject: [PATCH 04/21] simplified cache --- src/firebolt/async_db/connection.py | 44 ++++++++++++++++++---- src/firebolt/async_db/util.py | 16 +++++--- src/firebolt/common/cache.py | 46 +---------------------- src/firebolt/db/connection.py | 19 +++++++--- src/firebolt/db/util.py | 16 +++++--- src/firebolt/utils/util.py | 2 +- tests/unit/V1/async_db/test_connection.py | 7 +++- tests/unit/async_db/test_connection.py | 5 ++- 8 files changed, 84 insertions(+), 71 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 443a23f2952..faed35b560c 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -2,7 +2,8 @@ from ssl import SSLContext from types import TracebackType -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union +from uuid import uuid4 from httpx import Timeout @@ -34,7 +35,34 @@ validate_firebolt_core_parameters, ) from firebolt.utils.usage_tracker import get_user_agent_header -from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 +from firebolt.utils.util import ( + ConnectionInfo, + fix_url_schema, + validate_engine_name_and_url_v1, +) + + +def prepare_ua_parameters( + account_name: Optional[str], api_endpoint: str +) -> List[Tuple[str, str]]: + ua_parameters = [] + + # cached_id = _firebolt_cache.get_id([account_name, api_endpoint]) + conn_uuid = uuid4().hex + ua_parameters.append(("connId", conn_uuid)) + prepare_cache_if_needed(account_name, api_endpoint, conn_uuid) + # if cached_id: + # ua_parameters.append(("cachedConnId", cached_id + "-memory")) + # _firebolt_cache.set_id([account_name, api_endpoint], conn_uuid) + + return ua_parameters + + +def prepare_cache_if_needed( + account_name: Optional[str], api_endpoint: str, conn_id: str +) -> None: + if not _firebolt_cache.get([account_name, api_endpoint]): + _firebolt_cache.set([account_name, api_endpoint], ConnectionInfo(conn_id)) class Connection(BaseConnection): @@ -225,11 +253,17 @@ async def connect( if not auth: raise ConfigurationError("auth is required to connect.") + api_endpoint = fix_url_schema(api_endpoint) # Type checks assert auth is not None user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) - user_agent_header = get_user_agent_header(user_drivers, user_clients) + ua_parameters = [] + if disable_cache: + _firebolt_cache.disable() + else: + ua_parameters = prepare_ua_parameters(account_name, api_endpoint) + user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) if disable_cache: _firebolt_cache.disable() # Use CORE if auth is FireboltCore @@ -301,8 +335,6 @@ async def connect_v2( assert auth is not None assert account_name is not None - api_endpoint = fix_url_schema(api_endpoint) - system_engine_info = await _get_system_engine_url_and_params( auth, account_name, api_endpoint ) @@ -358,8 +390,6 @@ async def connect_v1( validate_engine_name_and_url_v1(engine_name, engine_url) - api_endpoint = fix_url_schema(api_endpoint) - no_engine_client = AsyncClientV1( auth=auth, base_url=api_endpoint, diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index f51c20fee86..4e6543266e4 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -11,7 +11,11 @@ InterfaceError, ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import EngineInfo, parse_url_and_params +from firebolt.utils.util import ( + ConnectionInfo, + EngineInfo, + parse_url_and_params, +) async def _get_system_engine_url_and_params( @@ -19,7 +23,8 @@ async def _get_system_engine_url_and_params( account_name: str, api_endpoint: str, ) -> EngineInfo: - if result := _firebolt_cache.system_engine_cache.get([account_name, api_endpoint]): + cache = _firebolt_cache.get([account_name, api_endpoint]) + if cache and (result := cache.system_engine): return result async with AsyncClientV2( auth=auth, @@ -38,7 +43,8 @@ async def _get_system_engine_url_and_params( f"{response.status_code} {response.content.decode()}" ) result = parse_url_and_params(response.json()["engineUrl"]) - _firebolt_cache.system_engine_cache.set( - key=[account_name, api_endpoint], value=result - ) + if not cache: + cache = ConnectionInfo() + cache.system_engine = result + _firebolt_cache.set([account_name, api_endpoint], cache) return result diff --git a/src/firebolt/common/cache.py b/src/firebolt/common/cache.py index 08cb455e9f1..8afcb9d9cab 100644 --- a/src/firebolt/common/cache.py +++ b/src/firebolt/common/cache.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, Generic, Optional, Protocol, TypeVar -from firebolt.utils.util import DatabaseInfo, EngineInfo +from firebolt.utils.util import ConnectionInfo T = TypeVar("T") @@ -73,46 +73,4 @@ def __contains__(self, key: str) -> bool: return key in self._cache -class CacheController: - def __init__(self) -> None: - self._engine_cache = UtilCache[EngineInfo](cache_name="engine_info") - - self._system_engine_cache = UtilCache[EngineInfo](cache_name="system_engine") - - self._database_cache = UtilCache[DatabaseInfo](cache_name="database_info") - - @property - def engine_cache(self) -> UtilCache[EngineInfo]: - """Get the engine cache.""" - return self._engine_cache - - @property - def database_cache(self) -> UtilCache[DatabaseInfo]: - """Get the database cache.""" - return self._database_cache - - @property - def system_engine_cache(self) -> UtilCache[EngineInfo]: - """Get the system engine cache.""" - return self._system_engine_cache - - def enable(self) -> None: - """Enable the cache.""" - self._engine_cache.enable() - self._system_engine_cache.enable() - self._database_cache.enable() - - def disable(self) -> None: - """Disable the cache.""" - self._engine_cache.disable() - self._system_engine_cache.disable() - self._database_cache.disable() - - def clear(self) -> None: - """Clear all caches.""" - self._engine_cache.clear() - self._system_engine_cache.clear() - self._database_cache.clear() - - -_firebolt_cache = CacheController() +_firebolt_cache = UtilCache[ConnectionInfo](cache_name="connection_info") diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index f847c7dff86..f6a909d80bd 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -36,7 +36,11 @@ validate_firebolt_core_parameters, ) from firebolt.utils.usage_tracker import get_user_agent_header -from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 +from firebolt.utils.util import ( + ConnectionInfo, + fix_url_schema, + validate_engine_name_and_url_v1, +) logger = logging.getLogger(__name__) @@ -49,6 +53,7 @@ def prepare_ua_parameters( # cached_id = _firebolt_cache.get_id([account_name, api_endpoint]) conn_uuid = uuid4().hex ua_parameters.append(("connId", conn_uuid)) + prepare_cache_if_needed(account_name, api_endpoint, conn_uuid) # if cached_id: # ua_parameters.append(("cachedConnId", cached_id + "-memory")) # _firebolt_cache.set_id([account_name, api_endpoint], conn_uuid) @@ -56,6 +61,13 @@ def prepare_ua_parameters( return ua_parameters +def prepare_cache_if_needed( + account_name: Optional[str], api_endpoint: str, conn_id: str +) -> None: + if not _firebolt_cache.get([account_name, api_endpoint]): + _firebolt_cache.set([account_name, api_endpoint], ConnectionInfo(conn_id)) + + def connect( auth: Optional[Auth] = None, account_name: Optional[str] = None, @@ -73,6 +85,7 @@ def connect( if not auth: raise ConfigurationError("auth is required to connect.") + api_endpoint = fix_url_schema(api_endpoint) # Type checks assert auth is not None user_drivers = additional_parameters.get("user_drivers", []) @@ -153,8 +166,6 @@ def connect_v2( assert auth is not None assert account_name is not None - api_endpoint = fix_url_schema(api_endpoint) - system_engine_info = _get_system_engine_url_and_params( auth, account_name, api_endpoint ) @@ -381,8 +392,6 @@ def connect_v1( validate_engine_name_and_url_v1(engine_name, engine_url) - api_endpoint = fix_url_schema(api_endpoint) - # Override tcp keepalive settings for connection no_engine_client = ClientV1( auth=auth, diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index dee2ef1b5dd..63384fd34b8 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -11,7 +11,11 @@ InterfaceError, ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import EngineInfo, parse_url_and_params +from firebolt.utils.util import ( + ConnectionInfo, + EngineInfo, + parse_url_and_params, +) def _get_system_engine_url_and_params( @@ -19,7 +23,8 @@ def _get_system_engine_url_and_params( account_name: str, api_endpoint: str, ) -> EngineInfo: - if result := _firebolt_cache.system_engine_cache.get([account_name, api_endpoint]): + cache = _firebolt_cache.get([account_name, api_endpoint]) + if cache and (result := cache.system_engine): return result with ClientV2( auth=auth, @@ -38,7 +43,8 @@ def _get_system_engine_url_and_params( f"{response.status_code} {response.content.decode()}" ) result = parse_url_and_params(response.json()["engineUrl"]) - _firebolt_cache.system_engine_cache.set( - key=[account_name, api_endpoint], value=result - ) + if not cache: + cache = ConnectionInfo() + cache.system_engine = result + _firebolt_cache.set([account_name, api_endpoint], cache) return result diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index e55d6095f8e..f2017e5fc31 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -39,7 +39,7 @@ class ConnectionInfo: id: Optional[str] = None expiry_time: Optional[int] = None - system_engine_url: Optional[str] = None + system_engine: Optional[EngineInfo] = None databases: Dict[str, DatabaseInfo] = field(default_factory=dict) engines: Dict[str, EngineInfo] = field(default_factory=dict) diff --git a/tests/unit/V1/async_db/test_connection.py b/tests/unit/V1/async_db/test_connection.py index 4e1bd4da938..d74bf59ee46 100644 --- a/tests/unit/V1/async_db/test_connection.py +++ b/tests/unit/V1/async_db/test_connection.py @@ -1,6 +1,7 @@ from asyncio import run from re import Pattern from typing import Callable, List +from unittest.mock import ANY as AnyValue from unittest.mock import patch from httpx import codes @@ -420,7 +421,9 @@ async def test_connect_with_user_agent( }, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) + ut.assert_called_once_with( + [("DriverA", "1.1")], [("MyConnector", "1.0")], AnyValue + ) async def test_connect_no_user_agent( @@ -445,7 +448,7 @@ async def test_connect_no_user_agent( api_endpoint=api_endpoint, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([], []) + ut.assert_called_once_with([], [], AnyValue) def test_from_asyncio( diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 7bb432d21ff..9e1390ebbe4 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -1,4 +1,5 @@ from typing import Callable, List, Optional, Tuple +from unittest.mock import ANY as AnyValue from unittest.mock import patch from pyfakefs.fake_filesystem_unittest import Patcher @@ -375,7 +376,7 @@ async def test_connect_with_user_agent( }, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) + ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")], AnyValue) async def test_connect_no_user_agent( @@ -405,7 +406,7 @@ async def test_connect_no_user_agent( api_endpoint=api_endpoint, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_with([], []) + ut.assert_called_with([], [], AnyValue) @mark.parametrize( From c5baa760cc850af8777c8d8a521b246d907ff3c2 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 1 Aug 2025 17:37:15 +0100 Subject: [PATCH 05/21] poc on async --- src/firebolt/async_db/connection.py | 27 +- .../dbapi/async/V2/test_queries_async.py | 74 ++++ tests/unit/async_db/test_caching.py | 335 ++++++++++++++++++ tests/unit/async_db/test_connection.py | 63 ---- 4 files changed, 434 insertions(+), 65 deletions(-) create mode 100644 tests/unit/async_db/test_caching.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index faed35b560c..646f1b1f99a 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -37,6 +37,8 @@ from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import ( ConnectionInfo, + DatabaseInfo, + EngineInfo, fix_url_schema, validate_engine_name_and_url_v1, ) @@ -357,10 +359,31 @@ async def connect_v2( ) as system_engine_connection: cursor = system_engine_connection.cursor() + + # TODO: rework this, this is prototyping right now if database: - await cursor.execute(f'USE DATABASE "{database}"') + cache = _firebolt_cache.get([account_name, api_endpoint]) + cache = cache if cache else ConnectionInfo() + if cache.databases.get(database): + # If database is cached, use it + cursor.database = database + else: + await cursor.execute(f'USE DATABASE "{database}"') + cache.databases[database] = DatabaseInfo(database) + _firebolt_cache.set([account_name, api_endpoint], cache) if engine_name: - await cursor.execute(f'USE ENGINE "{engine_name}"') + cache = _firebolt_cache.get([account_name, api_endpoint]) + cache = cache if cache else ConnectionInfo() + if cache.engines.get(engine_name): + # If engine is cached, use it + cursor.engine_url = cache.engines[engine_name].url + cursor._update_set_parameters(cache.engines[engine_name].params) + else: + await cursor.execute(f'USE ENGINE "{engine_name}"') + cache.engines[engine_name] = EngineInfo( + cursor.engine_url, cursor.parameters + ) # ?? + _firebolt_cache.set([account_name, api_endpoint], cache) # Ensure cursors created from this connection are using the same starting # database and engine return Connection( diff --git a/tests/integration/dbapi/async/V2/test_queries_async.py b/tests/integration/dbapi/async/V2/test_queries_async.py index fa8dc0ab123..1490e24d3f3 100644 --- a/tests/integration/dbapi/async/V2/test_queries_async.py +++ b/tests/integration/dbapi/async/V2/test_queries_async.py @@ -598,3 +598,77 @@ async def test_fb_numeric_paramstyle_incorrect_params( assert "Query referenced positional parameter $34, but it was not set" in str( exc_info.value ) + + +async def test_engine_switch( + database_name: str, + connection_system_engine: Connection, + auth: Auth, + account_name: str, + api_endpoint: str, + engine_name: str, +) -> None: + system_cursor = connection_system_engine.cursor() + await system_cursor.execute("SELECT current_engine()") + result = await system_cursor.fetchone() + assert ( + result[0] == "system" + ), f"Incorrect setup - system engine cursor points at {result[0]}" + async with await connect( + engine_name=engine_name, + database=database_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: + cursor = connection.cursor() + await cursor.execute("SELECT current_engine()") + result = await cursor.fetchone() + assert result[0] == engine_name, "Engine switch failed" + # Test switching back to system engine + await cursor.execute("USE ENGINE system") + await cursor.execute("SELECT current_engine()") + result = await cursor.fetchone() + assert result[0] == "system", "Switching back to system engine failed" + + +async def test_database_switch( + database_name: str, + connection_system_engine_no_db: Connection, + auth: Auth, + account_name: str, + api_endpoint: str, + engine_name: str, +) -> None: + system_cursor = connection_system_engine_no_db.cursor() + await system_cursor.execute("SELECT current_database()") + result = await system_cursor.fetchone() + assert ( + result[0] == "account_db" + ), f"Incorrect setup - system engine cursor points at {result[0]}" + async with await connect( + engine_name=engine_name, + database=database_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: + cursor = connection.cursor() + await cursor.execute("SELECT current_database()") + result = await cursor.fetchone() + assert result[0] == database_name, "Database switch failed" + try: + # Test switching back to system database + await system_cursor.execute( + f"CREATE DATABASE IF NOT EXISTS {database_name}_switch" + ) + await cursor.execute(f"USE DATABASE {database_name}_switch") + await cursor.execute("SELECT current_database()") + result = await cursor.fetchone() + assert ( + result[0] == f"{database_name}_switch" + ), "Switching back to switch database failed" + finally: + await system_cursor.execute( + f"DROP DATABASE IF EXISTS {database_name}_switch" + ) diff --git a/tests/unit/async_db/test_caching.py b/tests/unit/async_db/test_caching.py new file mode 100644 index 00000000000..173d6f7ea0b --- /dev/null +++ b/tests/unit/async_db/test_caching.py @@ -0,0 +1,335 @@ +from typing import Callable + +from pytest import mark +from pytest_httpx import HTTPXMock + +from firebolt.async_db import connect +from firebolt.client.auth import Auth +from firebolt.common.cache import _firebolt_cache + + +@mark.parametrize("cache_enabled", [True, False]) +async def test_connect_caching( + db_name: str, + engine_name: str, + auth_url: str, + api_endpoint: str, + auth: Auth, + account_name: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, +): + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + for _ in range(3): + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + if cache_enabled: + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_database_call_counter == 1, "Use database URL was not cached" + assert use_engine_call_counter == 1, "Use engine URL was not cached" + else: + assert system_engine_call_counter != 1, "System engine URL was cached" + assert use_database_call_counter != 1, "Use database URL was cached" + assert use_engine_call_counter != 1, "Use engine URL was cached" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +async def test_connect_db_switching_caching( + db_name: str, + engine_name: str, + auth_url: str, + api_endpoint: str, + auth: Auth, + account_name: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, +): + """Test caching when switching between different databases.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + second_db_name = f"{db_name}_second" + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + + # First database + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + # Second database + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{second_db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # Connect to first database + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + first_db_calls = use_database_call_counter + + # Connect to second database + async with await connect( + database=second_db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + second_db_calls = use_database_call_counter - first_db_calls + + # Connect to first database again + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + third_db_calls = use_database_call_counter - first_db_calls - second_db_calls + + if cache_enabled: + assert second_db_calls == 1, "Second database call was not made" + assert third_db_calls == 0, "First database was not cached" + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_engine_call_counter == 1, "Use engine URL was not cached" + else: + assert second_db_calls == 1, "Second database call was not made" + assert third_db_calls == 1, "First database was cached when cache disabled" + assert ( + system_engine_call_counter == 3 + ), "System engine URL was cached when cache disabled" + assert ( + use_engine_call_counter == 3 + ), "Use engine URL was cached when cache disabled" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +async def test_connect_engine_switching_caching( + db_name: str, + engine_name: str, + auth_url: str, + api_endpoint: str, + auth: Auth, + account_name: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, +): + """Test caching when switching between different engines.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + second_engine_name = f"{engine_name}_second" + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + # First engine + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + + # Second engine + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{second_engine_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback(query_callback, url=query_url) + + # Connect to first engine + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + first_engine_calls = use_engine_call_counter + + # Connect to second engine + async with await connect( + database=db_name, + engine_name=second_engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + second_engine_calls = use_engine_call_counter - first_engine_calls + + # Connect to first engine again + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + third_engine_calls = ( + use_engine_call_counter - first_engine_calls - second_engine_calls + ) + + if cache_enabled: + assert second_engine_calls == 1, "Second engine call was not made" + assert third_engine_calls == 0, "First engine was not cached" + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_database_call_counter == 1, "Use database URL was not cached" + else: + assert second_engine_calls == 1, "Second engine call was not made" + assert third_engine_calls == 1, "First engine was cached when cache disabled" + assert ( + system_engine_call_counter == 3 + ), "System engine URL was cached when cache disabled" + assert ( + use_database_call_counter == 3 + ), "Use database URL was cached when cache disabled" + + # Reset caches for the next test iteration + _firebolt_cache.enable() diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 9e1390ebbe4..9947031f453 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -9,7 +9,6 @@ from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType -from firebolt.common.cache import _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, ConfigurationError, @@ -186,68 +185,6 @@ async def test_connect_engine_failed( httpx_mock.reset(False) -@mark.parametrize("cache_enabled", [True, False]) -async def test_connect_caching( - db_name: str, - engine_name: str, - auth_url: str, - api_endpoint: str, - auth: Auth, - account_name: str, - httpx_mock: HTTPXMock, - check_credentials_callback: Callable, - get_system_engine_url: str, - get_system_engine_callback: Callable, - system_engine_query_url: str, - system_engine_no_db_query_url: str, - query_url: str, - use_database_callback: Callable, - use_engine_callback: Callable, - query_callback: Callable, - cache_enabled: bool, -): - system_engine_call_counter = 0 - - def system_engine_callback_counter(request, **kwargs): - nonlocal system_engine_call_counter - system_engine_call_counter += 1 - return get_system_engine_callback(request, **kwargs) - - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) - httpx_mock.add_callback( - use_database_callback, - url=system_engine_no_db_query_url, - match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), - ) - - httpx_mock.add_callback( - use_engine_callback, - url=system_engine_query_url, - match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), - ) - httpx_mock.add_callback(query_callback, url=query_url) - - for _ in range(3): - async with await connect( - database=db_name, - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - disable_cache=not cache_enabled, - ) as connection: - await connection.cursor().execute("select*") - - if cache_enabled: - assert system_engine_call_counter == 1, "System engine URL was not cached" - else: - assert system_engine_call_counter != 1, "System engine URL was cached" - - # Reset caches for the next test iteration - _firebolt_cache.enable() - - async def test_connect_system_engine_404( db_name: str, auth_url: str, From 04a33cffa344cb6b7b46c1785c1bb89640390c8b Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 1 Aug 2025 17:59:48 +0100 Subject: [PATCH 06/21] refactor test --- tests/unit/async_db/test_caching.py | 107 +++++++++------------------- 1 file changed, 34 insertions(+), 73 deletions(-) diff --git a/tests/unit/async_db/test_caching.py b/tests/unit/async_db/test_caching.py index 173d6f7ea0b..fa6decf8f7b 100644 --- a/tests/unit/async_db/test_caching.py +++ b/tests/unit/async_db/test_caching.py @@ -1,6 +1,6 @@ from typing import Callable -from pytest import mark +from pytest import fixture, mark from pytest_httpx import HTTPXMock from firebolt.async_db import connect @@ -8,14 +8,34 @@ from firebolt.common.cache import _firebolt_cache +@fixture +async def connection_test( + api_endpoint: str, + auth: Auth, + account_name: str, +): + """Fixture to create a connection factory for testing.""" + + async def factory(db_name: str, engine_name: str, caching: bool) -> Callable: + + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not caching, + ) as connection: + await connection.cursor().execute("select*") + + return factory + + @mark.parametrize("cache_enabled", [True, False]) async def test_connect_caching( db_name: str, engine_name: str, auth_url: str, - api_endpoint: str, - auth: Auth, - account_name: str, httpx_mock: HTTPXMock, check_credentials_callback: Callable, get_system_engine_url: str, @@ -27,6 +47,7 @@ async def test_connect_caching( use_engine_callback: Callable, query_callback: Callable, cache_enabled: bool, + connection_test: Callable, ): system_engine_call_counter = 0 use_database_call_counter = 0 @@ -63,15 +84,7 @@ def use_engine_callback_counter(request, **kwargs): httpx_mock.add_callback(query_callback, url=query_url) for _ in range(3): - async with await connect( - database=db_name, - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - disable_cache=not cache_enabled, - ) as connection: - await connection.cursor().execute("select*") + await connection_test(db_name, engine_name, cache_enabled) if cache_enabled: assert system_engine_call_counter == 1, "System engine URL was not cached" @@ -91,9 +104,6 @@ async def test_connect_db_switching_caching( db_name: str, engine_name: str, auth_url: str, - api_endpoint: str, - auth: Auth, - account_name: str, httpx_mock: HTTPXMock, check_credentials_callback: Callable, get_system_engine_url: str, @@ -105,6 +115,7 @@ async def test_connect_db_switching_caching( use_engine_callback: Callable, query_callback: Callable, cache_enabled: bool, + connection_test: Callable, ): """Test caching when switching between different databases.""" system_engine_call_counter = 0 @@ -152,41 +163,17 @@ def use_engine_callback_counter(request, **kwargs): httpx_mock.add_callback(query_callback, url=query_url) # Connect to first database - async with await connect( - database=db_name, - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - disable_cache=not cache_enabled, - ) as connection: - await connection.cursor().execute("select*") + await connection_test(db_name, engine_name, cache_enabled) first_db_calls = use_database_call_counter # Connect to second database - async with await connect( - database=second_db_name, - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - disable_cache=not cache_enabled, - ) as connection: - await connection.cursor().execute("select*") + await connection_test(second_db_name, engine_name, cache_enabled) second_db_calls = use_database_call_counter - first_db_calls # Connect to first database again - async with await connect( - database=db_name, - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - disable_cache=not cache_enabled, - ) as connection: - await connection.cursor().execute("select*") + await connection_test(db_name, engine_name, cache_enabled) third_db_calls = use_database_call_counter - first_db_calls - second_db_calls @@ -214,9 +201,6 @@ async def test_connect_engine_switching_caching( db_name: str, engine_name: str, auth_url: str, - api_endpoint: str, - auth: Auth, - account_name: str, httpx_mock: HTTPXMock, check_credentials_callback: Callable, get_system_engine_url: str, @@ -228,6 +212,7 @@ async def test_connect_engine_switching_caching( use_engine_callback: Callable, query_callback: Callable, cache_enabled: bool, + connection_test: Callable, ): """Test caching when switching between different engines.""" system_engine_call_counter = 0 @@ -276,41 +261,17 @@ def use_engine_callback_counter(request, **kwargs): httpx_mock.add_callback(query_callback, url=query_url) # Connect to first engine - async with await connect( - database=db_name, - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - disable_cache=not cache_enabled, - ) as connection: - await connection.cursor().execute("select*") + await connection_test(db_name, engine_name, cache_enabled) first_engine_calls = use_engine_call_counter # Connect to second engine - async with await connect( - database=db_name, - engine_name=second_engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - disable_cache=not cache_enabled, - ) as connection: - await connection.cursor().execute("select*") + await connection_test(db_name, second_engine_name, cache_enabled) second_engine_calls = use_engine_call_counter - first_engine_calls # Connect to first engine again - async with await connect( - database=db_name, - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - disable_cache=not cache_enabled, - ) as connection: - await connection.cursor().execute("select*") + await connection_test(db_name, engine_name, cache_enabled) third_engine_calls = ( use_engine_call_counter - first_engine_calls - second_engine_calls From 3f591d605e4df933677d0690e5ace739fd62eefa Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 4 Aug 2025 11:25:33 +0100 Subject: [PATCH 07/21] fix cache and more tests --- src/firebolt/async_db/connection.py | 47 +++------------- src/firebolt/async_db/cursor.py | 36 +++++++++++- tests/unit/async_db/test_caching.py | 8 +++ tests/unit/async_db/test_connection.py | 76 +++++++++++++++++++++++++- 4 files changed, 126 insertions(+), 41 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 646f1b1f99a..e1655da7173 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -37,8 +37,6 @@ from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import ( ConnectionInfo, - DatabaseInfo, - EngineInfo, fix_url_schema, validate_engine_name_and_url_v1, ) @@ -48,25 +46,18 @@ def prepare_ua_parameters( account_name: Optional[str], api_endpoint: str ) -> List[Tuple[str, str]]: ua_parameters = [] - - # cached_id = _firebolt_cache.get_id([account_name, api_endpoint]) conn_uuid = uuid4().hex + cache_key = [account_name, api_endpoint] ua_parameters.append(("connId", conn_uuid)) - prepare_cache_if_needed(account_name, api_endpoint, conn_uuid) - # if cached_id: - # ua_parameters.append(("cachedConnId", cached_id + "-memory")) - # _firebolt_cache.set_id([account_name, api_endpoint], conn_uuid) + cache = _firebolt_cache.get(cache_key) + if cache and cache.id: + ua_parameters.append(("cachedConnId", cache.id + "-memory")) + else: + _firebolt_cache.set(cache_key, ConnectionInfo(id=conn_uuid)) return ua_parameters -def prepare_cache_if_needed( - account_name: Optional[str], api_endpoint: str, conn_id: str -) -> None: - if not _firebolt_cache.get([account_name, api_endpoint]): - _firebolt_cache.set([account_name, api_endpoint], ConnectionInfo(conn_id)) - - class Connection(BaseConnection): """ Firebolt asynchronous database connection class. Implements `PEP 249`_. @@ -266,8 +257,6 @@ async def connect( else: ua_parameters = prepare_ua_parameters(account_name, api_endpoint) user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) - if disable_cache: - _firebolt_cache.disable() # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials # Use V1 if auth is ServiceAccount or UsernamePassword @@ -360,30 +349,10 @@ async def connect_v2( cursor = system_engine_connection.cursor() - # TODO: rework this, this is prototyping right now if database: - cache = _firebolt_cache.get([account_name, api_endpoint]) - cache = cache if cache else ConnectionInfo() - if cache.databases.get(database): - # If database is cached, use it - cursor.database = database - else: - await cursor.execute(f'USE DATABASE "{database}"') - cache.databases[database] = DatabaseInfo(database) - _firebolt_cache.set([account_name, api_endpoint], cache) + await cursor.use_database(database) if engine_name: - cache = _firebolt_cache.get([account_name, api_endpoint]) - cache = cache if cache else ConnectionInfo() - if cache.engines.get(engine_name): - # If engine is cached, use it - cursor.engine_url = cache.engines[engine_name].url - cursor._update_set_parameters(cache.engines[engine_name].params) - else: - await cursor.execute(f'USE ENGINE "{engine_name}"') - cache.engines[engine_name] = EngineInfo( - cursor.engine_url, cursor.parameters - ) # ?? - _firebolt_cache.set([account_name, api_endpoint], cache) + await cursor.use_engine(engine_name) # Ensure cursors created from this connection are using the same starting # database and engine return Connection( diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 20fa2c1a7e8..8f26d9cb5d7 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -57,8 +57,15 @@ if TYPE_CHECKING: from firebolt.async_db.connection import Connection +from firebolt.common.cache import _firebolt_cache from firebolt.utils.async_util import anext, async_islice -from firebolt.utils.util import Timer, raise_error_from_response +from firebolt.utils.util import ( + ConnectionInfo, + DatabaseInfo, + EngineInfo, + Timer, + raise_error_from_response, +) logger = logging.getLogger(__name__) @@ -332,6 +339,33 @@ async def _handle_query_execution( await self._parse_response_headers(resp.headers) await self._append_row_set_from_response(resp) + async def use_database(self, database: str) -> None: + """Switch the current database context with caching.""" + cache_key = [self._client.account_name, self.connection.api_endpoint] + cache = _firebolt_cache.get(cache_key) + cache = cache if cache else ConnectionInfo() + if cache.databases.get(database): + # If database is cached, use it + self.database = database + else: + await self.execute(f'USE DATABASE "{database}"') + cache.databases[database] = DatabaseInfo(database) + _firebolt_cache.set(cache_key, cache) + + async def use_engine(self, engine: str) -> None: + """Switch the current engine context with caching.""" + cache_key = [self._client.account_name, self.connection.api_endpoint] + cache = _firebolt_cache.get(cache_key) + cache = cache if cache else ConnectionInfo() + if cache.engines.get(engine): + # If engine is cached, use it + self.engine_url = cache.engines[engine].url + self._update_set_parameters(cache.engines[engine].params) + else: + await self.execute(f'USE ENGINE "{engine}"') + cache.engines[engine] = EngineInfo(self.engine_url, self.parameters) # ?? + _firebolt_cache.set(cache_key, cache) + @check_not_closed async def execute( self, diff --git a/tests/unit/async_db/test_caching.py b/tests/unit/async_db/test_caching.py index fa6decf8f7b..783613404cd 100644 --- a/tests/unit/async_db/test_caching.py +++ b/tests/unit/async_db/test_caching.py @@ -31,6 +31,14 @@ async def factory(db_name: str, engine_name: str, caching: bool) -> Callable: return factory +@fixture(autouse=True) +async def enable_cache(): + _firebolt_cache.enable() + _firebolt_cache.clear() + yield + _firebolt_cache.clear() + + @mark.parametrize("cache_enabled", [True, False]) async def test_connect_caching( db_name: str, diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 9947031f453..ed49370ae9e 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -1,6 +1,6 @@ from typing import Callable, List, Optional, Tuple from unittest.mock import ANY as AnyValue -from unittest.mock import patch +from unittest.mock import MagicMock, patch from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises @@ -9,6 +9,7 @@ from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType +from firebolt.common.cache import _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, ConfigurationError, @@ -346,6 +347,79 @@ async def test_connect_no_user_agent( ut.assert_called_with([], [], AnyValue) +async def test_connect_caching( + engine_name: str, + account_name: str, + api_endpoint: str, + db_name: str, + auth: Auth, + httpx_mock: HTTPXMock, + query_callback: Callable, + query_url: str, + mock_connection_flow: Callable, +) -> None: + async def do_connect(): + async with await connect( + auth=auth, + database=db_name, + engine_name=engine_name, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: + await connection.cursor().execute("select*") + + _firebolt_cache.clear() + mock_id = "12345" + mock_id2 = "67890" + mock_id3 = "54321" + with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + ut.side_effect = [ + f"connId:{mock_id}", + f"connId:{mock_id2}; cachedConnId:{mock_id}-memory", + f"connId:{mock_id3}", + ] + with patch("firebolt.async_db.connection.uuid4") as uuid4: + uuid4.side_effect = [ + MagicMock(hex=mock_id), + MagicMock(hex=mock_id2), + MagicMock(hex=mock_id3), + ] + mock_connection_flow() + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": f"connId:{mock_id}"}, + ) + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={ + "User-Agent": f"connId:{mock_id2}; cachedConnId:{mock_id}-memory" + }, + ) + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": f"connId:{mock_id3}"}, + ) + + await do_connect() + ut.assert_called_with(AnyValue, AnyValue, [("connId", mock_id)]) + + # Second call should use cached connection info + await do_connect() + ut.assert_called_with( + AnyValue, + AnyValue, + [("connId", mock_id2), ("cachedConnId", f"{mock_id}-memory")], + ) + _firebolt_cache.clear() + + # Third call should have a new connection id + await do_connect() + ut.assert_called_with(AnyValue, AnyValue, [("connId", mock_id3)]) + + @mark.parametrize( "server_status,expected_running,expected_success", [ From 892d6354c3992051bdc90b10c2c11fef345c8826 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 4 Aug 2025 18:43:21 +0100 Subject: [PATCH 08/21] refactor and tests --- src/firebolt/async_db/connection.py | 49 ++-- src/firebolt/async_db/cursor.py | 11 +- src/firebolt/async_db/util.py | 19 +- src/firebolt/db/connection.py | 58 ++--- src/firebolt/db/cursor.py | 33 +++ src/firebolt/db/util.py | 19 +- src/firebolt/{common => utils}/cache.py | 29 ++- src/firebolt/utils/usage_tracker.py | 13 + src/firebolt/utils/util.py | 42 +--- tests/unit/async_db/test_caching.py | 2 +- tests/unit/async_db/test_connection.py | 2 +- tests/unit/conftest.py | 2 +- tests/unit/db/test_caching.py | 304 ++++++++++++++++++++++++ tests/unit/db/test_connection.py | 2 +- tests/unit/utils/test_utils.py | 6 +- 15 files changed, 459 insertions(+), 132 deletions(-) rename src/firebolt/{common => utils}/cache.py (77%) create mode 100644 tests/unit/db/test_caching.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index e1655da7173..0ee10ef9519 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -2,7 +2,7 @@ from ssl import SSLContext from types import TracebackType -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Type, Union from uuid import uuid4 from httpx import Timeout @@ -22,8 +22,8 @@ BaseConnection, _parse_async_query_info_results, ) -from firebolt.common.cache import _firebolt_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -34,28 +34,11 @@ parse_firebolt_core_url, validate_firebolt_core_parameters, ) -from firebolt.utils.usage_tracker import get_user_agent_header -from firebolt.utils.util import ( - ConnectionInfo, - fix_url_schema, - validate_engine_name_and_url_v1, +from firebolt.utils.usage_tracker import ( + get_cache_tracking, + get_user_agent_header, ) - - -def prepare_ua_parameters( - account_name: Optional[str], api_endpoint: str -) -> List[Tuple[str, str]]: - ua_parameters = [] - conn_uuid = uuid4().hex - cache_key = [account_name, api_endpoint] - ua_parameters.append(("connId", conn_uuid)) - cache = _firebolt_cache.get(cache_key) - if cache and cache.id: - ua_parameters.append(("cachedConnId", cache.id + "-memory")) - else: - _firebolt_cache.set(cache_key, ConnectionInfo(id=conn_uuid)) - - return ua_parameters +from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 class Connection(BaseConnection): @@ -92,6 +75,7 @@ class Connection(BaseConnection): "_is_closed", "client_class", "cursor_type", + "id", ) def __init__( @@ -102,12 +86,14 @@ def __init__( cursor_type: Type[Cursor], api_endpoint: str, init_parameters: Optional[Dict[str, Any]] = None, + id: str = uuid4().hex, ): super().__init__(cursor_type) self.api_endpoint = api_endpoint self.engine_url = engine_url self._cursors: List[Cursor] = [] self._client = client + self.id = id self.init_parameters = init_parameters or {} if database: self.init_parameters["database"] = database @@ -251,11 +237,12 @@ async def connect( assert auth is not None user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) + connection_id = uuid4().hex ua_parameters = [] if disable_cache: _firebolt_cache.disable() else: - ua_parameters = prepare_ua_parameters(account_name, api_endpoint) + ua_parameters = get_cache_tracking([account_name, api_endpoint], connection_id) user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials @@ -279,6 +266,7 @@ async def connect( database=database, engine_name=engine_name, api_endpoint=api_endpoint, + connection_id=connection_id, ) elif auth_version == FireboltAuthVersion.V1: return await connect_v1( @@ -289,6 +277,7 @@ async def connect( engine_name=engine_name, engine_url=engine_url, api_endpoint=api_endpoint, + connection_id=connection_id, ) else: raise ConfigurationError(f"Unsupported auth type: {type(auth)}") @@ -297,6 +286,7 @@ async def connect( async def connect_v2( auth: Auth, user_agent_header: str, + connection_id: str, account_name: Optional[str] = None, database: Optional[str] = None, engine_name: Optional[str] = None, @@ -327,7 +317,7 @@ async def connect_v2( assert account_name is not None system_engine_info = await _get_system_engine_url_and_params( - auth, account_name, api_endpoint + auth, account_name, api_endpoint, connection_id ) client = AsyncClientV2( @@ -345,6 +335,7 @@ async def connect_v2( CursorV2, api_endpoint, system_engine_info.params, + connection_id, ) as system_engine_connection: cursor = system_engine_connection.cursor() @@ -362,12 +353,14 @@ async def connect_v2( CursorV2, api_endpoint, cursor.parameters, + connection_id, ) async def connect_v1( auth: Auth, user_agent_header: str, + connection_id: str, database: Optional[str] = None, account_name: Optional[str] = None, engine_name: Optional[str] = None, @@ -419,11 +412,7 @@ async def connect_v1( headers={"User-Agent": user_agent_header}, ) return Connection( - engine_url, - database, - client, - CursorV1, - api_endpoint, + engine_url, database, client, CursorV1, api_endpoint, id=connection_id ) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 8f26d9cb5d7..6b4c339ebe3 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -57,15 +57,14 @@ if TYPE_CHECKING: from firebolt.async_db.connection import Connection -from firebolt.common.cache import _firebolt_cache from firebolt.utils.async_util import anext, async_islice -from firebolt.utils.util import ( +from firebolt.utils.cache import ( ConnectionInfo, DatabaseInfo, EngineInfo, - Timer, - raise_error_from_response, + _firebolt_cache, ) +from firebolt.utils.util import Timer, raise_error_from_response logger = logging.getLogger(__name__) @@ -343,7 +342,7 @@ async def use_database(self, database: str) -> None: """Switch the current database context with caching.""" cache_key = [self._client.account_name, self.connection.api_endpoint] cache = _firebolt_cache.get(cache_key) - cache = cache if cache else ConnectionInfo() + cache = cache if cache else ConnectionInfo(id=self.connection.id) if cache.databases.get(database): # If database is cached, use it self.database = database @@ -356,7 +355,7 @@ async def use_engine(self, engine: str) -> None: """Switch the current engine context with caching.""" cache_key = [self._client.account_name, self.connection.api_endpoint] cache = _firebolt_cache.get(cache_key) - cache = cache if cache else ConnectionInfo() + cache = cache if cache else ConnectionInfo(id=self.connection.id) if cache.engines.get(engine): # If engine is cached, use it self.engine_url = cache.engines[engine].url diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index 4e6543266e4..4028ac399db 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -4,24 +4,21 @@ from firebolt.client.auth import Auth from firebolt.client.client import AsyncClientV2 -from firebolt.common.cache import _firebolt_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.utils.cache import ConnectionInfo, EngineInfo, _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, InterfaceError, ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import ( - ConnectionInfo, - EngineInfo, - parse_url_and_params, -) +from firebolt.utils.util import parse_url_and_params async def _get_system_engine_url_and_params( auth: Auth, account_name: str, api_endpoint: str, + connection_id: str, ) -> EngineInfo: cache = _firebolt_cache.get([account_name, api_endpoint]) if cache and (result := cache.system_engine): @@ -42,9 +39,9 @@ async def _get_system_engine_url_and_params( f"Unable to retrieve system engine endpoint {url}: " f"{response.status_code} {response.content.decode()}" ) - result = parse_url_and_params(response.json()["engineUrl"]) + url, params = parse_url_and_params(response.json()["engineUrl"]) if not cache: - cache = ConnectionInfo() - cache.system_engine = result - _firebolt_cache.set([account_name, api_endpoint], cache) - return result + cache = ConnectionInfo(id=connection_id) + _firebolt_cache.set([account_name, api_endpoint], cache) + cache.system_engine = EngineInfo(url=url, params=params) + return cache.system_engine diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index f6a909d80bd..2e66c290ae7 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -3,7 +3,7 @@ import logging from ssl import SSLContext from types import TracebackType -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Type, Union from uuid import uuid4 from warnings import warn @@ -21,10 +21,10 @@ BaseConnection, _parse_async_query_info_results, ) -from firebolt.common.cache import _firebolt_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.db.cursor import Cursor, CursorV1, CursorV2 from firebolt.db.util import _get_system_engine_url_and_params +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -35,39 +35,15 @@ parse_firebolt_core_url, validate_firebolt_core_parameters, ) -from firebolt.utils.usage_tracker import get_user_agent_header -from firebolt.utils.util import ( - ConnectionInfo, - fix_url_schema, - validate_engine_name_and_url_v1, +from firebolt.utils.usage_tracker import ( + get_cache_tracking, + get_user_agent_header, ) +from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 logger = logging.getLogger(__name__) -def prepare_ua_parameters( - account_name: Optional[str], api_endpoint: str -) -> List[Tuple[str, str]]: - ua_parameters = [] - - # cached_id = _firebolt_cache.get_id([account_name, api_endpoint]) - conn_uuid = uuid4().hex - ua_parameters.append(("connId", conn_uuid)) - prepare_cache_if_needed(account_name, api_endpoint, conn_uuid) - # if cached_id: - # ua_parameters.append(("cachedConnId", cached_id + "-memory")) - # _firebolt_cache.set_id([account_name, api_endpoint], conn_uuid) - - return ua_parameters - - -def prepare_cache_if_needed( - account_name: Optional[str], api_endpoint: str, conn_id: str -) -> None: - if not _firebolt_cache.get([account_name, api_endpoint]): - _firebolt_cache.set([account_name, api_endpoint], ConnectionInfo(conn_id)) - - def connect( auth: Optional[Auth] = None, account_name: Optional[str] = None, @@ -90,11 +66,12 @@ def connect( assert auth is not None user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) + connection_id = uuid4().hex ua_parameters = [] if disable_cache: _firebolt_cache.disable() else: - ua_parameters = prepare_ua_parameters(account_name, api_endpoint) + ua_parameters = get_cache_tracking([account_name, api_endpoint], connection_id) user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) auth_version = auth.get_firebolt_version() # Use CORE if auth is FireboltCore @@ -119,6 +96,7 @@ def connect( database=database, engine_name=engine_name, api_endpoint=api_endpoint, + connection_id=connection_id, ) elif auth_version == FireboltAuthVersion.V1: return connect_v1( @@ -129,6 +107,7 @@ def connect( engine_name=engine_name, engine_url=engine_url, api_endpoint=api_endpoint, + connection_id=connection_id, ) else: raise ConfigurationError(f"Unsupported auth type: {type(auth)}") @@ -137,6 +116,7 @@ def connect( def connect_v2( auth: Auth, user_agent_header: str, + connection_id: str, account_name: Optional[str] = None, database: Optional[str] = None, engine_name: Optional[str] = None, @@ -167,7 +147,7 @@ def connect_v2( assert account_name is not None system_engine_info = _get_system_engine_url_and_params( - auth, account_name, api_endpoint + auth, account_name, api_endpoint, connection_id ) client = ClientV2( @@ -185,13 +165,14 @@ def connect_v2( CursorV2, api_endpoint, system_engine_info.params, + connection_id, ) as system_engine_connection: cursor = system_engine_connection.cursor() if database: - cursor.execute(f'USE DATABASE "{database}"') + cursor.use_database(database) if engine_name: - cursor.execute(f'USE ENGINE "{engine_name}"') + cursor.use_engine(engine_name) # Ensure cursors created from this connection are using the same starting # database and engine return Connection( @@ -201,6 +182,7 @@ def connect_v2( CursorV2, api_endpoint, cursor.parameters, + connection_id, ) @@ -231,6 +213,7 @@ class Connection(BaseConnection): "_is_closed", "client_class", "cursor_type", + "id", ) def __init__( @@ -241,12 +224,14 @@ def __init__( cursor_type: Type[Cursor], api_endpoint: str = DEFAULT_API_URL, init_parameters: Optional[Dict[str, Any]] = None, + id: str = uuid4().hex, ): super().__init__(cursor_type) self.api_endpoint = api_endpoint self.engine_url = engine_url self._cursors: List[Cursor] = [] self._client = client + self.id = id self.init_parameters = init_parameters or {} if database: self.init_parameters["database"] = database @@ -378,6 +363,7 @@ def __del__(self) -> None: def connect_v1( auth: Auth, user_agent_header: str, + connection_id: str, database: Optional[str] = None, account_name: Optional[str] = None, engine_name: Optional[str] = None, @@ -429,7 +415,9 @@ def connect_v1( timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), headers={"User-Agent": user_agent_header}, ) - return Connection(engine_url, database, client, CursorV1, api_endpoint) + return Connection( + engine_url, database, client, CursorV1, api_endpoint, id=connection_id + ) def connect_core( diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index d8d913894e7..f3ce9a25217 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -50,6 +50,12 @@ from firebolt.common.row_set.synchronous.in_memory import InMemoryRowSet from firebolt.common.row_set.synchronous.streaming import StreamingRowSet from firebolt.common.statement_formatter import create_statement_formatter +from firebolt.utils.cache import ( + ConnectionInfo, + DatabaseInfo, + EngineInfo, + _firebolt_cache, +) from firebolt.utils.exception import ( EngineNotRunningError, FireboltDatabaseError, @@ -338,6 +344,33 @@ def _handle_query_execution( self._parse_response_headers(resp.headers) self._append_row_set_from_response(resp) + def use_database(self, database: str) -> None: + """Switch the current database context with caching.""" + cache_key = [self._client.account_name, self.connection.api_endpoint] + cache = _firebolt_cache.get(cache_key) + cache = cache if cache else ConnectionInfo(id=self.connection.id) + if cache.databases.get(database): + # If database is cached, use it + self.database = database + else: + self.execute(f'USE DATABASE "{database}"') + cache.databases[database] = DatabaseInfo(database) + _firebolt_cache.set(cache_key, cache) + + def use_engine(self, engine: str) -> None: + """Switch the current engine context with caching.""" + cache_key = [self._client.account_name, self.connection.api_endpoint] + cache = _firebolt_cache.get(cache_key) + cache = cache if cache else ConnectionInfo(id=self.connection.id) + if cache.engines.get(engine): + # If engine is cached, use it + self.engine_url = cache.engines[engine].url + self._update_set_parameters(cache.engines[engine].params) + else: + self.execute(f'USE ENGINE "{engine}"') + cache.engines[engine] = EngineInfo(self.engine_url, self.parameters) # ?? + _firebolt_cache.set(cache_key, cache) + @check_not_closed def execute( self, diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index 63384fd34b8..111cd6a1ba6 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -4,24 +4,21 @@ from firebolt.client import ClientV2 from firebolt.client.auth import Auth -from firebolt.common.cache import _firebolt_cache from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.utils.cache import ConnectionInfo, EngineInfo, _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, InterfaceError, ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import ( - ConnectionInfo, - EngineInfo, - parse_url_and_params, -) +from firebolt.utils.util import parse_url_and_params def _get_system_engine_url_and_params( auth: Auth, account_name: str, api_endpoint: str, + connection_id: str, ) -> EngineInfo: cache = _firebolt_cache.get([account_name, api_endpoint]) if cache and (result := cache.system_engine): @@ -42,9 +39,9 @@ def _get_system_engine_url_and_params( f"Unable to retrieve system engine endpoint {url}: " f"{response.status_code} {response.content.decode()}" ) - result = parse_url_and_params(response.json()["engineUrl"]) + url, params = parse_url_and_params(response.json()["engineUrl"]) if not cache: - cache = ConnectionInfo() - cache.system_engine = result - _firebolt_cache.set([account_name, api_endpoint], cache) - return result + cache = ConnectionInfo(id=connection_id) + _firebolt_cache.set([account_name, api_endpoint], cache) + cache.system_engine = EngineInfo(url=url, params=params) + return cache.system_engine diff --git a/src/firebolt/common/cache.py b/src/firebolt/utils/cache.py similarity index 77% rename from src/firebolt/common/cache.py rename to src/firebolt/utils/cache.py index 8afcb9d9cab..46587e36937 100644 --- a/src/firebolt/common/cache.py +++ b/src/firebolt/utils/cache.py @@ -1,8 +1,7 @@ import os +from dataclasses import dataclass, field from typing import Any, Callable, Dict, Generic, Optional, Protocol, TypeVar -from firebolt.utils.util import ConnectionInfo - T = TypeVar("T") @@ -11,6 +10,32 @@ def __repr__(self) -> str: ... +@dataclass +class EngineInfo: + """Class to hold engine information for caching.""" + + url: str + params: Dict[str, str] + + +@dataclass +class DatabaseInfo: + """Class to hold database information for caching.""" + + name: str + + +@dataclass +class ConnectionInfo: + """Class to hold connection information for caching.""" + + id: str + expiry_time: Optional[int] = None + system_engine: Optional[EngineInfo] = None + databases: Dict[str, DatabaseInfo] = field(default_factory=dict) + engines: Dict[str, EngineInfo] = field(default_factory=dict) + + def noop_if_disabled(func: Callable) -> Callable: """Decorator to make function do nothing if the cache is disabled.""" diff --git a/src/firebolt/utils/usage_tracker.py b/src/firebolt/utils/usage_tracker.py index 5eb656146ca..79c55352c35 100644 --- a/src/firebolt/utils/usage_tracker.py +++ b/src/firebolt/utils/usage_tracker.py @@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Tuple from firebolt import __version__ +from firebolt.utils.cache import ConnectionInfo, ReprCacheable, _firebolt_cache @dataclass @@ -224,3 +225,15 @@ def get_user_agent_header( for name, version in versions.drivers: drivers[name] = version return format_as_user_agent(drivers, clients, additional_properties or []) + + +def get_cache_tracking(cache_key: ReprCacheable, conn_id: str) -> List[Tuple[str, str]]: + ua_parameters = [] + ua_parameters.append(("connId", conn_id)) + cache = _firebolt_cache.get(cache_key) + if cache: + ua_parameters.append(("cachedConnId", cache.id + "-memory")) + else: + _firebolt_cache.set(cache_key, ConnectionInfo(id=conn_id)) + + return ua_parameters diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index f2017e5fc31..9a867301f93 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -1,10 +1,18 @@ import logging -from dataclasses import dataclass, field from functools import lru_cache from os import environ from time import time from types import TracebackType -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, +) from urllib.parse import parse_qs, urljoin, urlparse from httpx import URL, Response, codes @@ -18,32 +26,6 @@ logger = logging.getLogger(__name__) -@dataclass -class EngineInfo: - """Class to hold engine information for caching.""" - - url: str - params: Dict[str, str] - - -@dataclass -class DatabaseInfo: - """Class to hold database information for caching.""" - - name: str - - -@dataclass -class ConnectionInfo: - """Class to hold connection information for caching.""" - - id: Optional[str] = None - expiry_time: Optional[int] = None - system_engine: Optional[EngineInfo] = None - databases: Dict[str, DatabaseInfo] = field(default_factory=dict) - engines: Dict[str, EngineInfo] = field(default_factory=dict) - - def cached_property(func: Callable[..., T]) -> T: """cached_property implementation for 3.7 backward compatibility. @@ -230,7 +212,7 @@ def __exit__( logger.debug(log_message) -def parse_url_and_params(url: str) -> EngineInfo: +def parse_url_and_params(url: str) -> Tuple[str, Dict]: """Extract URL and query parameters separately from a URL.""" url = fix_url_schema(url) parsed_url = urlparse(url) @@ -246,7 +228,7 @@ def parse_url_and_params(url: str) -> EngineInfo: if len(values) > 1: raise ValueError(f"Multiple values found for key '{key}'") query_params_dict[key] = values[0] - return EngineInfo(url=result_url, params=query_params_dict) + return (result_url, query_params_dict) class _ExceptionGroup(Exception): diff --git a/tests/unit/async_db/test_caching.py b/tests/unit/async_db/test_caching.py index 783613404cd..2ba0ff39a05 100644 --- a/tests/unit/async_db/test_caching.py +++ b/tests/unit/async_db/test_caching.py @@ -5,7 +5,7 @@ from firebolt.async_db import connect from firebolt.client.auth import Auth -from firebolt.common.cache import _firebolt_cache +from firebolt.utils.cache import _firebolt_cache @fixture diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index ed49370ae9e..19df76dfe3c 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -9,7 +9,7 @@ from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType -from firebolt.common.cache import _firebolt_cache +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, ConfigurationError, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index c5d69c6de67..1eb81c7dc78 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,8 +8,8 @@ from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.client import ClientV2 -from firebolt.common.cache import _firebolt_cache from firebolt.common.settings import Settings +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( DatabaseError, DataError, diff --git a/tests/unit/db/test_caching.py b/tests/unit/db/test_caching.py new file mode 100644 index 00000000000..931fd66636f --- /dev/null +++ b/tests/unit/db/test_caching.py @@ -0,0 +1,304 @@ +from typing import Callable + +from pytest import fixture, mark +from pytest_httpx import HTTPXMock + +from firebolt.client.auth import Auth +from firebolt.db import connect +from firebolt.utils.cache import _firebolt_cache + + +@fixture +def connection_test( + api_endpoint: str, + auth: Auth, + account_name: str, +): + """Fixture to create a connection factory for testing.""" + + def factory(db_name: str, engine_name: str, caching: bool) -> Callable: + + with connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not caching, + ) as connection: + connection.cursor().execute("select*") + + return factory + + +@fixture(autouse=True) +def enable_cache(): + _firebolt_cache.enable() + _firebolt_cache.clear() + yield + _firebolt_cache.clear() + + +@mark.parametrize("cache_enabled", [True, False]) +def test_connect_caching( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, + connection_test: Callable, +): + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + for _ in range(3): + connection_test(db_name, engine_name, cache_enabled) + + if cache_enabled: + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_database_call_counter == 1, "Use database URL was not cached" + assert use_engine_call_counter == 1, "Use engine URL was not cached" + else: + assert system_engine_call_counter != 1, "System engine URL was cached" + assert use_database_call_counter != 1, "Use database URL was cached" + assert use_engine_call_counter != 1, "Use engine URL was cached" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +def test_connect_db_switching_caching( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, + connection_test: Callable, +): + """Test caching when switching between different databases.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + second_db_name = f"{db_name}_second" + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + + # First database + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + # Second database + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{second_db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # Connect to first database + connection_test(db_name, engine_name, cache_enabled) + + first_db_calls = use_database_call_counter + + # Connect to second database + connection_test(second_db_name, engine_name, cache_enabled) + + second_db_calls = use_database_call_counter - first_db_calls + + # Connect to first database again + connection_test(db_name, engine_name, cache_enabled) + + third_db_calls = use_database_call_counter - first_db_calls - second_db_calls + + if cache_enabled: + assert second_db_calls == 1, "Second database call was not made" + assert third_db_calls == 0, "First database was not cached" + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_engine_call_counter == 1, "Use engine URL was not cached" + else: + assert second_db_calls == 1, "Second database call was not made" + assert third_db_calls == 1, "First database was cached when cache disabled" + assert ( + system_engine_call_counter == 3 + ), "System engine URL was cached when cache disabled" + assert ( + use_engine_call_counter == 3 + ), "Use engine URL was cached when cache disabled" + + # Reset caches for the next test iteration + _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +def test_connect_engine_switching_caching( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + cache_enabled: bool, + connection_test: Callable, +): + """Test caching when switching between different engines.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + second_engine_name = f"{engine_name}_second" + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + # First engine + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + + # Second engine + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{second_engine_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback(query_callback, url=query_url) + + # Connect to first engine + connection_test(db_name, engine_name, cache_enabled) + + first_engine_calls = use_engine_call_counter + + # Connect to second engine + connection_test(db_name, second_engine_name, cache_enabled) + + second_engine_calls = use_engine_call_counter - first_engine_calls + + # Connect to first engine again + connection_test(db_name, engine_name, cache_enabled) + + third_engine_calls = ( + use_engine_call_counter - first_engine_calls - second_engine_calls + ) + + if cache_enabled: + assert second_engine_calls == 1, "Second engine call was not made" + assert third_engine_calls == 0, "First engine was not cached" + assert system_engine_call_counter == 1, "System engine URL was not cached" + assert use_database_call_counter == 1, "Use database URL was not cached" + else: + assert second_engine_calls == 1, "Second engine call was not made" + assert third_engine_calls == 1, "First engine was cached when cache disabled" + assert ( + system_engine_call_counter == 3 + ), "System engine URL was cached when cache disabled" + assert ( + use_database_call_counter == 3 + ), "Use database URL was cached when cache disabled" + + # Reset caches for the next test iteration + _firebolt_cache.enable() diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index d67f51abfa1..b361c834162 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -11,9 +11,9 @@ from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.client import ClientV2 from firebolt.common._types import ColType -from firebolt.common.cache import _firebolt_cache from firebolt.db import Connection, connect from firebolt.db.cursor import CursorV2 +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, ConfigurationError, diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 08101988f76..2fa525d4628 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -50,9 +50,9 @@ def test_get_internal_error_code(status_code, content, expected_error_code): ], ) def test_parse_url_and_params(url, expected_url, expected_params): - parsed_info = parse_url_and_params(url) - assert parsed_info.url == expected_url - assert parsed_info.params == expected_params + url, params = parse_url_and_params(url) + assert url == expected_url + assert params == expected_params @pytest.mark.parametrize( From 1cecdea0b701897d749a905359e6a20020eb7ae1 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 4 Aug 2025 19:02:57 +0100 Subject: [PATCH 09/21] fix core --- tests/integration/dbapi/async/V2/conftest.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/integration/dbapi/async/V2/conftest.py b/tests/integration/dbapi/async/V2/conftest.py index 81ac291233a..2167b8edc22 100644 --- a/tests/integration/dbapi/async/V2/conftest.py +++ b/tests/integration/dbapi/async/V2/conftest.py @@ -24,12 +24,9 @@ async def connection( ) -> Connection: if request.param == "core": kwargs = { - "engine_name": None, "database": "firebolt", "auth": core_auth, "url": core_url, - "account_name": None, - "api_endpoint": None, } else: kwargs = { From daa3bc3f0662ddb2ac79b238b7116f379d3d27e1 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 4 Aug 2025 19:08:58 +0100 Subject: [PATCH 10/21] fix core --- tests/integration/dbapi/sync/V2/conftest.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/integration/dbapi/sync/V2/conftest.py b/tests/integration/dbapi/sync/V2/conftest.py index 63ec51d5e9d..6c9addb21d3 100644 --- a/tests/integration/dbapi/sync/V2/conftest.py +++ b/tests/integration/dbapi/sync/V2/conftest.py @@ -24,12 +24,9 @@ def connection( ) -> Connection: if request.param == "core": kwargs = { - "engine_name": None, "database": "firebolt", "auth": core_auth, "url": core_url, - "account_name": None, - "api_endpoint": None, } else: kwargs = { From ec404147d197346d83e74200228dbf4051a70240 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 5 Aug 2025 11:25:40 +0100 Subject: [PATCH 11/21] using new connection key --- src/firebolt/async_db/connection.py | 7 +++- src/firebolt/async_db/cursor.py | 19 +++------ src/firebolt/async_db/util.py | 12 ++++-- src/firebolt/client/auth/base.py | 18 ++++++++ .../client/auth/client_credentials.py | 18 ++++++++ src/firebolt/client/auth/firebolt_core.py | 22 ++++++++++ src/firebolt/client/auth/service_account.py | 18 ++++++++ src/firebolt/client/auth/token.py | 18 ++++++++ src/firebolt/client/auth/username_password.py | 18 ++++++++ src/firebolt/common/cursor/base_cursor.py | 42 ++++++++++++++++++- src/firebolt/db/connection.py | 7 +++- src/firebolt/db/cursor.py | 19 +++------ src/firebolt/db/util.py | 12 ++++-- src/firebolt/utils/cache.py | 30 ++++++++++++- 14 files changed, 222 insertions(+), 38 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 0ee10ef9519..b2eacdfa051 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -23,7 +23,7 @@ _parse_async_query_info_results, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.cache import _firebolt_cache +from firebolt.utils.cache import SecureCacheKey, _firebolt_cache from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -242,7 +242,10 @@ async def connect( if disable_cache: _firebolt_cache.disable() else: - ua_parameters = get_cache_tracking([account_name, api_endpoint], connection_id) + cache_key = SecureCacheKey( + [auth.principal, auth.secret, account_name], auth.secret + ) + ua_parameters = get_cache_tracking(cache_key, connection_id) user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 6b4c339ebe3..42bbe26ccd0 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -58,12 +58,7 @@ from firebolt.async_db.connection import Connection from firebolt.utils.async_util import anext, async_islice -from firebolt.utils.cache import ( - ConnectionInfo, - DatabaseInfo, - EngineInfo, - _firebolt_cache, -) +from firebolt.utils.cache import ConnectionInfo, DatabaseInfo, EngineInfo from firebolt.utils.util import Timer, raise_error_from_response logger = logging.getLogger(__name__) @@ -91,8 +86,8 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - self._client = client self.connection = connection + self._client: AsyncClient = client self.engine_url = connection.engine_url self._row_set: Optional[BaseAsyncRowSet] = None if connection.init_parameters: @@ -340,8 +335,7 @@ async def _handle_query_execution( async def use_database(self, database: str) -> None: """Switch the current database context with caching.""" - cache_key = [self._client.account_name, self.connection.api_endpoint] - cache = _firebolt_cache.get(cache_key) + cache = self.get_cache() cache = cache if cache else ConnectionInfo(id=self.connection.id) if cache.databases.get(database): # If database is cached, use it @@ -349,12 +343,11 @@ async def use_database(self, database: str) -> None: else: await self.execute(f'USE DATABASE "{database}"') cache.databases[database] = DatabaseInfo(database) - _firebolt_cache.set(cache_key, cache) + self.set_cache(cache) async def use_engine(self, engine: str) -> None: """Switch the current engine context with caching.""" - cache_key = [self._client.account_name, self.connection.api_endpoint] - cache = _firebolt_cache.get(cache_key) + cache = self.get_cache() cache = cache if cache else ConnectionInfo(id=self.connection.id) if cache.engines.get(engine): # If engine is cached, use it @@ -363,7 +356,7 @@ async def use_engine(self, engine: str) -> None: else: await self.execute(f'USE ENGINE "{engine}"') cache.engines[engine] = EngineInfo(self.engine_url, self.parameters) # ?? - _firebolt_cache.set(cache_key, cache) + self.set_cache(cache) @check_not_closed async def execute( diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index 4028ac399db..98aae3e017d 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -5,7 +5,12 @@ from firebolt.client.auth import Auth from firebolt.client.client import AsyncClientV2 from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.cache import ConnectionInfo, EngineInfo, _firebolt_cache +from firebolt.utils.cache import ( + ConnectionInfo, + EngineInfo, + SecureCacheKey, + _firebolt_cache, +) from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, InterfaceError, @@ -20,7 +25,8 @@ async def _get_system_engine_url_and_params( api_endpoint: str, connection_id: str, ) -> EngineInfo: - cache = _firebolt_cache.get([account_name, api_endpoint]) + cache_key = SecureCacheKey([auth.principal, auth.secret, account_name], auth.secret) + cache = _firebolt_cache.get(cache_key) if cache and (result := cache.system_engine): return result async with AsyncClientV2( @@ -42,6 +48,6 @@ async def _get_system_engine_url_and_params( url, params = parse_url_and_params(response.json()["engineUrl"]) if not cache: cache = ConnectionInfo(id=connection_id) - _firebolt_cache.set([account_name, api_endpoint], cache) + _firebolt_cache.set(cache_key, cache) cache.system_engine = EngineInfo(url=url, params=params) return cache.system_engine diff --git a/src/firebolt/client/auth/base.py b/src/firebolt/client/auth/base.py index 74df42202db..098a475ddad 100644 --- a/src/firebolt/client/auth/base.py +++ b/src/firebolt/client/auth/base.py @@ -65,6 +65,24 @@ def token(self) -> Optional[str]: """ return self._token + @property + @abstractmethod + def principal(self) -> str: + """Get the principal (username) associated with the token. + + Returns: + Optional[str]: Principal username if available, None otherwise + """ + + @property + @abstractmethod + def secret(self) -> str: + """Get the secret (password) associated with the token. + + Returns: + Optional[str]: Secret if available, None otherwise + """ + @abstractmethod def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/client/auth/client_credentials.py b/src/firebolt/client/auth/client_credentials.py index 92181fb9188..39f34a19c22 100644 --- a/src/firebolt/client/auth/client_credentials.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -53,6 +53,24 @@ def copy(self) -> "ClientCredentials": self.client_id, self.client_secret, self._use_token_cache ) + @property + def principal(self) -> str: + """Get the principal (username) associated with the token. + + Returns: + str: Principal username + """ + return self.client_id + + @property + def secret(self) -> str: + """Get the secret (password) associated with the token. + + Returns: + str: Secret + """ + return self.client_secret + def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/client/auth/firebolt_core.py b/src/firebolt/client/auth/firebolt_core.py index 8cffe64aa4f..3f6c1183a9b 100644 --- a/src/firebolt/client/auth/firebolt_core.py +++ b/src/firebolt/client/auth/firebolt_core.py @@ -27,6 +27,28 @@ def __init__(self) -> None: self._token = "" self._expires = None + @property + def principal(self) -> str: + """Get the principal (username) associated with the token. + + For FireboltCore, this returns an empty string since no auth is needed. + + Returns: + str: Empty string (no principal needed) + """ + return "core" + + @property + def secret(self) -> str: + """Get the secret (password) associated with the token. + + For FireboltCore, this returns an empty string since no auth is needed. + + Returns: + str: Empty string (no secret needed) + """ + return "core" + def copy(self) -> "FireboltCore": """Make another auth object with same URL. diff --git a/src/firebolt/client/auth/service_account.py b/src/firebolt/client/auth/service_account.py index 6d670c00d0a..ffd557eec2b 100644 --- a/src/firebolt/client/auth/service_account.py +++ b/src/firebolt/client/auth/service_account.py @@ -43,6 +43,24 @@ def __init__( self.client_secret = client_secret super().__init__(use_token_cache) + @property + def principal(self) -> str: + """Get the principal (username) associated with the token. + + Returns: + str: Principal username + """ + return self.client_id + + @property + def secret(self) -> str: + """Get the secret (password) associated with the token. + + Returns: + str: Secret, which is the client secret itself + """ + return self.client_secret + def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/client/auth/token.py b/src/firebolt/client/auth/token.py index 4ad7bd13ce3..965dc2fb9b4 100644 --- a/src/firebolt/client/auth/token.py +++ b/src/firebolt/client/auth/token.py @@ -24,6 +24,24 @@ def __init__(self, token: str): super().__init__(use_token_cache=False) self._token = token + @property + def principal(self) -> str: + """Get the principal (username) associated with the token. + + Returns: + str: Principal username + """ + return "token_principal" + + @property + def secret(self) -> str: + """Get the secret (password) associated with the token. + + Returns: + str: Secret, which is the token itself + """ + return self._token or "token" + def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/client/auth/username_password.py b/src/firebolt/client/auth/username_password.py index 79f1548294a..646401f99ac 100644 --- a/src/firebolt/client/auth/username_password.py +++ b/src/firebolt/client/auth/username_password.py @@ -43,6 +43,24 @@ def __init__( self.password = password super().__init__(use_token_cache) + @property + def principal(self) -> str: + """Get the principal (username) associated with the token. + + Returns: + str: Principal username + """ + return self.username + + @property + def secret(self) -> str: + """Get the secret (password) associated with the token. + + Returns: + str: Secret password + """ + return self.password + def get_firebolt_version(self) -> FireboltAuthVersion: """Get Firebolt version from auth. diff --git a/src/firebolt/common/cursor/base_cursor.py b/src/firebolt/common/cursor/base_cursor.py index da89fc0f404..260cad8d1a1 100644 --- a/src/firebolt/common/cursor/base_cursor.py +++ b/src/firebolt/common/cursor/base_cursor.py @@ -8,6 +8,8 @@ from httpx import URL, Response +from firebolt.client.auth.base import Auth +from firebolt.client.client import AsyncClient, Client from firebolt.common._types import ParameterType, RawColType, SetParameter from firebolt.common.constants import ( DISALLOWED_PARAMETER_LIST, @@ -21,6 +23,11 @@ from firebolt.common.row_set.base import BaseRowSet from firebolt.common.row_set.types import AsyncResponse, Column, Statistics from firebolt.common.statement_formatter import StatementFormatter +from firebolt.utils.cache import ( + ConnectionInfo, + SecureCacheKey, + _firebolt_cache, +) from firebolt.utils.exception import ConfigurationError, FireboltError from firebolt.utils.util import fix_url_schema @@ -89,7 +96,10 @@ class BaseCursor: streaming_row_set_type: Type = BaseRowSet def __init__( - self, *args: Any, formatter: StatementFormatter, **kwargs: Any + self, + *args: Any, + formatter: StatementFormatter, + **kwargs: Any, ) -> None: self._arraysize = self.default_arraysize # These fields initialized here for type annotations purpose @@ -101,6 +111,7 @@ def __init__( self.engine_url = "" self._query_id = "" # not used self._query_token = "" + self._client: Optional[Union[Client, AsyncClient]] = None self._row_set: Optional[BaseRowSet] = None self._reset() @@ -314,3 +325,32 @@ def _get_output_format(is_streaming: bool) -> str: if is_streaming: return JSON_LINES_OUTPUT_FORMAT return JSON_OUTPUT_FORMAT + + def get_cache(self) -> Optional[ConnectionInfo]: + if not self._client or not self._client.auth: + return None + assert isinstance(self._client.auth, Auth) # Type check + cache_key = SecureCacheKey( + [ + self._client.auth.principal, + self._client.auth.secret, + self._client.account_name, + ], + self._client.auth.secret, + ) + cache = _firebolt_cache.get(cache_key) + return cache + + def set_cache(self, cache: ConnectionInfo) -> None: + if not self._client or not self._client.auth: + return + assert isinstance(self._client.auth, Auth) # Type check + cache_key = SecureCacheKey( + [ + self._client.auth.principal, + self._client.auth.secret, + self._client.account_name, + ], + self._client.auth.secret, + ) + _firebolt_cache.set(cache_key, cache) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 2e66c290ae7..0b374842d5e 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -24,7 +24,7 @@ from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.db.cursor import Cursor, CursorV1, CursorV2 from firebolt.db.util import _get_system_engine_url_and_params -from firebolt.utils.cache import _firebolt_cache +from firebolt.utils.cache import SecureCacheKey, _firebolt_cache from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -71,7 +71,10 @@ def connect( if disable_cache: _firebolt_cache.disable() else: - ua_parameters = get_cache_tracking([account_name, api_endpoint], connection_id) + cache_key = SecureCacheKey( + [auth.principal, auth.secret, account_name], auth.secret + ) + ua_parameters = get_cache_tracking(cache_key, connection_id) user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) auth_version = auth.get_firebolt_version() # Use CORE if auth is FireboltCore diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index f3ce9a25217..648fa18d9f4 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -50,12 +50,7 @@ from firebolt.common.row_set.synchronous.in_memory import InMemoryRowSet from firebolt.common.row_set.synchronous.streaming import StreamingRowSet from firebolt.common.statement_formatter import create_statement_formatter -from firebolt.utils.cache import ( - ConnectionInfo, - DatabaseInfo, - EngineInfo, - _firebolt_cache, -) +from firebolt.utils.cache import ConnectionInfo, DatabaseInfo, EngineInfo from firebolt.utils.exception import ( EngineNotRunningError, FireboltDatabaseError, @@ -97,7 +92,7 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - self._client = client + self._client: Client = client self.connection = connection self.engine_url = connection.engine_url self._row_set: Optional[BaseSyncRowSet] = None @@ -346,8 +341,7 @@ def _handle_query_execution( def use_database(self, database: str) -> None: """Switch the current database context with caching.""" - cache_key = [self._client.account_name, self.connection.api_endpoint] - cache = _firebolt_cache.get(cache_key) + cache = self.get_cache() cache = cache if cache else ConnectionInfo(id=self.connection.id) if cache.databases.get(database): # If database is cached, use it @@ -355,12 +349,11 @@ def use_database(self, database: str) -> None: else: self.execute(f'USE DATABASE "{database}"') cache.databases[database] = DatabaseInfo(database) - _firebolt_cache.set(cache_key, cache) + self.set_cache(cache) def use_engine(self, engine: str) -> None: """Switch the current engine context with caching.""" - cache_key = [self._client.account_name, self.connection.api_endpoint] - cache = _firebolt_cache.get(cache_key) + cache = self.get_cache() cache = cache if cache else ConnectionInfo(id=self.connection.id) if cache.engines.get(engine): # If engine is cached, use it @@ -369,7 +362,7 @@ def use_engine(self, engine: str) -> None: else: self.execute(f'USE ENGINE "{engine}"') cache.engines[engine] = EngineInfo(self.engine_url, self.parameters) # ?? - _firebolt_cache.set(cache_key, cache) + self.set_cache(cache) @check_not_closed def execute( diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index 111cd6a1ba6..ae1fc00b778 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -5,7 +5,12 @@ from firebolt.client import ClientV2 from firebolt.client.auth import Auth from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.cache import ConnectionInfo, EngineInfo, _firebolt_cache +from firebolt.utils.cache import ( + ConnectionInfo, + EngineInfo, + SecureCacheKey, + _firebolt_cache, +) from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, InterfaceError, @@ -20,7 +25,8 @@ def _get_system_engine_url_and_params( api_endpoint: str, connection_id: str, ) -> EngineInfo: - cache = _firebolt_cache.get([account_name, api_endpoint]) + cache_key = SecureCacheKey([auth.principal, auth.secret, account_name], auth.secret) + cache = _firebolt_cache.get(cache_key) if cache and (result := cache.system_engine): return result with ClientV2( @@ -42,6 +48,6 @@ def _get_system_engine_url_and_params( url, params = parse_url_and_params(response.json()["engineUrl"]) if not cache: cache = ConnectionInfo(id=connection_id) - _firebolt_cache.set([account_name, api_endpoint], cache) + _firebolt_cache.set(cache_key, cache) cache.system_engine = EngineInfo(url=url, params=params) return cache.system_engine diff --git a/src/firebolt/utils/cache.py b/src/firebolt/utils/cache.py index 46587e36937..3315fb8c119 100644 --- a/src/firebolt/utils/cache.py +++ b/src/firebolt/utils/cache.py @@ -1,6 +1,15 @@ import os from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Generic, Optional, Protocol, TypeVar +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Protocol, + TypeVar, +) T = TypeVar("T") @@ -98,4 +107,23 @@ def __contains__(self, key: str) -> bool: return key in self._cache +class SecureCacheKey(ReprCacheable): + """A secure cache key that can be used for caching sensitive information.""" + + def __init__(self, key_elements: List[Optional[str]], encryption_key: str): + self.key = "#".join(str(e) for e in key_elements) + self.encryption_key = encryption_key + + def __repr__(self) -> str: + return f"SecureCacheKey({self.key})" + + def __eq__(self, other: object) -> bool: + if isinstance(other, SecureCacheKey): + return self.key == other.key + return False + + def __hash__(self) -> int: + return hash(self.key) + + _firebolt_cache = UtilCache[ConnectionInfo](cache_name="connection_info") From b0af5174c7491d70a80f56debeaf891b247e809f Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 5 Aug 2025 13:43:50 +0100 Subject: [PATCH 12/21] fix docstring --- src/firebolt/client/auth/base.py | 8 ++++---- src/firebolt/client/auth/client_credentials.py | 6 +++--- src/firebolt/client/auth/service_account.py | 8 ++++---- src/firebolt/client/auth/token.py | 6 +++--- src/firebolt/client/auth/username_password.py | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/firebolt/client/auth/base.py b/src/firebolt/client/auth/base.py index 098a475ddad..56af8026eb7 100644 --- a/src/firebolt/client/auth/base.py +++ b/src/firebolt/client/auth/base.py @@ -68,19 +68,19 @@ def token(self) -> Optional[str]: @property @abstractmethod def principal(self) -> str: - """Get the principal (username) associated with the token. + """Get the principal (username or id) associated with the token. Returns: - Optional[str]: Principal username if available, None otherwise + str: Principal string """ @property @abstractmethod def secret(self) -> str: - """Get the secret (password) associated with the token. + """Get the secret (password or secret key) associated with the token. Returns: - Optional[str]: Secret if available, None otherwise + str: Secret string """ @abstractmethod diff --git a/src/firebolt/client/auth/client_credentials.py b/src/firebolt/client/auth/client_credentials.py index 39f34a19c22..d729a581fb3 100644 --- a/src/firebolt/client/auth/client_credentials.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -55,16 +55,16 @@ def copy(self) -> "ClientCredentials": @property def principal(self) -> str: - """Get the principal (username) associated with the token. + """Get the principal (client id) associated with this auth. Returns: - str: Principal username + str: Principal client id """ return self.client_id @property def secret(self) -> str: - """Get the secret (password) associated with the token. + """Get the secret (secret key) associated with this auth. Returns: str: Secret diff --git a/src/firebolt/client/auth/service_account.py b/src/firebolt/client/auth/service_account.py index ffd557eec2b..7e5b91c4a4b 100644 --- a/src/firebolt/client/auth/service_account.py +++ b/src/firebolt/client/auth/service_account.py @@ -45,19 +45,19 @@ def __init__( @property def principal(self) -> str: - """Get the principal (username) associated with the token. + """Get the principal (client id) associated with the auth. Returns: - str: Principal username + str: client id """ return self.client_id @property def secret(self) -> str: - """Get the secret (password) associated with the token. + """Get the secret (client secret) associated with the auth. Returns: - str: Secret, which is the client secret itself + str: client secret """ return self.client_secret diff --git a/src/firebolt/client/auth/token.py b/src/firebolt/client/auth/token.py index 965dc2fb9b4..223d4ec37ed 100644 --- a/src/firebolt/client/auth/token.py +++ b/src/firebolt/client/auth/token.py @@ -26,16 +26,16 @@ def __init__(self, token: str): @property def principal(self) -> str: - """Get the principal (username) associated with the token. + """Get the principal (placeholder) associated with the auth. Returns: - str: Principal username + str: Principal (placeholder) """ return "token_principal" @property def secret(self) -> str: - """Get the secret (password) associated with the token. + """Get the secret (token) associated with the auth. Returns: str: Secret, which is the token itself diff --git a/src/firebolt/client/auth/username_password.py b/src/firebolt/client/auth/username_password.py index 646401f99ac..29050641d16 100644 --- a/src/firebolt/client/auth/username_password.py +++ b/src/firebolt/client/auth/username_password.py @@ -45,7 +45,7 @@ def __init__( @property def principal(self) -> str: - """Get the principal (username) associated with the token. + """Get the principal (username) associated with the auth. Returns: str: Principal username @@ -54,7 +54,7 @@ def principal(self) -> str: @property def secret(self) -> str: - """Get the secret (password) associated with the token. + """Get the secret (password) associated with the auth. Returns: str: Secret password From 6ebe968108a3f49507c0bc281f2fdcf8b6e51b1d Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 5 Aug 2025 14:42:43 +0100 Subject: [PATCH 13/21] additional test --- src/firebolt/async_db/connection.py | 4 +- src/firebolt/async_db/util.py | 2 +- src/firebolt/client/auth/firebolt_core.py | 8 +- src/firebolt/db/connection.py | 4 +- src/firebolt/db/util.py | 2 +- src/firebolt/utils/usage_tracker.py | 4 +- src/firebolt/utils/util.py | 4 +- tests/unit/async_db/test_caching.py | 102 ++++++++++++++++++++++ tests/unit/db/test_caching.py | 102 ++++++++++++++++++++++ tests/unit/utils/test_utils.py | 6 +- 10 files changed, 222 insertions(+), 16 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index b2eacdfa051..cea2552fd97 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -35,7 +35,7 @@ validate_firebolt_core_parameters, ) from firebolt.utils.usage_tracker import ( - get_cache_tracking, + get_cache_tracking_params, get_user_agent_header, ) from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 @@ -245,7 +245,7 @@ async def connect( cache_key = SecureCacheKey( [auth.principal, auth.secret, account_name], auth.secret ) - ua_parameters = get_cache_tracking(cache_key, connection_id) + ua_parameters = get_cache_tracking_params(cache_key, connection_id) user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index 98aae3e017d..fe22256356c 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -48,6 +48,6 @@ async def _get_system_engine_url_and_params( url, params = parse_url_and_params(response.json()["engineUrl"]) if not cache: cache = ConnectionInfo(id=connection_id) - _firebolt_cache.set(cache_key, cache) cache.system_engine = EngineInfo(url=url, params=params) + _firebolt_cache.set(cache_key, cache) return cache.system_engine diff --git a/src/firebolt/client/auth/firebolt_core.py b/src/firebolt/client/auth/firebolt_core.py index 3f6c1183a9b..039082caef0 100644 --- a/src/firebolt/client/auth/firebolt_core.py +++ b/src/firebolt/client/auth/firebolt_core.py @@ -29,23 +29,23 @@ def __init__(self) -> None: @property def principal(self) -> str: - """Get the principal (username) associated with the token. + """Get the principal associated with the auth. For FireboltCore, this returns an empty string since no auth is needed. Returns: - str: Empty string (no principal needed) + str: Placeholder string for principal (no auth needed) """ return "core" @property def secret(self) -> str: - """Get the secret (password) associated with the token. + """Get the secret associated with the auth. For FireboltCore, this returns an empty string since no auth is needed. Returns: - str: Empty string (no secret needed) + str: Placeholder string for secret (no auth needed) """ return "core" diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 0b374842d5e..da359aa6ab2 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -36,7 +36,7 @@ validate_firebolt_core_parameters, ) from firebolt.utils.usage_tracker import ( - get_cache_tracking, + get_cache_tracking_params, get_user_agent_header, ) from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 @@ -74,7 +74,7 @@ def connect( cache_key = SecureCacheKey( [auth.principal, auth.secret, account_name], auth.secret ) - ua_parameters = get_cache_tracking(cache_key, connection_id) + ua_parameters = get_cache_tracking_params(cache_key, connection_id) user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) auth_version = auth.get_firebolt_version() # Use CORE if auth is FireboltCore diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index ae1fc00b778..a71549b6c8b 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -48,6 +48,6 @@ def _get_system_engine_url_and_params( url, params = parse_url_and_params(response.json()["engineUrl"]) if not cache: cache = ConnectionInfo(id=connection_id) - _firebolt_cache.set(cache_key, cache) cache.system_engine = EngineInfo(url=url, params=params) + _firebolt_cache.set(cache_key, cache) return cache.system_engine diff --git a/src/firebolt/utils/usage_tracker.py b/src/firebolt/utils/usage_tracker.py index 79c55352c35..83b30b6e9cf 100644 --- a/src/firebolt/utils/usage_tracker.py +++ b/src/firebolt/utils/usage_tracker.py @@ -227,7 +227,9 @@ def get_user_agent_header( return format_as_user_agent(drivers, clients, additional_properties or []) -def get_cache_tracking(cache_key: ReprCacheable, conn_id: str) -> List[Tuple[str, str]]: +def get_cache_tracking_params( + cache_key: ReprCacheable, conn_id: str +) -> List[Tuple[str, str]]: ua_parameters = [] ua_parameters.append(("connId", conn_id)) cache = _firebolt_cache.get(cache_key) diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index 9a867301f93..d296055f0d1 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -212,7 +212,7 @@ def __exit__( logger.debug(log_message) -def parse_url_and_params(url: str) -> Tuple[str, Dict]: +def parse_url_and_params(url: str) -> Tuple[str, Dict[str, str]]: """Extract URL and query parameters separately from a URL.""" url = fix_url_schema(url) parsed_url = urlparse(url) @@ -228,7 +228,7 @@ def parse_url_and_params(url: str) -> Tuple[str, Dict]: if len(values) > 1: raise ValueError(f"Multiple values found for key '{key}'") query_params_dict[key] = values[0] - return (result_url, query_params_dict) + return result_url, query_params_dict class _ExceptionGroup(Exception): diff --git a/tests/unit/async_db/test_caching.py b/tests/unit/async_db/test_caching.py index 2ba0ff39a05..81fe060bd9e 100644 --- a/tests/unit/async_db/test_caching.py +++ b/tests/unit/async_db/test_caching.py @@ -1,5 +1,6 @@ from typing import Callable +from httpx import URL from pytest import fixture, mark from pytest_httpx import HTTPXMock @@ -302,3 +303,104 @@ def use_engine_callback_counter(request, **kwargs): # Reset caches for the next test iteration _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +async def test_connect_db_different_accounts( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: URL, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + api_endpoint: str, + auth: Auth, + account_name: str, + cache_enabled: bool, +): + """Test caching when switching between different databases.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + get_system_engine_url_new_account = str(get_system_engine_url).replace( + account_name, account_name + "_second" + ) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + system_engine_callback_counter, url=get_system_engine_url_new_account + ) + + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # First connection + + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + assert system_engine_call_counter == 1, "System engine URL was not called" + assert use_engine_call_counter == 1, "Use engine URL was not called" + assert use_database_call_counter == 1, "Use database URL was not called" + + # Second connection against different account + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name + "_second", + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + # This should trigger additional calls to the system engine URL and engine/database + assert ( + system_engine_call_counter == 2 + ), "System engine URL was not called for second account" + assert ( + use_engine_call_counter == 2 + ), "Use engine URL was not called for second account" + assert ( + use_database_call_counter == 2 + ), "Use database URL was not called for second account" diff --git a/tests/unit/db/test_caching.py b/tests/unit/db/test_caching.py index 931fd66636f..964b98f80cd 100644 --- a/tests/unit/db/test_caching.py +++ b/tests/unit/db/test_caching.py @@ -1,5 +1,6 @@ from typing import Callable +from httpx import URL from pytest import fixture, mark from pytest_httpx import HTTPXMock @@ -302,3 +303,104 @@ def use_engine_callback_counter(request, **kwargs): # Reset caches for the next test iteration _firebolt_cache.enable() + + +@mark.parametrize("cache_enabled", [True, False]) +def test_connect_db_different_accounts( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: URL, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + api_endpoint: str, + auth: Auth, + account_name: str, + cache_enabled: bool, +): + """Test caching when switching between different databases.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + get_system_engine_url_new_account = str(get_system_engine_url).replace( + account_name, account_name + "_second" + ) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + system_engine_callback_counter, url=get_system_engine_url_new_account + ) + + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # First connection + + with connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + connection.cursor().execute("select*") + + assert system_engine_call_counter == 1, "System engine URL was not called" + assert use_engine_call_counter == 1, "Use engine URL was not called" + assert use_database_call_counter == 1, "Use database URL was not called" + + # Second connection against different account + with connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name + "_second", + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + connection.cursor().execute("select*") + + # This should trigger additional calls to the system engine URL and engine/database + assert ( + system_engine_call_counter == 2 + ), "System engine URL was not called for second account" + assert ( + use_engine_call_counter == 2 + ), "Use engine URL was not called for second account" + assert ( + use_database_call_counter == 2 + ), "Use database URL was not called for second account" diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 2fa525d4628..5271a7ae047 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -50,9 +50,9 @@ def test_get_internal_error_code(status_code, content, expected_error_code): ], ) def test_parse_url_and_params(url, expected_url, expected_params): - url, params = parse_url_and_params(url) - assert url == expected_url - assert params == expected_params + parsed_url, parsed_params = parse_url_and_params(url) + assert parsed_url == expected_url + assert parsed_params == expected_params @pytest.mark.parametrize( From bd865d42034e5b2672e8129517606e6e715378c0 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 5 Aug 2025 14:59:38 +0100 Subject: [PATCH 14/21] fix formatting --- tests/integration/dbapi/async/V2/test_queries_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/dbapi/async/V2/test_queries_async.py b/tests/integration/dbapi/async/V2/test_queries_async.py index aa6be974293..e2042e41834 100644 --- a/tests/integration/dbapi/async/V2/test_queries_async.py +++ b/tests/integration/dbapi/async/V2/test_queries_async.py @@ -673,7 +673,7 @@ async def test_database_switch( f"DROP DATABASE IF EXISTS {database_name}_switch" ) - + async def test_select_quoted_decimal( connection: Connection, long_decimal_value: str, long_value_decimal_sql: str ): From a50b865c72004109434055f2107c1efed8f93186 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 5 Aug 2025 16:41:21 +0100 Subject: [PATCH 15/21] fix tests --- tests/unit/async_db/test_caching.py | 17 +++--- tests/unit/async_db/test_connection.py | 65 +++++++++++++++++++- tests/unit/conftest.py | 15 ++++- tests/unit/db/test_caching.py | 17 +++--- tests/unit/db/test_connection.py | 84 ++++++++++++++++++++++++-- 5 files changed, 171 insertions(+), 27 deletions(-) diff --git a/tests/unit/async_db/test_caching.py b/tests/unit/async_db/test_caching.py index 81fe060bd9e..10fa04b6e99 100644 --- a/tests/unit/async_db/test_caching.py +++ b/tests/unit/async_db/test_caching.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Generator from httpx import URL from pytest import fixture, mark @@ -9,6 +9,13 @@ from firebolt.utils.cache import _firebolt_cache +@fixture(autouse=True) +def use_cache(enable_cache) -> Generator[None, None, None]: + _firebolt_cache.clear() + yield # This fixture is used to ensure cache is enabled for all tests by default + _firebolt_cache.clear() + + @fixture async def connection_test( api_endpoint: str, @@ -32,14 +39,6 @@ async def factory(db_name: str, engine_name: str, caching: bool) -> Callable: return factory -@fixture(autouse=True) -async def enable_cache(): - _firebolt_cache.enable() - _firebolt_cache.clear() - yield - _firebolt_cache.clear() - - @mark.parametrize("cache_enabled", [True, False]) async def test_connect_caching( db_name: str, diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 19df76dfe3c..42381cd280a 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Tuple +from typing import Callable, Generator, List, Optional, Tuple from unittest.mock import ANY as AnyValue from unittest.mock import MagicMock, patch @@ -143,6 +143,66 @@ async def test_connect_database_failed( httpx_mock.reset(False) +@mark.parametrize("cache_enabled", [True, False]) +async def test_connect_system_engine_caching( + db_name: str, + engine_name: str, + auth_url: str, + api_endpoint: str, + auth: Auth, + account_name: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + enable_cache: Generator, + cache_enabled: bool, +): + system_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + use_database_callback, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + + httpx_mock.add_callback( + use_engine_callback, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + for _ in range(3): + async with await connect( + database=db_name, + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + disable_cache=not cache_enabled, + ) as connection: + await connection.cursor().execute("select*") + + if cache_enabled: + assert system_engine_call_counter == 1, "System engine URL was not cached" + else: + assert system_engine_call_counter != 1, "System engine URL was cached" + + async def test_connect_engine_failed( db_name: str, account_name: str, @@ -347,7 +407,7 @@ async def test_connect_no_user_agent( ut.assert_called_with([], [], AnyValue) -async def test_connect_caching( +async def test_connect_caching_headers( engine_name: str, account_name: str, api_endpoint: str, @@ -357,6 +417,7 @@ async def test_connect_caching( query_callback: Callable, query_url: str, mock_connection_flow: Callable, + enable_cache: Generator, ) -> None: async def do_connect(): async with await connect( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1eb81c7dc78..bf2713a5b58 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,5 +1,5 @@ import functools -from typing import Callable +from typing import Callable, Generator import httpx from httpx import Request @@ -46,6 +46,19 @@ def clear_cache() -> None: _firebolt_cache.clear() +@fixture(autouse=True) +def disable_cache() -> None: + _firebolt_cache.disable() + + +@fixture +def enable_cache() -> Generator[None, None, None]: + """Fixture to enable cache for tests that require it.""" + _firebolt_cache.enable() + yield + _firebolt_cache.disable() + + @fixture def client_id() -> str: return "client_id" diff --git a/tests/unit/db/test_caching.py b/tests/unit/db/test_caching.py index 964b98f80cd..2b9253b254f 100644 --- a/tests/unit/db/test_caching.py +++ b/tests/unit/db/test_caching.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Generator from httpx import URL from pytest import fixture, mark @@ -9,6 +9,13 @@ from firebolt.utils.cache import _firebolt_cache +@fixture(autouse=True) +def use_cache(enable_cache) -> Generator[None, None, None]: + _firebolt_cache.clear() + yield # This fixture is used to ensure cache is enabled for all tests by default + _firebolt_cache.clear() + + @fixture def connection_test( api_endpoint: str, @@ -32,14 +39,6 @@ def factory(db_name: str, engine_name: str, caching: bool) -> Callable: return factory -@fixture(autouse=True) -def enable_cache(): - _firebolt_cache.enable() - _firebolt_cache.clear() - yield - _firebolt_cache.clear() - - @mark.parametrize("cache_enabled", [True, False]) def test_connect_caching( db_name: str, diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index b361c834162..1ac70b33dd6 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -1,8 +1,8 @@ import gc import warnings -from typing import Callable, List, Optional, Tuple +from typing import Callable, Generator, List, Optional, Tuple from unittest.mock import ANY as AnyValue -from unittest.mock import patch +from unittest.mock import MagicMock, patch from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises, warns @@ -191,7 +191,7 @@ def test_connect_engine_failed( @mark.parametrize("cache_enabled", [True, False]) -def test_connect_caching( +def test_connect_system_engine_caching( db_name: str, engine_name: str, auth_url: str, @@ -208,6 +208,7 @@ def test_connect_caching( use_database_callback: Callable, use_engine_callback: Callable, query_callback: Callable, + enable_cache: Generator, cache_enabled: bool, ): system_engine_call_counter = 0 @@ -248,9 +249,6 @@ def system_engine_callback_counter(request, **kwargs): else: assert system_engine_call_counter != 1, "System engine URL was cached" - # Reset caches for the next test iteration - _firebolt_cache.enable() - def test_connect_system_engine_404( db_name: str, @@ -427,6 +425,80 @@ def test_connect_no_user_agent( ut.assert_called_with([], [], AnyValue) +def test_connect_caching_headers( + engine_name: str, + account_name: str, + api_endpoint: str, + db_name: str, + auth: Auth, + httpx_mock: HTTPXMock, + query_callback: Callable, + query_url: str, + mock_connection_flow: Callable, + enable_cache: Generator, +) -> None: + def do_connect(): + with connect( + auth=auth, + database=db_name, + engine_name=engine_name, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: + connection.cursor().execute("select*") + + _firebolt_cache.clear() + mock_id = "12345" + mock_id2 = "67890" + mock_id3 = "54321" + with patch("firebolt.db.connection.get_user_agent_header") as ut: + ut.side_effect = [ + f"connId:{mock_id}", + f"connId:{mock_id2}; cachedConnId:{mock_id}-memory", + f"connId:{mock_id3}", + ] + with patch("firebolt.db.connection.uuid4") as uuid4: + uuid4.side_effect = [ + MagicMock(hex=mock_id), + MagicMock(hex=mock_id2), + MagicMock(hex=mock_id3), + ] + mock_connection_flow() + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": f"connId:{mock_id}"}, + ) + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={ + "User-Agent": f"connId:{mock_id2}; cachedConnId:{mock_id}-memory" + }, + ) + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": f"connId:{mock_id3}"}, + ) + + do_connect() + ut.assert_called_with(AnyValue, AnyValue, [("connId", mock_id)]) + + # Second call should use cached connection info + do_connect() + ut.assert_called_with( + AnyValue, + AnyValue, + [("connId", mock_id2), ("cachedConnId", f"{mock_id}-memory")], + ) + _firebolt_cache.clear() + + # Third call should have a new connection id + do_connect() + ut.assert_called_with(AnyValue, AnyValue, [("connId", mock_id3)]) + + @mark.parametrize( "server_status,expected_running,expected_success", [ From 67fad7fbc889c9712b39e2677573a34ecf3925ea Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 5 Aug 2025 17:37:02 +0100 Subject: [PATCH 16/21] improve cache disabling --- src/firebolt/async_db/connection.py | 14 ++++----- src/firebolt/async_db/cursor.py | 44 ++++++++++++++++++----------- src/firebolt/async_db/util.py | 24 ++++++++++------ src/firebolt/db/connection.py | 14 ++++----- src/firebolt/db/cursor.py | 44 ++++++++++++++++++----------- src/firebolt/db/util.py | 24 ++++++++++------ 6 files changed, 100 insertions(+), 64 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index cea2552fd97..f8116ccacd5 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -23,7 +23,7 @@ _parse_async_query_info_results, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.cache import SecureCacheKey, _firebolt_cache +from firebolt.utils.cache import SecureCacheKey from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -239,9 +239,7 @@ async def connect( user_clients = additional_parameters.get("user_clients", []) connection_id = uuid4().hex ua_parameters = [] - if disable_cache: - _firebolt_cache.disable() - else: + if not disable_cache: cache_key = SecureCacheKey( [auth.principal, auth.secret, account_name], auth.secret ) @@ -270,6 +268,7 @@ async def connect( engine_name=engine_name, api_endpoint=api_endpoint, connection_id=connection_id, + disable_cache=disable_cache, ) elif auth_version == FireboltAuthVersion.V1: return await connect_v1( @@ -294,6 +293,7 @@ async def connect_v2( database: Optional[str] = None, engine_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, + disable_cache: bool = False, ) -> Connection: """Connect to Firebolt. @@ -320,7 +320,7 @@ async def connect_v2( assert account_name is not None system_engine_info = await _get_system_engine_url_and_params( - auth, account_name, api_endpoint, connection_id + auth, account_name, api_endpoint, connection_id, disable_cache ) client = AsyncClientV2( @@ -344,9 +344,9 @@ async def connect_v2( cursor = system_engine_connection.cursor() if database: - await cursor.use_database(database) + await cursor.use_database(database, cache=not disable_cache) if engine_name: - await cursor.use_engine(engine_name) + await cursor.use_engine(engine_name, cache=not disable_cache) # Ensure cursors created from this connection are using the same starting # database and engine return Connection( diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 42bbe26ccd0..11610e56fb0 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -333,30 +333,40 @@ async def _handle_query_execution( await self._parse_response_headers(resp.headers) await self._append_row_set_from_response(resp) - async def use_database(self, database: str) -> None: + async def use_database(self, database: str, cache: bool = True) -> None: """Switch the current database context with caching.""" - cache = self.get_cache() - cache = cache if cache else ConnectionInfo(id=self.connection.id) - if cache.databases.get(database): - # If database is cached, use it - self.database = database + if cache: + cache_obj = self.get_cache() + cache_obj = ( + cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) + ) + if cache_obj.databases.get(database): + # If database is cached, use it + self.database = database + else: + await self.execute(f'USE DATABASE "{database}"') + cache_obj.databases[database] = DatabaseInfo(database) + self.set_cache(cache_obj) else: await self.execute(f'USE DATABASE "{database}"') - cache.databases[database] = DatabaseInfo(database) - self.set_cache(cache) - async def use_engine(self, engine: str) -> None: + async def use_engine(self, engine: str, cache: bool = True) -> None: """Switch the current engine context with caching.""" - cache = self.get_cache() - cache = cache if cache else ConnectionInfo(id=self.connection.id) - if cache.engines.get(engine): - # If engine is cached, use it - self.engine_url = cache.engines[engine].url - self._update_set_parameters(cache.engines[engine].params) + if cache: + cache_obj = self.get_cache() + cache_obj = ( + cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) + ) + if cache_obj.engines.get(engine): + # If engine is cached, use it + self.engine_url = cache_obj.engines[engine].url + self._update_set_parameters(cache_obj.engines[engine].params) + else: + await self.execute(f'USE ENGINE "{engine}"') + cache_obj.engines[engine] = EngineInfo(self.engine_url, self.parameters) + self.set_cache(cache_obj) else: await self.execute(f'USE ENGINE "{engine}"') - cache.engines[engine] = EngineInfo(self.engine_url, self.parameters) # ?? - self.set_cache(cache) @check_not_closed async def execute( diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index fe22256356c..06368188444 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -24,11 +24,15 @@ async def _get_system_engine_url_and_params( account_name: str, api_endpoint: str, connection_id: str, + disable_cache: bool = False, ) -> EngineInfo: cache_key = SecureCacheKey([auth.principal, auth.secret, account_name], auth.secret) - cache = _firebolt_cache.get(cache_key) - if cache and (result := cache.system_engine): - return result + + if not disable_cache: + cache = _firebolt_cache.get(cache_key) + if cache and (result := cache.system_engine): + return result + async with AsyncClientV2( auth=auth, base_url=api_endpoint, @@ -46,8 +50,12 @@ async def _get_system_engine_url_and_params( f"{response.status_code} {response.content.decode()}" ) url, params = parse_url_and_params(response.json()["engineUrl"]) - if not cache: - cache = ConnectionInfo(id=connection_id) - cache.system_engine = EngineInfo(url=url, params=params) - _firebolt_cache.set(cache_key, cache) - return cache.system_engine + + if not disable_cache: + if not cache: + cache = ConnectionInfo(id=connection_id) + cache.system_engine = EngineInfo(url=url, params=params) + _firebolt_cache.set(cache_key, cache) + return cache.system_engine + + return EngineInfo(url=url, params=params) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index da359aa6ab2..cc421ef147c 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -24,7 +24,7 @@ from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.db.cursor import Cursor, CursorV1, CursorV2 from firebolt.db.util import _get_system_engine_url_and_params -from firebolt.utils.cache import SecureCacheKey, _firebolt_cache +from firebolt.utils.cache import SecureCacheKey from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -68,9 +68,7 @@ def connect( user_clients = additional_parameters.get("user_clients", []) connection_id = uuid4().hex ua_parameters = [] - if disable_cache: - _firebolt_cache.disable() - else: + if not disable_cache: cache_key = SecureCacheKey( [auth.principal, auth.secret, account_name], auth.secret ) @@ -100,6 +98,7 @@ def connect( engine_name=engine_name, api_endpoint=api_endpoint, connection_id=connection_id, + disable_cache=disable_cache, ) elif auth_version == FireboltAuthVersion.V1: return connect_v1( @@ -124,6 +123,7 @@ def connect_v2( database: Optional[str] = None, engine_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, + disable_cache: bool = False, ) -> Connection: """Connect to Firebolt. @@ -150,7 +150,7 @@ def connect_v2( assert account_name is not None system_engine_info = _get_system_engine_url_and_params( - auth, account_name, api_endpoint, connection_id + auth, account_name, api_endpoint, connection_id, disable_cache ) client = ClientV2( @@ -173,9 +173,9 @@ def connect_v2( cursor = system_engine_connection.cursor() if database: - cursor.use_database(database) + cursor.use_database(database, cache=not disable_cache) if engine_name: - cursor.use_engine(engine_name) + cursor.use_engine(engine_name, cache=not disable_cache) # Ensure cursors created from this connection are using the same starting # database and engine return Connection( diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 648fa18d9f4..0d878d66dd5 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -339,30 +339,40 @@ def _handle_query_execution( self._parse_response_headers(resp.headers) self._append_row_set_from_response(resp) - def use_database(self, database: str) -> None: + def use_database(self, database: str, cache: bool = True) -> None: """Switch the current database context with caching.""" - cache = self.get_cache() - cache = cache if cache else ConnectionInfo(id=self.connection.id) - if cache.databases.get(database): - # If database is cached, use it - self.database = database + if cache: + cache_obj = self.get_cache() + cache_obj = ( + cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) + ) + if cache_obj.databases.get(database): + # If database is cached, use it + self.database = database + else: + self.execute(f'USE DATABASE "{database}"') + cache_obj.databases[database] = DatabaseInfo(database) + self.set_cache(cache_obj) else: self.execute(f'USE DATABASE "{database}"') - cache.databases[database] = DatabaseInfo(database) - self.set_cache(cache) - def use_engine(self, engine: str) -> None: + def use_engine(self, engine: str, cache: bool = True) -> None: """Switch the current engine context with caching.""" - cache = self.get_cache() - cache = cache if cache else ConnectionInfo(id=self.connection.id) - if cache.engines.get(engine): - # If engine is cached, use it - self.engine_url = cache.engines[engine].url - self._update_set_parameters(cache.engines[engine].params) + if cache: + cache_obj = self.get_cache() + cache_obj = ( + cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) + ) + if cache_obj.engines.get(engine): + # If engine is cached, use it + self.engine_url = cache_obj.engines[engine].url + self._update_set_parameters(cache_obj.engines[engine].params) + else: + self.execute(f'USE ENGINE "{engine}"') + cache_obj.engines[engine] = EngineInfo(self.engine_url, self.parameters) + self.set_cache(cache_obj) else: self.execute(f'USE ENGINE "{engine}"') - cache.engines[engine] = EngineInfo(self.engine_url, self.parameters) # ?? - self.set_cache(cache) @check_not_closed def execute( diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index a71549b6c8b..c6c6d478131 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -24,11 +24,15 @@ def _get_system_engine_url_and_params( account_name: str, api_endpoint: str, connection_id: str, + disable_cache: bool = False, ) -> EngineInfo: cache_key = SecureCacheKey([auth.principal, auth.secret, account_name], auth.secret) - cache = _firebolt_cache.get(cache_key) - if cache and (result := cache.system_engine): - return result + + if not disable_cache: + cache = _firebolt_cache.get(cache_key) + if cache and (result := cache.system_engine): + return result + with ClientV2( auth=auth, base_url=api_endpoint, @@ -46,8 +50,12 @@ def _get_system_engine_url_and_params( f"{response.status_code} {response.content.decode()}" ) url, params = parse_url_and_params(response.json()["engineUrl"]) - if not cache: - cache = ConnectionInfo(id=connection_id) - cache.system_engine = EngineInfo(url=url, params=params) - _firebolt_cache.set(cache_key, cache) - return cache.system_engine + + if not disable_cache: + if not cache: + cache = ConnectionInfo(id=connection_id) + cache.system_engine = EngineInfo(url=url, params=params) + _firebolt_cache.set(cache_key, cache) + return cache.system_engine + + return EngineInfo(url=url, params=params) From b645aa0f6acd6d34a1b9b44f1c195c4401a5c4a6 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Wed, 6 Aug 2025 15:23:17 +0100 Subject: [PATCH 17/21] expiry --- src/firebolt/utils/cache.py | 26 ++- tests/unit/async_db/test_caching.py | 96 +++++++- tests/unit/db/test_caching.py | 96 +++++++- tests/unit/utils/test_cache.py | 351 ++++++++++++++++++++++++++++ 4 files changed, 566 insertions(+), 3 deletions(-) create mode 100644 tests/unit/utils/test_cache.py diff --git a/src/firebolt/utils/cache.py b/src/firebolt/utils/cache.py index 3315fb8c119..e8f0d8bd9d1 100644 --- a/src/firebolt/utils/cache.py +++ b/src/firebolt/utils/cache.py @@ -1,4 +1,5 @@ import os +import time from dataclasses import dataclass, field from typing import ( Any, @@ -13,6 +14,9 @@ T = TypeVar("T") +# Cache expiry configuration +CACHE_EXPIRY_SECONDS = 3600 # 1 hour + class ReprCacheable(Protocol): def __repr__(self) -> str: @@ -79,11 +83,31 @@ def get(self, key: ReprCacheable) -> Optional[T]: if self.disabled: return None s_key = self.create_key(key) - return self._cache.get(s_key) + value = self._cache.get(s_key) + + if value is not None and self._is_expired(value): + # Cache miss due to expiry - delete the expired item + del self._cache[s_key] + return None + + return value + + def _is_expired(self, value: T) -> bool: + """Check if a cached value has expired.""" + # Only check expiry for ConnectionInfo objects that have expiry_time + if hasattr(value, "expiry_time") and value.expiry_time is not None: + current_time = int(time.time()) + return current_time >= value.expiry_time + return False @noop_if_disabled def set(self, key: ReprCacheable, value: T) -> None: if not self.disabled: + # Set expiry_time for ConnectionInfo objects + if hasattr(value, "expiry_time"): + current_time = int(time.time()) + value.expiry_time = current_time + CACHE_EXPIRY_SECONDS + s_key = self.create_key(key) self._cache[s_key] = value diff --git a/tests/unit/async_db/test_caching.py b/tests/unit/async_db/test_caching.py index 10fa04b6e99..cd035a9e456 100644 --- a/tests/unit/async_db/test_caching.py +++ b/tests/unit/async_db/test_caching.py @@ -1,4 +1,6 @@ +import time from typing import Callable, Generator +from unittest.mock import patch from httpx import URL from pytest import fixture, mark @@ -6,7 +8,7 @@ from firebolt.async_db import connect from firebolt.client.auth import Auth -from firebolt.utils.cache import _firebolt_cache +from firebolt.utils.cache import CACHE_EXPIRY_SECONDS, _firebolt_cache @fixture(autouse=True) @@ -403,3 +405,95 @@ def use_engine_callback_counter(request, **kwargs): assert ( use_database_call_counter == 2 ), "Use database URL was not called for second account" + + +async def test_calls_when_cache_expired( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + connection_test: Callable, +): + """Test that expired cache entries trigger new backend requests.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # First connection - should populate cache + await connection_test(db_name, engine_name, True) # cache_enabled=True + + # Verify initial calls were made + assert system_engine_call_counter == 1, "System engine URL was not called initially" + assert use_database_call_counter == 1, "Use database URL was not called initially" + assert use_engine_call_counter == 1, "Use engine URL was not called initially" + + # Second connection immediately - should use cache + await connection_test(db_name, engine_name, True) + + # Verify no additional calls were made (cache hit) + assert ( + system_engine_call_counter == 1 + ), "System engine URL was called when cache should hit" + assert ( + use_database_call_counter == 1 + ), "Use database URL was called when cache should hit" + assert ( + use_engine_call_counter == 1 + ), "Use engine URL was called when cache should hit" + + # Mock time to simulate cache expiry (1 hour + 1 second past current time) + current_time = int(time.time()) + expired_time = current_time + CACHE_EXPIRY_SECONDS + 1 + + with patch("firebolt.utils.cache.time.time", return_value=expired_time): + # Third connection after cache expiry - should trigger new backend calls + await connection_test(db_name, engine_name, True) + + # Verify additional calls were made due to cache expiry + assert ( + system_engine_call_counter == 2 + ), "System engine URL was not called after cache expiry" + assert ( + use_database_call_counter == 2 + ), "Use database URL was not called after cache expiry" + assert ( + use_engine_call_counter == 2 + ), "Use engine URL was not called after cache expiry" diff --git a/tests/unit/db/test_caching.py b/tests/unit/db/test_caching.py index 2b9253b254f..d4b57da91ce 100644 --- a/tests/unit/db/test_caching.py +++ b/tests/unit/db/test_caching.py @@ -1,4 +1,6 @@ +import time from typing import Callable, Generator +from unittest.mock import patch from httpx import URL from pytest import fixture, mark @@ -6,7 +8,7 @@ from firebolt.client.auth import Auth from firebolt.db import connect -from firebolt.utils.cache import _firebolt_cache +from firebolt.utils.cache import CACHE_EXPIRY_SECONDS, _firebolt_cache @fixture(autouse=True) @@ -403,3 +405,95 @@ def use_engine_callback_counter(request, **kwargs): assert ( use_database_call_counter == 2 ), "Use database URL was not called for second account" + + +def test_calls_when_cache_expired( + db_name: str, + engine_name: str, + auth_url: str, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + system_engine_no_db_query_url: str, + query_url: str, + use_database_callback: Callable, + use_engine_callback: Callable, + query_callback: Callable, + connection_test: Callable, +): + """Test that expired cache entries trigger new backend requests.""" + system_engine_call_counter = 0 + use_database_call_counter = 0 + use_engine_call_counter = 0 + + def system_engine_callback_counter(request, **kwargs): + nonlocal system_engine_call_counter + system_engine_call_counter += 1 + return get_system_engine_callback(request, **kwargs) + + def use_database_callback_counter(request, **kwargs): + nonlocal use_database_call_counter + use_database_call_counter += 1 + return use_database_callback(request, **kwargs) + + def use_engine_callback_counter(request, **kwargs): + nonlocal use_engine_call_counter + use_engine_call_counter += 1 + return use_engine_callback(request, **kwargs) + + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(system_engine_callback_counter, url=get_system_engine_url) + httpx_mock.add_callback( + use_database_callback_counter, + url=system_engine_no_db_query_url, + match_content=f'USE DATABASE "{db_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback( + use_engine_callback_counter, + url=system_engine_query_url, + match_content=f'USE ENGINE "{engine_name}"'.encode("utf-8"), + ) + httpx_mock.add_callback(query_callback, url=query_url) + + # First connection - should populate cache + connection_test(db_name, engine_name, True) # cache_enabled=True + + # Verify initial calls were made + assert system_engine_call_counter == 1, "System engine URL was not called initially" + assert use_database_call_counter == 1, "Use database URL was not called initially" + assert use_engine_call_counter == 1, "Use engine URL was not called initially" + + # Second connection immediately - should use cache + connection_test(db_name, engine_name, True) + + # Verify no additional calls were made (cache hit) + assert ( + system_engine_call_counter == 1 + ), "System engine URL was called when cache should hit" + assert ( + use_database_call_counter == 1 + ), "Use database URL was called when cache should hit" + assert ( + use_engine_call_counter == 1 + ), "Use engine URL was called when cache should hit" + + # Mock time to simulate cache expiry (1 hour + 1 second past current time) + current_time = int(time.time()) + expired_time = current_time + CACHE_EXPIRY_SECONDS + 1 + + with patch("firebolt.utils.cache.time.time", return_value=expired_time): + # Third connection after cache expiry - should trigger new backend calls + connection_test(db_name, engine_name, True) + + # Verify additional calls were made due to cache expiry + assert ( + system_engine_call_counter == 2 + ), "System engine URL was not called after cache expiry" + assert ( + use_database_call_counter == 2 + ), "Use database URL was not called after cache expiry" + assert ( + use_engine_call_counter == 2 + ), "Use engine URL was not called after cache expiry" diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py new file mode 100644 index 00000000000..55a9589e404 --- /dev/null +++ b/tests/unit/utils/test_cache.py @@ -0,0 +1,351 @@ +import time +from typing import Generator +from unittest.mock import patch + +from pytest import fixture, mark + +from firebolt.utils.cache import ( + CACHE_EXPIRY_SECONDS, + ConnectionInfo, + SecureCacheKey, + UtilCache, +) + + +@fixture +def cache() -> Generator[UtilCache[ConnectionInfo], None, None]: + """Create a fresh cache instance for testing.""" + cache = UtilCache[ConnectionInfo](cache_name="test_cache") + cache.enable() # Ensure cache is enabled for tests + yield cache + cache.clear() + + +@fixture +def string_cache() -> Generator[UtilCache[str], None, None]: + """Create a fresh string cache instance for testing.""" + cache = UtilCache[str](cache_name="string_test_cache") + cache.enable() # Ensure cache is enabled for tests + yield cache + cache.clear() + + +@fixture +def disabled_cache() -> UtilCache[ConnectionInfo]: + """Create a disabled cache instance for testing.""" + cache = UtilCache[ConnectionInfo](cache_name="test_disabled_cache") + cache.disable() + return cache + + +@fixture +def sample_connection_info() -> ConnectionInfo: + """Create a sample ConnectionInfo for testing.""" + return ConnectionInfo(id="test_connection") + + +@fixture +def sample_connection_info_with_expiry() -> ConnectionInfo: + """Create a sample ConnectionInfo with explicit None expiry_time.""" + return ConnectionInfo(id="test_connection_with_expiry", expiry_time=None) + + +@fixture +def sample_cache_key() -> SecureCacheKey: + """Create a sample cache key for testing.""" + return SecureCacheKey(["test", "key"], "secret") + + +@fixture +def additional_cache_keys(): + """Create additional cache keys for multi-key tests.""" + return { + "key1": SecureCacheKey(["key1"], "secret"), + "key2": SecureCacheKey(["key2"], "secret"), + "key3": SecureCacheKey(["user", "other"], "secret"), + } + + +@fixture +def fixed_time(): + """Provide a fixed timestamp for consistent testing.""" + return 1000000 + + +@fixture +def test_string(): + """Provide a test string for non-ConnectionInfo cache tests.""" + return "test_value" + + +def test_cache_set_and_get(cache, sample_cache_key, sample_connection_info): + """Test basic cache set and get operations.""" + # Test cache miss initially + assert cache.get(sample_cache_key) is None + + # Set value and verify it's cached + cache.set(sample_cache_key, sample_connection_info) + cached_value = cache.get(sample_cache_key) + + assert cached_value is not None + assert cached_value.id == "test_connection" + assert cached_value.expiry_time is not None + + # Verify expiry_time is set to current time + CACHE_EXPIRY_SECONDS + current_time = int(time.time()) + expected_expiry = current_time + CACHE_EXPIRY_SECONDS + # Allow for small time difference due to test execution time + assert abs(cached_value.expiry_time - expected_expiry) <= 2 + + +def test_cache_expiry(cache, sample_cache_key, sample_connection_info, fixed_time): + """Test that cache entries expire after the specified time.""" + with patch("time.time", return_value=fixed_time): + # Set a value in the cache + cache.set(sample_cache_key, sample_connection_info) + cached_value = cache.get(sample_cache_key) + + # Verify it's cached and expiry_time is set + assert cached_value is not None + assert cached_value.expiry_time == fixed_time + CACHE_EXPIRY_SECONDS + + # Simulate time passing but not enough to expire (59 minutes) + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS - 60): + cached_value = cache.get(sample_cache_key) + assert cached_value is not None # Should still be cached + + # Simulate time passing to exactly the expiry time + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS): + cached_value = cache.get(sample_cache_key) + assert cached_value is None # Should be expired and removed + + # Verify the item is actually deleted from cache + assert cache.create_key(sample_cache_key) not in cache._cache + + +def test_cache_expiry_past_expiry_time( + cache, sample_cache_key, sample_connection_info, fixed_time +): + """Test that cache entries are removed when accessed after expiry time.""" + with patch("time.time", return_value=fixed_time): + cache.set(sample_cache_key, sample_connection_info) + + # Simulate time passing beyond expiry (2 hours) + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS + 3600): + cached_value = cache.get(sample_cache_key) + assert cached_value is None # Should be expired + + # Verify the item is removed from the internal cache + cache_key_str = cache.create_key(sample_cache_key) + assert cache_key_str not in cache._cache + + +def test_cache_disabled_behavior( + disabled_cache, sample_cache_key, sample_connection_info +): + """Test that disabled cache doesn't store or retrieve values.""" + # Try to set a value + disabled_cache.set(sample_cache_key, sample_connection_info) + + # Should return None even though we set a value + assert disabled_cache.get(sample_cache_key) is None + + # Internal cache should be empty + assert len(disabled_cache._cache) == 0 + + +def test_cache_clear( + cache, sample_cache_key, sample_connection_info, additional_cache_keys +): + """Test that cache.clear() removes all entries.""" + # Add some entries + cache.set(additional_cache_keys["key1"], sample_connection_info) + cache.set(additional_cache_keys["key2"], sample_connection_info) + + # Verify entries exist + assert cache.get(additional_cache_keys["key1"]) is not None + assert cache.get(additional_cache_keys["key2"]) is not None + + # Clear cache + cache.clear() + + # Verify all entries are gone + assert cache.get(additional_cache_keys["key1"]) is None + assert cache.get(additional_cache_keys["key2"]) is None + assert len(cache._cache) == 0 + + +def test_cache_delete(cache, sample_cache_key, sample_connection_info): + """Test manual deletion of cache entries.""" + cache.set(sample_cache_key, sample_connection_info) + assert cache.get(sample_cache_key) is not None + + cache.delete(sample_cache_key) + assert cache.get(sample_cache_key) is None + + +def test_cache_contains_operator(cache, sample_cache_key, sample_connection_info): + """Test the 'in' operator for cache.""" + cache_key_str = cache.create_key(sample_cache_key) + + # Initially not in cache + assert cache_key_str not in cache + + # Add to cache + cache.set(sample_cache_key, sample_connection_info) + assert cache_key_str in cache + + # Test with disabled cache + cache.disable() + assert cache_key_str not in cache # Should return False when disabled + + +def test_non_connection_info_objects(string_cache, sample_cache_key, test_string): + """Test that non-ConnectionInfo objects don't get expiry_time set.""" + string_cache.set(sample_cache_key, test_string) + + # Should retrieve the string without expiry logic + cached_value = string_cache.get(sample_cache_key) + assert cached_value == test_string + + +def test_expiry_time_none_handling( + cache, sample_cache_key, sample_connection_info_with_expiry, fixed_time +): + """Test handling of ConnectionInfo with expiry_time set to None.""" + with patch("time.time", return_value=fixed_time): + cache.set(sample_cache_key, sample_connection_info_with_expiry) + + # Should set expiry_time during set operation + cached_value = cache.get(sample_cache_key) + assert cached_value is not None + assert cached_value.expiry_time is not None + + +def test_secure_cache_key_creation(): + """Test SecureCacheKey creation and repr.""" + key = SecureCacheKey(["user", "db", "engine"], "secret_key") + assert key.key == "user#db#engine" + assert key.encryption_key == "secret_key" + assert repr(key) == "SecureCacheKey(user#db#engine)" + + +def test_secure_cache_key_equality(): + """Test SecureCacheKey equality comparison.""" + key1 = SecureCacheKey(["user", "db"], "secret1") + key2 = SecureCacheKey(["user", "db"], "secret2") + key3 = SecureCacheKey(["user", "other"], "secret1") + + assert key1 == key2 # Same key content, different encryption key + assert key1 != key3 # Different key content + assert key1 != "not_a_key" # Different type + + +def test_secure_cache_key_hash(): + """Test SecureCacheKey hash functionality.""" + key1 = SecureCacheKey(["user", "db"], "secret1") + key2 = SecureCacheKey(["user", "db"], "secret2") + + # Same key content should have same hash + assert hash(key1) == hash(key2) + + # Should be usable in sets and dicts + key_set = {key1, key2} + assert len(key_set) == 1 # Should be treated as same key + + +def test_secure_cache_key_with_none_elements(): + """Test SecureCacheKey handling of None elements.""" + key = SecureCacheKey(["user", None, "engine"], "secret") + assert key.key == "user#None#engine" + + +@mark.parametrize("cache_expiry_time_offset", [-60, 0, 3600]) +def test_cache_expiry_parametrized( + cache, + sample_cache_key, + sample_connection_info, + fixed_time, + cache_expiry_time_offset, +): + """Test cache expiry behavior with different time offsets.""" + with patch("time.time", return_value=fixed_time): + cache.set(sample_cache_key, sample_connection_info) + + # Test at different time offsets relative to expiry time + check_time = fixed_time + CACHE_EXPIRY_SECONDS + cache_expiry_time_offset + + with patch("time.time", return_value=check_time): + cached_value = cache.get(sample_cache_key) + + if cache_expiry_time_offset < 0: + # Before expiry - should be cached + assert cached_value is not None + else: + # At or after expiry - should be None + assert cached_value is None + + +def test_cache_expiry_multiple_entries( + cache, additional_cache_keys, sample_connection_info, fixed_time +): + """Test that expiry works correctly with multiple cache entries.""" + # Set multiple entries at different times + with patch("time.time", return_value=fixed_time): + cache.set(additional_cache_keys["key1"], sample_connection_info) + + with patch("time.time", return_value=fixed_time + 1800): # 30 minutes later + cache.set(additional_cache_keys["key2"], sample_connection_info) + + # Check expiry of first entry while second is still valid + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS): + assert cache.get(additional_cache_keys["key1"]) is None # Expired + assert cache.get(additional_cache_keys["key2"]) is not None # Still valid + + # Check both are expired after sufficient time + with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS + 1800): + assert cache.get(additional_cache_keys["key1"]) is None + assert cache.get(additional_cache_keys["key2"]) is None + + +def test_cache_set_updates_expiry_time( + cache, sample_cache_key, sample_connection_info, fixed_time +): + """Test that setting a value again updates the expiry time.""" + # Set initial value + with patch("time.time", return_value=fixed_time): + cache.set(sample_cache_key, sample_connection_info) + initial_cached = cache.get(sample_cache_key) + initial_expiry = initial_cached.expiry_time + + # Set same key again later + with patch("time.time", return_value=fixed_time + 1800): # 30 minutes later + cache.set(sample_cache_key, sample_connection_info) + updated_cached = cache.get(sample_cache_key) + updated_expiry = updated_cached.expiry_time + + # Expiry time should be updated + assert updated_expiry > initial_expiry + assert updated_expiry == fixed_time + 1800 + CACHE_EXPIRY_SECONDS + + +@mark.parametrize("disable_cache_during_operation", [True, False]) +def test_cache_disable_enable_behavior( + cache, sample_cache_key, sample_connection_info, disable_cache_during_operation +): + """Test cache behavior when disabled and re-enabled.""" + # Set initial value + cache.set(sample_cache_key, sample_connection_info) + assert cache.get(sample_cache_key) is not None + + if disable_cache_during_operation: + # Disable cache - should return None even if value exists + cache.disable() + assert cache.get(sample_cache_key) is None + + # Re-enable cache - should work again + cache.enable() + assert cache.get(sample_cache_key) is not None + else: + # Keep cache enabled - should continue working + assert cache.get(sample_cache_key) is not None From 4985a4724fbd3f008c8672dff65a654f9d3c8da1 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Wed, 6 Aug 2025 16:10:04 +0100 Subject: [PATCH 18/21] fix tests --- tests/unit/utils/test_cache.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py index 55a9589e404..6fce309fbe3 100644 --- a/tests/unit/utils/test_cache.py +++ b/tests/unit/utils/test_cache.py @@ -286,16 +286,18 @@ def test_cache_expiry_parametrized( assert cached_value is None -def test_cache_expiry_multiple_entries( - cache, additional_cache_keys, sample_connection_info, fixed_time -): +def test_cache_expiry_multiple_entries(cache, additional_cache_keys, fixed_time): """Test that expiry works correctly with multiple cache entries.""" + # Create separate ConnectionInfo objects to avoid shared state + connection_info_1 = ConnectionInfo(id="test_connection_1") + connection_info_2 = ConnectionInfo(id="test_connection_2") + # Set multiple entries at different times with patch("time.time", return_value=fixed_time): - cache.set(additional_cache_keys["key1"], sample_connection_info) + cache.set(additional_cache_keys["key1"], connection_info_1) with patch("time.time", return_value=fixed_time + 1800): # 30 minutes later - cache.set(additional_cache_keys["key2"], sample_connection_info) + cache.set(additional_cache_keys["key2"], connection_info_2) # Check expiry of first entry while second is still valid with patch("time.time", return_value=fixed_time + CACHE_EXPIRY_SECONDS): From a20e394bdfaede25a7d276033bb37b6eb3c6bbf2 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Thu, 7 Aug 2025 09:33:25 +0100 Subject: [PATCH 19/21] refactor common logic --- src/firebolt/async_db/connection.py | 52 ++++++++++++++++++++-- src/firebolt/async_db/util.py | 61 -------------------------- src/firebolt/common/base_connection.py | 57 +++++++++++++++++++++++- src/firebolt/db/connection.py | 52 ++++++++++++++++++++-- src/firebolt/db/util.py | 61 -------------------------- 5 files changed, 152 insertions(+), 131 deletions(-) delete mode 100644 src/firebolt/async_db/util.py delete mode 100644 src/firebolt/db/util.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index f8116ccacd5..c586be1dc14 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -5,10 +5,9 @@ from typing import Any, Dict, List, Optional, Type, Union from uuid import uuid4 -from httpx import Timeout +from httpx import Timeout, codes from firebolt.async_db.cursor import Cursor, CursorV1, CursorV2 -from firebolt.async_db.util import _get_system_engine_url_and_params from firebolt.client import DEFAULT_API_URL from firebolt.client.auth import Auth from firebolt.client.auth.base import FireboltAuthVersion @@ -21,24 +20,33 @@ AsyncQueryInfo, BaseConnection, _parse_async_query_info_results, + get_cached_system_engine_info, + set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.cache import SecureCacheKey +from firebolt.utils.cache import EngineInfo, SecureCacheKey from firebolt.utils.exception import ( + AccountNotFoundOrNoAccessError, ConfigurationError, ConnectionClosedError, FireboltError, + InterfaceError, ) from firebolt.utils.firebolt_core import ( get_core_certificate_context, parse_firebolt_core_url, validate_firebolt_core_parameters, ) +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME from firebolt.utils.usage_tracker import ( get_cache_tracking_params, get_user_agent_header, ) -from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 +from firebolt.utils.util import ( + fix_url_schema, + parse_url_and_params, + validate_engine_name_and_url_v1, +) class Connection(BaseConnection): @@ -462,3 +470,39 @@ def connect_core( cursor_type=CursorV2, api_endpoint=verified_url, ) + + +async def _get_system_engine_url_and_params( + auth: Auth, + account_name: str, + api_endpoint: str, + connection_id: str, + disable_cache: bool = False, +) -> EngineInfo: + cache_key, cached_result = get_cached_system_engine_info( + auth, account_name, disable_cache + ) + if cached_result: + return cached_result + + async with AsyncClientV2( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) + response = await client.get(url=url) + if response.status_code == codes.NOT_FOUND: + raise AccountNotFoundOrNoAccessError(account_name) + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve system engine endpoint {url}: " + f"{response.status_code} {response.content.decode()}" + ) + url, params = parse_url_and_params(response.json()["engineUrl"]) + + return set_cached_system_engine_info( + cache_key, connection_id, url, params, disable_cache + ) diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py deleted file mode 100644 index 06368188444..00000000000 --- a/src/firebolt/async_db/util.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -from httpx import Timeout, codes - -from firebolt.client.auth import Auth -from firebolt.client.client import AsyncClientV2 -from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.cache import ( - ConnectionInfo, - EngineInfo, - SecureCacheKey, - _firebolt_cache, -) -from firebolt.utils.exception import ( - AccountNotFoundOrNoAccessError, - InterfaceError, -) -from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import parse_url_and_params - - -async def _get_system_engine_url_and_params( - auth: Auth, - account_name: str, - api_endpoint: str, - connection_id: str, - disable_cache: bool = False, -) -> EngineInfo: - cache_key = SecureCacheKey([auth.principal, auth.secret, account_name], auth.secret) - - if not disable_cache: - cache = _firebolt_cache.get(cache_key) - if cache and (result := cache.system_engine): - return result - - async with AsyncClientV2( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) - response = await client.get(url=url) - if response.status_code == codes.NOT_FOUND: - raise AccountNotFoundOrNoAccessError(account_name) - if response.status_code != codes.OK: - raise InterfaceError( - f"Unable to retrieve system engine endpoint {url}: " - f"{response.status_code} {response.content.decode()}" - ) - url, params = parse_url_and_params(response.json()["engineUrl"]) - - if not disable_cache: - if not cache: - cache = ConnectionInfo(id=connection_id) - cache.system_engine = EngineInfo(url=url, params=params) - _firebolt_cache.set(cache_key, cache) - return cache.system_engine - - return EngineInfo(url=url, params=params) diff --git a/src/firebolt/common/base_connection.py b/src/firebolt/common/base_connection.py index 24880c60bd5..1e2cfe47738 100644 --- a/src/firebolt/common/base_connection.py +++ b/src/firebolt/common/base_connection.py @@ -1,7 +1,14 @@ from collections import namedtuple -from typing import Any, List, Type +from typing import Any, List, Optional, Tuple, Type +from firebolt.client.auth.base import Auth from firebolt.common._types import ColType +from firebolt.utils.cache import ( + ConnectionInfo, + EngineInfo, + SecureCacheKey, + _firebolt_cache, +) from firebolt.utils.exception import ConnectionClosedError, FireboltError ASYNC_QUERY_STATUS_RUNNING = "RUNNING" @@ -75,3 +82,51 @@ def commit(self) -> None: if self.closed: raise ConnectionClosedError("Unable to commit: Connection closed.") + + +def get_cached_system_engine_info( + auth: Auth, + account_name: str, + disable_cache: bool = False, +) -> Tuple[SecureCacheKey, Optional[EngineInfo]]: + """ + Common cache retrieval logic for system engine info. + + Returns: + tuple: (cache_key, cached_engine_info_or_none) + """ + cache_key = SecureCacheKey([auth.principal, auth.secret, account_name], auth.secret) + + if disable_cache: + return cache_key, None + + cache = _firebolt_cache.get(cache_key) + cached_result = cache.system_engine if cache else None + + return cache_key, cached_result + + +def set_cached_system_engine_info( + cache_key: SecureCacheKey, + connection_id: str, + url: str, + params: dict, + disable_cache: bool = False, +) -> EngineInfo: + """ + Common cache setting logic for system engine info. + + Returns: + EngineInfo: The engine info that was cached (or created) + """ + + engine_info = EngineInfo(url=url, params=params) + + if not disable_cache: + cache = _firebolt_cache.get(cache_key) + if not cache: + cache = ConnectionInfo(id=connection_id) + cache.system_engine = engine_info + _firebolt_cache.set(cache_key, cache) + + return engine_info diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index cc421ef147c..7d72db698f3 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -7,7 +7,7 @@ from uuid import uuid4 from warnings import warn -from httpx import Timeout +from httpx import Timeout, codes from firebolt.client import DEFAULT_API_URL, Client, ClientV1, ClientV2 from firebolt.client.auth import Auth @@ -20,26 +20,34 @@ AsyncQueryInfo, BaseConnection, _parse_async_query_info_results, + get_cached_system_engine_info, + set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.db.cursor import Cursor, CursorV1, CursorV2 -from firebolt.db.util import _get_system_engine_url_and_params -from firebolt.utils.cache import SecureCacheKey +from firebolt.utils.cache import EngineInfo, SecureCacheKey from firebolt.utils.exception import ( + AccountNotFoundOrNoAccessError, ConfigurationError, ConnectionClosedError, FireboltError, + InterfaceError, ) from firebolt.utils.firebolt_core import ( get_core_certificate_context, parse_firebolt_core_url, validate_firebolt_core_parameters, ) +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME from firebolt.utils.usage_tracker import ( get_cache_tracking_params, get_user_agent_header, ) -from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1 +from firebolt.utils.util import ( + fix_url_schema, + parse_url_and_params, + validate_engine_name_and_url_v1, +) logger = logging.getLogger(__name__) @@ -467,3 +475,39 @@ def connect_core( cursor_type=CursorV2, api_endpoint=verified_url, ) + + +def _get_system_engine_url_and_params( + auth: Auth, + account_name: str, + api_endpoint: str, + connection_id: str, + disable_cache: bool = False, +) -> EngineInfo: + cache_key, cached_result = get_cached_system_engine_info( + auth, account_name, disable_cache + ) + if cached_result: + return cached_result + + with ClientV2( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) + response = client.get(url=url) + if response.status_code == codes.NOT_FOUND: + raise AccountNotFoundOrNoAccessError(account_name) + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve system engine endpoint {url}: " + f"{response.status_code} {response.content.decode()}" + ) + url, params = parse_url_and_params(response.json()["engineUrl"]) + + return set_cached_system_engine_info( + cache_key, connection_id, url, params, disable_cache + ) diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py deleted file mode 100644 index c6c6d478131..00000000000 --- a/src/firebolt/db/util.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -from httpx import Timeout, codes - -from firebolt.client import ClientV2 -from firebolt.client.auth import Auth -from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.cache import ( - ConnectionInfo, - EngineInfo, - SecureCacheKey, - _firebolt_cache, -) -from firebolt.utils.exception import ( - AccountNotFoundOrNoAccessError, - InterfaceError, -) -from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.util import parse_url_and_params - - -def _get_system_engine_url_and_params( - auth: Auth, - account_name: str, - api_endpoint: str, - connection_id: str, - disable_cache: bool = False, -) -> EngineInfo: - cache_key = SecureCacheKey([auth.principal, auth.secret, account_name], auth.secret) - - if not disable_cache: - cache = _firebolt_cache.get(cache_key) - if cache and (result := cache.system_engine): - return result - - with ClientV2( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) - response = client.get(url=url) - if response.status_code == codes.NOT_FOUND: - raise AccountNotFoundOrNoAccessError(account_name) - if response.status_code != codes.OK: - raise InterfaceError( - f"Unable to retrieve system engine endpoint {url}: " - f"{response.status_code} {response.content.decode()}" - ) - url, params = parse_url_and_params(response.json()["engineUrl"]) - - if not disable_cache: - if not cache: - cache = ConnectionInfo(id=connection_id) - cache.system_engine = EngineInfo(url=url, params=params) - _firebolt_cache.set(cache_key, cache) - return cache.system_engine - - return EngineInfo(url=url, params=params) From 865978ff375e4e516aef2e839b4dad0086e216db Mon Sep 17 00:00:00 2001 From: ptiurin Date: Thu, 7 Aug 2025 11:05:43 +0100 Subject: [PATCH 20/21] refactor more --- src/firebolt/async_db/connection.py | 19 ++++---------- src/firebolt/common/base_connection.py | 31 ++++++++++++++++++++++- src/firebolt/db/connection.py | 19 ++++---------- tests/unit/V1/async_db/test_connection.py | 4 +-- tests/unit/async_db/test_connection.py | 6 ++--- tests/unit/db/test_connection.py | 6 ++--- 6 files changed, 48 insertions(+), 37 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index c586be1dc14..b63270ef7c1 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -21,10 +21,11 @@ BaseConnection, _parse_async_query_info_results, get_cached_system_engine_info, + get_user_agent_for_connection, set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.cache import EngineInfo, SecureCacheKey +from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, ConfigurationError, @@ -38,10 +39,6 @@ validate_firebolt_core_parameters, ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.usage_tracker import ( - get_cache_tracking_params, - get_user_agent_header, -) from firebolt.utils.util import ( fix_url_schema, parse_url_and_params, @@ -243,16 +240,10 @@ async def connect( api_endpoint = fix_url_schema(api_endpoint) # Type checks assert auth is not None - user_drivers = additional_parameters.get("user_drivers", []) - user_clients = additional_parameters.get("user_clients", []) connection_id = uuid4().hex - ua_parameters = [] - if not disable_cache: - cache_key = SecureCacheKey( - [auth.principal, auth.secret, account_name], auth.secret - ) - ua_parameters = get_cache_tracking_params(cache_key, connection_id) - user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) + user_agent_header = get_user_agent_for_connection( + auth, connection_id, account_name, additional_parameters, disable_cache + ) # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials # Use V1 if auth is ServiceAccount or UsernamePassword diff --git a/src/firebolt/common/base_connection.py b/src/firebolt/common/base_connection.py index 1e2cfe47738..a39459be238 100644 --- a/src/firebolt/common/base_connection.py +++ b/src/firebolt/common/base_connection.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Any, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type from firebolt.client.auth.base import Auth from firebolt.common._types import ColType @@ -10,6 +10,10 @@ _firebolt_cache, ) from firebolt.utils.exception import ConnectionClosedError, FireboltError +from firebolt.utils.usage_tracker import ( + get_cache_tracking_params, + get_user_agent_header, +) ASYNC_QUERY_STATUS_RUNNING = "RUNNING" ASYNC_QUERY_STATUS_SUCCESSFUL = "ENDED_SUCCESSFULLY" @@ -130,3 +134,28 @@ def set_cached_system_engine_info( _firebolt_cache.set(cache_key, cache) return engine_info + + +def get_user_agent_for_connection( + auth: Auth, + connection_id: str, + account_name: Optional[str] = None, + additional_parameters: Dict[str, Any] = {}, + disable_cache: bool = False, +) -> str: + """ + Get the user agent string for the Firebolt connection. + + Returns: + str: The user agent string. + """ + user_drivers = additional_parameters.get("user_drivers", []) + user_clients = additional_parameters.get("user_clients", []) + ua_parameters = [] + if not disable_cache: + cache_key = SecureCacheKey( + [auth.principal, auth.secret, account_name], auth.secret + ) + ua_parameters = get_cache_tracking_params(cache_key, connection_id) + user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) + return user_agent_header diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 7d72db698f3..72021ce07ff 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -21,11 +21,12 @@ BaseConnection, _parse_async_query_info_results, get_cached_system_engine_info, + get_user_agent_for_connection, set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.db.cursor import Cursor, CursorV1, CursorV2 -from firebolt.utils.cache import EngineInfo, SecureCacheKey +from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, ConfigurationError, @@ -39,10 +40,6 @@ validate_firebolt_core_parameters, ) from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME -from firebolt.utils.usage_tracker import ( - get_cache_tracking_params, - get_user_agent_header, -) from firebolt.utils.util import ( fix_url_schema, parse_url_and_params, @@ -72,16 +69,10 @@ def connect( api_endpoint = fix_url_schema(api_endpoint) # Type checks assert auth is not None - user_drivers = additional_parameters.get("user_drivers", []) - user_clients = additional_parameters.get("user_clients", []) connection_id = uuid4().hex - ua_parameters = [] - if not disable_cache: - cache_key = SecureCacheKey( - [auth.principal, auth.secret, account_name], auth.secret - ) - ua_parameters = get_cache_tracking_params(cache_key, connection_id) - user_agent_header = get_user_agent_header(user_drivers, user_clients, ua_parameters) + user_agent_header = get_user_agent_for_connection( + auth, connection_id, account_name, additional_parameters, disable_cache + ) auth_version = auth.get_firebolt_version() # Use CORE if auth is FireboltCore # Use V2 if auth is ClientCredentials diff --git a/tests/unit/V1/async_db/test_connection.py b/tests/unit/V1/async_db/test_connection.py index d74bf59ee46..a3deac21994 100644 --- a/tests/unit/V1/async_db/test_connection.py +++ b/tests/unit/V1/async_db/test_connection.py @@ -402,7 +402,7 @@ async def test_connect_with_user_agent( query_url: str, access_token: str, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" httpx_mock.add_callback( query_callback, @@ -435,7 +435,7 @@ async def test_connect_no_user_agent( query_url: str, access_token: str, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" httpx_mock.add_callback( query_callback, url=query_url, match_headers={"User-Agent": "Python/3.0"} diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 42381cd280a..e3cc6915fb2 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -353,7 +353,7 @@ async def test_connect_with_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" mock_connection_flow() httpx_mock.add_callback( @@ -389,7 +389,7 @@ async def test_connect_no_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" mock_connection_flow() httpx_mock.add_callback( @@ -433,7 +433,7 @@ async def do_connect(): mock_id = "12345" mock_id2 = "67890" mock_id3 = "54321" - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.side_effect = [ f"connId:{mock_id}", f"connId:{mock_id2}; cachedConnId:{mock_id}-memory", diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 1ac70b33dd6..060a6d764fc 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -371,7 +371,7 @@ def test_connect_with_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" mock_connection_flow() httpx_mock.add_callback( @@ -407,7 +407,7 @@ def test_connect_no_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" mock_connection_flow() httpx_mock.add_callback( @@ -451,7 +451,7 @@ def do_connect(): mock_id = "12345" mock_id2 = "67890" mock_id3 = "54321" - with patch("firebolt.db.connection.get_user_agent_header") as ut: + with patch("firebolt.common.base_connection.get_user_agent_header") as ut: ut.side_effect = [ f"connId:{mock_id}", f"connId:{mock_id2}; cachedConnId:{mock_id}-memory", From 21d5cce93efe8c0e872c64b46e415f2ba8d5888f Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 12 Aug 2025 14:13:23 +0100 Subject: [PATCH 21/21] rename cache access --- src/firebolt/async_db/cursor.py | 16 ++++++++-------- src/firebolt/common/cursor/base_cursor.py | 10 +++++----- src/firebolt/db/cursor.py | 16 ++++++++-------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 11610e56fb0..dafc9a239fe 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -336,24 +336,24 @@ async def _handle_query_execution( async def use_database(self, database: str, cache: bool = True) -> None: """Switch the current database context with caching.""" if cache: - cache_obj = self.get_cache() - cache_obj = ( - cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) + cache_record = self.get_cache_record() + cache_record = ( + cache_record if cache_record else ConnectionInfo(id=self.connection.id) ) - if cache_obj.databases.get(database): + if cache_record.databases.get(database): # If database is cached, use it self.database = database else: await self.execute(f'USE DATABASE "{database}"') - cache_obj.databases[database] = DatabaseInfo(database) - self.set_cache(cache_obj) + cache_record.databases[database] = DatabaseInfo(database) + self.set_cache_record(cache_record) else: await self.execute(f'USE DATABASE "{database}"') async def use_engine(self, engine: str, cache: bool = True) -> None: """Switch the current engine context with caching.""" if cache: - cache_obj = self.get_cache() + cache_obj = self.get_cache_record() cache_obj = ( cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) ) @@ -364,7 +364,7 @@ async def use_engine(self, engine: str, cache: bool = True) -> None: else: await self.execute(f'USE ENGINE "{engine}"') cache_obj.engines[engine] = EngineInfo(self.engine_url, self.parameters) - self.set_cache(cache_obj) + self.set_cache_record(cache_obj) else: await self.execute(f'USE ENGINE "{engine}"') diff --git a/src/firebolt/common/cursor/base_cursor.py b/src/firebolt/common/cursor/base_cursor.py index 260cad8d1a1..381e43a90f6 100644 --- a/src/firebolt/common/cursor/base_cursor.py +++ b/src/firebolt/common/cursor/base_cursor.py @@ -326,7 +326,7 @@ def _get_output_format(is_streaming: bool) -> str: return JSON_LINES_OUTPUT_FORMAT return JSON_OUTPUT_FORMAT - def get_cache(self) -> Optional[ConnectionInfo]: + def get_cache_record(self) -> Optional[ConnectionInfo]: if not self._client or not self._client.auth: return None assert isinstance(self._client.auth, Auth) # Type check @@ -338,10 +338,10 @@ def get_cache(self) -> Optional[ConnectionInfo]: ], self._client.auth.secret, ) - cache = _firebolt_cache.get(cache_key) - return cache + record = _firebolt_cache.get(cache_key) + return record - def set_cache(self, cache: ConnectionInfo) -> None: + def set_cache_record(self, record: ConnectionInfo) -> None: if not self._client or not self._client.auth: return assert isinstance(self._client.auth, Auth) # Type check @@ -353,4 +353,4 @@ def set_cache(self, cache: ConnectionInfo) -> None: ], self._client.auth.secret, ) - _firebolt_cache.set(cache_key, cache) + _firebolt_cache.set(cache_key, record) diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 0d878d66dd5..b8e1be97840 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -342,24 +342,24 @@ def _handle_query_execution( def use_database(self, database: str, cache: bool = True) -> None: """Switch the current database context with caching.""" if cache: - cache_obj = self.get_cache() - cache_obj = ( - cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) + cache_record = self.get_cache_record() + cache_record = ( + cache_record if cache_record else ConnectionInfo(id=self.connection.id) ) - if cache_obj.databases.get(database): + if cache_record.databases.get(database): # If database is cached, use it self.database = database else: self.execute(f'USE DATABASE "{database}"') - cache_obj.databases[database] = DatabaseInfo(database) - self.set_cache(cache_obj) + cache_record.databases[database] = DatabaseInfo(database) + self.set_cache_record(cache_record) else: self.execute(f'USE DATABASE "{database}"') def use_engine(self, engine: str, cache: bool = True) -> None: """Switch the current engine context with caching.""" if cache: - cache_obj = self.get_cache() + cache_obj = self.get_cache_record() cache_obj = ( cache_obj if cache_obj else ConnectionInfo(id=self.connection.id) ) @@ -370,7 +370,7 @@ def use_engine(self, engine: str, cache: bool = True) -> None: else: self.execute(f'USE ENGINE "{engine}"') cache_obj.engines[engine] = EngineInfo(self.engine_url, self.parameters) - self.set_cache(cache_obj) + self.set_cache_record(cache_obj) else: self.execute(f'USE ENGINE "{engine}"')