Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from google.cloud import bigquery
from google.cloud.bigquery import StandardSqlDataType
from google.cloud.bigquery.client import Client as BigQueryClient
from google.cloud.bigquery.job import QueryJob
from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult
from google.cloud.bigquery.table import Table as BigQueryTable

Expand Down Expand Up @@ -186,6 +187,31 @@ def query_factory() -> Query:
)
]

def close(self) -> t.Any:
# Cancel all pending query jobs across all threads
all_query_jobs = self._connection_pool.get_all_attributes("query_job")
for query_job in all_query_jobs:
if query_job:
try:
if not self._db_call(query_job.done):
self._db_call(query_job.cancel)
logger.debug(
"Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
query_job.project,
query_job.location,
query_job.job_id,
)
except Exception as ex:
logger.debug(
"Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s",
query_job.project,
query_job.location,
query_job.job_id,
str(ex),
)

return super().close()

def _begin_session(self, properties: SessionProperties) -> None:
from google.cloud.bigquery import QueryJobConfig

Expand Down Expand Up @@ -318,7 +344,10 @@ def create_mapping_schema(
if len(table.parts) == 3 and "." in table.name:
# The client's `get_table` method can't handle paths with >3 identifiers
self.execute(exp.select("*").from_(table).limit(0))
query_results = self._query_job._query_results
query_job = self._query_job
assert query_job is not None

query_results = query_job._query_results
columns = create_mapping_schema(query_results.schema)
else:
bq_table = self._get_table(table)
Expand Down Expand Up @@ -717,7 +746,9 @@ def _fetch_native_df(
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
) -> DF:
self.execute(query, quote_identifiers=quote_identifiers)
return self._query_job.to_dataframe()
query_job = self._query_job
assert query_job is not None
return query_job.to_dataframe()

def _create_column_comments(
self,
Expand Down Expand Up @@ -1021,20 +1052,23 @@ def _execute(
job_config=job_config,
timeout=self._extra_config.get("job_creation_timeout_seconds"),
)
query_job = self._query_job
assert query_job is not None

logger.debug(
"BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
self._query_job.project,
self._query_job.location,
self._query_job.job_id,
query_job.project,
query_job.location,
query_job.job_id,
)

results = self._db_call(
self._query_job.result,
query_job.result,
timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore
)

self._query_data = iter(results) if results.total_rows else iter([])
query_results = self._query_job._query_results
query_results = query_job._query_results
self.cursor._set_rowcount(query_results)
self.cursor._set_description(query_results.schema)

Expand Down Expand Up @@ -1198,23 +1232,23 @@ def _query_data(self) -> t.Any:

@_query_data.setter
def _query_data(self, value: t.Any) -> None:
return self._connection_pool.set_attribute("query_data", value)
self._connection_pool.set_attribute("query_data", value)

@property
def _query_job(self) -> t.Any:
def _query_job(self) -> t.Optional[QueryJob]:
return self._connection_pool.get_attribute("query_job")

@_query_job.setter
def _query_job(self, value: t.Any) -> None:
return self._connection_pool.set_attribute("query_job", value)
self._connection_pool.set_attribute("query_job", value)

@property
def _session_id(self) -> t.Any:
return self._connection_pool.get_attribute("session_id")

@_session_id.setter
def _session_id(self, value: t.Any) -> None:
return self._connection_pool.set_attribute("session_id", value)
self._connection_pool.set_attribute("session_id", value)


class _ErrorCounter:
Expand Down
24 changes: 24 additions & 0 deletions sqlmesh/utils/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ def set_attribute(self, key: str, value: t.Any) -> None:
value: Attribute value.
"""

@abc.abstractmethod
def get_all_attributes(self, key: str) -> t.List[t.Any]:
"""Returns all attributes with the given key across all connections/threads.

Args:
key: Attribute key.

Returns:
List of attribute values from all connections/threads.
"""

@abc.abstractmethod
def begin(self) -> None:
"""Starts a new transaction."""
Expand Down Expand Up @@ -142,6 +153,14 @@ def set_attribute(self, key: str, value: t.Any) -> None:
thread_id = get_ident()
self._thread_attributes[thread_id][key] = value

def get_all_attributes(self, key: str) -> t.List[t.Any]:
"""Returns all attributes with the given key across all threads."""
return [
thread_attrs[key]
for thread_attrs in self._thread_attributes.values()
if key in thread_attrs
]

def begin(self) -> None:
self._do_begin()
with self._thread_transactions_lock:
Expand Down Expand Up @@ -282,6 +301,11 @@ def get_attribute(self, key: str) -> t.Optional[t.Any]:
def set_attribute(self, key: str, value: t.Any) -> None:
self._attributes[key] = value

def get_all_attributes(self, key: str) -> t.List[t.Any]:
"""Returns all attributes with the given key (single-threaded pool has at most one)."""
value = self._attributes.get(key)
return [value] if value is not None else []

def begin(self) -> None:
self._do_begin()
self._is_transaction_active = True
Expand Down
78 changes: 78 additions & 0 deletions tests/core/engine_adapter/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,81 @@ def test_get_alter_expressions_includes_catalog(
assert schema.db == "bar"
assert schema.sql(dialect="bigquery") == "catalog2.bar"
assert tables == {"bing"}


def test_job_cancellation_on_keyboard_interrupt_job_still_running(mocker: MockerFixture):
# Create a mock connection
connection_mock = mocker.NonCallableMock()
cursor_mock = mocker.Mock()
cursor_mock.connection = connection_mock
connection_mock.cursor.return_value = cursor_mock

# Mock the query job
mock_job = mocker.Mock()
mock_job.project = "test-project"
mock_job.location = "us-central1"
mock_job.job_id = "test-job-123"
mock_job.done.return_value = False # Job is still running
mock_job.result.side_effect = KeyboardInterrupt()
mock_job._query_results = mocker.Mock()
mock_job._query_results.total_rows = 0
mock_job._query_results.schema = []

# Set up the client to return our mock job
connection_mock._client.query.return_value = mock_job

# Create adapter with the mocked connection
adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0)

# Execute a query and expect KeyboardInterrupt
with pytest.raises(KeyboardInterrupt):
adapter.execute("SELECT 1")

# Ensure the adapter's closed, so that the job can be aborted
adapter.close()

# Verify the job was created
connection_mock._client.query.assert_called_once()

# Verify job status was checked and cancellation was called
mock_job.done.assert_called_once()
mock_job.cancel.assert_called_once()


def test_job_cancellation_on_keyboard_interrupt_job_already_done(mocker: MockerFixture):
# Create a mock connection
connection_mock = mocker.NonCallableMock()
cursor_mock = mocker.Mock()
cursor_mock.connection = connection_mock
connection_mock.cursor.return_value = cursor_mock

# Mock the query job
mock_job = mocker.Mock()
mock_job.project = "test-project"
mock_job.location = "us-central1"
mock_job.job_id = "test-job-456"
mock_job.done.return_value = True # Job is already done
mock_job.result.side_effect = KeyboardInterrupt()
mock_job._query_results = mocker.Mock()
mock_job._query_results.total_rows = 0
mock_job._query_results.schema = []

# Set up the client to return our mock job
connection_mock._client.query.return_value = mock_job

# Create adapter with the mocked connection
adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0)

# Execute a query and expect KeyboardInterrupt
with pytest.raises(KeyboardInterrupt):
adapter.execute("SELECT 1")

# Ensure the adapter's closed, so that the job can be aborted
adapter.close()

# Verify the job was created
connection_mock._client.query.assert_called_once()

# Verify job status was checked but cancellation was NOT called
mock_job.done.assert_called_once()
mock_job.cancel.assert_not_called()
Loading