|
33 | 33 | from google.cloud import bigquery
|
34 | 34 | from google.cloud.bigquery import StandardSqlDataType
|
35 | 35 | from google.cloud.bigquery.client import Client as BigQueryClient
|
| 36 | + from google.cloud.bigquery.job import QueryJob |
36 | 37 | from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult
|
37 | 38 | from google.cloud.bigquery.table import Table as BigQueryTable
|
38 | 39 |
|
@@ -187,21 +188,23 @@ def query_factory() -> Query:
|
187 | 188 | ]
|
188 | 189 |
|
189 | 190 | def close(self) -> t.Any:
|
190 |
| - # Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts |
191 |
| - for query_job in self._query_jobs: |
192 |
| - try: |
193 |
| - if not self._db_call(query_job.done): |
194 |
| - self._db_call(query_job.cancel) |
195 |
| - except Exception as ex: |
196 |
| - logger.debug( |
197 |
| - "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", |
198 |
| - self._query_job.project, |
199 |
| - self._query_job.location, |
200 |
| - self._query_job.job_id, |
201 |
| - str(ex), |
202 |
| - ) |
| 191 | + query_job = self._query_job |
| 192 | + if not query_job: |
| 193 | + return super().close() |
| 194 | + |
| 195 | + # Cancel the last submitted query job if it's still pending, to avoid it becoming orphan (e.g., if interrupted) |
| 196 | + try: |
| 197 | + if not self._db_call(query_job.done): |
| 198 | + self._db_call(query_job.cancel) |
| 199 | + except Exception as ex: |
| 200 | + logger.debug( |
| 201 | + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", |
| 202 | + query_job.project, |
| 203 | + query_job.location, |
| 204 | + query_job.job_id, |
| 205 | + str(ex), |
| 206 | + ) |
203 | 207 |
|
204 |
| - self._query_jobs.clear() |
205 | 208 | return super().close()
|
206 | 209 |
|
207 | 210 | def _begin_session(self, properties: SessionProperties) -> None:
|
@@ -336,7 +339,10 @@ def create_mapping_schema(
|
336 | 339 | if len(table.parts) == 3 and "." in table.name:
|
337 | 340 | # The client's `get_table` method can't handle paths with >3 identifiers
|
338 | 341 | self.execute(exp.select("*").from_(table).limit(0))
|
339 |
| - query_results = self._query_job._query_results |
| 342 | + query_job = self._query_job |
| 343 | + assert query_job is not None |
| 344 | + |
| 345 | + query_results = query_job._query_results |
340 | 346 | columns = create_mapping_schema(query_results.schema)
|
341 | 347 | else:
|
342 | 348 | bq_table = self._get_table(table)
|
@@ -723,7 +729,9 @@ def _fetch_native_df(
|
723 | 729 | self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
|
724 | 730 | ) -> DF:
|
725 | 731 | self.execute(query, quote_identifiers=quote_identifiers)
|
726 |
| - return self._query_job.to_dataframe() |
| 732 | + query_job = self._query_job |
| 733 | + assert query_job is not None |
| 734 | + return query_job.to_dataframe() |
727 | 735 |
|
728 | 736 | def _create_column_comments(
|
729 | 737 | self,
|
@@ -1027,24 +1035,23 @@ def _execute(
|
1027 | 1035 | job_config=job_config,
|
1028 | 1036 | timeout=self._extra_config.get("job_creation_timeout_seconds"),
|
1029 | 1037 | )
|
1030 |
| - self._query_jobs.add(self._query_job) |
| 1038 | + query_job = self._query_job |
| 1039 | + assert query_job is not None |
1031 | 1040 |
|
1032 | 1041 | logger.debug(
|
1033 | 1042 | "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
|
1034 |
| - self._query_job.project, |
1035 |
| - self._query_job.location, |
1036 |
| - self._query_job.job_id, |
| 1043 | + query_job.project, |
| 1044 | + query_job.location, |
| 1045 | + query_job.job_id, |
1037 | 1046 | )
|
1038 | 1047 |
|
1039 | 1048 | results = self._db_call(
|
1040 |
| - self._query_job.result, |
| 1049 | + query_job.result, |
1041 | 1050 | timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore
|
1042 | 1051 | )
|
1043 | 1052 |
|
1044 |
| - self._query_jobs.remove(self._query_job) |
1045 |
| - |
1046 | 1053 | self._query_data = iter(results) if results.total_rows else iter([])
|
1047 |
| - query_results = self._query_job._query_results |
| 1054 | + query_results = query_job._query_results |
1048 | 1055 | self.cursor._set_rowcount(query_results)
|
1049 | 1056 | self.cursor._set_description(query_results.schema)
|
1050 | 1057 |
|
@@ -1208,32 +1215,23 @@ def _query_data(self) -> t.Any:
|
1208 | 1215 |
|
1209 | 1216 | @_query_data.setter
|
1210 | 1217 | def _query_data(self, value: t.Any) -> None:
|
1211 |
| - return self._connection_pool.set_attribute("query_data", value) |
1212 |
| - |
1213 |
| - @property |
1214 |
| - def _query_jobs(self) -> t.Any: |
1215 |
| - query_jobs = self._connection_pool.get_attribute("query_jobs") |
1216 |
| - if not isinstance(query_jobs, set): |
1217 |
| - query_jobs = set() |
1218 |
| - self._connection_pool.set_attribute("query_jobs", query_jobs) |
1219 |
| - |
1220 |
| - return query_jobs |
| 1218 | + self._connection_pool.set_attribute("query_data", value) |
1221 | 1219 |
|
1222 | 1220 | @property
|
1223 |
| - def _query_job(self) -> t.Any: |
| 1221 | + def _query_job(self) -> t.Optional[QueryJob]: |
1224 | 1222 | return self._connection_pool.get_attribute("query_job")
|
1225 | 1223 |
|
1226 | 1224 | @_query_job.setter
|
1227 | 1225 | def _query_job(self, value: t.Any) -> None:
|
1228 |
| - return self._connection_pool.set_attribute("query_job", value) |
| 1226 | + self._connection_pool.set_attribute("query_job", value) |
1229 | 1227 |
|
1230 | 1228 | @property
|
1231 | 1229 | def _session_id(self) -> t.Any:
|
1232 | 1230 | return self._connection_pool.get_attribute("session_id")
|
1233 | 1231 |
|
1234 | 1232 | @_session_id.setter
|
1235 | 1233 | def _session_id(self, value: t.Any) -> None:
|
1236 |
| - return self._connection_pool.set_attribute("session_id", value) |
| 1234 | + self._connection_pool.set_attribute("session_id", value) |
1237 | 1235 |
|
1238 | 1236 |
|
1239 | 1237 | class _ErrorCounter:
|
|
0 commit comments