diff --git a/.librarian/state.yaml b/.librarian/state.yaml index 95cf19f13..7054c7957 100644 --- a/.librarian/state.yaml +++ b/.librarian/state.yaml @@ -2,7 +2,7 @@ image: us-central1-docker.pkg.dev/cloud-sdk-librarian-prod/images-prod/python-li libraries: - id: google-cloud-firestore version: 2.21.0 - last_generated_commit: 659ea6e98acc7d58661ce2aa7b4cf76a7ef3fd42 + last_generated_commit: b60f5a5783d5ec0e8a8d254f73ad00316b9b646f apis: - path: google/firestore/v1 service_config: firestore_v1.yaml diff --git a/google/cloud/firestore_v1/gapic_metadata.json b/google/cloud/firestore_v1/gapic_metadata.json index d0462f964..03a6e428b 100644 --- a/google/cloud/firestore_v1/gapic_metadata.json +++ b/google/cloud/firestore_v1/gapic_metadata.json @@ -40,6 +40,11 @@ "delete_document" ] }, + "ExecutePipeline": { + "methods": [ + "execute_pipeline" + ] + }, "GetDocument": { "methods": [ "get_document" @@ -125,6 +130,11 @@ "delete_document" ] }, + "ExecutePipeline": { + "methods": [ + "execute_pipeline" + ] + }, "GetDocument": { "methods": [ "get_document" @@ -210,6 +220,11 @@ "delete_document" ] }, + "ExecutePipeline": { + "methods": [ + "execute_pipeline" + ] + }, "GetDocument": { "methods": [ "get_document" diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index b904229b0..3557eb94c 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -53,6 +53,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -1248,6 +1249,109 @@ async def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Awaitable[AsyncIterable[firestore.ExecutePipelineResponse]]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + async def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreAsyncClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = await client.execute_pipeline(request=request) + + # Handle the response + async for response in stream: + print(response) + + Args: + request (Optional[Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + AsyncIterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.execute_pipeline + ] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index 805561242..83b58f307 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -68,6 +68,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -1631,6 +1632,107 @@ def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Iterable[firestore.ExecutePipelineResponse]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = client.execute_pipeline(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + Iterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.execute_pipeline] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index 02d6c0bbc..905dded09 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -291,6 +291,23 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: gapic_v1.method.wrap_method( + self.execute_pipeline, + default_retry=retries.Retry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.InternalServerError, + core_exceptions.ResourceExhausted, + core_exceptions.ServiceUnavailable, + ), + deadline=300.0, + ), + default_timeout=300.0, + client_info=client_info, + ), self.run_aggregation_query: gapic_v1.method.wrap_method( self.run_aggregation_query, default_retry=retries.Retry( @@ -514,6 +531,18 @@ def run_query( ]: raise NotImplementedError() + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], + Union[ + firestore.ExecutePipelineResponse, + Awaitable[firestore.ExecutePipelineResponse], + ], + ]: + raise NotImplementedError() + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc.py b/google/cloud/firestore_v1/services/firestore/transports/grpc.py index 3c5bded2d..f057d16e3 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc.py @@ -573,6 +573,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + ~.ExecutePipelineResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py index 6cc93e21a..cf6006672 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py @@ -589,6 +589,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], Awaitable[firestore.ExecutePipelineResponse] + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + Awaitable[~.ExecutePipelineResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, @@ -964,6 +992,23 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: self._wrap_method( + self.execute_pipeline, + default_retry=retries.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.InternalServerError, + core_exceptions.ResourceExhausted, + core_exceptions.ServiceUnavailable, + ), + deadline=300.0, + ), + default_timeout=300.0, + client_info=client_info, + ), self.run_aggregation_query: self._wrap_method( self.run_aggregation_query, default_retry=retries.AsyncRetry( diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index a32a7e84e..845569d97 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -127,6 +127,14 @@ def pre_delete_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata + def pre_execute_pipeline(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_execute_pipeline(self, response): + logging.log(f"Received response: {response}") + return response + def pre_get_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -445,6 +453,56 @@ def pre_delete_document( """ return request, metadata + def pre_execute_pipeline( + self, + request: firestore.ExecutePipelineRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.ExecutePipelineRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for execute_pipeline + + Override in a subclass to manipulate the request or metadata + before they are sent to the Firestore server. + """ + return request, metadata + + def post_execute_pipeline( + self, response: rest_streaming.ResponseIterator + ) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for execute_pipeline + + DEPRECATED. Please use the `post_execute_pipeline_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the Firestore server but before + it is returned to user code. This `post_execute_pipeline` interceptor runs + before the `post_execute_pipeline_with_metadata` interceptor. + """ + return response + + def post_execute_pipeline_with_metadata( + self, + response: rest_streaming.ResponseIterator, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + rest_streaming.ResponseIterator, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for execute_pipeline + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_execute_pipeline_with_metadata` + interceptor in new development instead of the `post_execute_pipeline` interceptor. + When both interceptors are used, this `post_execute_pipeline_with_metadata` interceptor runs after the + `post_execute_pipeline` interceptor. The (possibly modified) response returned by + `post_execute_pipeline` will be passed to + `post_execute_pipeline_with_metadata`. + """ + return response, metadata + def pre_get_document( self, request: firestore.GetDocumentRequest, @@ -1873,6 +1931,158 @@ def __call__( if response.status_code >= 400: raise core_exceptions.from_http_response(response) + class _ExecutePipeline( + _BaseFirestoreRestTransport._BaseExecutePipeline, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.ExecutePipeline") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response + + def __call__( + self, + request: firestore.ExecutePipelineRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> rest_streaming.ResponseIterator: + r"""Call the execute pipeline method over HTTP. + + Args: + request (~.firestore.ExecutePipelineRequest): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.firestore.ExecutePipelineResponse: + The response for [Firestore.Execute][]. + """ + + http_options = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_http_options() + ) + + request, metadata = self._interceptor.pre_execute_pipeline( + request, metadata + ) + transcoded_request = _BaseFirestoreRestTransport._BaseExecutePipeline._get_transcoded_request( + http_options, request + ) + + body = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_request_body_json( + transcoded_request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.ExecutePipeline", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ExecutePipeline", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FirestoreRestTransport._ExecutePipeline._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = rest_streaming.ResponseIterator( + response, firestore.ExecutePipelineResponse + ) + + resp = self._interceptor.post_execute_pipeline(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_execute_pipeline_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + http_response = { + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.execute_pipeline", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ExecutePipeline", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + class _GetDocument(_BaseFirestoreRestTransport._BaseGetDocument, FirestoreRestStub): def __hash__(self): return hash("FirestoreRestTransport.GetDocument") @@ -3143,6 +3353,16 @@ def delete_document( # In C++ this would require a dynamic_cast return self._DeleteDocument(self._session, self._host, self._interceptor) # type: ignore + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ExecutePipeline(self._session, self._host, self._interceptor) # type: ignore + @property def get_document( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 1d95cd16e..80ce35e49 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -426,6 +426,63 @@ def _get_query_params_json(transcoded_request): query_params["$alt"] = "json;enum-encoding=int" return query_params + class _BaseExecutePipeline: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ExecutePipelineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + class _BaseGetDocument: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") diff --git a/google/cloud/firestore_v1/types/__init__.py b/google/cloud/firestore_v1/types/__init__.py index ae1004e13..ed1965d7f 100644 --- a/google/cloud/firestore_v1/types/__init__.py +++ b/google/cloud/firestore_v1/types/__init__.py @@ -28,9 +28,14 @@ from .document import ( ArrayValue, Document, + Function, MapValue, + Pipeline, Value, ) +from .explain_stats import ( + ExplainStats, +) from .firestore import ( BatchGetDocumentsRequest, BatchGetDocumentsResponse, @@ -42,6 +47,8 @@ CommitResponse, CreateDocumentRequest, DeleteDocumentRequest, + ExecutePipelineRequest, + ExecutePipelineResponse, GetDocumentRequest, ListCollectionIdsRequest, ListCollectionIdsResponse, @@ -62,6 +69,9 @@ WriteRequest, WriteResponse, ) +from .pipeline import ( + StructuredPipeline, +) from .query import ( Cursor, StructuredAggregationQuery, @@ -92,8 +102,11 @@ "TransactionOptions", "ArrayValue", "Document", + "Function", "MapValue", + "Pipeline", "Value", + "ExplainStats", "BatchGetDocumentsRequest", "BatchGetDocumentsResponse", "BatchWriteRequest", @@ -104,6 +117,8 @@ "CommitResponse", "CreateDocumentRequest", "DeleteDocumentRequest", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "GetDocumentRequest", "ListCollectionIdsRequest", "ListCollectionIdsResponse", @@ -123,6 +138,7 @@ "UpdateDocumentRequest", "WriteRequest", "WriteResponse", + "StructuredPipeline", "Cursor", "StructuredAggregationQuery", "StructuredQuery", diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 22fe79b73..8073ad97a 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -31,6 +31,8 @@ "Value", "ArrayValue", "MapValue", + "Function", + "Pipeline", }, ) @@ -183,6 +185,37 @@ class Value(proto.Message): map_value (google.cloud.firestore_v1.types.MapValue): A map value. + This field is a member of `oneof`_ ``value_type``. + field_reference_value (str): + Value which references a field. + + This is considered relative (vs absolute) since it only + refers to a field and not a field within a particular + document. + + **Requires:** + + - Must follow [field reference][FieldReference.field_path] + limitations. + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + function_value (google.cloud.firestore_v1.types.Function): + A value that represents an unevaluated expression. + + **Requires:** + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + pipeline_value (google.cloud.firestore_v1.types.Pipeline): + A value that represents an unevaluated pipeline. + + **Requires:** + + - Not allowed to be used when writing documents. + This field is a member of `oneof`_ ``value_type``. """ @@ -246,6 +279,23 @@ class Value(proto.Message): oneof="value_type", message="MapValue", ) + field_reference_value: str = proto.Field( + proto.STRING, + number=19, + oneof="value_type", + ) + function_value: "Function" = proto.Field( + proto.MESSAGE, + number=20, + oneof="value_type", + message="Function", + ) + pipeline_value: "Pipeline" = proto.Field( + proto.MESSAGE, + number=21, + oneof="value_type", + message="Pipeline", + ) class ArrayValue(proto.Message): @@ -285,4 +335,119 @@ class MapValue(proto.Message): ) +class Function(proto.Message): + r"""Represents an unevaluated scalar expression. + + For example, the expression ``like(user_name, "%alice%")`` is + represented as: + + :: + + name: "like" + args { field_reference: "user_name" } + args { string_value: "%alice%" } + + Attributes: + name (str): + Required. The name of the function to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + function expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + +class Pipeline(proto.Message): + r"""A Firestore query represented as an ordered list of + operations / stages. + + Attributes: + stages (MutableSequence[google.cloud.firestore_v1.types.Pipeline.Stage]): + Required. Ordered list of stages to evaluate. + """ + + class Stage(proto.Message): + r"""A single operation within a pipeline. + + A stage is made up of a unique name, and a list of arguments. The + exact number of arguments & types is dependent on the stage type. + + To give an example, the stage ``filter(state = "MD")`` would be + encoded as: + + :: + + name: "filter" + args { + function_value { + name: "eq" + args { field_reference_value: "state" } + args { string_value: "MD" } + } + } + + See public documentation for the full list. + + Attributes: + name (str): + Required. The name of the stage to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + stage expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + stages: MutableSequence[Stage] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=Stage, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/explain_stats.py b/google/cloud/firestore_v1/types/explain_stats.py new file mode 100644 index 000000000..b0f9421ba --- /dev/null +++ b/google/cloud/firestore_v1/types/explain_stats.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.protobuf import any_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "ExplainStats", + }, +) + + +class ExplainStats(proto.Message): + r"""Pipeline explain stats. + + Depending on the explain options in the original request, this + can contain the optimized plan and / or execution stats. + + Attributes: + data (google.protobuf.any_pb2.Any): + The format depends on the ``output_format`` options in the + request. + + Currently there are two supported options: ``TEXT`` and + ``JSON``. Both supply a ``google.protobuf.StringValue``. + """ + + data: any_pb2.Any = proto.Field( + proto.MESSAGE, + number=1, + message=any_pb2.Any, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index 190f55d28..4e53ba313 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -22,6 +22,8 @@ from google.cloud.firestore_v1.types import aggregation_result from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats as gf_explain_stats +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query as gf_query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write @@ -48,6 +50,8 @@ "RollbackRequest", "RunQueryRequest", "RunQueryResponse", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "RunAggregationQueryRequest", "RunAggregationQueryResponse", "PartitionQueryRequest", @@ -835,6 +839,151 @@ class RunQueryResponse(proto.Message): ) +class ExecutePipelineRequest(proto.Message): + r"""The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + database (str): + Required. Database identifier, in the form + ``projects/{project}/databases/{database}``. + structured_pipeline (google.cloud.firestore_v1.types.StructuredPipeline): + A pipelined operation. + + This field is a member of `oneof`_ ``pipeline_type``. + transaction (bytes): + Run the query within an already active + transaction. + The value here is the opaque transaction ID to + execute the query in. + + This field is a member of `oneof`_ ``consistency_selector``. + new_transaction (google.cloud.firestore_v1.types.TransactionOptions): + Execute the pipeline in a new transaction. + + The identifier of the newly created transaction + will be returned in the first response on the + stream. This defaults to a read-only + transaction. + + This field is a member of `oneof`_ ``consistency_selector``. + read_time (google.protobuf.timestamp_pb2.Timestamp): + Execute the pipeline in a snapshot + transaction at the given time. + This must be a microsecond precision timestamp + within the past one hour, or if Point-in-Time + Recovery is enabled, can additionally be a whole + minute timestamp within the past 7 days. + + This field is a member of `oneof`_ ``consistency_selector``. + """ + + database: str = proto.Field( + proto.STRING, + number=1, + ) + structured_pipeline: pipeline.StructuredPipeline = proto.Field( + proto.MESSAGE, + number=2, + oneof="pipeline_type", + message=pipeline.StructuredPipeline, + ) + transaction: bytes = proto.Field( + proto.BYTES, + number=5, + oneof="consistency_selector", + ) + new_transaction: common.TransactionOptions = proto.Field( + proto.MESSAGE, + number=6, + oneof="consistency_selector", + message=common.TransactionOptions, + ) + read_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=7, + oneof="consistency_selector", + message=timestamp_pb2.Timestamp, + ) + + +class ExecutePipelineResponse(proto.Message): + r"""The response for [Firestore.Execute][]. + + Attributes: + transaction (bytes): + Newly created transaction identifier. + + This field is only specified as part of the first response + from the server, alongside the ``results`` field when the + original request specified + [ExecuteRequest.new_transaction][]. + results (MutableSequence[google.cloud.firestore_v1.types.Document]): + An ordered batch of results returned executing a pipeline. + + The batch size is variable, and can even be zero for when + only a partial progress message is returned. + + The fields present in the returned documents are only those + that were explicitly requested in the pipeline, this + includes those like + [``__name__``][google.firestore.v1.Document.name] and + [``__update_time__``][google.firestore.v1.Document.update_time]. + This is explicitly a divergence from ``Firestore.RunQuery`` + / ``Firestore.GetDocument`` RPCs which always return such + fields even when they are not specified in the + [``mask``][google.firestore.v1.DocumentMask]. + execution_time (google.protobuf.timestamp_pb2.Timestamp): + The time at which the results are valid. + + This is a (not strictly) monotonically increasing value + across multiple responses in the same stream. The API + guarantees that all previously returned results are still + valid at the latest ``execution_time``. This allows the API + consumer to treat the query if it ran at the latest + ``execution_time`` returned. + + If the query returns no results, a response with + ``execution_time`` and no ``results`` will be sent, and this + represents the time at which the operation was run. + explain_stats (google.cloud.firestore_v1.types.ExplainStats): + Query explain stats. + + This is present on the **last** response if the request + configured explain to run in 'analyze' or 'explain' mode in + the pipeline options. If the query does not return any + results, a response with ``explain_stats`` and no + ``results`` will still be sent. + """ + + transaction: bytes = proto.Field( + proto.BYTES, + number=1, + ) + results: MutableSequence[gf_document.Document] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gf_document.Document, + ) + execution_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + explain_stats: gf_explain_stats.ExplainStats = proto.Field( + proto.MESSAGE, + number=4, + message=gf_explain_stats.ExplainStats, + ) + + class RunAggregationQueryRequest(proto.Message): r"""The request for [Firestore.RunAggregationQuery][google.firestore.v1.Firestore.RunAggregationQuery]. @@ -1416,9 +1565,9 @@ class Target(proto.Message): Note that if the client sends multiple ``AddTarget`` requests without an ID, the order of IDs returned in - ``TargetChage.target_ids`` are undefined. Therefore, clients - should provide a target ID instead of relying on the server - to assign one. + ``TargetChange.target_ids`` are undefined. Therefore, + clients should provide a target ID instead of relying on the + server to assign one. If ``target_id`` is non-zero, there must not be an existing active target on this stream with the same ID. diff --git a/google/cloud/firestore_v1/types/pipeline.py b/google/cloud/firestore_v1/types/pipeline.py new file mode 100644 index 000000000..07688dda7 --- /dev/null +++ b/google/cloud/firestore_v1/types/pipeline.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.firestore_v1.types import document + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "StructuredPipeline", + }, +) + + +class StructuredPipeline(proto.Message): + r"""A Firestore query represented as an ordered list of operations / + stages. + + This is considered the top-level function which plans and executes a + query. It is logically equivalent to ``query(stages, options)``, but + prevents the client from having to build a function wrapper. + + Attributes: + pipeline (google.cloud.firestore_v1.types.Pipeline): + Required. The pipeline query to execute. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional query-level arguments. + """ + + pipeline: document.Pipeline = proto.Field( + proto.MESSAGE, + number=1, + message=document.Pipeline, + ) + options: MutableMapping[str, document.Value] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=2, + message=document.Value, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/query.py b/google/cloud/firestore_v1/types/query.py index c2856d0b4..d50742785 100644 --- a/google/cloud/firestore_v1/types/query.py +++ b/google/cloud/firestore_v1/types/query.py @@ -555,9 +555,9 @@ class FindNearest(proto.Message): when the vectors are more similar, the comparison is inverted. - - For EUCLIDEAN, COSINE: WHERE distance <= - distance_threshold - - For DOT_PRODUCT: WHERE distance >= distance_threshold + - For EUCLIDEAN, COSINE: + ``WHERE distance <= distance_threshold`` + - For DOT_PRODUCT: ``WHERE distance >= distance_threshold`` """ class DistanceMeasure(proto.Enum): diff --git a/scripts/fixup_firestore_v1_keywords.py b/scripts/fixup_firestore_v1_keywords.py index 6481e76bb..da4d2b0ec 100644 --- a/scripts/fixup_firestore_v1_keywords.py +++ b/scripts/fixup_firestore_v1_keywords.py @@ -51,6 +51,7 @@ class firestoreCallTransformer(cst.CSTTransformer): 'commit': ('database', 'writes', 'transaction', ), 'create_document': ('parent', 'collection_id', 'document', 'document_id', 'mask', ), 'delete_document': ('name', 'current_document', ), + 'execute_pipeline': ('database', 'structured_pipeline', 'transaction', 'new_transaction', 'read_time', ), 'get_document': ('name', 'mask', 'transaction', 'read_time', ), 'list_collection_ids': ('parent', 'page_size', 'page_token', 'read_time', ), 'list_documents': ('parent', 'collection_id', 'page_size', 'page_token', 'order_by', 'mask', 'transaction', 'read_time', 'show_missing', ), diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index eac609cab..20fb2059f 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -61,7 +61,9 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write as gf_write @@ -3884,6 +3886,185 @@ async def test_run_query_field_headers_async(): ) in kw["metadata"] +@pytest.mark.parametrize( + "request_type", + [ + firestore.ExecutePipelineRequest, + dict, + ], +) +def test_execute_pipeline(request_type, transport: str = "grpc"): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = iter([firestore.ExecutePipelineResponse()]) + response = client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = firestore.ExecutePipelineRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, firestore.ExecutePipelineResponse) + + +def test_execute_pipeline_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = firestore.ExecutePipelineRequest( + database="database_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.execute_pipeline(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == firestore.ExecutePipelineRequest( + database="database_value", + ) + + +def test_execute_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.execute_pipeline in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.execute_pipeline + ] = mock_rpc + request = {} + client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.execute_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_execute_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.execute_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.execute_pipeline + ] = mock_rpc + + request = {} + await client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.execute_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_execute_pipeline_async( + transport: str = "grpc_asyncio", request_type=firestore.ExecutePipelineRequest +): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + response = await client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = firestore.ExecutePipelineRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, firestore.ExecutePipelineResponse) + + +@pytest.mark.asyncio +async def test_execute_pipeline_async_from_dict(): + await test_execute_pipeline_async(request_type=dict) + + @pytest.mark.parametrize( "request_type", [ @@ -7410,7 +7591,7 @@ def test_run_query_rest_unset_required_fields(): assert set(unset_fields) == (set(()) & set(("parent",))) -def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): +def test_execute_pipeline_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -7424,10 +7605,7 @@ def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert ( - client._transport.run_aggregation_query - in client._transport._wrapped_methods - ) + assert client._transport.execute_pipeline in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() @@ -7435,29 +7613,29 @@ def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): "foo" # operation_request.operation in compute client(s) expect a string. ) client._transport._wrapped_methods[ - client._transport.run_aggregation_query + client._transport.execute_pipeline ] = mock_rpc request = {} - client.run_aggregation_query(request) + client.execute_pipeline(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.run_aggregation_query(request) + client.execute_pipeline(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_run_aggregation_query_rest_required_fields( - request_type=firestore.RunAggregationQueryRequest, +def test_execute_pipeline_rest_required_fields( + request_type=firestore.ExecutePipelineRequest, ): transport_class = transports.FirestoreRestTransport request_init = {} - request_init["parent"] = "" + request_init["database"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -7468,21 +7646,21 @@ def test_run_aggregation_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).run_aggregation_query._get_unset_required_fields(jsonified_request) + ).execute_pipeline._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["database"] = "database_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).run_aggregation_query._get_unset_required_fields(jsonified_request) + ).execute_pipeline._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "database" in jsonified_request + assert jsonified_request["database"] == "database_value" client = FirestoreClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7491,7 +7669,7 @@ def test_run_aggregation_query_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = firestore.RunAggregationQueryResponse() + return_value = firestore.ExecutePipelineResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7513,7 +7691,7 @@ def test_run_aggregation_query_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = firestore.RunAggregationQueryResponse.pb(return_value) + return_value = firestore.ExecutePipelineResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) @@ -7523,23 +7701,23 @@ def test_run_aggregation_query_rest_required_fields( with mock.patch.object(response_value, "iter_content") as iter_content: iter_content.return_value = iter(json_return_value) - response = client.run_aggregation_query(request) + response = client.execute_pipeline(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_run_aggregation_query_rest_unset_required_fields(): +def test_execute_pipeline_rest_unset_required_fields(): transport = transports.FirestoreRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.run_aggregation_query._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("parent",))) + unset_fields = transport.execute_pipeline._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("database",))) -def test_partition_query_rest_use_cached_wrapped_rpc(): +def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -7553,30 +7731,35 @@ def test_partition_query_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.partition_query in client._transport._wrapped_methods + assert ( + client._transport.run_aggregation_query + in client._transport._wrapped_methods + ) # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.partition_query] = mock_rpc + client._transport._wrapped_methods[ + client._transport.run_aggregation_query + ] = mock_rpc request = {} - client.partition_query(request) + client.run_aggregation_query(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.partition_query(request) + client.run_aggregation_query(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_partition_query_rest_required_fields( - request_type=firestore.PartitionQueryRequest, +def test_run_aggregation_query_rest_required_fields( + request_type=firestore.RunAggregationQueryRequest, ): transport_class = transports.FirestoreRestTransport @@ -7592,7 +7775,7 @@ def test_partition_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).partition_query._get_unset_required_fields(jsonified_request) + ).run_aggregation_query._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -7601,7 +7784,7 @@ def test_partition_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).partition_query._get_unset_required_fields(jsonified_request) + ).run_aggregation_query._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -7615,7 +7798,7 @@ def test_partition_query_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = firestore.PartitionQueryResponse() + return_value = firestore.RunAggregationQueryResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7637,68 +7820,192 @@ def test_partition_query_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = firestore.PartitionQueryResponse.pb(return_value) + return_value = firestore.RunAggregationQueryResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.partition_query(request) + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + response = client.run_aggregation_query(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_partition_query_rest_unset_required_fields(): +def test_run_aggregation_query_rest_unset_required_fields(): transport = transports.FirestoreRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.partition_query._get_unset_required_fields({}) + unset_fields = transport.run_aggregation_query._get_unset_required_fields({}) assert set(unset_fields) == (set(()) & set(("parent",))) -def test_partition_query_rest_pager(transport: str = "rest"): - client = FirestoreClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, "request") as req: - # TODO(kbandes): remove this mock unless there's a good reason for it. - # with mock.patch.object(path_template, 'transcode') as transcode: - # Set the response as a series of pages - response = ( - firestore.PartitionQueryResponse( - partitions=[ - query.Cursor(), - query.Cursor(), - query.Cursor(), - ], - next_page_token="abc", - ), - firestore.PartitionQueryResponse( - partitions=[], - next_page_token="def", - ), - firestore.PartitionQueryResponse( - partitions=[ - query.Cursor(), - ], - next_page_token="ghi", - ), - firestore.PartitionQueryResponse( - partitions=[ - query.Cursor(), - query.Cursor(), - ], - ), +def test_partition_query_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - # Two responses for two calls - response = response + response + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.partition_query in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.partition_query] = mock_rpc + + request = {} + client.partition_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.partition_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_partition_query_rest_required_fields( + request_type=firestore.PartitionQueryRequest, +): + transport_class = transports.FirestoreRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).partition_query._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).partition_query._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = firestore.PartitionQueryResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = firestore.PartitionQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.partition_query(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_partition_query_rest_unset_required_fields(): + transport = transports.FirestoreRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.partition_query._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("parent",))) + + +def test_partition_query_rest_pager(transport: str = "rest"): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + firestore.PartitionQueryResponse( + partitions=[ + query.Cursor(), + query.Cursor(), + query.Cursor(), + ], + next_page_token="abc", + ), + firestore.PartitionQueryResponse( + partitions=[], + next_page_token="def", + ), + firestore.PartitionQueryResponse( + partitions=[ + query.Cursor(), + ], + next_page_token="ghi", + ), + firestore.PartitionQueryResponse( + partitions=[ + query.Cursor(), + query.Cursor(), + ], + ), + ) + # Two responses for two calls + response = response + response # Wrap the values into proper Response objs response = tuple(firestore.PartitionQueryResponse.to_json(x) for x in response) @@ -8553,6 +8860,27 @@ def test_run_query_empty_call_grpc(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_execute_pipeline_empty_call_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. def test_run_aggregation_query_empty_call_grpc(): @@ -8662,6 +8990,60 @@ def test_create_document_empty_call_grpc(): assert args[0] == request_msg +def test_execute_pipeline_routing_parameters_request_1_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_execute_pipeline_routing_parameters_request_2_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_grpc_asyncio(): transport = FirestoreAsyncClient.get_transport_class("grpc_asyncio")( credentials=async_anonymous_credentials() @@ -8911,6 +9293,32 @@ async def test_run_query_empty_call_grpc_asyncio(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_execute_pipeline_empty_call_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @pytest.mark.asyncio @@ -9048,6 +9456,70 @@ async def test_create_document_empty_call_grpc_asyncio(): assert args[0] == request_msg +@pytest.mark.asyncio +async def test_execute_pipeline_routing_parameters_request_1_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +@pytest.mark.asyncio +async def test_execute_pipeline_routing_parameters_request_2_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_rest(): transport = FirestoreClient.get_transport_class("rest")( credentials=ga_credentials.AnonymousCredentials() @@ -10233,6 +10705,137 @@ def test_run_query_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() +def test_execute_pipeline_rest_bad_request( + request_type=firestore.ExecutePipelineRequest, +): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/databases/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.execute_pipeline(request) + + +@pytest.mark.parametrize( + "request_type", + [ + firestore.ExecutePipelineRequest, + dict, + ], +) +def test_execute_pipeline_rest_call_success(request_type): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/databases/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = firestore.ExecutePipelineResponse( + transaction=b"transaction_blob", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = firestore.ExecutePipelineResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) + response_value.iter_content = mock.Mock(return_value=iter(json_return_value)) + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.execute_pipeline(request) + + assert isinstance(response, Iterable) + response = next(response) + + # Establish that the response is the type that we expect. + assert isinstance(response, firestore.ExecutePipelineResponse) + assert response.transaction == b"transaction_blob" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_execute_pipeline_rest_interceptors(null_interceptor): + transport = transports.FirestoreRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.FirestoreRestInterceptor(), + ) + client = FirestoreClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.FirestoreRestInterceptor, "post_execute_pipeline" + ) as post, mock.patch.object( + transports.FirestoreRestInterceptor, "post_execute_pipeline_with_metadata" + ) as post_with_metadata, mock.patch.object( + transports.FirestoreRestInterceptor, "pre_execute_pipeline" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = firestore.ExecutePipelineRequest.pb( + firestore.ExecutePipelineRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = firestore.ExecutePipelineResponse.to_json( + firestore.ExecutePipelineResponse() + ) + req.return_value.iter_content = mock.Mock(return_value=iter(return_value)) + + request = firestore.ExecutePipelineRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = firestore.ExecutePipelineResponse() + post_with_metadata.return_value = firestore.ExecutePipelineResponse(), metadata + + client.execute_pipeline( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() + + def test_run_aggregation_query_rest_bad_request( request_type=firestore.RunAggregationQueryRequest, ): @@ -11409,6 +12012,26 @@ def test_run_query_empty_call_rest(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_execute_pipeline_empty_call_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. def test_run_aggregation_query_empty_call_rest(): @@ -11513,6 +12136,58 @@ def test_create_document_empty_call_rest(): assert args[0] == request_msg +def test_execute_pipeline_routing_parameters_request_1_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_execute_pipeline_routing_parameters_request_2_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = FirestoreClient( @@ -11555,6 +12230,7 @@ def test_firestore_base_transport(): "commit", "rollback", "run_query", + "execute_pipeline", "run_aggregation_query", "partition_query", "write", @@ -11860,6 +12536,9 @@ def test_firestore_client_transport_session_collision(transport_name): session1 = client1.transport.run_query._session session2 = client2.transport.run_query._session assert session1 != session2 + session1 = client1.transport.execute_pipeline._session + session2 = client2.transport.execute_pipeline._session + assert session1 != session2 session1 = client1.transport.run_aggregation_query._session session2 = client2.transport.run_aggregation_query._session assert session1 != session2