Skip to content

Commit 38b13cf

Browse files
committed
working example
1 parent 6bf80a0 commit 38b13cf

File tree

8 files changed

+153
-33
lines changed

8 files changed

+153
-33
lines changed

src/firebolt/async_db/connection.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from typing import Any, Dict, List, Optional, Type, Union
66
from uuid import uuid4
77

8-
from httpx import Timeout, codes
8+
import trio
9+
from httpx import Request, Response, Timeout, codes
910

1011
from firebolt.async_db.cursor import Cursor, CursorV1, CursorV2
1112
from firebolt.client import DEFAULT_API_URL
@@ -78,6 +79,9 @@ class Connection(BaseConnection):
7879
"engine_url",
7980
"api_endpoint",
8081
"_is_closed",
82+
"_transaction_id",
83+
"_transaction_sequence_id",
84+
"_transaction_lock",
8185
"client_class",
8286
"cursor_type",
8387
"id",
@@ -99,6 +103,7 @@ def __init__(
99103
self._cursors: List[Cursor] = []
100104
self._client = client
101105
self.id = id
106+
self._transaction_lock: trio.Lock = trio.Lock()
102107
self.init_parameters = init_parameters or {}
103108
if database:
104109
self.init_parameters["database"] = database
@@ -192,6 +197,25 @@ async def cancel_async_query(self, token: str) -> None:
192197
cursor = self.cursor()
193198
await cursor.execute(ASYNC_QUERY_CANCEL, [async_query_info[0].query_id])
194199

200+
async def _execute_query_impl(self, request: Request) -> Response:
201+
self._add_transaction_headers(request)
202+
response = await self._client.send(request, stream=True)
203+
self._handle_transaction_updates(response.headers)
204+
return response
205+
206+
async def _execute_query(self, request: Request) -> Response:
207+
if self.in_transaction():
208+
async with self._transaction_lock:
209+
return await self._execute_query_impl(request)
210+
else:
211+
return await self._execute_query_impl(request)
212+
213+
async def commit(self) -> None:
214+
await self.cursor().execute("COMMIT")
215+
216+
async def rollback(self) -> None:
217+
await self.cursor().execute("ROLLBACK")
218+
195219
# Context manager support
196220
async def __aenter__(self) -> Connection:
197221
if self.closed:

src/firebolt/async_db/cursor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
)
3030
from firebolt.common.cursor.base_cursor import (
3131
BaseCursor,
32-
_parse_remove_parameters,
3332
_parse_update_endpoint,
34-
_parse_update_parameters,
3533
_raise_if_internal_set_parameter,
3634
)
3735
from firebolt.common.cursor.decorators import (
@@ -58,6 +56,10 @@
5856
)
5957
from firebolt.utils.timeout_controller import TimeoutController
6058
from firebolt.utils.urls import DATABASES_URL, ENGINES_URL
59+
from firebolt.utils.util import (
60+
_parse_remove_parameters,
61+
_parse_update_parameters,
62+
)
6163

6264
if TYPE_CHECKING:
6365
from firebolt.async_db.connection import Connection
@@ -135,7 +137,7 @@ async def _api_request(
135137
content=query,
136138
timeout=timeout if timeout is not None else USE_CLIENT_DEFAULT,
137139
)
138-
return await self._client.send(req, stream=True)
140+
return await self.connection._execute_query(req)
139141
except TimeoutException:
140142
raise QueryTimeoutError()
141143

src/firebolt/common/base_connection.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
11
from collections import namedtuple
22
from typing import Any, Dict, List, Optional, Tuple, Type
33

4+
from httpx import Headers, Request
5+
46
from firebolt.client.auth.base import Auth
57
from firebolt.common._types import ColType
8+
from firebolt.common.constants import (
9+
REMOVE_PARAMETERS_HEADER,
10+
RESET_SESSION_HEADER,
11+
TRANSACTION_ID_SETTING,
12+
TRANSACTION_SEQUENCE_ID_SETTING,
13+
UPDATE_PARAMETERS_HEADER,
14+
)
615
from firebolt.utils.cache import (
716
ConnectionInfo,
817
EngineInfo,
918
SecureCacheKey,
1019
_firebolt_cache,
1120
)
12-
from firebolt.utils.exception import ConnectionClosedError, FireboltError
21+
from firebolt.utils.exception import FireboltError
1322
from firebolt.utils.usage_tracker import (
1423
get_cache_tracking_params,
1524
get_user_agent_header,
1625
)
26+
from firebolt.utils.util import (
27+
_parse_remove_parameters,
28+
_parse_update_parameters,
29+
)
1730

1831
ASYNC_QUERY_STATUS_RUNNING = "RUNNING"
1932
ASYNC_QUERY_STATUS_SUCCESSFUL = "ENDED_SUCCESSFULLY"
@@ -68,6 +81,8 @@ def __init__(self, cursor_type: Type) -> None:
6881
self.cursor_type = cursor_type
6982
self._cursors: List[Any] = []
7083
self._is_closed = False
84+
self._transaction_id: Optional[str] = None
85+
self._transaction_sequence_id: Optional[str] = None
7186

7287
def _remove_cursor(self, cursor: Any) -> None:
7388
# This way it's atomic
@@ -76,17 +91,62 @@ def _remove_cursor(self, cursor: Any) -> None:
7691
except ValueError:
7792
pass
7893

94+
def in_transaction(self) -> bool:
95+
"""`True` if connection is in a transaction; `False` otherwise."""
96+
return self._transaction_id is not None
97+
98+
def _parse_response_headers_transaction(self, headers: Headers) -> None:
99+
parameters_header = headers.get(UPDATE_PARAMETERS_HEADER)
100+
if not parameters_header:
101+
return
102+
parameters = _parse_update_parameters(parameters_header)
103+
transaction_id = parameters.get(TRANSACTION_ID_SETTING)
104+
if transaction_id:
105+
self._transaction_id = transaction_id
106+
sequence_id = parameters.get(TRANSACTION_SEQUENCE_ID_SETTING)
107+
if sequence_id:
108+
self._transaction_sequence_id = sequence_id
109+
110+
def _parse_remove_headers_transaction(self, headers: Headers) -> None:
111+
parameters_header = headers.get(REMOVE_PARAMETERS_HEADER)
112+
if not parameters_header:
113+
return
114+
parameters = _parse_remove_parameters(parameters_header)
115+
for param in parameters:
116+
if param == TRANSACTION_ID_SETTING:
117+
self._transaction_id = None
118+
elif param == TRANSACTION_SEQUENCE_ID_SETTING:
119+
self._transaction_sequence_id = None
120+
121+
def _reset_transaction_state(self) -> None:
122+
self._transaction_id = None
123+
self._transaction_sequence_id = None
124+
125+
def create_transaction_params(self) -> Dict[str, str]:
126+
params: Dict[str, str] = {}
127+
if self._transaction_id:
128+
params[TRANSACTION_ID_SETTING] = self._transaction_id
129+
if self._transaction_sequence_id is not None:
130+
params[TRANSACTION_SEQUENCE_ID_SETTING] = str(self._transaction_sequence_id)
131+
return params
132+
133+
def _add_transaction_headers(self, request: Request) -> None:
134+
transaction_params = self.create_transaction_params()
135+
for key, value in transaction_params.items():
136+
request.headers[key] = value
137+
138+
def _handle_transaction_updates(self, headers: Headers) -> None:
139+
self._parse_response_headers_transaction(headers)
140+
if headers.get(RESET_SESSION_HEADER):
141+
self._reset_transaction_state()
142+
if headers.get(REMOVE_PARAMETERS_HEADER):
143+
self._parse_remove_headers_transaction(headers)
144+
79145
@property
80146
def closed(self) -> bool:
81147
"""`True` if connection is closed; `False` otherwise."""
82148
return self._is_closed
83149

84-
def commit(self) -> None:
85-
"""Does nothing since Firebolt doesn't have transactions."""
86-
87-
if self.closed:
88-
raise ConnectionClosedError("Unable to commit: Connection closed.")
89-
90150

91151
def get_cached_system_engine_info(
92152
auth: Auth,

src/firebolt/common/constants.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,19 @@ class ParameterStyle(Enum):
2525
FB_NUMERIC = "fb_numeric" # $1, $2, ... as placeholders (server-side)
2626

2727

28+
TRANSACTION_ID_SETTING = "transaction_id"
29+
TRANSACTION_SEQUENCE_ID_SETTING = "transaction_sequence_id"
30+
2831
# Parameters that should be set using USE instead of SET
2932
USE_PARAMETER_LIST = ["database", "engine"]
3033
# parameters that can only be set by the backend
3134
DISALLOWED_PARAMETER_LIST = ["output_format"]
35+
# Connection level transaction management
36+
TRANSACTION_PARAMETER_LIST = [TRANSACTION_ID_SETTING, TRANSACTION_SEQUENCE_ID_SETTING]
3237
# parameters that are set by the backend and should not be set by the user
33-
IMMUTABLE_PARAMETER_LIST = USE_PARAMETER_LIST + DISALLOWED_PARAMETER_LIST
38+
IMMUTABLE_PARAMETER_LIST = (
39+
USE_PARAMETER_LIST + DISALLOWED_PARAMETER_LIST + TRANSACTION_PARAMETER_LIST
40+
)
3441
UPDATE_ENDPOINT_HEADER = "Firebolt-Update-Endpoint"
3542
UPDATE_PARAMETERS_HEADER = "Firebolt-Update-Parameters"
3643
RESET_SESSION_HEADER = "Firebolt-Reset-Session"

src/firebolt/common/cursor/base_cursor.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,6 @@
3333
logger = logging.getLogger(__name__)
3434

3535

36-
def _parse_update_parameters(parameter_header: str) -> Dict[str, str]:
37-
"""Parse update parameters and set them as attributes."""
38-
# parse key1=value1,key2=value2 comma separated string into dict
39-
param_dict = dict(item.split("=") for item in parameter_header.split(","))
40-
# strip whitespace from keys and values
41-
param_dict = {key.strip(): value.strip() for key, value in param_dict.items()}
42-
return param_dict
43-
44-
45-
def _parse_remove_parameters(parameter_header: str) -> List[str]:
46-
"""Parse remove parameters header and return list of parameter names to remove."""
47-
# parse key1,key2,key3 comma separated string into list
48-
param_list = [item.strip() for item in parameter_header.split(",")]
49-
return param_list
50-
51-
5236
def _parse_update_endpoint(
5337
new_engine_endpoint_header: str,
5438
) -> Tuple[str, Dict[str, str]]:

src/firebolt/db/connection.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

33
import logging
4+
import threading
45
from ssl import SSLContext
56
from types import TracebackType
67
from typing import Any, Dict, List, Optional, Type, Union
78
from uuid import uuid4
89
from warnings import warn
910

10-
from httpx import Timeout, codes
11+
from httpx import Request, Response, Timeout, codes
1112

1213
from firebolt.client import DEFAULT_API_URL, Client, ClientV1, ClientV2
1314
from firebolt.client.auth import Auth
@@ -213,6 +214,9 @@ class Connection(BaseConnection):
213214
"engine_url",
214215
"api_endpoint",
215216
"_is_closed",
217+
"_transaction_id",
218+
"_transaction_sequence_id",
219+
"_transaction_lock",
216220
"client_class",
217221
"cursor_type",
218222
"id",
@@ -234,6 +238,7 @@ def __init__(
234238
self._cursors: List[Cursor] = []
235239
self._client = client
236240
self.id = id
241+
self._transaction_lock: threading.Lock = threading.Lock()
237242
self.init_parameters = init_parameters or {}
238243
if database:
239244
self.init_parameters["database"] = database
@@ -264,6 +269,25 @@ def close(self) -> None:
264269
self._client.close()
265270
self._is_closed = True
266271

272+
def _execute_query_impl(self, request: Request) -> Response:
273+
self._add_transaction_headers(request)
274+
response = self._client.send(request, stream=True)
275+
self._handle_transaction_updates(response.headers)
276+
return response
277+
278+
def _execute_query(self, request: Request) -> Response:
279+
if self.in_transaction():
280+
with self._transaction_lock:
281+
return self._execute_query_impl(request)
282+
else:
283+
return self._execute_query_impl(request)
284+
285+
def commit(self) -> None:
286+
self.cursor().execute("COMMIT")
287+
288+
def rollback(self) -> None:
289+
self.cursor().execute("ROLLBACK")
290+
267291
# Server-side async methods
268292

269293
def get_async_query_info(self, token: str) -> List[AsyncQueryInfo]:

src/firebolt/db/cursor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@
3737
)
3838
from firebolt.common.cursor.base_cursor import (
3939
BaseCursor,
40-
_parse_remove_parameters,
4140
_parse_update_endpoint,
42-
_parse_update_parameters,
4341
_raise_if_internal_set_parameter,
4442
)
4543
from firebolt.common.cursor.decorators import (
@@ -67,7 +65,12 @@
6765
)
6866
from firebolt.utils.timeout_controller import TimeoutController
6967
from firebolt.utils.urls import DATABASES_URL, ENGINES_URL
70-
from firebolt.utils.util import Timer, raise_error_from_response
68+
from firebolt.utils.util import (
69+
Timer,
70+
_parse_remove_parameters,
71+
_parse_update_parameters,
72+
raise_error_from_response,
73+
)
7174

7275
if TYPE_CHECKING:
7376
from firebolt.db.connection import Connection
@@ -163,7 +166,7 @@ def _api_request(
163166
content=query,
164167
timeout=timeout if timeout is not None else USE_CLIENT_DEFAULT,
165168
)
166-
return self._client.send(req, stream=True)
169+
return self.connection._execute_query(req)
167170
except TimeoutException:
168171
raise QueryTimeoutError()
169172

src/firebolt/utils/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,22 @@ def raise_error_from_response(resp: Response) -> None:
189189
resp.raise_for_status()
190190

191191

192+
def _parse_update_parameters(parameter_header: str) -> Dict[str, str]:
193+
"""Parse update parameters and set them as attributes."""
194+
# parse key1=value1,key2=value2 comma separated string into dict
195+
param_dict = dict(item.split("=") for item in parameter_header.split(","))
196+
# strip whitespace from keys and values
197+
param_dict = {key.strip(): value.strip() for key, value in param_dict.items()}
198+
return param_dict
199+
200+
201+
def _parse_remove_parameters(parameter_header: str) -> List[str]:
202+
"""Parse remove parameters header and return list of parameter names to remove."""
203+
# parse key1,key2,key3 comma separated string into list
204+
param_list = [item.strip() for item in parameter_header.split(",")]
205+
return param_list
206+
207+
192208
class Timer:
193209
def __init__(self, message: str = ""):
194210
self._message = message

0 commit comments

Comments
 (0)