Skip to content

Commit 2a5b223

Browse files
committed
Refactor: cancel jobs on close()
1 parent 32eee2c commit 2a5b223

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,14 @@ def query_factory() -> Query:
186186
)
187187
]
188188

189+
def close(self) -> t.Any:
190+
# Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts
191+
for query_job in self._query_jobs:
192+
if not self._db_call(query_job.done):
193+
self._db_call(query_job.cancel)
194+
195+
return super().close()
196+
189197
def _begin_session(self, properties: SessionProperties) -> None:
190198
from google.cloud.bigquery import QueryJobConfig
191199

@@ -1009,6 +1017,7 @@ def _execute(
10091017
job_config=job_config,
10101018
timeout=self._extra_config.get("job_creation_timeout_seconds"),
10111019
)
1020+
self._query_jobs.add(self._query_job)
10121021

10131022
logger.debug(
10141023
"BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
@@ -1017,21 +1026,12 @@ def _execute(
10171026
self._query_job.job_id,
10181027
)
10191028

1020-
try:
1021-
results = self._db_call(
1022-
self._query_job.result,
1023-
timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore
1024-
)
1025-
except KeyboardInterrupt:
1026-
# Wrapping this in another try-except to ensure the subsequent db calls don't change
1027-
# the original exception type.
1028-
try:
1029-
if not self._db_call(self._query_job.done):
1030-
self._db_call(self._query_job.cancel)
1031-
except:
1032-
pass
1029+
results = self._db_call(
1030+
self._query_job.result,
1031+
timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore
1032+
)
10331033

1034-
raise
1034+
self._query_jobs.remove(self._query_job)
10351035

10361036
self._query_data = iter(results) if results.total_rows else iter([])
10371037
query_results = self._query_job._query_results
@@ -1200,6 +1200,15 @@ def _query_data(self) -> t.Any:
12001200
def _query_data(self, value: t.Any) -> None:
12011201
return self._connection_pool.set_attribute("query_data", value)
12021202

1203+
@property
1204+
def _query_jobs(self) -> t.Any:
1205+
query_jobs = self._connection_pool.get_attribute("query_jobs")
1206+
if not isinstance(query_jobs, set):
1207+
query_jobs = set()
1208+
self._connection_pool.set_attribute("query_jobs", query_jobs)
1209+
1210+
return query_jobs
1211+
12031212
@property
12041213
def _query_job(self) -> t.Any:
12051214
return self._connection_pool.get_attribute("query_job")

tests/core/engine_adapter/test_bigquery.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,9 @@ def test_job_cancellation_on_keyboard_interrupt_job_still_running(mocker: Mocker
11021102
with pytest.raises(KeyboardInterrupt):
11031103
adapter.execute("SELECT 1")
11041104

1105+
# Ensure the adapter's closed, so that the job can be aborted
1106+
adapter.close()
1107+
11051108
# Verify the job was created
11061109
connection_mock._client.query.assert_called_once()
11071110

@@ -1138,6 +1141,9 @@ def test_job_cancellation_on_keyboard_interrupt_job_already_done(mocker: MockerF
11381141
with pytest.raises(KeyboardInterrupt):
11391142
adapter.execute("SELECT 1")
11401143

1144+
# Ensure the adapter's closed, so that the job can be aborted
1145+
adapter.close()
1146+
11411147
# Verify the job was created
11421148
connection_mock._client.query.assert_called_once()
11431149

0 commit comments

Comments
 (0)