Skip to content

Commit d177625

Browse files
authored
Feat!: cancel submitted BigQuery jobs on keyboard interrupts (#4979)
1 parent eed4c26 commit d177625

File tree

3 files changed

+147
-11
lines changed

3 files changed

+147
-11
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from google.cloud import bigquery
3434
from google.cloud.bigquery import StandardSqlDataType
3535
from google.cloud.bigquery.client import Client as BigQueryClient
36+
from google.cloud.bigquery.job import QueryJob
3637
from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult
3738
from google.cloud.bigquery.table import Table as BigQueryTable
3839

@@ -186,6 +187,31 @@ def query_factory() -> Query:
186187
)
187188
]
188189

190+
def close(self) -> t.Any:
191+
# Cancel all pending query jobs across all threads
192+
all_query_jobs = self._connection_pool.get_all_attributes("query_job")
193+
for query_job in all_query_jobs:
194+
if query_job:
195+
try:
196+
if not self._db_call(query_job.done):
197+
self._db_call(query_job.cancel)
198+
logger.debug(
199+
"Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
200+
query_job.project,
201+
query_job.location,
202+
query_job.job_id,
203+
)
204+
except Exception as ex:
205+
logger.debug(
206+
"Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s",
207+
query_job.project,
208+
query_job.location,
209+
query_job.job_id,
210+
str(ex),
211+
)
212+
213+
return super().close()
214+
189215
def _begin_session(self, properties: SessionProperties) -> None:
190216
from google.cloud.bigquery import QueryJobConfig
191217

@@ -318,7 +344,10 @@ def create_mapping_schema(
318344
if len(table.parts) == 3 and "." in table.name:
319345
# The client's `get_table` method can't handle paths with >3 identifiers
320346
self.execute(exp.select("*").from_(table).limit(0))
321-
query_results = self._query_job._query_results
347+
query_job = self._query_job
348+
assert query_job is not None
349+
350+
query_results = query_job._query_results
322351
columns = create_mapping_schema(query_results.schema)
323352
else:
324353
bq_table = self._get_table(table)
@@ -717,7 +746,9 @@ def _fetch_native_df(
717746
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
718747
) -> DF:
719748
self.execute(query, quote_identifiers=quote_identifiers)
720-
return self._query_job.to_dataframe()
749+
query_job = self._query_job
750+
assert query_job is not None
751+
return query_job.to_dataframe()
721752

722753
def _create_column_comments(
723754
self,
@@ -1021,20 +1052,23 @@ def _execute(
10211052
job_config=job_config,
10221053
timeout=self._extra_config.get("job_creation_timeout_seconds"),
10231054
)
1055+
query_job = self._query_job
1056+
assert query_job is not None
10241057

10251058
logger.debug(
10261059
"BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
1027-
self._query_job.project,
1028-
self._query_job.location,
1029-
self._query_job.job_id,
1060+
query_job.project,
1061+
query_job.location,
1062+
query_job.job_id,
10301063
)
10311064

10321065
results = self._db_call(
1033-
self._query_job.result,
1066+
query_job.result,
10341067
timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore
10351068
)
1069+
10361070
self._query_data = iter(results) if results.total_rows else iter([])
1037-
query_results = self._query_job._query_results
1071+
query_results = query_job._query_results
10381072
self.cursor._set_rowcount(query_results)
10391073
self.cursor._set_description(query_results.schema)
10401074

@@ -1198,23 +1232,23 @@ def _query_data(self) -> t.Any:
11981232

11991233
@_query_data.setter
12001234
def _query_data(self, value: t.Any) -> None:
1201-
return self._connection_pool.set_attribute("query_data", value)
1235+
self._connection_pool.set_attribute("query_data", value)
12021236

12031237
@property
1204-
def _query_job(self) -> t.Any:
1238+
def _query_job(self) -> t.Optional[QueryJob]:
12051239
return self._connection_pool.get_attribute("query_job")
12061240

12071241
@_query_job.setter
12081242
def _query_job(self, value: t.Any) -> None:
1209-
return self._connection_pool.set_attribute("query_job", value)
1243+
self._connection_pool.set_attribute("query_job", value)
12101244

12111245
@property
12121246
def _session_id(self) -> t.Any:
12131247
return self._connection_pool.get_attribute("session_id")
12141248

12151249
@_session_id.setter
12161250
def _session_id(self, value: t.Any) -> None:
1217-
return self._connection_pool.set_attribute("session_id", value)
1251+
self._connection_pool.set_attribute("session_id", value)
12181252

12191253

12201254
class _ErrorCounter:

sqlmesh/utils/connection_pool.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ def set_attribute(self, key: str, value: t.Any) -> None:
4848
value: Attribute value.
4949
"""
5050

51+
@abc.abstractmethod
52+
def get_all_attributes(self, key: str) -> t.List[t.Any]:
53+
"""Returns all attributes with the given key across all connections/threads.
54+
55+
Args:
56+
key: Attribute key.
57+
58+
Returns:
59+
List of attribute values from all connections/threads.
60+
"""
61+
5162
@abc.abstractmethod
5263
def begin(self) -> None:
5364
"""Starts a new transaction."""
@@ -142,6 +153,14 @@ def set_attribute(self, key: str, value: t.Any) -> None:
142153
thread_id = get_ident()
143154
self._thread_attributes[thread_id][key] = value
144155

156+
def get_all_attributes(self, key: str) -> t.List[t.Any]:
157+
"""Returns all attributes with the given key across all threads."""
158+
return [
159+
thread_attrs[key]
160+
for thread_attrs in self._thread_attributes.values()
161+
if key in thread_attrs
162+
]
163+
145164
def begin(self) -> None:
146165
self._do_begin()
147166
with self._thread_transactions_lock:
@@ -282,6 +301,11 @@ def get_attribute(self, key: str) -> t.Optional[t.Any]:
282301
def set_attribute(self, key: str, value: t.Any) -> None:
283302
self._attributes[key] = value
284303

304+
def get_all_attributes(self, key: str) -> t.List[t.Any]:
305+
"""Returns all attributes with the given key (single-threaded pool has at most one)."""
306+
value = self._attributes.get(key)
307+
return [value] if value is not None else []
308+
285309
def begin(self) -> None:
286310
self._do_begin()
287311
self._is_transaction_active = True

tests/core/engine_adapter/test_bigquery.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,3 +1072,81 @@ def test_get_alter_expressions_includes_catalog(
10721072
assert schema.db == "bar"
10731073
assert schema.sql(dialect="bigquery") == "catalog2.bar"
10741074
assert tables == {"bing"}
1075+
1076+
1077+
def test_job_cancellation_on_keyboard_interrupt_job_still_running(mocker: MockerFixture):
1078+
# Create a mock connection
1079+
connection_mock = mocker.NonCallableMock()
1080+
cursor_mock = mocker.Mock()
1081+
cursor_mock.connection = connection_mock
1082+
connection_mock.cursor.return_value = cursor_mock
1083+
1084+
# Mock the query job
1085+
mock_job = mocker.Mock()
1086+
mock_job.project = "test-project"
1087+
mock_job.location = "us-central1"
1088+
mock_job.job_id = "test-job-123"
1089+
mock_job.done.return_value = False # Job is still running
1090+
mock_job.result.side_effect = KeyboardInterrupt()
1091+
mock_job._query_results = mocker.Mock()
1092+
mock_job._query_results.total_rows = 0
1093+
mock_job._query_results.schema = []
1094+
1095+
# Set up the client to return our mock job
1096+
connection_mock._client.query.return_value = mock_job
1097+
1098+
# Create adapter with the mocked connection
1099+
adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0)
1100+
1101+
# Execute a query and expect KeyboardInterrupt
1102+
with pytest.raises(KeyboardInterrupt):
1103+
adapter.execute("SELECT 1")
1104+
1105+
# Ensure the adapter's closed, so that the job can be aborted
1106+
adapter.close()
1107+
1108+
# Verify the job was created
1109+
connection_mock._client.query.assert_called_once()
1110+
1111+
# Verify job status was checked and cancellation was called
1112+
mock_job.done.assert_called_once()
1113+
mock_job.cancel.assert_called_once()
1114+
1115+
1116+
def test_job_cancellation_on_keyboard_interrupt_job_already_done(mocker: MockerFixture):
1117+
# Create a mock connection
1118+
connection_mock = mocker.NonCallableMock()
1119+
cursor_mock = mocker.Mock()
1120+
cursor_mock.connection = connection_mock
1121+
connection_mock.cursor.return_value = cursor_mock
1122+
1123+
# Mock the query job
1124+
mock_job = mocker.Mock()
1125+
mock_job.project = "test-project"
1126+
mock_job.location = "us-central1"
1127+
mock_job.job_id = "test-job-456"
1128+
mock_job.done.return_value = True # Job is already done
1129+
mock_job.result.side_effect = KeyboardInterrupt()
1130+
mock_job._query_results = mocker.Mock()
1131+
mock_job._query_results.total_rows = 0
1132+
mock_job._query_results.schema = []
1133+
1134+
# Set up the client to return our mock job
1135+
connection_mock._client.query.return_value = mock_job
1136+
1137+
# Create adapter with the mocked connection
1138+
adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0)
1139+
1140+
# Execute a query and expect KeyboardInterrupt
1141+
with pytest.raises(KeyboardInterrupt):
1142+
adapter.execute("SELECT 1")
1143+
1144+
# Ensure the adapter's closed, so that the job can be aborted
1145+
adapter.close()
1146+
1147+
# Verify the job was created
1148+
connection_mock._client.query.assert_called_once()
1149+
1150+
# Verify job status was checked but cancellation was NOT called
1151+
mock_job.done.assert_called_once()
1152+
mock_job.cancel.assert_not_called()

0 commit comments

Comments
 (0)