Skip to content

Commit f1809d8

Browse files
committed
Refactor
1 parent df5ffcb commit f1809d8

File tree

1 file changed

+35
-37
lines changed

1 file changed

+35
-37
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from google.cloud import bigquery
3434
from google.cloud.bigquery import StandardSqlDataType
3535
from google.cloud.bigquery.client import Client as BigQueryClient
36+
from google.cloud.bigquery.job import QueryJob
3637
from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult
3738
from google.cloud.bigquery.table import Table as BigQueryTable
3839

@@ -187,21 +188,23 @@ def query_factory() -> Query:
187188
]
188189

189190
def close(self) -> t.Any:
190-
# Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts
191-
for query_job in self._query_jobs:
192-
try:
193-
if not self._db_call(query_job.done):
194-
self._db_call(query_job.cancel)
195-
except Exception as ex:
196-
logger.debug(
197-
"Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s",
198-
self._query_job.project,
199-
self._query_job.location,
200-
self._query_job.job_id,
201-
str(ex),
202-
)
191+
query_job = self._query_job
192+
if not query_job:
193+
return super().close()
194+
195+
# Cancel the last submitted query job if it's still pending, to avoid it becoming orphan (e.g., if interrupted)
196+
try:
197+
if not self._db_call(query_job.done):
198+
self._db_call(query_job.cancel)
199+
except Exception as ex:
200+
logger.debug(
201+
"Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s",
202+
query_job.project,
203+
query_job.location,
204+
query_job.job_id,
205+
str(ex),
206+
)
203207

204-
self._query_jobs.clear()
205208
return super().close()
206209

207210
def _begin_session(self, properties: SessionProperties) -> None:
@@ -336,7 +339,10 @@ def create_mapping_schema(
336339
if len(table.parts) == 3 and "." in table.name:
337340
# The client's `get_table` method can't handle paths with >3 identifiers
338341
self.execute(exp.select("*").from_(table).limit(0))
339-
query_results = self._query_job._query_results
342+
query_job = self._query_job
343+
assert query_job is not None
344+
345+
query_results = query_job._query_results
340346
columns = create_mapping_schema(query_results.schema)
341347
else:
342348
bq_table = self._get_table(table)
@@ -723,7 +729,9 @@ def _fetch_native_df(
723729
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
724730
) -> DF:
725731
self.execute(query, quote_identifiers=quote_identifiers)
726-
return self._query_job.to_dataframe()
732+
query_job = self._query_job
733+
assert query_job is not None
734+
return query_job.to_dataframe()
727735

728736
def _create_column_comments(
729737
self,
@@ -1027,24 +1035,23 @@ def _execute(
10271035
job_config=job_config,
10281036
timeout=self._extra_config.get("job_creation_timeout_seconds"),
10291037
)
1030-
self._query_jobs.add(self._query_job)
1038+
query_job = self._query_job
1039+
assert query_job is not None
10311040

10321041
logger.debug(
10331042
"BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
1034-
self._query_job.project,
1035-
self._query_job.location,
1036-
self._query_job.job_id,
1043+
query_job.project,
1044+
query_job.location,
1045+
query_job.job_id,
10371046
)
10381047

10391048
results = self._db_call(
1040-
self._query_job.result,
1049+
query_job.result,
10411050
timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore
10421051
)
10431052

1044-
self._query_jobs.remove(self._query_job)
1045-
10461053
self._query_data = iter(results) if results.total_rows else iter([])
1047-
query_results = self._query_job._query_results
1054+
query_results = query_job._query_results
10481055
self.cursor._set_rowcount(query_results)
10491056
self.cursor._set_description(query_results.schema)
10501057

@@ -1208,32 +1215,23 @@ def _query_data(self) -> t.Any:
12081215

12091216
@_query_data.setter
12101217
def _query_data(self, value: t.Any) -> None:
1211-
return self._connection_pool.set_attribute("query_data", value)
1212-
1213-
@property
1214-
def _query_jobs(self) -> t.Any:
1215-
query_jobs = self._connection_pool.get_attribute("query_jobs")
1216-
if not isinstance(query_jobs, set):
1217-
query_jobs = set()
1218-
self._connection_pool.set_attribute("query_jobs", query_jobs)
1219-
1220-
return query_jobs
1218+
self._connection_pool.set_attribute("query_data", value)
12211219

12221220
@property
1223-
def _query_job(self) -> t.Any:
1221+
def _query_job(self) -> t.Optional[QueryJob]:
12241222
return self._connection_pool.get_attribute("query_job")
12251223

12261224
@_query_job.setter
12271225
def _query_job(self, value: t.Any) -> None:
1228-
return self._connection_pool.set_attribute("query_job", value)
1226+
self._connection_pool.set_attribute("query_job", value)
12291227

12301228
@property
12311229
def _session_id(self) -> t.Any:
12321230
return self._connection_pool.get_attribute("session_id")
12331231

12341232
@_session_id.setter
12351233
def _session_id(self, value: t.Any) -> None:
1236-
return self._connection_pool.set_attribute("session_id", value)
1234+
self._connection_pool.set_attribute("session_id", value)
12371235

12381236

12391237
class _ErrorCounter:

0 commit comments

Comments
 (0)