|
33 | 33 | from google.cloud import bigquery
|
34 | 34 | from google.cloud.bigquery import StandardSqlDataType
|
35 | 35 | from google.cloud.bigquery.client import Client as BigQueryClient
|
| 36 | + from google.cloud.bigquery.job import QueryJob |
36 | 37 | from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult
|
37 | 38 | from google.cloud.bigquery.table import Table as BigQueryTable
|
38 | 39 |
|
@@ -187,21 +188,23 @@ def query_factory() -> Query:
|
187 | 188 | ]
|
188 | 189 |
|
189 | 190 | def close(self) -> t.Any:
|
190 |
| - # Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts |
191 |
| - for query_job in self._query_jobs: |
192 |
| - try: |
193 |
| - if not self._db_call(query_job.done): |
194 |
| - self._db_call(query_job.cancel) |
195 |
| - except Exception as ex: |
196 |
| - logger.debug( |
197 |
| - "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", |
198 |
| - self._query_job.project, |
199 |
| - self._query_job.location, |
200 |
| - self._query_job.job_id, |
201 |
| - str(ex), |
202 |
| - ) |
| 191 | + query_job = self._query_job |
| 192 | + if not query_job: |
| 193 | + return super().close() |
| 194 | + |
| 195 | + # Cancel the last submitted query job if it's still pending, to avoid it becoming orphan (e.g., if interrupted) |
| 196 | + try: |
| 197 | + if not self._db_call(query_job.done): |
| 198 | + self._db_call(query_job.cancel) |
| 199 | + except Exception as ex: |
| 200 | + logger.debug( |
| 201 | + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", |
| 202 | + query_job.project, |
| 203 | + query_job.location, |
| 204 | + query_job.job_id, |
| 205 | + str(ex), |
| 206 | + ) |
203 | 207 |
|
204 |
| - self._query_jobs.clear() |
205 | 208 | return super().close()
|
206 | 209 |
|
207 | 210 | def _begin_session(self, properties: SessionProperties) -> None:
|
@@ -336,7 +339,10 @@ def create_mapping_schema(
|
336 | 339 | if len(table.parts) == 3 and "." in table.name:
|
337 | 340 | # The client's `get_table` method can't handle paths with >3 identifiers
|
338 | 341 | self.execute(exp.select("*").from_(table).limit(0))
|
339 |
| - query_results = self._query_job._query_results |
| 342 | + query_job = self._query_job |
| 343 | + assert query_job is not None |
| 344 | + |
| 345 | + query_results = query_job._query_results |
340 | 346 | columns = create_mapping_schema(query_results.schema)
|
341 | 347 | else:
|
342 | 348 | bq_table = self._get_table(table)
|
@@ -735,7 +741,9 @@ def _fetch_native_df(
|
735 | 741 | self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
|
736 | 742 | ) -> DF:
|
737 | 743 | self.execute(query, quote_identifiers=quote_identifiers)
|
738 |
| - return self._query_job.to_dataframe() |
| 744 | + query_job = self._query_job |
| 745 | + assert query_job is not None |
| 746 | + return query_job.to_dataframe() |
739 | 747 |
|
740 | 748 | def _create_column_comments(
|
741 | 749 | self,
|
@@ -1039,24 +1047,23 @@ def _execute(
|
1039 | 1047 | job_config=job_config,
|
1040 | 1048 | timeout=self._extra_config.get("job_creation_timeout_seconds"),
|
1041 | 1049 | )
|
1042 |
| - self._query_jobs.add(self._query_job) |
| 1050 | + query_job = self._query_job |
| 1051 | + assert query_job is not None |
1043 | 1052 |
|
1044 | 1053 | logger.debug(
|
1045 | 1054 | "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
|
1046 |
| - self._query_job.project, |
1047 |
| - self._query_job.location, |
1048 |
| - self._query_job.job_id, |
| 1055 | + query_job.project, |
| 1056 | + query_job.location, |
| 1057 | + query_job.job_id, |
1049 | 1058 | )
|
1050 | 1059 |
|
1051 | 1060 | results = self._db_call(
|
1052 |
| - self._query_job.result, |
| 1061 | + query_job.result, |
1053 | 1062 | timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore
|
1054 | 1063 | )
|
1055 | 1064 |
|
1056 |
| - self._query_jobs.remove(self._query_job) |
1057 |
| - |
1058 | 1065 | self._query_data = iter(results) if results.total_rows else iter([])
|
1059 |
| - query_results = self._query_job._query_results |
| 1066 | + query_results = query_job._query_results |
1060 | 1067 | self.cursor._set_rowcount(query_results)
|
1061 | 1068 | self.cursor._set_description(query_results.schema)
|
1062 | 1069 |
|
@@ -1220,32 +1227,23 @@ def _query_data(self) -> t.Any:
|
1220 | 1227 |
|
1221 | 1228 | @_query_data.setter
|
1222 | 1229 | def _query_data(self, value: t.Any) -> None:
|
1223 |
| - return self._connection_pool.set_attribute("query_data", value) |
1224 |
| - |
1225 |
| - @property |
1226 |
| - def _query_jobs(self) -> t.Any: |
1227 |
| - query_jobs = self._connection_pool.get_attribute("query_jobs") |
1228 |
| - if not isinstance(query_jobs, set): |
1229 |
| - query_jobs = set() |
1230 |
| - self._connection_pool.set_attribute("query_jobs", query_jobs) |
1231 |
| - |
1232 |
| - return query_jobs |
| 1230 | + self._connection_pool.set_attribute("query_data", value) |
1233 | 1231 |
|
1234 | 1232 | @property
|
1235 |
| - def _query_job(self) -> t.Any: |
| 1233 | + def _query_job(self) -> t.Optional[QueryJob]: |
1236 | 1234 | return self._connection_pool.get_attribute("query_job")
|
1237 | 1235 |
|
1238 | 1236 | @_query_job.setter
|
1239 | 1237 | def _query_job(self, value: t.Any) -> None:
|
1240 |
| - return self._connection_pool.set_attribute("query_job", value) |
| 1238 | + self._connection_pool.set_attribute("query_job", value) |
1241 | 1239 |
|
1242 | 1240 | @property
|
1243 | 1241 | def _session_id(self) -> t.Any:
|
1244 | 1242 | return self._connection_pool.get_attribute("session_id")
|
1245 | 1243 |
|
1246 | 1244 | @_session_id.setter
|
1247 | 1245 | def _session_id(self, value: t.Any) -> None:
|
1248 |
| - return self._connection_pool.set_attribute("session_id", value) |
| 1246 | + self._connection_pool.set_attribute("session_id", value) |
1249 | 1247 |
|
1250 | 1248 |
|
1251 | 1249 | class _ErrorCounter:
|
|
0 commit comments