From d3ebf9f8ac6af54e6e0cc49b15128f1b458a52d8 Mon Sep 17 00:00:00 2001 From: Ishaan Gandhi Date: Wed, 26 Nov 2025 19:28:56 -0500 Subject: [PATCH] fix: add types to some functions Many return and argument types were missing type annotations. --- google/cloud/firestore_v1/async_batch.py | 3 ++- google/cloud/firestore_v1/async_client.py | 17 ++++++++++++++--- google/cloud/firestore_v1/async_collection.py | 5 +++-- google/cloud/firestore_v1/base_batch.py | 8 +++++--- google/cloud/firestore_v1/base_client.py | 6 ++++-- google/cloud/firestore_v1/base_collection.py | 3 ++- google/cloud/firestore_v1/base_document.py | 6 +++--- google/cloud/firestore_v1/client.py | 7 +++++-- google/cloud/firestore_v1/document.py | 2 +- 9 files changed, 39 insertions(+), 18 deletions(-) diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index 689753fe9..f74ccacea 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -19,6 +19,7 @@ from google.api_core import retry_async as retries from google.cloud.firestore_v1.base_batch import BaseWriteBatch +from google.cloud.firestore_v1.types.write import WriteResult class AsyncWriteBatch(BaseWriteBatch): @@ -40,7 +41,7 @@ async def commit( self, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, - ) -> list: + ) -> list[WriteResult]: """Commit the changes accumulated in this batch. Args: diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 15b31af31..fd016dfe7 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -25,7 +25,15 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Iterable, + List, + Optional, + Union, +) from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -40,6 +48,7 @@ from google.cloud.firestore_v1.async_transaction import AsyncTransaction from google.cloud.firestore_v1.base_client import _parse_batch_get # type: ignore from google.cloud.firestore_v1.base_client import _CLIENT_INFO, BaseClient, _path_helper +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.services.firestore import ( async_client as firestore_client, @@ -410,7 +419,9 @@ def batch(self) -> AsyncWriteBatch: """ return AsyncWriteBatch(self) - def transaction(self, **kwargs) -> AsyncTransaction: + def transaction( + self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False + ) -> AsyncTransaction: """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` for @@ -426,4 +437,4 @@ def transaction(self, **kwargs) -> AsyncTransaction: :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`: A transaction attached to this client. """ - return AsyncTransaction(self, **kwargs) + return AsyncTransaction(self, max_attempts=max_attempts, read_only=read_only) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index cc99aa460..561111163 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -15,7 +15,7 @@ """Classes for representing collections for the Google Cloud Firestore API.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple, cast from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -153,7 +153,8 @@ def document(self, document_id: str | None = None) -> AsyncDocumentReference: :class:`~google.cloud.firestore_v1.document.async_document.AsyncDocumentReference`: The child document. """ - return super(AsyncCollectionReference, self).document(document_id) + doc = super(AsyncCollectionReference, self).document(document_id) + return cast("AsyncDocumentReference", doc) async def list_documents( self, diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index b0d50f1f4..851c7849f 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -15,7 +15,7 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" from __future__ import annotations import abc -from typing import Dict, Union +from typing import Any, Dict, Union # Types needed only for Type Hints from google.api_core import retry as retries @@ -67,7 +67,9 @@ def commit(self): write depend on the implementing class.""" raise NotImplementedError() - def create(self, reference: BaseDocumentReference, document_data: dict) -> None: + def create( + self, reference: BaseDocumentReference, document_data: dict[str, Any] + ) -> None: """Add a "change" to this batch to create a document. If the document given by ``reference`` already exists, then this @@ -120,7 +122,7 @@ def set( def update( self, reference: BaseDocumentReference, - field_updates: dict, + field_updates: dict[str, Any], option: _helpers.WriteOption | None = None, ) -> None: """Add a "change" to update a document. diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 4a0e3f6b8..f3eeeae49 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -57,7 +57,7 @@ DocumentSnapshot, ) from google.cloud.firestore_v1.base_query import BaseQuery -from google.cloud.firestore_v1.base_transaction import BaseTransaction +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS, BaseTransaction from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.services.firestore import client as firestore_client @@ -497,7 +497,9 @@ def collections( def batch(self) -> BaseWriteBatch: raise NotImplementedError - def transaction(self, **kwargs) -> BaseTransaction: + def transaction( + self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False + ) -> BaseTransaction: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 1b1ef0411..be817c5fe 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -35,6 +35,7 @@ from google.api_core import retry as retries from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.cloud.firestore_v1.base_query import QueryType if TYPE_CHECKING: # pragma: NO COVER @@ -133,7 +134,7 @@ def _aggregation_query(self) -> BaseAggregationQuery: def _vector_query(self) -> BaseVectorQuery: raise NotImplementedError - def document(self, document_id: Optional[str] = None): + def document(self, document_id: Optional[str] = None) -> BaseDocumentReference: """Create a sub-document underneath the current collection. Args: diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index 517db20d3..fe6113bfc 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -418,7 +418,7 @@ def _client(self): return self._reference._client @property - def exists(self): + def exists(self) -> bool: """Existence flag. Indicates if the document existed at the time this snapshot @@ -430,7 +430,7 @@ def exists(self): return self._exists @property - def id(self): + def id(self) -> str: """The document identifier (within its collection). Returns: @@ -439,7 +439,7 @@ def id(self): return self._reference.id @property - def reference(self): + def reference(self) -> BaseDocumentReference: """Document reference corresponding to document that owns this data. Returns: diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index ec906f991..54943aded 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -39,6 +39,7 @@ # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.document import DocumentReference @@ -391,7 +392,9 @@ def batch(self) -> WriteBatch: """ return WriteBatch(self) - def transaction(self, **kwargs) -> Transaction: + def transaction( + self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False + ) -> Transaction: """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.transaction.Transaction` for @@ -407,4 +410,4 @@ def transaction(self, **kwargs) -> Transaction: :class:`~google.cloud.firestore_v1.transaction.Transaction`: A transaction attached to this client. """ - return Transaction(self, **kwargs) + return Transaction(self, max_attempts=max_attempts, read_only=read_only) diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 4e0132e49..4bb6399a7 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -169,7 +169,7 @@ def set( def update( self, - field_updates: dict, + field_updates: dict[str, Any], option: _helpers.WriteOption | None = None, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None,