diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index c5fc56bcc9..ee39f80d61 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -27,8 +27,9 @@ from google.api_core.retry import Retry from google.api_core.retry import if_exception_type from google.cloud.exceptions import NotFound -from google.api_core.exceptions import Aborted +from google.api_core.exceptions import Aborted, GoogleAPICallError from google.api_core import gapic_v1 +from google.cloud.spanner_v1.exceptions import SpannerException from google.iam.v1 import iam_policy_pb2 from google.iam.v1 import options_pb2 from google.protobuf.field_mask_pb2 import FieldMask @@ -525,11 +526,11 @@ def create(self): database_dialect=self._database_dialect, proto_descriptors=self._proto_descriptors, ) - future = api.create_database( - request=request, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + return _call_api_with_request_id( + api.create_database, + request, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) - return future def exists(self): """Test whether this database exists. @@ -544,11 +545,10 @@ def exists(self): metadata = _metadata_with_prefix(self.name) try: - api.get_database_ddl( - database=self.name, - metadata=self.metadata_with_request_id( - self._next_nth_request, 1, metadata - ), + _call_api_with_request_id( + api.get_database_ddl, + {"database": self.name}, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) except NotFound: return False @@ -566,15 +566,17 @@ def reload(self): """ api = self._instance._client.database_admin_api metadata = _metadata_with_prefix(self.name) - response = api.get_database_ddl( - database=self.name, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + response = _call_api_with_request_id( + api.get_database_ddl, + {"database": self.name}, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) self._ddl_statements = tuple(response.statements) self._proto_descriptors = response.proto_descriptors - response = api.get_database( - name=self.name, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + response = _call_api_with_request_id( + api.get_database, + {"name": self.name}, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) self._state = DatabasePB.State(response.state) self._create_time = response.create_time @@ -620,11 +622,11 @@ def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None): proto_descriptors=proto_descriptors, ) - future = api.update_database_ddl( - request=request, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + return _call_api_with_request_id( + api.update_database_ddl, + request, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) - return future def update(self, fields): """Update this database. @@ -660,14 +662,12 @@ def update(self, fields): field_mask = FieldMask(paths=fields) metadata = _metadata_with_prefix(self.name) - future = api.update_database( - database=database_pb, - update_mask=field_mask, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + return _call_api_with_request_id( + api.update_database, + {"database": database_pb, "update_mask": field_mask}, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) - return future - def drop(self): """Drop this database. @@ -676,9 +676,10 @@ def drop(self): """ api = self._instance._client.database_admin_api metadata = _metadata_with_prefix(self.name) - api.drop_database( - database=self.name, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + _call_api_with_request_id( + api.drop_database, + {"database": self.name}, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) def execute_partitioned_dml( @@ -1071,11 +1072,11 @@ def restore(self, source): backup=source.name, encryption_config=self._encryption_config or None, ) - future = api.restore_database( - request=request, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + return _call_api_with_request_id( + api.restore_database, + request, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) - return future def is_ready(self): """Test whether this database is ready for use. @@ -1142,9 +1143,10 @@ def list_database_roles(self, page_size=None): parent=self.name, page_size=page_size, ) - return api.list_database_roles( - request=request, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + return _call_api_with_request_id( + api.list_database_roles, + request, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) def table(self, table_id): @@ -1229,11 +1231,11 @@ def get_iam_policy(self, policy_version=None): requested_policy_version=policy_version ), ) - response = api.get_iam_policy( - request=request, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + return _call_api_with_request_id( + api.get_iam_policy, + request, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) - return response def set_iam_policy(self, policy): """Sets the access control policy on a database resource. @@ -1254,11 +1256,11 @@ def set_iam_policy(self, policy): resource=self.name, policy=policy, ) - response = api.set_iam_policy( - request=request, - metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + return _call_api_with_request_id( + api.set_iam_policy, + request, + self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) - return response @property def observability_options(self): @@ -2005,6 +2007,39 @@ def close(self): self._session.delete() +def _call_api_with_request_id(api_callable, request, metadata): + """Helper to call a GAPIC API callable and wrap exceptions. + + :type api_callable: callable + :param api_callable: GAPIC method implementing the API call. + + :type request: + :class:`~google.cloud.spanner_admin_database_v1.types.CreateDatabaseRequest` + or :class:`~google.cloud.spanner_admin_database_v1.types.UpdateDatabaseDdlRequest` + or :class:`~google.cloud.spanner_admin_database_v1.types.DropDatabaseRequest` + or :class:`~google.cloud.spanner_admin_database_v1.types.GetDatabaseDdlRequest` + :param request: The request protobuf. + + :type metadata: list of tuple + :param metadata: The metadata for the request. + + :rtype: varies + :returns: The result of the API call. + :raises: :class:`~google.cloud.spanner_v1.exceptions.SpannerException` + if the API call fails. + """ + try: + return api_callable(request=request, metadata=metadata) + except GoogleAPICallError as e: + request_id = dict(metadata).get("x-goog-spanner-request-id") + raise SpannerException( + message=e.message, + errors=e.errors, + response=e.response, + request_id=request_id, + ) from e + + def _check_ddl_statements(value): """Validate DDL Statements used to define database schema. diff --git a/google/cloud/spanner_v1/exceptions.py b/google/cloud/spanner_v1/exceptions.py new file mode 100644 index 0000000000..bfed156fb6 --- /dev/null +++ b/google/cloud/spanner_v1/exceptions.py @@ -0,0 +1,34 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom exceptions for Google Cloud Spanner.""" + +from google.api_core import exceptions + + +class SpannerException(exceptions.GoogleAPICallError): + """Base class for all Spanner exceptions.""" + + def __init__(self, message, errors=None, response=None, request_id=None): + super().__init__(message, errors, response) + self._request_id = request_id + + @property + def request_id(self): + """The request ID associated with the failed API call. + + :rtype: str + :returns: The request ID. + """ + return self._request_id