diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index e953f4d1d0..1b32195f77 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -33,6 +33,7 @@ from google.cloud import bigquery from google.cloud.bigquery import StandardSqlDataType from google.cloud.bigquery.client import Client as BigQueryClient + from google.cloud.bigquery.job import QueryJob from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult from google.cloud.bigquery.table import Table as BigQueryTable @@ -186,6 +187,31 @@ def query_factory() -> Query: ) ] + def close(self) -> t.Any: + # Cancel all pending query jobs across all threads + all_query_jobs = self._connection_pool.get_all_attributes("query_job") + for query_job in all_query_jobs: + if query_job: + try: + if not self._db_call(query_job.done): + self._db_call(query_job.cancel) + logger.debug( + "Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", + query_job.project, + query_job.location, + query_job.job_id, + ) + except Exception as ex: + logger.debug( + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", + query_job.project, + query_job.location, + query_job.job_id, + str(ex), + ) + + return super().close() + def _begin_session(self, properties: SessionProperties) -> None: from google.cloud.bigquery import QueryJobConfig @@ -318,7 +344,10 @@ def create_mapping_schema( if len(table.parts) == 3 and "." in table.name: # The client's `get_table` method can't handle paths with >3 identifiers self.execute(exp.select("*").from_(table).limit(0)) - query_results = self._query_job._query_results + query_job = self._query_job + assert query_job is not None + + query_results = query_job._query_results columns = create_mapping_schema(query_results.schema) else: bq_table = self._get_table(table) @@ -717,7 +746,9 @@ def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False ) -> DF: self.execute(query, quote_identifiers=quote_identifiers) - return self._query_job.to_dataframe() + query_job = self._query_job + assert query_job is not None + return query_job.to_dataframe() def _create_column_comments( self, @@ -1021,20 +1052,23 @@ def _execute( job_config=job_config, timeout=self._extra_config.get("job_creation_timeout_seconds"), ) + query_job = self._query_job + assert query_job is not None logger.debug( "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", - self._query_job.project, - self._query_job.location, - self._query_job.job_id, + query_job.project, + query_job.location, + query_job.job_id, ) results = self._db_call( - self._query_job.result, + query_job.result, timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore ) + self._query_data = iter(results) if results.total_rows else iter([]) - query_results = self._query_job._query_results + query_results = query_job._query_results self.cursor._set_rowcount(query_results) self.cursor._set_description(query_results.schema) @@ -1198,15 +1232,15 @@ def _query_data(self) -> t.Any: @_query_data.setter def _query_data(self, value: t.Any) -> None: - return self._connection_pool.set_attribute("query_data", value) + self._connection_pool.set_attribute("query_data", value) @property - def _query_job(self) -> t.Any: + def _query_job(self) -> t.Optional[QueryJob]: return self._connection_pool.get_attribute("query_job") @_query_job.setter def _query_job(self, value: t.Any) -> None: - return self._connection_pool.set_attribute("query_job", value) + self._connection_pool.set_attribute("query_job", value) @property def _session_id(self) -> t.Any: @@ -1214,7 +1248,7 @@ def _session_id(self) -> t.Any: @_session_id.setter def _session_id(self, value: t.Any) -> None: - return self._connection_pool.set_attribute("session_id", value) + self._connection_pool.set_attribute("session_id", value) class _ErrorCounter: diff --git a/sqlmesh/utils/connection_pool.py b/sqlmesh/utils/connection_pool.py index 54d62a2f8c..a4f9486184 100644 --- a/sqlmesh/utils/connection_pool.py +++ b/sqlmesh/utils/connection_pool.py @@ -48,6 +48,17 @@ def set_attribute(self, key: str, value: t.Any) -> None: value: Attribute value. """ + @abc.abstractmethod + def get_all_attributes(self, key: str) -> t.List[t.Any]: + """Returns all attributes with the given key across all connections/threads. + + Args: + key: Attribute key. + + Returns: + List of attribute values from all connections/threads. + """ + @abc.abstractmethod def begin(self) -> None: """Starts a new transaction.""" @@ -142,6 +153,14 @@ def set_attribute(self, key: str, value: t.Any) -> None: thread_id = get_ident() self._thread_attributes[thread_id][key] = value + def get_all_attributes(self, key: str) -> t.List[t.Any]: + """Returns all attributes with the given key across all threads.""" + return [ + thread_attrs[key] + for thread_attrs in self._thread_attributes.values() + if key in thread_attrs + ] + def begin(self) -> None: self._do_begin() with self._thread_transactions_lock: @@ -282,6 +301,11 @@ def get_attribute(self, key: str) -> t.Optional[t.Any]: def set_attribute(self, key: str, value: t.Any) -> None: self._attributes[key] = value + def get_all_attributes(self, key: str) -> t.List[t.Any]: + """Returns all attributes with the given key (single-threaded pool has at most one).""" + value = self._attributes.get(key) + return [value] if value is not None else [] + def begin(self) -> None: self._do_begin() self._is_transaction_active = True diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index e01e42049b..32377ac1de 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -1072,3 +1072,81 @@ def test_get_alter_expressions_includes_catalog( assert schema.db == "bar" assert schema.sql(dialect="bigquery") == "catalog2.bar" assert tables == {"bing"} + + +def test_job_cancellation_on_keyboard_interrupt_job_still_running(mocker: MockerFixture): + # Create a mock connection + connection_mock = mocker.NonCallableMock() + cursor_mock = mocker.Mock() + cursor_mock.connection = connection_mock + connection_mock.cursor.return_value = cursor_mock + + # Mock the query job + mock_job = mocker.Mock() + mock_job.project = "test-project" + mock_job.location = "us-central1" + mock_job.job_id = "test-job-123" + mock_job.done.return_value = False # Job is still running + mock_job.result.side_effect = KeyboardInterrupt() + mock_job._query_results = mocker.Mock() + mock_job._query_results.total_rows = 0 + mock_job._query_results.schema = [] + + # Set up the client to return our mock job + connection_mock._client.query.return_value = mock_job + + # Create adapter with the mocked connection + adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0) + + # Execute a query and expect KeyboardInterrupt + with pytest.raises(KeyboardInterrupt): + adapter.execute("SELECT 1") + + # Ensure the adapter's closed, so that the job can be aborted + adapter.close() + + # Verify the job was created + connection_mock._client.query.assert_called_once() + + # Verify job status was checked and cancellation was called + mock_job.done.assert_called_once() + mock_job.cancel.assert_called_once() + + +def test_job_cancellation_on_keyboard_interrupt_job_already_done(mocker: MockerFixture): + # Create a mock connection + connection_mock = mocker.NonCallableMock() + cursor_mock = mocker.Mock() + cursor_mock.connection = connection_mock + connection_mock.cursor.return_value = cursor_mock + + # Mock the query job + mock_job = mocker.Mock() + mock_job.project = "test-project" + mock_job.location = "us-central1" + mock_job.job_id = "test-job-456" + mock_job.done.return_value = True # Job is already done + mock_job.result.side_effect = KeyboardInterrupt() + mock_job._query_results = mocker.Mock() + mock_job._query_results.total_rows = 0 + mock_job._query_results.schema = [] + + # Set up the client to return our mock job + connection_mock._client.query.return_value = mock_job + + # Create adapter with the mocked connection + adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0) + + # Execute a query and expect KeyboardInterrupt + with pytest.raises(KeyboardInterrupt): + adapter.execute("SELECT 1") + + # Ensure the adapter's closed, so that the job can be aborted + adapter.close() + + # Verify the job was created + connection_mock._client.query.assert_called_once() + + # Verify job status was checked but cancellation was NOT called + mock_job.done.assert_called_once() + mock_job.cancel.assert_not_called()