Skip to content

Commit 76cd59a

Browse files
committed
Refactor
1 parent c125ca3 commit 76cd59a

File tree

1 file changed

+35
-37
lines changed

1 file changed

+35
-37
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from google.cloud import bigquery
3434
from google.cloud.bigquery import StandardSqlDataType
3535
from google.cloud.bigquery.client import Client as BigQueryClient
36+
from google.cloud.bigquery.job import QueryJob
3637
from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult
3738
from google.cloud.bigquery.table import Table as BigQueryTable
3839

@@ -187,21 +188,23 @@ def query_factory() -> Query:
187188
]
188189

189190
def close(self) -> t.Any:
190-
# Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts
191-
for query_job in self._query_jobs:
192-
try:
193-
if not self._db_call(query_job.done):
194-
self._db_call(query_job.cancel)
195-
except Exception as ex:
196-
logger.debug(
197-
"Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s",
198-
self._query_job.project,
199-
self._query_job.location,
200-
self._query_job.job_id,
201-
str(ex),
202-
)
191+
query_job = self._query_job
192+
if not query_job:
193+
return super().close()
194+
195+
# Cancel the last submitted query job if it's still pending, to avoid it becoming orphan (e.g., if interrupted)
196+
try:
197+
if not self._db_call(query_job.done):
198+
self._db_call(query_job.cancel)
199+
except Exception as ex:
200+
logger.debug(
201+
"Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s",
202+
query_job.project,
203+
query_job.location,
204+
query_job.job_id,
205+
str(ex),
206+
)
203207

204-
self._query_jobs.clear()
205208
return super().close()
206209

207210
def _begin_session(self, properties: SessionProperties) -> None:
@@ -336,7 +339,10 @@ def create_mapping_schema(
336339
if len(table.parts) == 3 and "." in table.name:
337340
# The client's `get_table` method can't handle paths with >3 identifiers
338341
self.execute(exp.select("*").from_(table).limit(0))
339-
query_results = self._query_job._query_results
342+
query_job = self._query_job
343+
assert query_job is not None
344+
345+
query_results = query_job._query_results
340346
columns = create_mapping_schema(query_results.schema)
341347
else:
342348
bq_table = self._get_table(table)
@@ -735,7 +741,9 @@ def _fetch_native_df(
735741
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
736742
) -> DF:
737743
self.execute(query, quote_identifiers=quote_identifiers)
738-
return self._query_job.to_dataframe()
744+
query_job = self._query_job
745+
assert query_job is not None
746+
return query_job.to_dataframe()
739747

740748
def _create_column_comments(
741749
self,
@@ -1039,24 +1047,23 @@ def _execute(
10391047
job_config=job_config,
10401048
timeout=self._extra_config.get("job_creation_timeout_seconds"),
10411049
)
1042-
self._query_jobs.add(self._query_job)
1050+
query_job = self._query_job
1051+
assert query_job is not None
10431052

10441053
logger.debug(
10451054
"BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
1046-
self._query_job.project,
1047-
self._query_job.location,
1048-
self._query_job.job_id,
1055+
query_job.project,
1056+
query_job.location,
1057+
query_job.job_id,
10491058
)
10501059

10511060
results = self._db_call(
1052-
self._query_job.result,
1061+
query_job.result,
10531062
timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore
10541063
)
10551064

1056-
self._query_jobs.remove(self._query_job)
1057-
10581065
self._query_data = iter(results) if results.total_rows else iter([])
1059-
query_results = self._query_job._query_results
1066+
query_results = query_job._query_results
10601067
self.cursor._set_rowcount(query_results)
10611068
self.cursor._set_description(query_results.schema)
10621069

@@ -1220,32 +1227,23 @@ def _query_data(self) -> t.Any:
12201227

12211228
@_query_data.setter
12221229
def _query_data(self, value: t.Any) -> None:
1223-
return self._connection_pool.set_attribute("query_data", value)
1224-
1225-
@property
1226-
def _query_jobs(self) -> t.Any:
1227-
query_jobs = self._connection_pool.get_attribute("query_jobs")
1228-
if not isinstance(query_jobs, set):
1229-
query_jobs = set()
1230-
self._connection_pool.set_attribute("query_jobs", query_jobs)
1231-
1232-
return query_jobs
1230+
self._connection_pool.set_attribute("query_data", value)
12331231

12341232
@property
1235-
def _query_job(self) -> t.Any:
1233+
def _query_job(self) -> t.Optional[QueryJob]:
12361234
return self._connection_pool.get_attribute("query_job")
12371235

12381236
@_query_job.setter
12391237
def _query_job(self, value: t.Any) -> None:
1240-
return self._connection_pool.set_attribute("query_job", value)
1238+
self._connection_pool.set_attribute("query_job", value)
12411239

12421240
@property
12431241
def _session_id(self) -> t.Any:
12441242
return self._connection_pool.get_attribute("session_id")
12451243

12461244
@_session_id.setter
12471245
def _session_id(self, value: t.Any) -> None:
1248-
return self._connection_pool.set_attribute("session_id", value)
1246+
self._connection_pool.set_attribute("session_id", value)
12491247

12501248

12511249
class _ErrorCounter:

0 commit comments

Comments
 (0)