From 3cb08b94b2df1762ef8ee0220493d8876f73be7b Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 15 Jul 2025 13:12:24 +0300 Subject: [PATCH 1/6] Feat!: cancel submitted BigQuery jobs on keyboard interrupts --- sqlmesh/core/engine_adapter/bigquery.py | 20 ++++-- tests/core/engine_adapter/test_bigquery.py | 72 ++++++++++++++++++++++ 2 files changed, 88 insertions(+), 4 deletions(-) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index e953f4d1d0..11cbb408cb 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -1029,10 +1029,22 @@ def _execute( self._query_job.job_id, ) - results = self._db_call( - self._query_job.result, - timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore - ) + try: + results = self._db_call( + self._query_job.result, + timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore + ) + except KeyboardInterrupt: + # Wrapping this in another try-except to ensure the subsequent db calls don't change + # the original exception type. + try: + if not self._db_call(self._query_job.done): + self._db_call(self._query_job.cancel) + except: + pass + + raise + self._query_data = iter(results) if results.total_rows else iter([]) query_results = self._query_job._query_results self.cursor._set_rowcount(query_results) diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index e01e42049b..6b260e3f86 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -1072,3 +1072,75 @@ def test_get_alter_expressions_includes_catalog( assert schema.db == "bar" assert schema.sql(dialect="bigquery") == "catalog2.bar" assert tables == {"bing"} + + +def test_job_cancellation_on_keyboard_interrupt_job_still_running(mocker: MockerFixture): + # Create a mock connection + connection_mock = mocker.NonCallableMock() + cursor_mock = mocker.Mock() + cursor_mock.connection = connection_mock + connection_mock.cursor.return_value = cursor_mock + + # Mock the query job + mock_job = mocker.Mock() + mock_job.project = "test-project" + mock_job.location = "us-central1" + mock_job.job_id = "test-job-123" + mock_job.done.return_value = False # Job is still running + mock_job.result.side_effect = KeyboardInterrupt() + mock_job._query_results = mocker.Mock() + mock_job._query_results.total_rows = 0 + mock_job._query_results.schema = [] + + # Set up the client to return our mock job + connection_mock._client.query.return_value = mock_job + + # Create adapter with the mocked connection + adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0) + + # Execute a query and expect KeyboardInterrupt + with pytest.raises(KeyboardInterrupt): + adapter.execute("SELECT 1") + + # Verify the job was created + connection_mock._client.query.assert_called_once() + + # Verify job status was checked and cancellation was called + mock_job.done.assert_called_once() + mock_job.cancel.assert_called_once() + + +def test_job_cancellation_on_keyboard_interrupt_job_already_done(mocker: MockerFixture): + # Create a mock connection + connection_mock = mocker.NonCallableMock() + cursor_mock = mocker.Mock() + cursor_mock.connection = connection_mock + connection_mock.cursor.return_value = cursor_mock + + # Mock the query job + mock_job = mocker.Mock() + mock_job.project = "test-project" + mock_job.location = "us-central1" + mock_job.job_id = "test-job-456" + mock_job.done.return_value = True # Job is already done + mock_job.result.side_effect = KeyboardInterrupt() + mock_job._query_results = mocker.Mock() + mock_job._query_results.total_rows = 0 + mock_job._query_results.schema = [] + + # Set up the client to return our mock job + connection_mock._client.query.return_value = mock_job + + # Create adapter with the mocked connection + adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0) + + # Execute a query and expect KeyboardInterrupt + with pytest.raises(KeyboardInterrupt): + adapter.execute("SELECT 1") + + # Verify the job was created + connection_mock._client.query.assert_called_once() + + # Verify job status was checked but cancellation was NOT called + mock_job.done.assert_called_once() + mock_job.cancel.assert_not_called() From ba62b6b7f4be50bad90e21cd3f8b5f35311a7ef8 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Fri, 18 Jul 2025 14:23:22 +0300 Subject: [PATCH 2/6] Refactor: cancel jobs on `close()` --- sqlmesh/core/engine_adapter/bigquery.py | 37 ++++++++++++++-------- tests/core/engine_adapter/test_bigquery.py | 6 ++++ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 11cbb408cb..fefb407860 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -186,6 +186,14 @@ def query_factory() -> Query: ) ] + def close(self) -> t.Any: + # Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts + for query_job in self._query_jobs: + if not self._db_call(query_job.done): + self._db_call(query_job.cancel) + + return super().close() + def _begin_session(self, properties: SessionProperties) -> None: from google.cloud.bigquery import QueryJobConfig @@ -1021,6 +1029,7 @@ def _execute( job_config=job_config, timeout=self._extra_config.get("job_creation_timeout_seconds"), ) + self._query_jobs.add(self._query_job) logger.debug( "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", @@ -1029,21 +1038,12 @@ def _execute( self._query_job.job_id, ) - try: - results = self._db_call( - self._query_job.result, - timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore - ) - except KeyboardInterrupt: - # Wrapping this in another try-except to ensure the subsequent db calls don't change - # the original exception type. - try: - if not self._db_call(self._query_job.done): - self._db_call(self._query_job.cancel) - except: - pass + results = self._db_call( + self._query_job.result, + timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore + ) - raise + self._query_jobs.remove(self._query_job) self._query_data = iter(results) if results.total_rows else iter([]) query_results = self._query_job._query_results @@ -1212,6 +1212,15 @@ def _query_data(self) -> t.Any: def _query_data(self, value: t.Any) -> None: return self._connection_pool.set_attribute("query_data", value) + @property + def _query_jobs(self) -> t.Any: + query_jobs = self._connection_pool.get_attribute("query_jobs") + if not isinstance(query_jobs, set): + query_jobs = set() + self._connection_pool.set_attribute("query_jobs", query_jobs) + + return query_jobs + @property def _query_job(self) -> t.Any: return self._connection_pool.get_attribute("query_job") diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index 6b260e3f86..32377ac1de 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -1102,6 +1102,9 @@ def test_job_cancellation_on_keyboard_interrupt_job_still_running(mocker: Mocker with pytest.raises(KeyboardInterrupt): adapter.execute("SELECT 1") + # Ensure the adapter's closed, so that the job can be aborted + adapter.close() + # Verify the job was created connection_mock._client.query.assert_called_once() @@ -1138,6 +1141,9 @@ def test_job_cancellation_on_keyboard_interrupt_job_already_done(mocker: MockerF with pytest.raises(KeyboardInterrupt): adapter.execute("SELECT 1") + # Ensure the adapter's closed, so that the job can be aborted + adapter.close() + # Verify the job was created connection_mock._client.query.assert_called_once() From 0a9bd5de245a89e39e4310d731a4f00ab34a91d9 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 21 Jul 2025 11:54:22 +0300 Subject: [PATCH 3/6] Clear query_jobs set on `close()` --- sqlmesh/core/engine_adapter/bigquery.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index fefb407860..0584bfae1e 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -192,6 +192,7 @@ def close(self) -> t.Any: if not self._db_call(query_job.done): self._db_call(query_job.cancel) + self._query_jobs.clear() return super().close() def _begin_session(self, properties: SessionProperties) -> None: From 9538a4fba6fa5ec32fc510c9992ecf9dae262684 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 23 Jul 2025 12:06:10 +0300 Subject: [PATCH 4/6] PR feedback: use try/except to avoid failing when cancelling a job --- sqlmesh/core/engine_adapter/bigquery.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 0584bfae1e..40c62cce85 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -189,8 +189,17 @@ def query_factory() -> Query: def close(self) -> t.Any: # Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts for query_job in self._query_jobs: - if not self._db_call(query_job.done): - self._db_call(query_job.cancel) + try: + if not self._db_call(query_job.done): + self._db_call(query_job.cancel) + except Exception as ex: + logger.debug( + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", + self._query_job.project, + self._query_job.location, + self._query_job.job_id, + str(ex), + ) self._query_jobs.clear() return super().close() From 88a3501eb4e3bdab9f80307644d93fd6d3581271 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Fri, 25 Jul 2025 16:39:38 +0300 Subject: [PATCH 5/6] Refactor --- sqlmesh/core/engine_adapter/bigquery.py | 72 ++++++++++++------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 40c62cce85..1c134cbfd7 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -33,6 +33,7 @@ from google.cloud import bigquery from google.cloud.bigquery import StandardSqlDataType from google.cloud.bigquery.client import Client as BigQueryClient + from google.cloud.bigquery.job import QueryJob from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult from google.cloud.bigquery.table import Table as BigQueryTable @@ -187,21 +188,23 @@ def query_factory() -> Query: ] def close(self) -> t.Any: - # Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts - for query_job in self._query_jobs: - try: - if not self._db_call(query_job.done): - self._db_call(query_job.cancel) - except Exception as ex: - logger.debug( - "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", - self._query_job.project, - self._query_job.location, - self._query_job.job_id, - str(ex), - ) + query_job = self._query_job + if not query_job: + return super().close() + + # Cancel the last submitted query job if it's still pending, to avoid it becoming orphan (e.g., if interrupted) + try: + if not self._db_call(query_job.done): + self._db_call(query_job.cancel) + except Exception as ex: + logger.debug( + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", + query_job.project, + query_job.location, + query_job.job_id, + str(ex), + ) - self._query_jobs.clear() return super().close() def _begin_session(self, properties: SessionProperties) -> None: @@ -336,7 +339,10 @@ def create_mapping_schema( if len(table.parts) == 3 and "." in table.name: # The client's `get_table` method can't handle paths with >3 identifiers self.execute(exp.select("*").from_(table).limit(0)) - query_results = self._query_job._query_results + query_job = self._query_job + assert query_job is not None + + query_results = query_job._query_results columns = create_mapping_schema(query_results.schema) else: bq_table = self._get_table(table) @@ -735,7 +741,9 @@ def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False ) -> DF: self.execute(query, quote_identifiers=quote_identifiers) - return self._query_job.to_dataframe() + query_job = self._query_job + assert query_job is not None + return query_job.to_dataframe() def _create_column_comments( self, @@ -1039,24 +1047,23 @@ def _execute( job_config=job_config, timeout=self._extra_config.get("job_creation_timeout_seconds"), ) - self._query_jobs.add(self._query_job) + query_job = self._query_job + assert query_job is not None logger.debug( "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", - self._query_job.project, - self._query_job.location, - self._query_job.job_id, + query_job.project, + query_job.location, + query_job.job_id, ) results = self._db_call( - self._query_job.result, + query_job.result, timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore ) - self._query_jobs.remove(self._query_job) - self._query_data = iter(results) if results.total_rows else iter([]) - query_results = self._query_job._query_results + query_results = query_job._query_results self.cursor._set_rowcount(query_results) self.cursor._set_description(query_results.schema) @@ -1220,24 +1227,15 @@ def _query_data(self) -> t.Any: @_query_data.setter def _query_data(self, value: t.Any) -> None: - return self._connection_pool.set_attribute("query_data", value) - - @property - def _query_jobs(self) -> t.Any: - query_jobs = self._connection_pool.get_attribute("query_jobs") - if not isinstance(query_jobs, set): - query_jobs = set() - self._connection_pool.set_attribute("query_jobs", query_jobs) - - return query_jobs + self._connection_pool.set_attribute("query_data", value) @property - def _query_job(self) -> t.Any: + def _query_job(self) -> t.Optional[QueryJob]: return self._connection_pool.get_attribute("query_job") @_query_job.setter def _query_job(self, value: t.Any) -> None: - return self._connection_pool.set_attribute("query_job", value) + self._connection_pool.set_attribute("query_job", value) @property def _session_id(self) -> t.Any: @@ -1245,7 +1243,7 @@ def _session_id(self) -> t.Any: @_session_id.setter def _session_id(self, value: t.Any) -> None: - return self._connection_pool.set_attribute("session_id", value) + self._connection_pool.set_attribute("session_id", value) class _ErrorCounter: From 1bd8ae9bec2e27ddd07c8da465f5223942c33712 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 12 Aug 2025 18:25:43 +0300 Subject: [PATCH 6/6] Refactor to fetch all jobs across all threads --- sqlmesh/core/engine_adapter/bigquery.py | 37 ++++++++++++++----------- sqlmesh/utils/connection_pool.py | 24 ++++++++++++++++ 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 1c134cbfd7..1b32195f77 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -188,22 +188,27 @@ def query_factory() -> Query: ] def close(self) -> t.Any: - query_job = self._query_job - if not query_job: - return super().close() - - # Cancel the last submitted query job if it's still pending, to avoid it becoming orphan (e.g., if interrupted) - try: - if not self._db_call(query_job.done): - self._db_call(query_job.cancel) - except Exception as ex: - logger.debug( - "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", - query_job.project, - query_job.location, - query_job.job_id, - str(ex), - ) + # Cancel all pending query jobs across all threads + all_query_jobs = self._connection_pool.get_all_attributes("query_job") + for query_job in all_query_jobs: + if query_job: + try: + if not self._db_call(query_job.done): + self._db_call(query_job.cancel) + logger.debug( + "Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", + query_job.project, + query_job.location, + query_job.job_id, + ) + except Exception as ex: + logger.debug( + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", + query_job.project, + query_job.location, + query_job.job_id, + str(ex), + ) return super().close() diff --git a/sqlmesh/utils/connection_pool.py b/sqlmesh/utils/connection_pool.py index 54d62a2f8c..a4f9486184 100644 --- a/sqlmesh/utils/connection_pool.py +++ b/sqlmesh/utils/connection_pool.py @@ -48,6 +48,17 @@ def set_attribute(self, key: str, value: t.Any) -> None: value: Attribute value. """ + @abc.abstractmethod + def get_all_attributes(self, key: str) -> t.List[t.Any]: + """Returns all attributes with the given key across all connections/threads. + + Args: + key: Attribute key. + + Returns: + List of attribute values from all connections/threads. + """ + @abc.abstractmethod def begin(self) -> None: """Starts a new transaction.""" @@ -142,6 +153,14 @@ def set_attribute(self, key: str, value: t.Any) -> None: thread_id = get_ident() self._thread_attributes[thread_id][key] = value + def get_all_attributes(self, key: str) -> t.List[t.Any]: + """Returns all attributes with the given key across all threads.""" + return [ + thread_attrs[key] + for thread_attrs in self._thread_attributes.values() + if key in thread_attrs + ] + def begin(self) -> None: self._do_begin() with self._thread_transactions_lock: @@ -282,6 +301,11 @@ def get_attribute(self, key: str) -> t.Optional[t.Any]: def set_attribute(self, key: str, value: t.Any) -> None: self._attributes[key] = value + def get_all_attributes(self, key: str) -> t.List[t.Any]: + """Returns all attributes with the given key (single-threaded pool has at most one).""" + value = self._attributes.get(key) + return [value] if value is not None else [] + def begin(self) -> None: self._do_begin() self._is_transaction_active = True