From 8bb9046b583ec3a1ec4036e07aecc189ccb676c6 Mon Sep 17 00:00:00 2001 From: kiraksi Date: Mon, 4 Mar 2024 10:25:25 -0800 Subject: [PATCH 1/6] feat: AsyncClient for async query_and_wait for BQ jobs --- google/cloud/bigquery/async_client.py | 225 ++++++++++++++++++ tests/unit/test_async_client.py | 324 ++++++++++++++++++++++++++ 2 files changed, 549 insertions(+) create mode 100644 google/cloud/bigquery/async_client.py create mode 100644 tests/unit/test_async_client.py diff --git a/google/cloud/bigquery/async_client.py b/google/cloud/bigquery/async_client.py new file mode 100644 index 000000000..67550966b --- /dev/null +++ b/google/cloud/bigquery/async_client.py @@ -0,0 +1,225 @@ +from google.cloud.bigquery.client import * +from google.cloud.bigquery import _job_helpers +from google.cloud.bigquery import table +import asyncio +from google.api_core import gapic_v1, retry_async + +class AsyncClient(Client): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + + async def query_and_wait( + self, + query, + *, + job_config: Optional[QueryJobConfig] = None, + location: Optional[str] = None, + project: Optional[str] = None, + api_timeout: TimeoutType = DEFAULT_TIMEOUT, + wait_timeout: TimeoutType = None, + retry: retries.Retry = DEFAULT_RETRY, + job_retry: retries.Retry = DEFAULT_JOB_RETRY, + page_size: Optional[int] = None, + max_results: Optional[int] = None, + ) -> RowIterator: + + if project is None: + project = self.project + + if location is None: + location = self.location + + # if job_config is not None: + # self._verify_job_config_type(job_config, QueryJobConfig) + + # if job_config is not None: + # self._verify_job_config_type(job_config, QueryJobConfig) + + job_config = _job_helpers.job_config_with_defaults( + job_config, self._default_query_job_config + ) + + return await async_query_and_wait( + self, + query, + job_config=job_config, + location=location, + project=project, + api_timeout=api_timeout, + wait_timeout=wait_timeout, + retry=retry, + job_retry=job_retry, + page_size=page_size, + max_results=max_results, + ) + + +async def async_query_and_wait( + client: "Client", + query: str, + *, + job_config: Optional[job.QueryJobConfig], + location: Optional[str], + project: str, + api_timeout: Optional[float] = None, + wait_timeout: Optional[float] = None, + retry: Optional[retries.Retry], + job_retry: Optional[retries.Retry], + page_size: Optional[int] = None, + max_results: Optional[int] = None, +) -> table.RowIterator: + + # Some API parameters aren't supported by the jobs.query API. In these + # cases, fallback to a jobs.insert call. + if not _job_helpers._supported_by_jobs_query(job_config): + return await async_wait_or_cancel( + _job_helpers.query_jobs_insert( + client=client, + query=query, + job_id=None, + job_id_prefix=None, + job_config=job_config, + location=location, + project=project, + retry=retry, + timeout=api_timeout, + job_retry=job_retry, + ), + api_timeout=api_timeout, + wait_timeout=wait_timeout, + retry=retry, + page_size=page_size, + max_results=max_results, + ) + + path = _job_helpers._to_query_path(project) + request_body = _job_helpers._to_query_request( + query=query, job_config=job_config, location=location, timeout=api_timeout + ) + + if page_size is not None and max_results is not None: + request_body["maxResults"] = min(page_size, max_results) + elif page_size is not None or max_results is not None: + request_body["maxResults"] = page_size or max_results + + if os.getenv("QUERY_PREVIEW_ENABLED", "").casefold() == "true": + request_body["jobCreationMode"] = "JOB_CREATION_OPTIONAL" + + async def do_query(): + request_body["requestId"] = _job_helpers.make_job_id() + span_attributes = {"path": path} + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + client._call_api, + default_retry=retry_async.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.ServiceUnavailable, + ), + deadline=60.0, + ), + default_timeout=60.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # For easier testing, handle the retries ourselves. + # if retry is not None: + # response = retry(client._call_api)( + # retry=None, # We're calling the retry decorator ourselves. + # span_name="BigQuery.query", + # span_attributes=span_attributes, + # method="POST", + # path=path, + # data=request_body, + # timeout=api_timeout, + # ) + # else: + response = await rpc( + retry=None, + span_name="BigQuery.query", + span_attributes=span_attributes, + method="POST", + path=path, + data=request_body, + timeout=api_timeout, + ) + + # Even if we run with JOB_CREATION_OPTIONAL, if there are more pages + # to fetch, there will be a job ID for jobs.getQueryResults. + query_results = google.cloud.bigquery.query._QueryResults.from_api_repr( + response + ) + page_token = query_results.page_token + more_pages = page_token is not None + + if more_pages or not query_results.complete: + # TODO(swast): Avoid a call to jobs.get in some cases (few + # remaining pages) by waiting for the query to finish and calling + # client._list_rows_from_query_results directly. Need to update + # RowIterator to fetch destination table via the job ID if needed. + return await async_wait_or_cancel( + _job_helpers._to_query_job(client, query, job_config, response), + api_timeout=api_timeout, + wait_timeout=wait_timeout, + retry=retry, + page_size=page_size, + max_results=max_results, + ) + + return table.RowIterator( + client=client, + api_request=functools.partial(client._call_api, retry, timeout=api_timeout), + path=None, + schema=query_results.schema, + max_results=max_results, + page_size=page_size, + total_rows=query_results.total_rows, + first_page_response=response, + location=query_results.location, + job_id=query_results.job_id, + query_id=query_results.query_id, + project=query_results.project, + num_dml_affected_rows=query_results.num_dml_affected_rows, + ) + + + if job_retry is not None: + return job_retry(do_query)() + else: + return do_query() + +async def async_wait_or_cancel( + job: job.QueryJob, + api_timeout: Optional[float], + wait_timeout: Optional[float], + retry: Optional[retries.Retry], + page_size: Optional[int], + max_results: Optional[int], +) -> table.RowIterator: + try: + return await job.result( + page_size=page_size, + max_results=max_results, + retry=retry, + timeout=wait_timeout, + ) + except Exception: + # Attempt to cancel the job since we can't return the results. + try: + job.cancel(retry=retry, timeout=api_timeout) + except Exception: + # Don't eat the original exception if cancel fails. + pass + raise + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + "3.17.2" +) + +__all__ = ("AsyncClient",) \ No newline at end of file diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py new file mode 100644 index 000000000..9df9b5513 --- /dev/null +++ b/tests/unit/test_async_client.py @@ -0,0 +1,324 @@ +import copy +import collections +import datetime +import decimal +import email +import gzip +import http.client +import io +import itertools +import json +import operator +import unittest +import warnings + +import mock +import requests +import packaging +import pytest +import sys +import inspect + +if sys.version_info >= (3, 9): + import asyncio + +try: + import importlib.metadata as metadata +except ImportError: + import importlib_metadata as metadata + +try: + import pandas +except (ImportError, AttributeError): # pragma: NO COVER + pandas = None + +try: + import opentelemetry +except ImportError: + opentelemetry = None + +if opentelemetry is not None: + try: + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + except (ImportError, AttributeError) as exc: # pragma: NO COVER + msg = "Error importing from opentelemetry, is the installed version compatible?" + raise ImportError(msg) from exc + +try: + import pyarrow +except (ImportError, AttributeError): # pragma: NO COVER + pyarrow = None + +import google.api_core.exceptions +from google.api_core import client_info +import google.cloud._helpers +from google.cloud import bigquery + +from google.cloud.bigquery.dataset import DatasetReference +from google.cloud.bigquery import exceptions +from google.cloud.bigquery import ParquetOptions +from google.cloud.bigquery.retry import DEFAULT_TIMEOUT +import google.cloud.bigquery.table + +try: + from google.cloud import bigquery_storage +except (ImportError, AttributeError): # pragma: NO COVER + bigquery_storage = None +from test_utils.imports import maybe_fail_import +from tests.unit.helpers import make_connection + +if pandas is not None: + PANDAS_INSTALLED_VERSION = metadata.version("pandas") +else: + PANDAS_INSTALLED_VERSION = "0.0.0" + +def asyncio_run(async_func): + def wrapper(*args, **kwargs): + return asyncio.run(async_func(*args, **kwargs)) + + wrapper.__signature__ = inspect.signature( + async_func + ) # without this, fixtures are not injected + + return wrapper + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + + +class TestClient(unittest.TestCase): + PROJECT = "PROJECT" + DS_ID = "DATASET_ID" + TABLE_ID = "TABLE_ID" + MODEL_ID = "MODEL_ID" + TABLE_REF = DatasetReference(PROJECT, DS_ID).table(TABLE_ID) + KMS_KEY_NAME = "projects/1/locations/us/keyRings/1/cryptoKeys/1" + LOCATION = "us-central" + + @staticmethod + def _get_target_class(): + from google.cloud.bigquery.async_client import AsyncClient + + return AsyncClient + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def _make_table_resource(self): + return { + "id": "%s:%s:%s" % (self.PROJECT, self.DS_ID, self.TABLE_ID), + "tableReference": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "tableId": self.TABLE_ID, + }, + } + + def test_ctor_defaults(self): + from google.cloud.bigquery._http import Connection + + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertIsNone(client.location) + self.assertEqual( + client._connection.API_BASE_URL, Connection.DEFAULT_API_ENDPOINT + ) + + def test_ctor_w_empty_client_options(self): + from google.api_core.client_options import ClientOptions + + creds = _make_credentials() + http = object() + client_options = ClientOptions() + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + ) + self.assertEqual( + client._connection.API_BASE_URL, client._connection.DEFAULT_API_ENDPOINT + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_query_and_wait_defaults(self): + query = "select count(*) from `bigquery-public-data.usa_names.usa_1910_2013`" + jobs_query_response = { + "jobComplete": True, + "schema": { + "fields": [ + { + "name": "f0_", + "type": "INTEGER", + "mode": "NULLABLE", + }, + ], + }, + "totalRows": "1", + "rows": [{"f": [{"v": "5552452"}]}], + "queryId": "job_abcDEF_", + } + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(jobs_query_response) + + rows = await client.query_and_wait(query) + + self.assertIsInstance(rows, google.cloud.bigquery.table.RowIterator) + self.assertEqual(rows.query_id, "job_abcDEF_") + self.assertEqual(rows.total_rows, 1) + # No job reference in the response should be OK for completed query. + self.assertIsNone(rows.job_id) + self.assertIsNone(rows.project) + self.assertIsNone(rows.location) + + # Verify the request we send is to jobs.query. + conn.api_request = await conn.api_request + conn.api_request.assert_called_once() + _, req = conn.api_request.call_args + self.assertEqual(req["method"], "POST") + self.assertEqual(req["path"], "/projects/PROJECT/queries") + self.assertEqual(req["timeout"], DEFAULT_TIMEOUT) + sent = req["data"] + self.assertEqual(sent["query"], query) + self.assertFalse(sent["useLegacySql"]) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_query_and_wait_w_default_query_job_config(self): + from google.cloud.bigquery import job + + query = "select count(*) from `bigquery-public-data.usa_names.usa_1910_2013`" + jobs_query_response = { + "jobComplete": True, + } + creds = _make_credentials() + http = object() + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + default_query_job_config=job.QueryJobConfig( + labels={ + "default-label": "default-value", + }, + ), + ) + conn = client._connection = make_connection(jobs_query_response) + + future_result = client.query_and_wait(query) + _ = await future_result + + # Verify the request we send is to jobs.query. + # Instantiate my query path, dumping call stacks to see where I am. Get the address of my mocked call and actual call thats invoked, see if thats the same. See if my mocked thing is the thing getting invoked or not. + # conn.api_request.assert_called_once() + _, req = conn.api_request.call_args + self.assertEqual(req["method"], "POST") + self.assertEqual(req["path"], f"/projects/{self.PROJECT}/queries") + sent = req["data"] + self.assertEqual(sent["labels"], {"default-label": "default-value"}) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_query_and_wait_w_job_config(self): + from google.cloud.bigquery import job + + query = "select count(*) from `bigquery-public-data.usa_names.usa_1910_2013`" + jobs_query_response = { + "jobComplete": True, + } + creds = _make_credentials() + http = object() + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + ) + conn = client._connection = make_connection(jobs_query_response) + + future_result = client.query_and_wait( + query, + job_config=job.QueryJobConfig( + labels={ + "job_config-label": "job_config-value", + }, + ), + ) + rows = await future_result + + # Verify the request we send is to jobs.query. + # conn.api_request.assert_called_once() + _, req = conn.api_request.call_args + self.assertEqual(req["method"], "POST") + self.assertEqual(req["path"], f"/projects/{self.PROJECT}/queries") + sent = req["data"] + self.assertEqual(sent["labels"], {"job_config-label": "job_config-value"}) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_query_and_wait_w_location(self): + query = "select count(*) from `bigquery-public-data.usa_names.usa_1910_2013`" + jobs_query_response = { + "jobComplete": True, + } + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(jobs_query_response) + + future_result = client.query_and_wait(query, location="not-the-client-location") + _ = await future_result + + # Verify the request we send is to jobs.query. + # conn.api_request.assert_called_once() + _, req = conn.api_request.call_args + self.assertEqual(req["method"], "POST") + self.assertEqual(req["path"], f"/projects/{self.PROJECT}/queries") + sent = req["data"] + self.assertEqual(sent["location"], "not-the-client-location") + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_query_and_wait_w_project(self): + query = "select count(*) from `bigquery-public-data.usa_names.usa_1910_2013`" + jobs_query_response = { + "jobComplete": True, + } + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(jobs_query_response) + + future_result = client.query_and_wait(query, project="not-the-client-project") + _ = await future_result + + # Verify the request we send is to jobs.query. + # conn.api_request.assert_called_once() + _, req = conn.api_request.call_args + self.assertEqual(req["method"], "POST") + self.assertEqual(req["path"], "/projects/not-the-client-project/queries") \ No newline at end of file From 4abc0b569cd3d88f3b01cbe08a1f54b48ab3967a Mon Sep 17 00:00:00 2001 From: kiraksi Date: Mon, 4 Mar 2024 10:26:27 -0800 Subject: [PATCH 2/6] refactor to make it closer to synchronous execution for testing --- google/cloud/bigquery/async_client.py | 70 +++++++++------------------ 1 file changed, 22 insertions(+), 48 deletions(-) diff --git a/google/cloud/bigquery/async_client.py b/google/cloud/bigquery/async_client.py index 67550966b..08f0f4d1e 100644 --- a/google/cloud/bigquery/async_client.py +++ b/google/cloud/bigquery/async_client.py @@ -2,7 +2,6 @@ from google.cloud.bigquery import _job_helpers from google.cloud.bigquery import table import asyncio -from google.api_core import gapic_v1, retry_async class AsyncClient(Client): def __init__(self, *args, **kwargs): @@ -110,44 +109,27 @@ async def do_query(): request_body["requestId"] = _job_helpers.make_job_id() span_attributes = {"path": path} - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - client._call_api, - default_retry=retry_async.AsyncRetry( - initial=0.1, - maximum=60.0, - multiplier=1.3, - predicate=retries.if_exception_type( - core_exceptions.ServiceUnavailable, - ), - deadline=60.0, - ), - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) - # For easier testing, handle the retries ourselves. - # if retry is not None: - # response = retry(client._call_api)( - # retry=None, # We're calling the retry decorator ourselves. - # span_name="BigQuery.query", - # span_attributes=span_attributes, - # method="POST", - # path=path, - # data=request_body, - # timeout=api_timeout, - # ) - # else: - response = await rpc( - retry=None, - span_name="BigQuery.query", - span_attributes=span_attributes, - method="POST", - path=path, - data=request_body, - timeout=api_timeout, - ) + if retry is not None: + response = retry(client._call_api)( + retry=None, # We're calling the retry decorator ourselves. + span_name="BigQuery.query", + span_attributes=span_attributes, + method="POST", + path=path, + data=request_body, + timeout=api_timeout, + ) + else: + response = client._call_api( + retry=None, + span_name="BigQuery.query", + span_attributes=span_attributes, + method="POST", + path=path, + data=request_body, + timeout=api_timeout, + ) # Even if we run with JOB_CREATION_OPTIONAL, if there are more pages # to fetch, there will be a job ID for jobs.getQueryResults. @@ -186,12 +168,11 @@ async def do_query(): project=query_results.project, num_dml_affected_rows=query_results.num_dml_affected_rows, ) - if job_retry is not None: return job_retry(do_query)() else: - return do_query() + return await do_query() async def async_wait_or_cancel( job: job.QueryJob, @@ -215,11 +196,4 @@ async def async_wait_or_cancel( except Exception: # Don't eat the original exception if cancel fails. pass - raise - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - "3.17.2" -) - -__all__ = ("AsyncClient",) \ No newline at end of file + raise \ No newline at end of file From 7ba62e8c8559b6ac14d7b28376ec1e760318576d Mon Sep 17 00:00:00 2001 From: kiraksi Date: Tue, 5 Mar 2024 08:53:28 -0800 Subject: [PATCH 3/6] refactoring by plumping async through more coroutines, added notes for more work, added async_retries(breaking tests) --- google/cloud/bigquery/async_client.py | 165 ++++++++++++++------------ google/cloud/bigquery/retry.py | 8 +- setup.py | 3 + tests/unit/test_async_client.py | 15 ++- 4 files changed, 104 insertions(+), 87 deletions(-) diff --git a/google/cloud/bigquery/async_client.py b/google/cloud/bigquery/async_client.py index 08f0f4d1e..c471cfaaa 100644 --- a/google/cloud/bigquery/async_client.py +++ b/google/cloud/bigquery/async_client.py @@ -1,11 +1,19 @@ from google.cloud.bigquery.client import * from google.cloud.bigquery import _job_helpers from google.cloud.bigquery import table +from google.cloud.bigquery.retry import ( + DEFAULT_ASYNC_JOB_RETRY, + DEFAULT_ASYNC_RETRY, + DEFAULT_TIMEOUT, +) +from google.api_core import retry_async as retries import asyncio +import google.auth.transport._aiohttp_requests -class AsyncClient(Client): + +class AsyncClient(): def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + self._client = Client(*args, **kwargs) async def query_and_wait( @@ -17,30 +25,30 @@ async def query_and_wait( project: Optional[str] = None, api_timeout: TimeoutType = DEFAULT_TIMEOUT, wait_timeout: TimeoutType = None, - retry: retries.Retry = DEFAULT_RETRY, - job_retry: retries.Retry = DEFAULT_JOB_RETRY, + retry: retries.AsyncRetry = DEFAULT_ASYNC_RETRY, + job_retry: retries.AsyncRetry = DEFAULT_ASYNC_JOB_RETRY, page_size: Optional[int] = None, max_results: Optional[int] = None, ) -> RowIterator: if project is None: - project = self.project + project = self._client.project if location is None: - location = self.location + location = self._client.location # if job_config is not None: - # self._verify_job_config_type(job_config, QueryJobConfig) + # self._client._verify_job_config_type(job_config, QueryJobConfig) # if job_config is not None: - # self._verify_job_config_type(job_config, QueryJobConfig) + # self._client._verify_job_config_type(job_config, QueryJobConfig) job_config = _job_helpers.job_config_with_defaults( - job_config, self._default_query_job_config + job_config, self._client._default_query_job_config ) return await async_query_and_wait( - self, + self._client, query, job_config=job_config, location=location, @@ -63,8 +71,8 @@ async def async_query_and_wait( project: str, api_timeout: Optional[float] = None, wait_timeout: Optional[float] = None, - retry: Optional[retries.Retry], - job_retry: Optional[retries.Retry], + retry: Optional[retries.AsyncRetry], + job_retry: Optional[retries.AsyncRetry], page_size: Optional[int] = None, max_results: Optional[int] = None, ) -> table.RowIterator: @@ -73,7 +81,7 @@ async def async_query_and_wait( # cases, fallback to a jobs.insert call. if not _job_helpers._supported_by_jobs_query(job_config): return await async_wait_or_cancel( - _job_helpers.query_jobs_insert( + asyncio.to_thread(_job_helpers.query_jobs_insert( # throw in a background thread client=client, query=query, job_id=None, @@ -84,7 +92,7 @@ async def async_query_and_wait( retry=retry, timeout=api_timeout, job_retry=job_retry, - ), + )), api_timeout=api_timeout, wait_timeout=wait_timeout, retry=retry, @@ -105,90 +113,91 @@ async def async_query_and_wait( if os.getenv("QUERY_PREVIEW_ENABLED", "").casefold() == "true": request_body["jobCreationMode"] = "JOB_CREATION_OPTIONAL" - async def do_query(): - request_body["requestId"] = _job_helpers.make_job_id() - span_attributes = {"path": path} - - # For easier testing, handle the retries ourselves. - if retry is not None: - response = retry(client._call_api)( - retry=None, # We're calling the retry decorator ourselves. - span_name="BigQuery.query", - span_attributes=span_attributes, - method="POST", - path=path, - data=request_body, - timeout=api_timeout, - ) - else: - response = client._call_api( - retry=None, - span_name="BigQuery.query", - span_attributes=span_attributes, - method="POST", - path=path, - data=request_body, - timeout=api_timeout, - ) - # Even if we run with JOB_CREATION_OPTIONAL, if there are more pages - # to fetch, there will be a job ID for jobs.getQueryResults. - query_results = google.cloud.bigquery.query._QueryResults.from_api_repr( - response + request_body["requestId"] = _job_helpers.make_job_id() + span_attributes = {"path": path} + + # For easier testing, handle the retries ourselves. + if retry is not None: + response = retry(client._call_api)( # ASYNCHRONOUS HTTP CALLS aiohttp (optional of google-auth) + retry=None, # We're calling the retry decorator ourselves, async_retries + span_name="BigQuery.query", + span_attributes=span_attributes, + method="POST", + path=path, + data=request_body, + timeout=api_timeout, ) - page_token = query_results.page_token - more_pages = page_token is not None - - if more_pages or not query_results.complete: - # TODO(swast): Avoid a call to jobs.get in some cases (few - # remaining pages) by waiting for the query to finish and calling - # client._list_rows_from_query_results directly. Need to update - # RowIterator to fetch destination table via the job ID if needed. - return await async_wait_or_cancel( - _job_helpers._to_query_job(client, query, job_config, response), - api_timeout=api_timeout, - wait_timeout=wait_timeout, - retry=retry, - page_size=page_size, - max_results=max_results, - ) - - return table.RowIterator( - client=client, - api_request=functools.partial(client._call_api, retry, timeout=api_timeout), - path=None, - schema=query_results.schema, - max_results=max_results, + else: + response = client._call_api( + retry=None, + span_name="BigQuery.query", + span_attributes=span_attributes, + method="POST", + path=path, + data=request_body, + timeout=api_timeout, + ) + + # Even if we run with JOB_CREATION_OPTIONAL, if there are more pages + # to fetch, there will be a job ID for jobs.getQueryResults. + query_results = google.cloud.bigquery.query._QueryResults.from_api_repr( + await response + ) + page_token = query_results.page_token + more_pages = page_token is not None + + if more_pages or not query_results.complete: + # TODO(swast): Avoid a call to jobs.get in some cases (few + # remaining pages) by waiting for the query to finish and calling + # client._list_rows_from_query_results directly. Need to update + # RowIterator to fetch destination table via the job ID if needed. + result = await async_wait_or_cancel( + _job_helpers._to_query_job(client, query, job_config, response), + api_timeout=api_timeout, + wait_timeout=wait_timeout, + retry=retry, page_size=page_size, - total_rows=query_results.total_rows, - first_page_response=response, - location=query_results.location, - job_id=query_results.job_id, - query_id=query_results.query_id, - project=query_results.project, - num_dml_affected_rows=query_results.num_dml_affected_rows, + max_results=max_results, ) + result = table.RowIterator( # async of RowIterator? async version without all the pandas stuff + client=client, + api_request=functools.partial(client._call_api, retry, timeout=api_timeout), + path=None, + schema=query_results.schema, + max_results=max_results, + page_size=page_size, + total_rows=query_results.total_rows, + first_page_response=response, + location=query_results.location, + job_id=query_results.job_id, + query_id=query_results.query_id, + project=query_results.project, + num_dml_affected_rows=query_results.num_dml_affected_rows, + ) + + if job_retry is not None: - return job_retry(do_query)() + return job_retry(result) # AsyncRetries, new default objects, default_job_retry_async, default_retry_async else: - return await do_query() + return result async def async_wait_or_cancel( job: job.QueryJob, api_timeout: Optional[float], wait_timeout: Optional[float], - retry: Optional[retries.Retry], + retry: Optional[retries.AsyncRetry], page_size: Optional[int], max_results: Optional[int], ) -> table.RowIterator: try: - return await job.result( + return asyncio.to_thread(job.result( # run in a background thread page_size=page_size, max_results=max_results, retry=retry, timeout=wait_timeout, - ) + )) except Exception: # Attempt to cancel the job since we can't return the results. try: diff --git a/google/cloud/bigquery/retry.py b/google/cloud/bigquery/retry.py index 01b127972..9acbf1382 100644 --- a/google/cloud/bigquery/retry.py +++ b/google/cloud/bigquery/retry.py @@ -13,7 +13,7 @@ # limitations under the License. from google.api_core import exceptions -from google.api_core import retry +from google.api_core import retry, retry_async from google.auth import exceptions as auth_exceptions # type: ignore import requests.exceptions @@ -90,3 +90,9 @@ def _job_should_retry(exc): """ The default job retry object. """ + +DEFAULT_ASYNC_RETRY = retry_async.AsyncRetry(predicate=_should_retry, deadline=_DEFAULT_RETRY_DEADLINE) # deadline is deprecated + +DEFAULT_ASYNC_JOB_RETRY = retry_async.AsyncRetry( + predicate=_job_should_retry, deadline=_DEFAULT_JOB_DEADLINE # deadline is deprecated +) \ No newline at end of file diff --git a/setup.py b/setup.py index 5a35f4136..1c5025f29 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,9 @@ "proto-plus >= 1.15.0, <2.0.0dev", "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", # For the legacy proto-based types. ], + "google-auth": [ + "aiohttp", + ] } all_extras = [] diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py index 9df9b5513..a190b5973 100644 --- a/tests/unit/test_async_client.py +++ b/tests/unit/test_async_client.py @@ -128,7 +128,7 @@ def test_ctor_defaults(self): creds = _make_credentials() http = object() - client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)._client self.assertIsInstance(client._connection, Connection) self.assertIs(client._connection.credentials, creds) self.assertIs(client._connection.http, http) @@ -148,7 +148,7 @@ def test_ctor_w_empty_client_options(self): credentials=creds, _http=http, client_options=client_options, - ) + )._client self.assertEqual( client._connection.API_BASE_URL, client._connection.DEFAULT_API_ENDPOINT ) @@ -177,7 +177,7 @@ async def test_query_and_wait_defaults(self): creds = _make_credentials() http = object() client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) - conn = client._connection = make_connection(jobs_query_response) + conn = client._client._connection = make_connection(jobs_query_response) rows = await client.query_and_wait(query) @@ -190,7 +190,6 @@ async def test_query_and_wait_defaults(self): self.assertIsNone(rows.location) # Verify the request we send is to jobs.query. - conn.api_request = await conn.api_request conn.api_request.assert_called_once() _, req = conn.api_request.call_args self.assertEqual(req["method"], "POST") @@ -223,7 +222,7 @@ async def test_query_and_wait_w_default_query_job_config(self): }, ), ) - conn = client._connection = make_connection(jobs_query_response) + conn = client._client._connection = make_connection(jobs_query_response) future_result = client.query_and_wait(query) _ = await future_result @@ -255,7 +254,7 @@ async def test_query_and_wait_w_job_config(self): credentials=creds, _http=http, ) - conn = client._connection = make_connection(jobs_query_response) + conn = client._client._connection = make_connection(jobs_query_response) future_result = client.query_and_wait( query, @@ -287,7 +286,7 @@ async def test_query_and_wait_w_location(self): creds = _make_credentials() http = object() client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) - conn = client._connection = make_connection(jobs_query_response) + conn = client._client._connection = make_connection(jobs_query_response) future_result = client.query_and_wait(query, location="not-the-client-location") _ = await future_result @@ -312,7 +311,7 @@ async def test_query_and_wait_w_project(self): creds = _make_credentials() http = object() client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) - conn = client._connection = make_connection(jobs_query_response) + conn = client._client._connection = make_connection(jobs_query_response) future_result = client.query_and_wait(query, project="not-the-client-project") _ = await future_result From 6e74478d2aaacd8d2e9f3ca8ab0530cbcfc559ee Mon Sep 17 00:00:00 2001 From: kiraksi Date: Wed, 6 Mar 2024 01:08:51 -0800 Subject: [PATCH 4/6] Remove async retry from api call for now, remove rendudant comments, add in AsyncRetry and aiohttp --- google/cloud/bigquery/async_client.py | 70 ++++++----- google/cloud/bigquery/retry.py | 9 +- setup.py | 7 +- testing/constraints-3.9.txt | 1 + tests/unit/test_async_client.py | 163 ++++++++++++++++++++++++-- 5 files changed, 200 insertions(+), 50 deletions(-) diff --git a/google/cloud/bigquery/async_client.py b/google/cloud/bigquery/async_client.py index c471cfaaa..81bb9a197 100644 --- a/google/cloud/bigquery/async_client.py +++ b/google/cloud/bigquery/async_client.py @@ -8,14 +8,12 @@ ) from google.api_core import retry_async as retries import asyncio -import google.auth.transport._aiohttp_requests -class AsyncClient(): +class AsyncClient: def __init__(self, *args, **kwargs): self._client = Client(*args, **kwargs) - async def query_and_wait( self, query, @@ -29,14 +27,14 @@ async def query_and_wait( job_retry: retries.AsyncRetry = DEFAULT_ASYNC_JOB_RETRY, page_size: Optional[int] = None, max_results: Optional[int] = None, - ) -> RowIterator: - + ) -> RowIterator: if project is None: project = self._client.project if location is None: location = self._client.location + # for some reason these cannot find the function call # if job_config is not None: # self._client._verify_job_config_type(job_config, QueryJobConfig) @@ -62,7 +60,7 @@ async def query_and_wait( ) -async def async_query_and_wait( +async def async_query_and_wait( client: "Client", query: str, *, @@ -76,23 +74,24 @@ async def async_query_and_wait( page_size: Optional[int] = None, max_results: Optional[int] = None, ) -> table.RowIterator: - # Some API parameters aren't supported by the jobs.query API. In these # cases, fallback to a jobs.insert call. if not _job_helpers._supported_by_jobs_query(job_config): return await async_wait_or_cancel( - asyncio.to_thread(_job_helpers.query_jobs_insert( # throw in a background thread - client=client, - query=query, - job_id=None, - job_id_prefix=None, - job_config=job_config, - location=location, - project=project, - retry=retry, - timeout=api_timeout, - job_retry=job_retry, - )), + asyncio.to_thread( + _job_helpers.query_jobs_insert( + client=client, + query=query, + job_id=None, + job_id_prefix=None, + job_config=job_config, + location=location, + project=project, + retry=retry, + timeout=api_timeout, + job_retry=job_retry, + ) + ), api_timeout=api_timeout, wait_timeout=wait_timeout, retry=retry, @@ -113,14 +112,12 @@ async def async_query_and_wait( if os.getenv("QUERY_PREVIEW_ENABLED", "").casefold() == "true": request_body["jobCreationMode"] = "JOB_CREATION_OPTIONAL" - request_body["requestId"] = _job_helpers.make_job_id() span_attributes = {"path": path} - # For easier testing, handle the retries ourselves. if retry is not None: - response = retry(client._call_api)( # ASYNCHRONOUS HTTP CALLS aiohttp (optional of google-auth) - retry=None, # We're calling the retry decorator ourselves, async_retries + response = client._call_api( # ASYNCHRONOUS HTTP CALLS aiohttp (optional of google-auth), add back retry() + retry=None, # We're calling the retry decorator ourselves, async_retries, need to implement after making HTTP calls async span_name="BigQuery.query", span_attributes=span_attributes, method="POST", @@ -128,6 +125,7 @@ async def async_query_and_wait( data=request_body, timeout=api_timeout, ) + else: response = client._call_api( retry=None, @@ -141,9 +139,7 @@ async def async_query_and_wait( # Even if we run with JOB_CREATION_OPTIONAL, if there are more pages # to fetch, there will be a job ID for jobs.getQueryResults. - query_results = google.cloud.bigquery.query._QueryResults.from_api_repr( - await response - ) + query_results = google.cloud.bigquery.query._QueryResults.from_api_repr(response) page_token = query_results.page_token more_pages = page_token is not None @@ -161,7 +157,7 @@ async def async_query_and_wait( max_results=max_results, ) - result = table.RowIterator( # async of RowIterator? async version without all the pandas stuff + result = table.RowIterator( # async of RowIterator? async version without all the pandas stuff client=client, api_request=functools.partial(client._call_api, retry, timeout=api_timeout), path=None, @@ -177,12 +173,12 @@ async def async_query_and_wait( num_dml_affected_rows=query_results.num_dml_affected_rows, ) - if job_retry is not None: - return job_retry(result) # AsyncRetries, new default objects, default_job_retry_async, default_retry_async + return job_retry(result) else: return result + async def async_wait_or_cancel( job: job.QueryJob, api_timeout: Optional[float], @@ -192,12 +188,14 @@ async def async_wait_or_cancel( max_results: Optional[int], ) -> table.RowIterator: try: - return asyncio.to_thread(job.result( # run in a background thread - page_size=page_size, - max_results=max_results, - retry=retry, - timeout=wait_timeout, - )) + return asyncio.to_thread( + job.result( # run in a background thread + page_size=page_size, + max_results=max_results, + retry=retry, + timeout=wait_timeout, + ) + ) except Exception: # Attempt to cancel the job since we can't return the results. try: @@ -205,4 +203,4 @@ async def async_wait_or_cancel( except Exception: # Don't eat the original exception if cancel fails. pass - raise \ No newline at end of file + raise diff --git a/google/cloud/bigquery/retry.py b/google/cloud/bigquery/retry.py index 9acbf1382..c5fbb7fda 100644 --- a/google/cloud/bigquery/retry.py +++ b/google/cloud/bigquery/retry.py @@ -91,8 +91,11 @@ def _job_should_retry(exc): The default job retry object. """ -DEFAULT_ASYNC_RETRY = retry_async.AsyncRetry(predicate=_should_retry, deadline=_DEFAULT_RETRY_DEADLINE) # deadline is deprecated +DEFAULT_ASYNC_RETRY = retry_async.AsyncRetry( + predicate=_should_retry, deadline=_DEFAULT_RETRY_DEADLINE +) # deadline is deprecated DEFAULT_ASYNC_JOB_RETRY = retry_async.AsyncRetry( - predicate=_job_should_retry, deadline=_DEFAULT_JOB_DEADLINE # deadline is deprecated -) \ No newline at end of file + predicate=_job_should_retry, + deadline=_DEFAULT_JOB_DEADLINE, # deadline is deprecated +) diff --git a/setup.py b/setup.py index 1c5025f29..9f6fabcfc 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ # NOTE: Maintainers, please do not require google-cloud-core>=2.x.x # Until this issue is closed # https://github.com/googleapis/google-cloud-python/issues/10566 + "google-auth >= 2.14.1, <3.0.0dev", "google-cloud-core >= 1.6.0, <3.0.0dev", "google-resumable-media >= 0.6.0, < 3.0dev", "packaging >= 20.0.0", @@ -84,9 +85,9 @@ "proto-plus >= 1.15.0, <2.0.0dev", "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", # For the legacy proto-based types. ], - "google-auth": [ - "aiohttp", - ] + "aiohttp": [ + "google-auth[aiohttp]", + ], } all_extras = [] diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index d4c302867..f4adf95c3 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -4,5 +4,6 @@ # # NOTE: Not comprehensive yet, will eventually be maintained semi-automatically by # the renovate bot. +aiohttp==3.6.2 grpcio==1.47.0 pyarrow>=4.0.0 diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py index a190b5973..472504711 100644 --- a/tests/unit/test_async_client.py +++ b/tests/unit/test_async_client.py @@ -77,6 +77,17 @@ else: PANDAS_INSTALLED_VERSION = "0.0.0" +from google.cloud.bigquery.retry import ( + DEFAULT_ASYNC_JOB_RETRY, + DEFAULT_ASYNC_RETRY, + DEFAULT_TIMEOUT, +) +from google.api_core import retry_async as retries +from google.cloud.bigquery import async_client +from google.cloud.bigquery.async_client import AsyncClient +from google.cloud.bigquery.job import query as job_query + + def asyncio_run(async_func): def wrapper(*args, **kwargs): return asyncio.run(async_func(*args, **kwargs)) @@ -94,7 +105,6 @@ def _make_credentials(): return mock.Mock(spec=google.auth.credentials.Credentials) - class TestClient(unittest.TestCase): PROJECT = "PROJECT" DS_ID = "DATASET_ID" @@ -123,12 +133,17 @@ def _make_table_resource(self): }, } + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) def test_ctor_defaults(self): from google.cloud.bigquery._http import Connection creds = _make_credentials() http = object() - client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)._client + client = self._make_one( + project=self.PROJECT, credentials=creds, _http=http + )._client self.assertIsInstance(client._connection, Connection) self.assertIs(client._connection.credentials, creds) self.assertIs(client._connection.http, http) @@ -137,6 +152,9 @@ def test_ctor_defaults(self): client._connection.API_BASE_URL, Connection.DEFAULT_API_ENDPOINT ) + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) def test_ctor_w_empty_client_options(self): from google.api_core.client_options import ClientOptions @@ -154,7 +172,133 @@ def test_ctor_w_empty_client_options(self): ) @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_client_options_dict(self): + creds = _make_credentials() + http = object() + client_options = {"api_endpoint": "https://www.foo-googleapis.com"} + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual( + client._connection.API_BASE_URL, "https://www.foo-googleapis.com" + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_client_options_object(self): + from google.api_core.client_options import ClientOptions + + creds = _make_credentials() + http = object() + client_options = ClientOptions(api_endpoint="https://www.foo-googleapis.com") + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual( + client._connection.API_BASE_URL, "https://www.foo-googleapis.com" + ) + + @pytest.mark.skipif( + packaging.version.parse(getattr(google.api_core, "__version__", "0.0.0")) + < packaging.version.Version("2.15.0"), + reason="universe_domain not supported with google-api-core < 2.15.0", + ) + def test_ctor_w_client_options_universe(self): + creds = _make_credentials() + http = object() + client_options = {"universe_domain": "foo.com"} + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual(client._connection.API_BASE_URL, "https://bigquery.foo.com") + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_location(self): + from google.cloud.bigquery._http import Connection + + creds = _make_credentials() + http = object() + location = "us-central" + client = self._make_one( + project=self.PROJECT, credentials=creds, _http=http, location=location + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertEqual(client.location, location) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_query_job_config(self): + from google.cloud.bigquery._http import Connection + from google.cloud.bigquery import QueryJobConfig + + creds = _make_credentials() + http = object() + location = "us-central" + job_config = QueryJobConfig() + job_config.dry_run = True + + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + location=location, + default_query_job_config=job_config, + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertEqual(client.location, location) + + self.assertIsInstance(client._default_query_job_config, QueryJobConfig) + self.assertTrue(client._default_query_job_config.dry_run) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_load_job_config(self): + from google.cloud.bigquery._http import Connection + from google.cloud.bigquery import LoadJobConfig + + creds = _make_credentials() + http = object() + location = "us-central" + job_config = LoadJobConfig() + job_config.create_session = True + + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + location=location, + default_load_job_config=job_config, + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertEqual(client.location, location) + + self.assertIsInstance(client._default_load_job_config, LoadJobConfig) + self.assertTrue(client._default_load_job_config.create_session) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_defaults(self): @@ -200,7 +344,7 @@ async def test_query_and_wait_defaults(self): self.assertFalse(sent["useLegacySql"]) @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_w_default_query_job_config(self): @@ -237,7 +381,7 @@ async def test_query_and_wait_w_default_query_job_config(self): self.assertEqual(sent["labels"], {"default-label": "default-value"}) @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_w_job_config(self): @@ -275,7 +419,7 @@ async def test_query_and_wait_w_job_config(self): self.assertEqual(sent["labels"], {"job_config-label": "job_config-value"}) @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_w_location(self): @@ -300,7 +444,7 @@ async def test_query_and_wait_w_location(self): self.assertEqual(sent["location"], "not-the-client-location") @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_w_project(self): @@ -320,4 +464,7 @@ async def test_query_and_wait_w_project(self): # conn.api_request.assert_called_once() _, req = conn.api_request.call_args self.assertEqual(req["method"], "POST") - self.assertEqual(req["path"], "/projects/not-the-client-project/queries") \ No newline at end of file + self.assertEqual(req["path"], "/projects/not-the-client-project/queries") + + +# Add tests for async_query_and_wait and async_wait_or_cancel From 56a0a0f59380ba56f020abd925e4731e7b87465f Mon Sep 17 00:00:00 2001 From: kiraksi Date: Mon, 11 Mar 2024 13:49:25 -0700 Subject: [PATCH 5/6] add async _call_api, RowIterator and get_job to implementation --- google/cloud/bigquery/async_client.py | 155 +++++++++++++++--- .../cloud/bigquery/opentelemetry_tracing.py | 33 +++- noxfile.py | 12 +- setup.py | 1 - tests/unit/test_async_client.py | 99 +++++++++++ 5 files changed, 272 insertions(+), 28 deletions(-) diff --git a/google/cloud/bigquery/async_client.py b/google/cloud/bigquery/async_client.py index 81bb9a197..3dc7632d5 100644 --- a/google/cloud/bigquery/async_client.py +++ b/google/cloud/bigquery/async_client.py @@ -1,6 +1,12 @@ from google.cloud.bigquery.client import * +from google.cloud.bigquery.client import ( + _add_server_timeout_header, + _extract_job_reference, +) +from google.cloud.bigquery.opentelemetry_tracing import async_create_span from google.cloud.bigquery import _job_helpers -from google.cloud.bigquery import table +from google.cloud.bigquery.table import * +from google.api_core.page_iterator import HTTPIterator from google.cloud.bigquery.retry import ( DEFAULT_ASYNC_JOB_RETRY, DEFAULT_ASYNC_RETRY, @@ -8,12 +14,54 @@ ) from google.api_core import retry_async as retries import asyncio +from google.auth.transport import _aiohttp_requests + +# This code is experimental class AsyncClient: def __init__(self, *args, **kwargs): self._client = Client(*args, **kwargs) + async def get_job( + self, + job_id: Union[str, job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob], + project: Optional[str] = None, + location: Optional[str] = None, + retry: retries.AsyncRetry = DEFAULT_ASYNC_RETRY, + timeout: TimeoutType = DEFAULT_TIMEOUT, + ) -> Union[job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob, job.UnknownJob]: + extra_params = {"projection": "full"} + + project, location, job_id = _extract_job_reference( + job_id, project=project, location=location + ) + + if project is None: + project = self._client.project + + if location is None: + location = self._client.location + + if location is not None: + extra_params["location"] = location + + path = "/projects/{}/jobs/{}".format(project, job_id) + + span_attributes = {"path": path, "job_id": job_id, "location": location} + + resource = await self._call_api( + retry, + span_name="BigQuery.getJob", + span_attributes=span_attributes, + method="GET", + path=path, + query_params=extra_params, + timeout=timeout, + ) + + return await asyncio.to_thread(self._client.job_from_resource(await resource)) + async def query_and_wait( self, query, @@ -46,7 +94,7 @@ async def query_and_wait( ) return await async_query_and_wait( - self._client, + self, query, job_config=job_config, location=location, @@ -59,9 +107,41 @@ async def query_and_wait( max_results=max_results, ) + async def _call_api( + self, + retry: Optional[retries.AsyncRetry] = None, + span_name: Optional[str] = None, + span_attributes: Optional[Dict] = None, + job_ref=None, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ): + kwargs = _add_server_timeout_header(headers, kwargs) + + # Prepare the asynchronous request function + # async with _aiohttp_requests.Request(**kwargs) as response: + # response.raise_for_status() + # response = await response.json() # or response.text() + + async_call = functools.partial(self._client._connection.api_request, **kwargs) + + if retry: + async_call = retry(async_call) + + if span_name is not None: + async with async_create_span( + name=span_name, + attributes=span_attributes, + client=self._client, + job_ref=job_ref, + ): + return async_call() # Await the asynchronous call + + return async_call() # Await the asynchronous call + async def async_query_and_wait( - client: "Client", + client: "AsyncClient", query: str, *, job_config: Optional[job.QueryJobConfig], @@ -73,14 +153,12 @@ async def async_query_and_wait( job_retry: Optional[retries.AsyncRetry], page_size: Optional[int] = None, max_results: Optional[int] = None, -) -> table.RowIterator: - # Some API parameters aren't supported by the jobs.query API. In these - # cases, fallback to a jobs.insert call. +) -> RowIterator: if not _job_helpers._supported_by_jobs_query(job_config): return await async_wait_or_cancel( asyncio.to_thread( _job_helpers.query_jobs_insert( - client=client, + client=client._client, query=query, job_id=None, job_id_prefix=None, @@ -116,7 +194,7 @@ async def async_query_and_wait( span_attributes = {"path": path} if retry is not None: - response = client._call_api( # ASYNCHRONOUS HTTP CALLS aiohttp (optional of google-auth), add back retry() + response = await client._call_api( # ASYNCHRONOUS HTTP CALLS aiohttp (optional of google-auth), add back retry() retry=None, # We're calling the retry decorator ourselves, async_retries, need to implement after making HTTP calls async span_name="BigQuery.query", span_attributes=span_attributes, @@ -127,7 +205,7 @@ async def async_query_and_wait( ) else: - response = client._call_api( + response = await client._call_api( retry=None, span_name="BigQuery.query", span_attributes=span_attributes, @@ -149,17 +227,28 @@ async def async_query_and_wait( # client._list_rows_from_query_results directly. Need to update # RowIterator to fetch destination table via the job ID if needed. result = await async_wait_or_cancel( - _job_helpers._to_query_job(client, query, job_config, response), - api_timeout=api_timeout, - wait_timeout=wait_timeout, - retry=retry, - page_size=page_size, - max_results=max_results, + asyncio.to_thread( + _job_helpers._to_query_job(client._client, query, job_config, response), + api_timeout=api_timeout, + wait_timeout=wait_timeout, + retry=retry, + page_size=page_size, + max_results=max_results, + ) + ) + + def api_request(*args, **kwargs): + return client._call_api( + span_name="BigQuery.query", + span_attributes=span_attributes, + *args, + timeout=api_timeout, + **kwargs, ) - result = table.RowIterator( # async of RowIterator? async version without all the pandas stuff - client=client, - api_request=functools.partial(client._call_api, retry, timeout=api_timeout), + result = AsyncRowIterator( # async of RowIterator? async version without all the pandas stuff + client=client._client, + api_request=api_request, path=None, schema=query_results.schema, max_results=max_results, @@ -186,10 +275,10 @@ async def async_wait_or_cancel( retry: Optional[retries.AsyncRetry], page_size: Optional[int], max_results: Optional[int], -) -> table.RowIterator: +) -> RowIterator: try: return asyncio.to_thread( - job.result( # run in a background thread + job.result( page_size=page_size, max_results=max_results, retry=retry, @@ -204,3 +293,29 @@ async def async_wait_or_cancel( # Don't eat the original exception if cancel fails. pass raise + + +class AsyncRowIterator(RowIterator): + async def _get_next_page_response(self): + """Asynchronous version of fetching the next response page.""" + if self._first_page_response: + rows = self._first_page_response.get(self._items_key, [])[ + : self.max_results + ] + response = { + self._items_key: rows, + } + if self._next_token in self._first_page_response: + response[self._next_token] = self._first_page_response[self._next_token] + + self._first_page_response = None + return response + + params = self._get_query_params() + if self._page_size is not None: + if self.page_number and "startIndex" in params: + del params["startIndex"] + params["maxResults"] = self._page_size + return await self.api_request( + method=self._HTTP_METHOD, path=self.path, query_params=params + ) diff --git a/google/cloud/bigquery/opentelemetry_tracing.py b/google/cloud/bigquery/opentelemetry_tracing.py index e2a05e4d0..c1594c1a2 100644 --- a/google/cloud/bigquery/opentelemetry_tracing.py +++ b/google/cloud/bigquery/opentelemetry_tracing.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager from google.api_core.exceptions import GoogleAPICallError # type: ignore logger = logging.getLogger(__name__) @@ -86,6 +86,37 @@ def create_span(name, attributes=None, client=None, job_ref=None): raise +@asynccontextmanager +async def async_create_span(name, attributes=None, client=None, job_ref=None): + """Asynchronous context manager for creating and exporting OpenTelemetry spans.""" + global _warned_telemetry + final_attributes = _get_final_span_attributes(attributes, client, job_ref) + + if not HAS_OPENTELEMETRY: + if not _warned_telemetry: + logger.debug( + "This service is instrumented using OpenTelemetry. " + "OpenTelemetry or one of its components could not be imported; " + "please add compatible versions of opentelemetry-api and " + "opentelemetry-instrumentation packages in order to get BigQuery " + "Tracing data." + ) + _warned_telemetry = True + yield None + return + tracer = trace.get_tracer(__name__) + + async with tracer.start_as_current_span( + name=name, attributes=final_attributes + ) as span: + try: + yield span + except GoogleAPICallError as error: + if error.code is not None: + span.set_status(Status(http_status_to_status_code(error.code))) + raise + + def _get_final_span_attributes(attributes=None, client=None, job_ref=None): """Compiles attributes from: client, job_ref, user-provided attributes. diff --git a/noxfile.py b/noxfile.py index c31d098b8..26d55111f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -80,8 +80,8 @@ def default(session, install_extras=True): constraints_path, ) - if install_extras and session.python in ["3.11", "3.12"]: - install_target = ".[bqstorage,ipywidgets,pandas,tqdm,opentelemetry]" + if install_extras and session.python in ["3.12"]: + install_target = ".[bqstorage,ipywidgets,pandas,tqdm,opentelemetry,aiohttp]" elif install_extras: install_target = ".[all]" else: @@ -188,8 +188,8 @@ def system(session): # Data Catalog needed for the column ACL test with a real Policy Tag. session.install("google-cloud-datacatalog", "-c", constraints_path) - if session.python in ["3.11", "3.12"]: - extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry]" + if session.python in ["3.12"]: + extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry,aiohttp]" # look at geopandas to see if it supports 3.11/3.12 (up to 3.11) else: extras = "[all]" session.install("-e", f".{extras}", "-c", constraints_path) @@ -254,8 +254,8 @@ def snippets(session): session.install("google-cloud-storage", "-c", constraints_path) session.install("grpcio", "-c", constraints_path) - if session.python in ["3.11", "3.12"]: - extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry]" + if session.python in ["3.12"]: + extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry,aiohttp]" else: extras = "[all]" session.install("-e", f".{extras}", "-c", constraints_path) diff --git a/setup.py b/setup.py index 9f6fabcfc..7d672d239 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,6 @@ # NOTE: Maintainers, please do not require google-cloud-core>=2.x.x # Until this issue is closed # https://github.com/googleapis/google-cloud-python/issues/10566 - "google-auth >= 2.14.1, <3.0.0dev", "google-cloud-core >= 1.6.0, <3.0.0dev", "google-resumable-media >= 0.6.0, < 3.0dev", "packaging >= 20.0.0", diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py index 472504711..e500c6340 100644 --- a/tests/unit/test_async_client.py +++ b/tests/unit/test_async_client.py @@ -297,6 +297,105 @@ def test_ctor_w_load_job_config(self): self.assertIsInstance(client._default_load_job_config, LoadJobConfig) self.assertTrue(client._default_load_job_config.create_session) + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_job_miss_w_explict_project(self): + from google.cloud.exceptions import NotFound + + OTHER_PROJECT = "OTHER_PROJECT" + JOB_ID = "NONESUCH" + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + conn = client._client._connection = make_connection() + + with self.assertRaises(NotFound): + await client.get_job(JOB_ID, project=OTHER_PROJECT) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/OTHER_PROJECT/jobs/NONESUCH", + query_params={"projection": "full"}, + timeout=DEFAULT_TIMEOUT, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_job_miss_w_client_location(self): + from google.cloud.exceptions import NotFound + + JOB_ID = "NONESUCH" + creds = _make_credentials() + client = self._make_one("client-proj", creds, location="client-loc") + conn = client._client._connection = make_connection() + + with self.assertRaises(NotFound): + await client.get_job(JOB_ID) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/client-proj/jobs/NONESUCH", + query_params={"projection": "full", "location": "client-loc"}, + timeout=DEFAULT_TIMEOUT, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_job_hit_w_timeout(self): + from google.cloud.bigquery.job import CreateDisposition + from google.cloud.bigquery.job import QueryJob + from google.cloud.bigquery.job import WriteDisposition + + JOB_ID = "query_job" + QUERY_DESTINATION_TABLE = "query_destination_table" + QUERY = "SELECT * from test_dataset:test_table" + ASYNC_QUERY_DATA = { + "id": "{}:{}".format(self.PROJECT, JOB_ID), + "jobReference": { + "projectId": "resource-proj", + "jobId": "query_job", + "location": "us-east1", + }, + "state": "DONE", + "configuration": { + "query": { + "query": QUERY, + "destinationTable": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "tableId": QUERY_DESTINATION_TABLE, + }, + "createDisposition": CreateDisposition.CREATE_IF_NEEDED, + "writeDisposition": WriteDisposition.WRITE_TRUNCATE, + } + }, + } + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + conn = client._client._connection = make_connection(ASYNC_QUERY_DATA) + job_from_resource = QueryJob.from_api_repr(ASYNC_QUERY_DATA, client._client) + + job = await client.get_job(job_from_resource, timeout=7.5) + + self.assertIsInstance(job, QueryJob) + self.assertEqual(job.job_id, JOB_ID) + self.assertEqual(job.project, "resource-proj") + self.assertEqual(job.location, "us-east1") + self.assertEqual(job.create_disposition, CreateDisposition.CREATE_IF_NEEDED) + self.assertEqual(job.write_disposition, WriteDisposition.WRITE_TRUNCATE) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/resource-proj/jobs/query_job", + query_params={"projection": "full", "location": "us-east1"}, + timeout=7.5, + ) + @pytest.mark.skipif( sys.version_info < (3, 9), reason="requires python3.9 or higher" ) From 869cc7ea94ffa9ab206d3464fdb920aee5c053e9 Mon Sep 17 00:00:00 2001 From: kiraksi Date: Mon, 11 Mar 2024 16:13:52 -0700 Subject: [PATCH 6/6] removed AsyncRowIterator until AsyncHTTPIterator is supported --- google/cloud/bigquery/async_client.py | 67 +--- google/cloud/bigquery/table.py | 6 + tests/unit/test_async_client.py | 498 ++++++++++++++++++++------ 3 files changed, 404 insertions(+), 167 deletions(-) diff --git a/google/cloud/bigquery/async_client.py b/google/cloud/bigquery/async_client.py index 3dc7632d5..8cedbb434 100644 --- a/google/cloud/bigquery/async_client.py +++ b/google/cloud/bigquery/async_client.py @@ -6,7 +6,6 @@ from google.cloud.bigquery.opentelemetry_tracing import async_create_span from google.cloud.bigquery import _job_helpers from google.cloud.bigquery.table import * -from google.api_core.page_iterator import HTTPIterator from google.cloud.bigquery.retry import ( DEFAULT_ASYNC_JOB_RETRY, DEFAULT_ASYNC_RETRY, @@ -15,6 +14,7 @@ from google.api_core import retry_async as retries import asyncio from google.auth.transport import _aiohttp_requests +from google.api_core.page_iterator_async import AsyncIterator # This code is experimental @@ -23,45 +23,6 @@ class AsyncClient: def __init__(self, *args, **kwargs): self._client = Client(*args, **kwargs) - async def get_job( - self, - job_id: Union[str, job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob], - project: Optional[str] = None, - location: Optional[str] = None, - retry: retries.AsyncRetry = DEFAULT_ASYNC_RETRY, - timeout: TimeoutType = DEFAULT_TIMEOUT, - ) -> Union[job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob, job.UnknownJob]: - extra_params = {"projection": "full"} - - project, location, job_id = _extract_job_reference( - job_id, project=project, location=location - ) - - if project is None: - project = self._client.project - - if location is None: - location = self._client.location - - if location is not None: - extra_params["location"] = location - - path = "/projects/{}/jobs/{}".format(project, job_id) - - span_attributes = {"path": path, "job_id": job_id, "location": location} - - resource = await self._call_api( - retry, - span_name="BigQuery.getJob", - span_attributes=span_attributes, - method="GET", - path=path, - query_params=extra_params, - timeout=timeout, - ) - - return await asyncio.to_thread(self._client.job_from_resource(await resource)) - async def query_and_wait( self, query, @@ -295,27 +256,5 @@ async def async_wait_or_cancel( raise -class AsyncRowIterator(RowIterator): - async def _get_next_page_response(self): - """Asynchronous version of fetching the next response page.""" - if self._first_page_response: - rows = self._first_page_response.get(self._items_key, [])[ - : self.max_results - ] - response = { - self._items_key: rows, - } - if self._next_token in self._first_page_response: - response[self._next_token] = self._first_page_response[self._next_token] - - self._first_page_response = None - return response - - params = self._get_query_params() - if self._page_size is not None: - if self.page_number and "startIndex" in params: - del params["startIndex"] - params["maxResults"] = self._page_size - return await self.api_request( - method=self._HTTP_METHOD, path=self.path, query_params=params - ) +class AsyncRowIterator(AsyncHTTPIterator): + pass diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index b3be4ff90..1b93091b3 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -57,6 +57,8 @@ import google.api_core.exceptions from google.api_core.page_iterator import HTTPIterator +# from google.api_core.page_iterator_async import AsyncHTTPIterator <- when supported in google api core + import google.cloud._helpers # type: ignore from google.cloud.bigquery import _helpers from google.cloud.bigquery import _pandas_helpers @@ -2444,6 +2446,10 @@ def to_geodataframe( ) +# class AsyncRowIterator(AsyncHTTPIterator): +# pass + + class _EmptyRowIterator(RowIterator): """An empty row iterator. diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py index e500c6340..8c3b09349 100644 --- a/tests/unit/test_async_client.py +++ b/tests/unit/test_async_client.py @@ -11,7 +11,7 @@ import operator import unittest import warnings - +import freezegun import mock import requests import packaging @@ -83,8 +83,8 @@ DEFAULT_TIMEOUT, ) from google.api_core import retry_async as retries -from google.cloud.bigquery import async_client -from google.cloud.bigquery.async_client import AsyncClient +from google.cloud.bigquery.async_client import AsyncClient, async_query_and_wait +from google.cloud.bigquery.client import Client from google.cloud.bigquery.job import query as job_query @@ -297,105 +297,6 @@ def test_ctor_w_load_job_config(self): self.assertIsInstance(client._default_load_job_config, LoadJobConfig) self.assertTrue(client._default_load_job_config.create_session) - @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" - ) - @asyncio_run - async def test_get_job_miss_w_explict_project(self): - from google.cloud.exceptions import NotFound - - OTHER_PROJECT = "OTHER_PROJECT" - JOB_ID = "NONESUCH" - creds = _make_credentials() - client = self._make_one(self.PROJECT, creds) - conn = client._client._connection = make_connection() - - with self.assertRaises(NotFound): - await client.get_job(JOB_ID, project=OTHER_PROJECT) - - conn.api_request.assert_called_once_with( - method="GET", - path="/projects/OTHER_PROJECT/jobs/NONESUCH", - query_params={"projection": "full"}, - timeout=DEFAULT_TIMEOUT, - ) - - @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" - ) - @asyncio_run - async def test_get_job_miss_w_client_location(self): - from google.cloud.exceptions import NotFound - - JOB_ID = "NONESUCH" - creds = _make_credentials() - client = self._make_one("client-proj", creds, location="client-loc") - conn = client._client._connection = make_connection() - - with self.assertRaises(NotFound): - await client.get_job(JOB_ID) - - conn.api_request.assert_called_once_with( - method="GET", - path="/projects/client-proj/jobs/NONESUCH", - query_params={"projection": "full", "location": "client-loc"}, - timeout=DEFAULT_TIMEOUT, - ) - - @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" - ) - @asyncio_run - async def test_get_job_hit_w_timeout(self): - from google.cloud.bigquery.job import CreateDisposition - from google.cloud.bigquery.job import QueryJob - from google.cloud.bigquery.job import WriteDisposition - - JOB_ID = "query_job" - QUERY_DESTINATION_TABLE = "query_destination_table" - QUERY = "SELECT * from test_dataset:test_table" - ASYNC_QUERY_DATA = { - "id": "{}:{}".format(self.PROJECT, JOB_ID), - "jobReference": { - "projectId": "resource-proj", - "jobId": "query_job", - "location": "us-east1", - }, - "state": "DONE", - "configuration": { - "query": { - "query": QUERY, - "destinationTable": { - "projectId": self.PROJECT, - "datasetId": self.DS_ID, - "tableId": QUERY_DESTINATION_TABLE, - }, - "createDisposition": CreateDisposition.CREATE_IF_NEEDED, - "writeDisposition": WriteDisposition.WRITE_TRUNCATE, - } - }, - } - creds = _make_credentials() - client = self._make_one(self.PROJECT, creds) - conn = client._client._connection = make_connection(ASYNC_QUERY_DATA) - job_from_resource = QueryJob.from_api_repr(ASYNC_QUERY_DATA, client._client) - - job = await client.get_job(job_from_resource, timeout=7.5) - - self.assertIsInstance(job, QueryJob) - self.assertEqual(job.job_id, JOB_ID) - self.assertEqual(job.project, "resource-proj") - self.assertEqual(job.location, "us-east1") - self.assertEqual(job.create_disposition, CreateDisposition.CREATE_IF_NEEDED) - self.assertEqual(job.write_disposition, WriteDisposition.WRITE_TRUNCATE) - - conn.api_request.assert_called_once_with( - method="GET", - path="/projects/resource-proj/jobs/query_job", - query_params={"projection": "full", "location": "us-east1"}, - timeout=7.5, - ) - @pytest.mark.skipif( sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @@ -565,5 +466,396 @@ async def test_query_and_wait_w_project(self): self.assertEqual(req["method"], "POST") self.assertEqual(req["path"], "/projects/not-the-client-project/queries") +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" +) +@asyncio_run +async def test_query_and_wait_retries_job(): + freezegun.freeze_time(auto_tick_seconds=100) + client = mock.create_autospec(AsyncClient) + client._client = mock.create_autospec(Client) + client._call_api.__name__ = "_call_api" + client._call_api.__qualname__ = "Client._call_api" + client._call_api.__annotations__ = {} + client._call_api.__type_params__ = () + client._call_api.side_effect = ( + google.api_core.exceptions.BadGateway("retry me"), + google.api_core.exceptions.InternalServerError("job_retry me"), + google.api_core.exceptions.BadGateway("retry me"), + { + "jobReference": { + "projectId": "response-project", + "jobId": "abc", + "location": "response-location", + }, + "jobComplete": True, + "schema": { + "fields": [ + {"name": "full_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "age", "type": "INT64", "mode": "NULLABLE"}, + ], + }, + "rows": [ + {"f": [{"v": "Whillma Phlyntstone"}, {"v": "27"}]}, + {"f": [{"v": "Bhetty Rhubble"}, {"v": "28"}]}, + {"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}, + {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, + ], + }, + ) + rows = await async_query_and_wait( + client, + query="SELECT 1", + location="request-location", + project="request-project", + job_config=None, + page_size=None, + max_results=None, + retry=retries.AsyncRetry( + lambda exc: isinstance(exc, google.api_core.exceptions.BadGateway), + multiplier=1.0, + ).with_deadline( + 200.0 + ), # Since auto_tick_seconds is 100, we should get at least 1 retry. + job_retry=retries.AsyncRetry( + lambda exc: isinstance(exc, google.api_core.exceptions.InternalServerError), + multiplier=1.0, + ).with_deadline(600.0), + ) + assert len(list(rows)) == 4 + + # For this code path, where the query has finished immediately, we should + # only be calling the jobs.query API and no other request path. + request_path = "/projects/request-project/queries" + for call in client._call_api.call_args_list: + _, kwargs = call + assert kwargs["method"] == "POST" + assert kwargs["path"] == request_path + + +@freezegun.freeze_time(auto_tick_seconds=100) +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" +) +@asyncio_run +async def test_query_and_wait_retries_job_times_out(): + client = mock.create_autospec(AsyncClient) + client._client = mock.create_autospec(Client) + client._call_api.__name__ = "_call_api" + client._call_api.__qualname__ = "Client._call_api" + client._call_api.__annotations__ = {} + client._call_api.__type_params__ = () + client._call_api.side_effect = ( + google.api_core.exceptions.BadGateway("retry me"), + google.api_core.exceptions.InternalServerError("job_retry me"), + google.api_core.exceptions.BadGateway("retry me"), + google.api_core.exceptions.InternalServerError("job_retry me"), + ) + + with pytest.raises(google.api_core.exceptions.RetryError) as exc_info: + await async_query_and_wait( + client, + query="SELECT 1", + location="request-location", + project="request-project", + job_config=None, + page_size=None, + max_results=None, + retry=retries.AsyncRetry( + lambda exc: isinstance(exc, google.api_core.exceptions.BadGateway), + multiplier=1.0, + ).with_deadline( + 200.0 + ), # Since auto_tick_seconds is 100, we should get at least 1 retry. + job_retry=retries.AsyncRetry( + lambda exc: isinstance( + exc, google.api_core.exceptions.InternalServerError + ), + multiplier=1.0, + ).with_deadline(400.0), + ) + + assert isinstance( + exc_info.value.cause, google.api_core.exceptions.InternalServerError + ) + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" +) +@asyncio_run +async def test_query_and_wait_sets_job_creation_mode(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv( + "QUERY_PREVIEW_ENABLED", + # The comparison should be case insensitive. + "TrUe", + ) + client = mock.create_autospec(AsyncClient) + client._client = mock.create_autospec(Client) + client._call_api.return_value = { + "jobReference": { + "projectId": "response-project", + "jobId": "abc", + "location": "response-location", + }, + "jobComplete": True, + } + async_query_and_wait( + client, + query="SELECT 1", + location="request-location", + project="request-project", + job_config=None, + retry=None, + job_retry=None, + page_size=None, + max_results=None, + ) + + # We should only call jobs.query once, no additional row requests needed. + request_path = "/projects/request-project/queries" + client._call_api.assert_called_once_with( + None, # retry + span_name="BigQuery.query", + span_attributes={"path": request_path}, + method="POST", + path=request_path, + data={ + "query": "SELECT 1", + "location": "request-location", + "useLegacySql": False, + "formatOptions": { + "useInt64Timestamp": True, + }, + "requestId": mock.ANY, + "jobCreationMode": "JOB_CREATION_OPTIONAL", + }, + timeout=None, + ) + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" +) +@asyncio_run +async def test_query_and_wait_sets_location(): + client = mock.create_autospec(AsyncClient) + client._client = mock.create_autospec(Client) + client._call_api.return_value = { + "jobReference": { + "projectId": "response-project", + "jobId": "abc", + "location": "response-location", + }, + "jobComplete": True, + } + rows = await async_query_and_wait( + client, + query="SELECT 1", + location="request-location", + project="request-project", + job_config=None, + retry=None, + job_retry=None, + page_size=None, + max_results=None, + ) + assert rows.location == "response-location" + + # We should only call jobs.query once, no additional row requests needed. + request_path = "/projects/request-project/queries" + client._call_api.assert_called_once_with( + None, # retry + span_name="BigQuery.query", + span_attributes={"path": request_path}, + method="POST", + path=request_path, + data={ + "query": "SELECT 1", + "location": "request-location", + "useLegacySql": False, + "formatOptions": { + "useInt64Timestamp": True, + }, + "requestId": mock.ANY, + }, + timeout=None, + ) + + +@pytest.mark.parametrize( + ("max_results", "page_size", "expected"), + [ + (10, None, 10), + (None, 11, 11), + (12, 100, 12), + (100, 13, 13), + ], +) +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" +) +@asyncio_run +async def test_query_and_wait_sets_max_results(max_results, page_size, expected): + client = mock.create_autospec(AsyncClient) + client._client = mock.create_autospec(Client) + client._call_api.return_value = { + "jobReference": { + "projectId": "response-project", + "jobId": "abc", + "location": "response-location", + }, + "jobComplete": True, + } + rows = await async_query_and_wait( + client, + query="SELECT 1", + location="request-location", + project="request-project", + job_config=None, + retry=None, + job_retry=None, + page_size=page_size, + max_results=max_results, + ) + assert rows.location == "response-location" + + # We should only call jobs.query once, no additional row requests needed. + request_path = "/projects/request-project/queries" + client._call_api.assert_called_once_with( + None, # retry + span_name="BigQuery.query", + span_attributes={"path": request_path}, + method="POST", + path=request_path, + data={ + "query": "SELECT 1", + "location": "request-location", + "useLegacySql": False, + "formatOptions": { + "useInt64Timestamp": True, + }, + "requestId": mock.ANY, + "maxResults": expected, + }, + timeout=None, + ) + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" +) +@asyncio_run +async def test_query_and_wait_caches_completed_query_results_one_page(): + client = mock.create_autospec(AsyncClient) + client._client = mock.create_autospec(Client) + client._call_api.return_value = { + "jobReference": { + "projectId": "response-project", + "jobId": "abc", + "location": "US", + }, + "jobComplete": True, + "queryId": "xyz", + "schema": { + "fields": [ + {"name": "full_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "age", "type": "INT64", "mode": "NULLABLE"}, + ], + }, + "rows": [ + {"f": [{"v": "Whillma Phlyntstone"}, {"v": "27"}]}, + {"f": [{"v": "Bhetty Rhubble"}, {"v": "28"}]}, + {"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}, + {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, + ], + # Even though totalRows > len(rows), we should use the presence of a + # next page token to decide if there are any more pages. + "totalRows": 8, + } + rows = await async_query_and_wait( + client, + query="SELECT full_name, age FROM people;", + job_config=None, + location=None, + project="request-project", + retry=None, + job_retry=None, + page_size=None, + max_results=None, + ) + rows_list = list(rows) + assert rows.project == "response-project" + assert rows.job_id == "abc" + assert rows.location == "US" + assert rows.query_id == "xyz" + assert rows.total_rows == 8 + assert len(rows_list) == 4 + + # We should only call jobs.query once, no additional row requests needed. + request_path = "/projects/request-project/queries" + client._call_api.assert_called_once_with( + None, # retry + span_name="BigQuery.query", + span_attributes={"path": request_path}, + method="POST", + path=request_path, + data={ + "query": "SELECT full_name, age FROM people;", + "useLegacySql": False, + "formatOptions": { + "useInt64Timestamp": True, + }, + "requestId": mock.ANY, + }, + timeout=None, + ) -# Add tests for async_query_and_wait and async_wait_or_cancel +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" +) +@asyncio_run +async def test_query_and_wait_caches_completed_query_results_one_page_no_rows(): + client = mock.create_autospec(AsyncClient) + client._client = mock.create_autospec(Client) + client._call_api.return_value = { + "jobReference": { + "projectId": "response-project", + "jobId": "abc", + "location": "US", + }, + "jobComplete": True, + "queryId": "xyz", + } + rows = await async_query_and_wait( + client, + query="CREATE TABLE abc;", + project="request-project", + job_config=None, + location=None, + retry=None, + job_retry=None, + page_size=None, + max_results=None, + ) + assert rows.project == "response-project" + assert rows.job_id == "abc" + assert rows.location == "US" + assert rows.query_id == "xyz" + assert list(rows) == [] + + # We should only call jobs.query once, no additional row requests needed. + request_path = "/projects/request-project/queries" + client._call_api.assert_called_once_with( + None, # retry + span_name="BigQuery.query", + span_attributes={"path": request_path}, + method="POST", + path=request_path, + data={ + "query": "CREATE TABLE abc;", + "useLegacySql": False, + "formatOptions": { + "useInt64Timestamp": True, + }, + "requestId": mock.ANY, + }, + timeout=None, + )