Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 74 additions & 25 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from ssl import SSLContext
from types import TracebackType
from typing import Any, Dict, List, Optional, Type, Union
from uuid import uuid4

from httpx import Timeout
from httpx import Timeout, codes

from firebolt.async_db.cursor import Cursor, CursorV1, CursorV2
from firebolt.async_db.util import _get_system_engine_url_and_params
from firebolt.client import DEFAULT_API_URL
from firebolt.client.auth import Auth
from firebolt.client.auth.base import FireboltAuthVersion
Expand All @@ -20,21 +20,30 @@
AsyncQueryInfo,
BaseConnection,
_parse_async_query_info_results,
get_cached_system_engine_info,
get_user_agent_for_connection,
set_cached_system_engine_info,
)
from firebolt.common.cache import _firebolt_system_engine_cache
from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS
from firebolt.utils.cache import EngineInfo
from firebolt.utils.exception import (
AccountNotFoundOrNoAccessError,
ConfigurationError,
ConnectionClosedError,
FireboltError,
InterfaceError,
)
from firebolt.utils.firebolt_core import (
get_core_certificate_context,
parse_firebolt_core_url,
validate_firebolt_core_parameters,
)
from firebolt.utils.usage_tracker import get_user_agent_header
from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1
from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME
from firebolt.utils.util import (
fix_url_schema,
parse_url_and_params,
validate_engine_name_and_url_v1,
)


class Connection(BaseConnection):
Expand Down Expand Up @@ -71,6 +80,7 @@ class Connection(BaseConnection):
"_is_closed",
"client_class",
"cursor_type",
"id",
)

def __init__(
Expand All @@ -81,12 +91,14 @@ def __init__(
cursor_type: Type[Cursor],
api_endpoint: str,
init_parameters: Optional[Dict[str, Any]] = None,
id: str = uuid4().hex,
):
super().__init__(cursor_type)
self.api_endpoint = api_endpoint
self.engine_url = engine_url
self._cursors: List[Cursor] = []
self._client = client
self.id = id
self.init_parameters = init_parameters or {}
if database:
self.init_parameters["database"] = database
Expand Down Expand Up @@ -225,13 +237,13 @@ async def connect(
if not auth:
raise ConfigurationError("auth is required to connect.")

api_endpoint = fix_url_schema(api_endpoint)
# Type checks
assert auth is not None
user_drivers = additional_parameters.get("user_drivers", [])
user_clients = additional_parameters.get("user_clients", [])
user_agent_header = get_user_agent_header(user_drivers, user_clients)
if disable_cache:
_firebolt_system_engine_cache.disable()
connection_id = uuid4().hex
user_agent_header = get_user_agent_for_connection(
auth, connection_id, account_name, additional_parameters, disable_cache
)
# Use CORE if auth is FireboltCore
# Use V2 if auth is ClientCredentials
# Use V1 if auth is ServiceAccount or UsernamePassword
Expand All @@ -254,6 +266,8 @@ async def connect(
database=database,
engine_name=engine_name,
api_endpoint=api_endpoint,
connection_id=connection_id,
disable_cache=disable_cache,
)
elif auth_version == FireboltAuthVersion.V1:
return await connect_v1(
Expand All @@ -264,6 +278,7 @@ async def connect(
engine_name=engine_name,
engine_url=engine_url,
api_endpoint=api_endpoint,
connection_id=connection_id,
)
else:
raise ConfigurationError(f"Unsupported auth type: {type(auth)}")
Expand All @@ -272,10 +287,12 @@ async def connect(
async def connect_v2(
auth: Auth,
user_agent_header: str,
connection_id: str,
account_name: Optional[str] = None,
database: Optional[str] = None,
engine_name: Optional[str] = None,
api_endpoint: str = DEFAULT_API_URL,
disable_cache: bool = False,
) -> Connection:
"""Connect to Firebolt.

Expand All @@ -301,10 +318,8 @@ async def connect_v2(
assert auth is not None
assert account_name is not None

api_endpoint = fix_url_schema(api_endpoint)

system_engine_url, system_engine_params = await _get_system_engine_url_and_params(
auth, account_name, api_endpoint
system_engine_info = await _get_system_engine_url_and_params(
auth, account_name, api_endpoint, connection_id, disable_cache
)

client = AsyncClientV2(
Expand All @@ -316,19 +331,21 @@ async def connect_v2(
)

async with Connection(
system_engine_url,
system_engine_info.url,
None,
client,
CursorV2,
api_endpoint,
system_engine_params,
system_engine_info.params,
connection_id,
) as system_engine_connection:

cursor = system_engine_connection.cursor()

if database:
await cursor.execute(f'USE DATABASE "{database}"')
await cursor.use_database(database, cache=not disable_cache)
if engine_name:
await cursor.execute(f'USE ENGINE "{engine_name}"')
await cursor.use_engine(engine_name, cache=not disable_cache)
# Ensure cursors created from this connection are using the same starting
# database and engine
return Connection(
Expand All @@ -338,12 +355,14 @@ async def connect_v2(
CursorV2,
api_endpoint,
cursor.parameters,
connection_id,
)


async def connect_v1(
auth: Auth,
user_agent_header: str,
connection_id: str,
database: Optional[str] = None,
account_name: Optional[str] = None,
engine_name: Optional[str] = None,
Expand All @@ -358,8 +377,6 @@ async def connect_v1(

validate_engine_name_and_url_v1(engine_name, engine_url)

api_endpoint = fix_url_schema(api_endpoint)

no_engine_client = AsyncClientV1(
auth=auth,
base_url=api_endpoint,
Expand Down Expand Up @@ -397,11 +414,7 @@ async def connect_v1(
headers={"User-Agent": user_agent_header},
)
return Connection(
engine_url,
database,
client,
CursorV1,
api_endpoint,
engine_url, database, client, CursorV1, api_endpoint, id=connection_id
)


Expand Down Expand Up @@ -448,3 +461,39 @@ def connect_core(
cursor_type=CursorV2,
api_endpoint=verified_url,
)


async def _get_system_engine_url_and_params(
auth: Auth,
account_name: str,
api_endpoint: str,
connection_id: str,
disable_cache: bool = False,
) -> EngineInfo:
cache_key, cached_result = get_cached_system_engine_info(
auth, account_name, disable_cache
)
if cached_result:
return cached_result

async with AsyncClientV2(
auth=auth,
base_url=api_endpoint,
account_name=account_name,
api_endpoint=api_endpoint,
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS),
) as client:
url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name)
response = await client.get(url=url)
if response.status_code == codes.NOT_FOUND:
raise AccountNotFoundOrNoAccessError(account_name)
if response.status_code != codes.OK:
raise InterfaceError(
f"Unable to retrieve system engine endpoint {url}: "
f"{response.status_code} {response.content.decode()}"
)
url, params = parse_url_and_params(response.json()["engineUrl"])

return set_cached_system_engine_info(
cache_key, connection_id, url, params, disable_cache
)
38 changes: 37 additions & 1 deletion src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from firebolt.async_db.connection import Connection

from firebolt.utils.async_util import anext, async_islice
from firebolt.utils.cache import ConnectionInfo, DatabaseInfo, EngineInfo
from firebolt.utils.util import Timer, raise_error_from_response

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,8 +86,8 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._client = client
self.connection = connection
self._client: AsyncClient = client
self.engine_url = connection.engine_url
self._row_set: Optional[BaseAsyncRowSet] = None
if connection.init_parameters:
Expand Down Expand Up @@ -332,6 +333,41 @@ async def _handle_query_execution(
await self._parse_response_headers(resp.headers)
await self._append_row_set_from_response(resp)

async def use_database(self, database: str, cache: bool = True) -> None:
"""Switch the current database context with caching."""
if cache:
cache_record = self.get_cache_record()
cache_record = (
cache_record if cache_record else ConnectionInfo(id=self.connection.id)
)
if cache_record.databases.get(database):
# If database is cached, use it
self.database = database
else:
await self.execute(f'USE DATABASE "{database}"')
cache_record.databases[database] = DatabaseInfo(database)
self.set_cache_record(cache_record)
else:
await self.execute(f'USE DATABASE "{database}"')

async def use_engine(self, engine: str, cache: bool = True) -> None:
"""Switch the current engine context with caching."""
if cache:
cache_obj = self.get_cache_record()
cache_obj = (
cache_obj if cache_obj else ConnectionInfo(id=self.connection.id)
)
if cache_obj.engines.get(engine):
# If engine is cached, use it
self.engine_url = cache_obj.engines[engine].url
self._update_set_parameters(cache_obj.engines[engine].params)
else:
await self.execute(f'USE ENGINE "{engine}"')
cache_obj.engines[engine] = EngineInfo(self.engine_url, self.parameters)
self.set_cache_record(cache_obj)
else:
await self.execute(f'USE ENGINE "{engine}"')

@check_not_closed
async def execute(
self,
Expand Down
46 changes: 0 additions & 46 deletions src/firebolt/async_db/util.py

This file was deleted.

18 changes: 18 additions & 0 deletions src/firebolt/client/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def token(self) -> Optional[str]:
"""
return self._token

@property
@abstractmethod
def principal(self) -> str:
"""Get the principal (username or id) associated with the token.

Returns:
str: Principal string
"""

@property
@abstractmethod
def secret(self) -> str:
"""Get the secret (password or secret key) associated with the token.

Returns:
str: Secret string
"""

@abstractmethod
def get_firebolt_version(self) -> FireboltAuthVersion:
"""Get Firebolt version from auth.
Expand Down
18 changes: 18 additions & 0 deletions src/firebolt/client/auth/client_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ def copy(self) -> "ClientCredentials":
self.client_id, self.client_secret, self._use_token_cache
)

@property
def principal(self) -> str:
"""Get the principal (client id) associated with this auth.

Returns:
str: Principal client id
"""
return self.client_id

@property
def secret(self) -> str:
"""Get the secret (secret key) associated with this auth.

Returns:
str: Secret
"""
return self.client_secret

def get_firebolt_version(self) -> FireboltAuthVersion:
"""Get Firebolt version from auth.

Expand Down
Loading
Loading