@@ -186,6 +186,14 @@ def query_factory() -> Query:
186
186
)
187
187
]
188
188
189
+ def close (self ) -> t .Any :
190
+ # Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts
191
+ for query_job in self ._query_jobs :
192
+ if not self ._db_call (query_job .done ):
193
+ self ._db_call (query_job .cancel )
194
+
195
+ return super ().close ()
196
+
189
197
def _begin_session (self , properties : SessionProperties ) -> None :
190
198
from google .cloud .bigquery import QueryJobConfig
191
199
@@ -1021,6 +1029,7 @@ def _execute(
1021
1029
job_config = job_config ,
1022
1030
timeout = self ._extra_config .get ("job_creation_timeout_seconds" ),
1023
1031
)
1032
+ self ._query_jobs .add (self ._query_job )
1024
1033
1025
1034
logger .debug (
1026
1035
"BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s" ,
@@ -1029,21 +1038,12 @@ def _execute(
1029
1038
self ._query_job .job_id ,
1030
1039
)
1031
1040
1032
- try :
1033
- results = self ._db_call (
1034
- self ._query_job .result ,
1035
- timeout = self ._extra_config .get ("job_execution_timeout_seconds" ), # type: ignore
1036
- )
1037
- except KeyboardInterrupt :
1038
- # Wrapping this in another try-except to ensure the subsequent db calls don't change
1039
- # the original exception type.
1040
- try :
1041
- if not self ._db_call (self ._query_job .done ):
1042
- self ._db_call (self ._query_job .cancel )
1043
- except :
1044
- pass
1041
+ results = self ._db_call (
1042
+ self ._query_job .result ,
1043
+ timeout = self ._extra_config .get ("job_execution_timeout_seconds" ), # type: ignore
1044
+ )
1045
1045
1046
- raise
1046
+ self . _query_jobs . remove ( self . _query_job )
1047
1047
1048
1048
self ._query_data = iter (results ) if results .total_rows else iter ([])
1049
1049
query_results = self ._query_job ._query_results
@@ -1212,6 +1212,15 @@ def _query_data(self) -> t.Any:
1212
1212
def _query_data (self , value : t .Any ) -> None :
1213
1213
return self ._connection_pool .set_attribute ("query_data" , value )
1214
1214
1215
+ @property
1216
+ def _query_jobs (self ) -> t .Any :
1217
+ query_jobs = self ._connection_pool .get_attribute ("query_jobs" )
1218
+ if not isinstance (query_jobs , set ):
1219
+ query_jobs = set ()
1220
+ self ._connection_pool .set_attribute ("query_jobs" , query_jobs )
1221
+
1222
+ return query_jobs
1223
+
1215
1224
@property
1216
1225
def _query_job (self ) -> t .Any :
1217
1226
return self ._connection_pool .get_attribute ("query_job" )
0 commit comments