@@ -186,6 +186,14 @@ def query_factory() -> Query:
186
186
)
187
187
]
188
188
189
+ def close (self ) -> t .Any :
190
+ # Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts
191
+ for query_job in self ._query_jobs :
192
+ if not self ._db_call (query_job .done ):
193
+ self ._db_call (query_job .cancel )
194
+
195
+ return super ().close ()
196
+
189
197
def _begin_session (self , properties : SessionProperties ) -> None :
190
198
from google .cloud .bigquery import QueryJobConfig
191
199
@@ -1009,6 +1017,7 @@ def _execute(
1009
1017
job_config = job_config ,
1010
1018
timeout = self ._extra_config .get ("job_creation_timeout_seconds" ),
1011
1019
)
1020
+ self ._query_jobs .add (self ._query_job )
1012
1021
1013
1022
logger .debug (
1014
1023
"BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s" ,
@@ -1017,21 +1026,12 @@ def _execute(
1017
1026
self ._query_job .job_id ,
1018
1027
)
1019
1028
1020
- try :
1021
- results = self ._db_call (
1022
- self ._query_job .result ,
1023
- timeout = self ._extra_config .get ("job_execution_timeout_seconds" ), # type: ignore
1024
- )
1025
- except KeyboardInterrupt :
1026
- # Wrapping this in another try-except to ensure the subsequent db calls don't change
1027
- # the original exception type.
1028
- try :
1029
- if not self ._db_call (self ._query_job .done ):
1030
- self ._db_call (self ._query_job .cancel )
1031
- except :
1032
- pass
1029
+ results = self ._db_call (
1030
+ self ._query_job .result ,
1031
+ timeout = self ._extra_config .get ("job_execution_timeout_seconds" ), # type: ignore
1032
+ )
1033
1033
1034
- raise
1034
+ self . _query_jobs . remove ( self . _query_job )
1035
1035
1036
1036
self ._query_data = iter (results ) if results .total_rows else iter ([])
1037
1037
query_results = self ._query_job ._query_results
@@ -1200,6 +1200,15 @@ def _query_data(self) -> t.Any:
1200
1200
def _query_data (self , value : t .Any ) -> None :
1201
1201
return self ._connection_pool .set_attribute ("query_data" , value )
1202
1202
1203
+ @property
1204
+ def _query_jobs (self ) -> t .Any :
1205
+ query_jobs = self ._connection_pool .get_attribute ("query_jobs" )
1206
+ if not isinstance (query_jobs , set ):
1207
+ query_jobs = set ()
1208
+ self ._connection_pool .set_attribute ("query_jobs" , query_jobs )
1209
+
1210
+ return query_jobs
1211
+
1203
1212
@property
1204
1213
def _query_job (self ) -> t .Any :
1205
1214
return self ._connection_pool .get_attribute ("query_job" )
0 commit comments