From 8f6c10b57cc3cff59fe99e7481babcfb909edbe6 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 15:17:03 -0700 Subject: [PATCH 01/13] Setup pyright --- .github/workflows/pythontest.yml | 5 ++++- .gitignore | 3 ++- .pre-commit-config.yaml | 4 ++++ pyproject.toml | 16 +++++++++++++++- tdclient/py.typed | 0 5 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 tdclient/py.typed diff --git a/.github/workflows/pythontest.yml b/.github/workflows/pythontest.yml index e8e486c..d3a08a3 100644 --- a/.github/workflows/pythontest.yml +++ b/.github/workflows/pythontest.yml @@ -20,9 +20,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install . + pip install ".[dev]" pip install -r requirements.txt -r test-requirements.txt pip install -U coveralls pyyaml + - name: Run pyright + run: | + pyright tdclient - name: Run test run: | coverage run --source=tdclient -m pytest tdclient/test diff --git a/.gitignore b/.gitignore index 3022a30..5f1c7df 100644 --- a/.gitignore +++ b/.gitignore @@ -8,10 +8,11 @@ /.eggs /.envrc /.tox -/.venv +/.venv* /build /dist /tmp +/.agent # JetBrains IDE .idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e9e020..d6a2ed1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,7 @@ repos: hooks: - id: black language_version: python3.7 + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.407 + hooks: + - id: pyright diff --git a/pyproject.toml b/pyproject.toml index b7b208a..f24a87b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,12 +34,15 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["ruff"] +dev = ["ruff", "pyright"] docs = ["sphinx", "sphinx_rtd_theme"] [tool.setuptools] packages = ["tdclient"] +[tool.setuptools.package-data] +tdclient = ["py.typed"] + [tool.ruff] line-length = 88 @@ -53,3 +56,14 @@ ignore = ["E203", "E501"] [tool.ruff.lint.isort] known-third-party = ["dateutil","msgpack","pkg_resources","pytest","setuptools","urllib3"] + +[tool.pyright] +include = ["tdclient"] +exclude = ["**/__pycache__", "tdclient/test"] +typeCheckingMode = "basic" +pythonVersion = "3.9" +pythonPlatform = "All" +reportMissingTypeStubs = false +reportUnknownMemberType = false +reportUnknownArgumentType = false +reportUnknownVariableType = false diff --git a/tdclient/py.typed b/tdclient/py.typed new file mode 100644 index 0000000..e69de29 From 85ec7605d7ddfb1f9ffa9884f6e0fa54af77555a Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 15:44:56 -0700 Subject: [PATCH 02/13] Use ruff for pre-commit --- .pre-commit-config.yaml | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d6a2ed1..5fbb266 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,19 +8,14 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files - - repo: https://github.com/asottile/seed-isort-config - rev: v1.9.3 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.2 hooks: - - id: seed-isort-config - - repo: https://github.com/pre-commit/mirrors-isort - rev: v4.3.21 - hooks: - - id: isort - - repo: https://github.com/python/black - rev: stable - hooks: - - id: black - language_version: python3.7 + # Run the linter. + - id: ruff + args: [--fix] + # Run the formatter. + - id: ruff-format - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.407 hooks: From c8e5d7d50bbbf8abe7c6445395c619ebfb48fad7 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 15:46:46 -0700 Subject: [PATCH 03/13] Make pyright manual on pre-commit --- .pre-commit-config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5fbb266..0ab8edc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -20,3 +20,4 @@ repos: rev: v1.1.407 hooks: - id: pyright + stages: [manual] From ee1e3d8e0745b463bb760336786ffb65a44011be Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 15:45:28 -0700 Subject: [PATCH 04/13] Add type hint for client.py --- docs/conf.py | 14 +- pyproject.toml | 1 + tdclient/__init__.py | 20 +- tdclient/bulk_import_model.py | 6 +- tdclient/client.py | 282 +++++++++++++++----- tdclient/job_api.py | 4 +- tdclient/result_model.py | 1 - tdclient/test/bulk_import_model_test.py | 4 +- tdclient/test/client_test.py | 2 + tdclient/test/dtypes_and_converters_test.py | 2 +- 10 files changed, 240 insertions(+), 96 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 84b0b85..534f521 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -78,9 +78,9 @@ def linkcode_resolve(domain, info): # -- Project information ----------------------------------------------------- -project = 'td-client-python' -copyright = '2019, Arm Treasure Data' -author = 'Arm Treasure Data' +project = "td-client-python" +copyright = "2019, Arm Treasure Data" +author = "Arm Treasure Data" # The full version, including alpha/beta/rc tags release = pkg_resources.get_distribution(PACKAGE).version @@ -98,12 +98,12 @@ def linkcode_resolve(domain, info): ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -111,11 +111,11 @@ def linkcode_resolve(domain, info): # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] -autodoc_member_order = 'groupwise' +autodoc_member_order = "groupwise" diff --git a/pyproject.toml b/pyproject.toml index f24a87b..17dceac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "python-dateutil", "msgpack>=0.6.2", "urllib3", + "typing-extensions>=4.0.0", ] [project.optional-dependencies] diff --git a/tdclient/__init__.py b/tdclient/__init__.py index 0b9ac40..f569a41 100644 --- a/tdclient/__init__.py +++ b/tdclient/__init__.py @@ -1,4 +1,8 @@ +from __future__ import annotations + import datetime +import time +from typing import Any from tdclient import client, connection, errors, version @@ -7,7 +11,7 @@ Client = client.Client -def connect(*args, **kwargs): +def connect(*args: Any, **kwargs: Any) -> connection.Connection: """Returns a DBAPI compatible connection object Args: @@ -44,19 +48,19 @@ def connect(*args, **kwargs): Timestamp = datetime.datetime -def DateFromTicks(ticks): - return datetime.date(*datetime.localtime(ticks)[:3]) +def DateFromTicks(ticks: float) -> datetime.date: + return datetime.date(*time.localtime(ticks)[:3]) -def TimeFromTicks(ticks): - return datetime.time(*datetime.localtime(ticks)[3:6]) +def TimeFromTicks(ticks: float) -> datetime.time: + return datetime.time(*time.localtime(ticks)[3:6]) -def TimestampFromTicks(ticks): - return datetime.datetime(*datetime.localtime(ticks)[:6]) +def TimestampFromTicks(ticks: float) -> datetime.datetime: + return datetime.datetime(*time.localtime(ticks)[:6]) -def Binary(string): +def Binary(string: bytes) -> bytes: return bytes(string) diff --git a/tdclient/bulk_import_model.py b/tdclient/bulk_import_model.py index 9cc0ff6..235d7fd 100644 --- a/tdclient/bulk_import_model.py +++ b/tdclient/bulk_import_model.py @@ -118,7 +118,11 @@ def perform(self, wait=False, wait_interval=5, wait_callback=None, timeout=None) ) job = self._client.perform_bulk_import(self.name) if wait: - job.wait(timeout=timeout, wait_interval=wait_interval, wait_callback=wait_callback) + job.wait( + timeout=timeout, + wait_interval=wait_interval, + wait_callback=wait_callback, + ) self.update() return job diff --git a/tdclient/client.py b/tdclient/client.py index b4b779f..e401508 100644 --- a/tdclient/client.py +++ b/tdclient/client.py @@ -1,44 +1,115 @@ #!/usr/bin/env python +from __future__ import annotations + +import datetime import json +from collections.abc import Iterator +from typing import IO, Any, Literal, TypedDict + +from typing_extensions import TypeAlias from tdclient import api, models +# Type aliases for file-like objects +FileLike: TypeAlias = "str | bytes | IO[bytes]" + +# Common literal types +QueryEngineType: TypeAlias = 'Literal["presto", "hive"]' +EngineVersion: TypeAlias = 'Literal["stable", "experimental"]' +Priority: TypeAlias = "Literal[-2, -1, 0, 1, 2]" +ExportFileFormat: TypeAlias = 'Literal["jsonl.gz", "tsv.gz", "json.gz"]' +DataFormat: TypeAlias = 'Literal["msgpack", "msgpack.gz", "json", "json.gz", "csv", "csv.gz", "tsv", "tsv.gz"]' +ResultFormat: TypeAlias = 'Literal["msgpack", "json", "csv", "tsv"]' + + +class ScheduleParams(TypedDict, total=False): + """Parameters for create_schedule and update_schedule""" + + type: QueryEngineType # Query type + database: str # Target database name + timezone: str # Timezone e.g. "UTC" + cron: str # Schedule: "@daily", "@hourly", or cron expression + delay: int # Delay in seconds before running + query: str # SQL query to execute + priority: Priority # Priority: -2 (very low) to 2 (very high) + retry_limit: int # Automatic retry count + engine_version: EngineVersion # Engine version + pool_name: str # For Presto only: pool name + result: str # Result output location URL + + +class ExportParams(TypedDict, total=False): + """Parameters for export_data""" + + access_key_id: str # ID to access the export destination + secret_access_key: str # Password for access_key_id + file_prefix: str # Filename prefix for exported file + file_format: ExportFileFormat # File format + from_: ( + int # Start time in Unix epoch format (use 'from_' to avoid keyword conflict) + ) + to: int # End time in Unix epoch format + assume_role: str # Assume role ARN + bucket: str # Bucket name + domain_key: str # Job domain key + pool_name: str # For Presto only: pool name + + +class BulkImportParams(TypedDict, total=False): + """Parameters for create_bulk_import""" + + # Add any optional parameters for bulk import if needed + pass + + +class ResultParams(TypedDict, total=False): + """Parameters for create_result""" + + # Add any optional parameters for result creation if needed + pass + + class Client: """API Client for Treasure Data Service""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self._api = api.API(*args, **kwargs) - def __enter__(self): + def __enter__(self) -> Client: return self - def __exit__(self, type, value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: object, + ) -> None: self.close() @property - def api(self): + def api(self) -> api.API: """ an instance of :class:`tdclient.api.API` """ return self._api @property - def apikey(self): + def apikey(self) -> str | None: """ API key string. """ return self._api.apikey - def server_status(self): + def server_status(self) -> str: """ Returns: a string represents current server status. """ return self.api.server_status() - def create_database(self, db_name, **kwargs): + def create_database(self, db_name: str, **kwargs: Any) -> bool: """ Args: db_name (str): name of a database to create @@ -48,7 +119,7 @@ def create_database(self, db_name, **kwargs): """ return self.api.create_database(db_name, **kwargs) - def delete_database(self, db_name): + def delete_database(self, db_name: str) -> bool: """ Args: db_name (str): name of database to delete @@ -58,7 +129,7 @@ def delete_database(self, db_name): """ return self.api.delete_database(db_name) - def databases(self): + def databases(self) -> list[models.Database]: """ Returns: a list of :class:`tdclient.models.Database` @@ -69,7 +140,7 @@ def databases(self): for (db_name, kwargs) in databases.items() ] - def database(self, db_name): + def database(self, db_name: str) -> models.Database: """ Args: db_name (str): name of a database @@ -83,7 +154,7 @@ def database(self, db_name): return models.Database(self, name, **kwargs) raise api.NotFoundError("Database '%s' does not exist" % (db_name)) - def create_log_table(self, db_name, table_name): + def create_log_table(self, db_name: str, table_name: str) -> bool: """ Args: db_name (str): name of a database @@ -94,7 +165,7 @@ def create_log_table(self, db_name, table_name): """ return self.api.create_log_table(db_name, table_name) - def swap_table(self, db_name, table_name1, table_name2): + def swap_table(self, db_name: str, table_name1: str, table_name2: str) -> bool: """ Args: db_name (str): name of a database @@ -106,7 +177,9 @@ def swap_table(self, db_name, table_name1, table_name2): """ return self.api.swap_table(db_name, table_name1, table_name2) - def update_schema(self, db_name, table_name, schema): + def update_schema( + self, db_name: str, table_name: str, schema: list[list[str]] + ) -> bool: """Updates the schema of a table Args: @@ -132,7 +205,7 @@ def update_schema(self, db_name, table_name, schema): """ return self.api.update_schema(db_name, table_name, json.dumps(schema)) - def update_expire(self, db_name, table_name, expire_days): + def update_expire(self, db_name: str, table_name: str, expire_days: int) -> bool: """Set expiration date to a table Args: @@ -145,7 +218,7 @@ def update_expire(self, db_name, table_name, expire_days): """ return self.api.update_expire(db_name, table_name, expire_days) - def delete_table(self, db_name, table_name): + def delete_table(self, db_name: str, table_name: str) -> str: """Delete a table Args: @@ -157,7 +230,7 @@ def delete_table(self, db_name, table_name): """ return self.api.delete_table(db_name, table_name) - def tables(self, db_name): + def tables(self, db_name: str) -> list[models.Table]: """List existing tables Args: @@ -172,7 +245,7 @@ def tables(self, db_name): for (table_name, kwargs) in m.items() ] - def table(self, db_name, table_name): + def table(self, db_name: str, table_name: str) -> models.Table: """ Args: db_name (str): name of a database @@ -190,7 +263,15 @@ def table(self, db_name, table_name): return table raise api.NotFoundError("Table '%s.%s' does not exist" % (db_name, table_name)) - def tail(self, db_name, table_name, count, to=None, _from=None, block=None): + def tail( + self, + db_name: str, + table_name: str, + count: int, + to: None = None, + _from: None = None, + block: None = None, + ) -> list[dict[str, Any]]: """Get the contents of the table in reverse order based on the registered time (last data first). @@ -207,7 +288,7 @@ def tail(self, db_name, table_name, count, to=None, _from=None, block=None): """ return self.api.tail(db_name, table_name, count, to, _from, block) - def change_database(self, db_name, table_name, new_db_name): + def change_database(self, db_name: str, table_name: str, new_db_name: str) -> bool: """Move a target table from it's original database to new destination database. Args: @@ -222,14 +303,14 @@ def change_database(self, db_name, table_name, new_db_name): def query( self, - db_name, - q, - result_url=None, - priority=None, - retry_limit=None, - type="hive", - **kwargs, - ): + db_name: str, + q: str, + result_url: str | None = None, + priority: int | str | None = None, + retry_limit: int | None = None, + type: str = "hive", + **kwargs: Any, + ) -> models.Job: """Run a query on specified database table. Args: @@ -261,7 +342,13 @@ def query( ) return models.Job(self, job_id, type, q) - def jobs(self, _from=None, to=None, status=None, conditions=None): + def jobs( + self, + _from: int | None = None, + to: int | None = None, + status: str | None = None, + conditions: str | None = None, + ) -> list[models.Job]: """List jobs Args: @@ -275,11 +362,11 @@ def jobs(self, _from=None, to=None, status=None, conditions=None): Returns: a list of :class:`tdclient.models.Job` """ - results = self.api.list_jobs(_from, to, status, conditions) + results = self.api.list_jobs(_from or 0, to, status, conditions) return [job_from_dict(self, d) for d in results] - def job(self, job_id): + def job(self, job_id: str | int) -> models.Job: """Get a job from `job_id` Args: @@ -291,7 +378,7 @@ def job(self, job_id): d = self.api.show_job(str(job_id)) return job_from_dict(self, d, job_id=job_id) - def job_status(self, job_id): + def job_status(self, job_id: str | int) -> str: """ Args: job_id (str): job id @@ -301,7 +388,7 @@ def job_status(self, job_id): """ return self.api.job_status(job_id) - def job_result(self, job_id): + def job_result(self, job_id: str | int) -> list[Any]: """ Args: job_id (str): job id @@ -311,7 +398,7 @@ def job_result(self, job_id): """ return self.api.job_result(job_id) - def job_result_each(self, job_id): + def job_result_each(self, job_id: str | int) -> Iterator[Any]: """ Args: job_id (str): job id @@ -322,7 +409,9 @@ def job_result_each(self, job_id): for row in self.api.job_result_each(job_id): yield row - def job_result_format(self, job_id, format, header=False): + def job_result_format( + self, job_id: str | int, format: ResultFormat, header: bool = False + ) -> list[Any]: """ Args: job_id (str): job id @@ -334,8 +423,13 @@ def job_result_format(self, job_id, format, header=False): return self.api.job_result_format(job_id, format, header=header) def job_result_format_each( - self, job_id, format, header=False, store_tmpfile=False, num_threads=4 - ): + self, + job_id: str | int, + format: ResultFormat, + header: bool = False, + store_tmpfile: bool = False, + num_threads: int = 4, + ) -> Iterator[Any]: """ Args: job_id (str): job id @@ -359,7 +453,9 @@ def job_result_format_each( ): yield row - def download_job_result(self, job_id, path, num_threads=4): + def download_job_result( + self, job_id: str | int, path: str, num_threads: int = 4 + ) -> bool: """Save the job result into a msgpack.gz file. Args: job_id (str): job id @@ -372,7 +468,7 @@ def download_job_result(self, job_id, path, num_threads=4): """ return self.api.download_job_result(job_id, path, num_threads=num_threads) - def kill(self, job_id): + def kill(self, job_id: str | int) -> str: """ Args: job_id (str): job id @@ -382,7 +478,13 @@ def kill(self, job_id): """ return self.api.kill(job_id) - def export_data(self, db_name, table_name, storage_type, params=None): + def export_data( + self, + db_name: str, + table_name: str, + storage_type: str, + params: ExportParams | None = None, + ) -> models.Job: """Export data from Treasure Data Service Args: @@ -421,7 +523,13 @@ def export_data(self, db_name, table_name, storage_type, params=None): job_id = self.api.export_data(db_name, table_name, storage_type, params) return models.Job(self, job_id, "export", None) - def create_bulk_import(self, name, database, table, params=None): + def create_bulk_import( + self, + name: str, + database: str, + table: str, + params: BulkImportParams | None = None, + ) -> models.BulkImport: """Create new bulk import session Args: @@ -436,7 +544,7 @@ def create_bulk_import(self, name, database, table, params=None): self.api.create_bulk_import(name, database, table, params) return models.BulkImport(self, name=name, database=database, table=table) - def delete_bulk_import(self, name): + def delete_bulk_import(self, name: str) -> bool: """Delete a bulk import session Args: @@ -447,7 +555,7 @@ def delete_bulk_import(self, name): """ return self.api.delete_bulk_import(name) - def freeze_bulk_import(self, name): + def freeze_bulk_import(self, name: str) -> bool: """Freeze a bulk import session Args: @@ -458,7 +566,7 @@ def freeze_bulk_import(self, name): """ return self.api.freeze_bulk_import(name) - def unfreeze_bulk_import(self, name): + def unfreeze_bulk_import(self, name: str) -> bool: """Unfreeze a bulk import session Args: @@ -469,7 +577,7 @@ def unfreeze_bulk_import(self, name): """ return self.api.unfreeze_bulk_import(name) - def perform_bulk_import(self, name): + def perform_bulk_import(self, name: str) -> models.Job: """Perform a bulk import session Args: @@ -481,7 +589,7 @@ def perform_bulk_import(self, name): job_id = self.api.perform_bulk_import(name) return models.Job(self, job_id, "bulk_import", None) - def commit_bulk_import(self, name): + def commit_bulk_import(self, name: str) -> bool: """Commit a bulk import session Args: @@ -492,7 +600,7 @@ def commit_bulk_import(self, name): """ return self.api.commit_bulk_import(name) - def bulk_import_error_records(self, name): + def bulk_import_error_records(self, name: str) -> Iterator[Any]: """ Args: name (str): name of a bulk import session @@ -503,7 +611,7 @@ def bulk_import_error_records(self, name): for record in self.api.bulk_import_error_records(name): yield record - def bulk_import(self, name): + def bulk_import(self, name: str) -> models.BulkImport: """Get a bulk import session Args: @@ -515,7 +623,7 @@ def bulk_import(self, name): data = self.api.show_bulk_import(name) return models.BulkImport(self, **data) - def bulk_imports(self): + def bulk_imports(self) -> list[models.BulkImport]: """List bulk import sessions Returns: @@ -525,7 +633,9 @@ def bulk_imports(self): models.BulkImport(self, **data) for data in self.api.list_bulk_imports() ] - def bulk_import_upload_part(self, name, part_name, bytes_or_stream, size): + def bulk_import_upload_part( + self, name: str, part_name: str, bytes_or_stream: FileLike, size: int + ) -> None: """Upload a part to a bulk import session Args: @@ -536,7 +646,14 @@ def bulk_import_upload_part(self, name, part_name, bytes_or_stream, size): """ return self.api.bulk_import_upload_part(name, part_name, bytes_or_stream, size) - def bulk_import_upload_file(self, name, part_name, format, file, **kwargs): + def bulk_import_upload_file( + self, + name: str, + part_name: str, + format: DataFormat, + file: FileLike, + **kwargs: Any, + ) -> None: """Upload a part to Bulk Import session, from an existing file on filesystem. Args: @@ -570,7 +687,7 @@ def bulk_import_upload_file(self, name, part_name, format, file, **kwargs): """ return self.api.bulk_import_upload_file(name, part_name, format, file, **kwargs) - def bulk_import_delete_part(self, name, part_name): + def bulk_import_delete_part(self, name: str, part_name: str) -> bool: """Delete a part from a bulk import session Args: @@ -582,7 +699,7 @@ def bulk_import_delete_part(self, name, part_name): """ return self.api.bulk_import_delete_part(name, part_name) - def list_bulk_import_parts(self, name): + def list_bulk_import_parts(self, name: str) -> list[str]: """List parts of a bulk import session Args: @@ -593,7 +710,9 @@ def list_bulk_import_parts(self, name): """ return self.api.list_bulk_import_parts(name) - def create_schedule(self, name, params=None): + def create_schedule( + self, name: str, params: ScheduleParams | None = None + ) -> datetime.datetime | None: """Create a new scheduled query with the specified name. Args: @@ -635,14 +754,14 @@ def create_schedule(self, name, params=None): Returns: :class:`datetime.datetime`: Start date time. """ + params = {} if params is None else params if "cron" not in params: raise ValueError("'cron' option is required") if "query" not in params: raise ValueError("'query' option is required") - params = {} if params is None else params return self.api.create_schedule(name, params) - def delete_schedule(self, name): + def delete_schedule(self, name: str) -> tuple[str, str]: """Delete the scheduled query with the specified name. Args: @@ -652,7 +771,7 @@ def delete_schedule(self, name): """ return self.api.delete_schedule(name) - def schedules(self): + def schedules(self) -> list[models.Schedule]: """Get the list of all the scheduled queries. Returns: @@ -661,7 +780,7 @@ def schedules(self): result = self.api.list_schedules() return [models.Schedule(self, **m) for m in result] - def update_schedule(self, name, params=None): + def update_schedule(self, name: str, params: ScheduleParams | None = None) -> None: """Update the scheduled query. Args: @@ -704,7 +823,9 @@ def update_schedule(self, name, params=None): params = {} if params is None else params self.api.update_schedule(name, params) - def history(self, name, _from=None, to=None): + def history( + self, name: str, _from: int | None = None, to: int | None = None + ) -> list[models.ScheduledJob]: """Get the history details of the saved query for the past 90days. Args: @@ -720,7 +841,7 @@ def history(self, name, _from=None, to=None): Returns: [:class:`tdclient.models.ScheduledJob`] """ - result = self.api.history(name, _from, to) + result = self.api.history(name, _from or 0, to) def scheduled_job(m): ( @@ -756,7 +877,7 @@ def scheduled_job(m): return [scheduled_job(m) for m in result] - def run_schedule(self, name, time, num): + def run_schedule(self, name: str, time: int, num: int) -> list[models.ScheduledJob]: """Execute the specified query. Args: @@ -776,8 +897,14 @@ def scheduled_job(m): return [scheduled_job(m) for m in results] def import_data( - self, db_name, table_name, format, bytes_or_stream, size, unique_id=None - ): + self, + db_name: str, + table_name: str, + format: DataFormat, + bytes_or_stream: FileLike, + size: int, + unique_id: str | None = None, + ) -> float: """Import data into Treasure Data Service Args: @@ -795,7 +922,14 @@ def import_data( db_name, table_name, format, bytes_or_stream, size, unique_id=unique_id ) - def import_file(self, db_name, table_name, format, file, unique_id=None): + def import_file( + self, + db_name: str, + table_name: str, + format: DataFormat, + file: FileLike, + unique_id: str | None = None, + ) -> float: """Import data into Treasure Data Service, from an existing file on filesystem. This method will decompress/deserialize records from given file, and then @@ -815,7 +949,7 @@ def import_file(self, db_name, table_name, format, file, unique_id=None): db_name, table_name, format, file, unique_id=unique_id ) - def results(self): + def results(self) -> list[models.Result]: """Get the list of all the available authentications. Returns: @@ -829,7 +963,9 @@ def result(m): return [result(m) for m in results] - def create_result(self, name, url, params=None): + def create_result( + self, name: str, url: str, params: ResultParams | None = None + ) -> bool: """Create a new authentication with the specified name. Args: @@ -842,7 +978,7 @@ def create_result(self, name, url, params=None): params = {} if params is None else params return self.api.create_result(name, url, params) - def delete_result(self, name): + def delete_result(self, name: str) -> bool: """Delete the authentication having the specified name. Args: @@ -866,7 +1002,7 @@ def user(m): return [user(m) for m in results] - def add_user(self, name, org, email, password): + def add_user(self, name: str, org: str, email: str, password: str) -> bool: """Add a new user Args: @@ -880,7 +1016,7 @@ def add_user(self, name, org, email, password): """ return self.api.add_user(name, org, email, password) - def remove_user(self, name): + def remove_user(self, name: str) -> bool: """Remove a user Args: @@ -891,7 +1027,7 @@ def remove_user(self, name): """ return self.api.remove_user(name) - def list_apikeys(self, name): + def list_apikeys(self, name: str) -> list[str]: """ Args: name (str): name of the user @@ -901,7 +1037,7 @@ def list_apikeys(self, name): """ return self.api.list_apikeys(name) - def add_apikey(self, name): + def add_apikey(self, name: str) -> bool: """ Args: name (str): name of the user @@ -911,7 +1047,7 @@ def add_apikey(self, name): """ return self.api.add_apikey(name) - def remove_apikey(self, name, apikey): + def remove_apikey(self, name: str, apikey: str) -> bool: """ Args: name (str): name of the user @@ -922,12 +1058,12 @@ def remove_apikey(self, name, apikey): """ return self.api.remove_apikey(name, apikey) - def close(self): + def close(self) -> None: """Close opened API connections.""" return self._api.close() -def job_from_dict(client, dd, **values): +def job_from_dict(client: Client, dd: dict[str, Any], **values: Any) -> models.Job: d = dict() d.update(dd) d.update(values) diff --git a/tdclient/job_api.py b/tdclient/job_api.py index bb06e10..e5ea9a0 100644 --- a/tdclient/job_api.py +++ b/tdclient/job_api.py @@ -282,9 +282,7 @@ def job_result_format_each( if code != 200: self.raise_error("Get job result failed", res, "") if format == "msgpack": - unpacker = msgpack.Unpacker( - raw=False, max_buffer_size=1000 * 1024**2 - ) + unpacker = msgpack.Unpacker(raw=False, max_buffer_size=1000 * 1024**2) for chunk in res.stream(1024**2): unpacker.feed(chunk) for row in unpacker: diff --git a/tdclient/result_model.py b/tdclient/result_model.py index 264afcd..fcea1d2 100644 --- a/tdclient/result_model.py +++ b/tdclient/result_model.py @@ -26,4 +26,3 @@ def url(self): def org_name(self): """str: organization name""" return self._org_name - diff --git a/tdclient/test/bulk_import_model_test.py b/tdclient/test/bulk_import_model_test.py index 226da6a..8d99e8d 100644 --- a/tdclient/test/bulk_import_model_test.py +++ b/tdclient/test/bulk_import_model_test.py @@ -124,7 +124,7 @@ def test_bulk_import_perform_with_timeout(): client = mock.MagicMock() job_mock = mock.MagicMock() client.perform_bulk_import.return_value = job_mock - + bulk_import = models.BulkImport( client, name="name", @@ -140,7 +140,7 @@ def test_bulk_import_perform_with_timeout(): ) bulk_import.update = mock.MagicMock() bulk_import.perform(wait=True, timeout=300, wait_interval=10) - + client.perform_bulk_import.assert_called_with("name") job_mock.wait.assert_called_with(timeout=300, wait_interval=10, wait_callback=None) assert bulk_import.update.called diff --git a/tdclient/test/client_test.py b/tdclient/test/client_test.py index 4197527..15bcabd 100644 --- a/tdclient/test/client_test.py +++ b/tdclient/test/client_test.py @@ -197,6 +197,7 @@ def test_query(): ) assert job.job_id == "12345" + def test_trino_query(): td = client.Client("APIKEY") td._api = mock.MagicMock() @@ -212,6 +213,7 @@ def test_trino_query(): ) assert job.job_id == "12345" + def test_jobs(): td = client.Client("APIKEY") td._api = mock.MagicMock() diff --git a/tdclient/test/dtypes_and_converters_test.py b/tdclient/test/dtypes_and_converters_test.py index 3a80d7c..5a5c119 100644 --- a/tdclient/test/dtypes_and_converters_test.py +++ b/tdclient/test/dtypes_and_converters_test.py @@ -269,7 +269,7 @@ def test_dtypes_overridden_by_converters(): DEFAULT_HEADER_BYTE_CSV = ( - b"time,col1,col2,col3,col4\n" b"100,0001,10,1.0,abcd\n" b"200,0002,20,2.0,efgh\n" + b"time,col1,col2,col3,col4\n100,0001,10,1.0,abcd\n200,0002,20,2.0,efgh\n" ) From 1201d4ff6bd3d9b92b96d7047daff39f0515098f Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 16:25:18 -0700 Subject: [PATCH 05/13] Add type hint for api.py --- tdclient/api.py | 123 +++++++++++++++++++++++++++++++++------------ tdclient/client.py | 72 ++++---------------------- tdclient/types.py | 88 ++++++++++++++++++++++++++++++++ 3 files changed, 189 insertions(+), 94 deletions(-) create mode 100644 tdclient/types.py diff --git a/tdclient/api.py b/tdclient/api.py index 296d9ce..b513930 100644 --- a/tdclient/api.py +++ b/tdclient/api.py @@ -1,21 +1,23 @@ #!/usr/bin/env python +from __future__ import annotations + import contextlib import csv import email.utils import gzip -import http +import http.client import io import json import logging import os import socket import ssl -import sys import tempfile import time import urllib.parse as urlparse from array import array +from typing import IO, Any, cast import msgpack import urllib3 @@ -32,6 +34,7 @@ from tdclient.schedule_api import ScheduleAPI from tdclient.server_status_api import ServerStatusAPI from tdclient.table_api import TableAPI +from tdclient.types import BytesOrStream from tdclient.user_api import UserAPI from tdclient.util import ( csv_dict_record_reader, @@ -85,15 +88,15 @@ class API( def __init__( self, - apikey=None, - user_agent=None, - endpoint=None, - headers=None, - retry_post_requests=False, - max_cumul_retry_delay=600, - http_proxy=None, - **kwargs, - ): + apikey: str | None = None, + user_agent: str | None = None, + endpoint: str | None = None, + headers: dict[str, str] | None = None, + retry_post_requests: bool = False, + max_cumul_retry_delay: int = 600, + http_proxy: str | None = None, + **kwargs: Any, + ) -> None: headers = {} if headers is None else headers if apikey is not None: self._apikey = apikey @@ -134,14 +137,17 @@ def __init__( self._headers = {key.lower(): value for (key, value) in headers.items()} @property - def apikey(self): + def apikey(self) -> str | None: return self._apikey @property - def endpoint(self): + def endpoint(self) -> str: + assert self._endpoint is not None # Always set in __init__ return self._endpoint - def _init_http(self, http_proxy=None, **kwargs): + def _init_http( + self, http_proxy: str | None = None, **kwargs: Any + ) -> urllib3.PoolManager | urllib3.ProxyManager: if http_proxy is None: return urllib3.PoolManager(**kwargs) else: @@ -150,7 +156,7 @@ def _init_http(self, http_proxy=None, **kwargs): else: return self._init_http_proxy("http://%s" % (http_proxy,), **kwargs) - def _init_http_proxy(self, http_proxy, **kwargs): + def _init_http_proxy(self, http_proxy: str, **kwargs: Any) -> urllib3.ProxyManager: pool_options = dict(kwargs) p = urlparse.urlparse(http_proxy) scheme = p.scheme @@ -160,7 +166,13 @@ def _init_http_proxy(self, http_proxy, **kwargs): pool_options["proxy_headers"] = urllib3.make_headers(proxy_basic_auth=auth) return urllib3.ProxyManager("%s://%s" % (scheme, netloc), **pool_options) - def get(self, path, params=None, headers=None, **kwargs): + def get( + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> contextlib.AbstractContextManager[urllib3.BaseHTTPResponse]: headers = {} if headers is None else dict(headers) headers["accept-encoding"] = "deflate, gzip" url, headers = self.build_request(path=path, headers=headers, **kwargs) @@ -239,7 +251,13 @@ def get(self, path, params=None, headers=None, **kwargs): return contextlib.closing(response) - def post(self, path, params=None, headers=None, **kwargs): + def post( + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> contextlib.AbstractContextManager[urllib3.BaseHTTPResponse]: headers = {} if headers is None else dict(headers) url, headers = self.build_request(path=path, headers=headers, **kwargs) @@ -332,7 +350,14 @@ def post(self, path, params=None, headers=None, **kwargs): return contextlib.closing(response) - def put(self, path, bytes_or_stream, size, headers=None, **kwargs): + def put( + self, + path: str, + bytes_or_stream: BytesOrStream, + size: int, + headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> contextlib.AbstractContextManager[urllib3.BaseHTTPResponse]: headers = {} if headers is None else dict(headers) headers["content-length"] = str(size) if "content-type" not in headers: @@ -345,23 +370,28 @@ def put(self, path, bytes_or_stream, size, headers=None, **kwargs): repr(path), ) + stream: array[int] | IO[bytes] if hasattr(bytes_or_stream, "read"): # file-like must support `read` and `fileno` to work with `httplib` - fileno_supported = hasattr(bytes_or_stream, "fileno") + # Type guard: if it has 'read', it's IO[bytes] + file_like = cast(IO[bytes], bytes_or_stream) + fileno_supported = hasattr(file_like, "fileno") if fileno_supported: try: - bytes_or_stream.fileno() + file_like.fileno() except io.UnsupportedOperation: # `io.BytesIO` doesn't support `fileno` fileno_supported = False if fileno_supported: - stream = bytes_or_stream + stream = file_like else: - stream = array("b", bytes_or_stream.read()) + stream = array("b", file_like.read()) else: # send request body as an `array.array` since `httplib` requires the request body to be a unicode string - stream = array("b", bytes_or_stream) + # Type guard: if it doesn't have 'read', it's bytes | bytearray + byte_data = cast("bytes | bytearray", bytes_or_stream) + stream = array("b", byte_data) response = None try: @@ -393,7 +423,13 @@ def put(self, path, bytes_or_stream, size, headers=None, **kwargs): return contextlib.closing(response) - def delete(self, path, params=None, headers=None, **kwargs): + def delete( + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> contextlib.AbstractContextManager[urllib3.BaseHTTPResponse]: headers = {} if headers is None else dict(headers) url, headers = self.build_request(path=path, headers=headers, **kwargs) @@ -469,19 +505,32 @@ def delete(self, path, params=None, headers=None, **kwargs): return contextlib.closing(response) - def build_request(self, path=None, headers=None, endpoint=None): + def build_request( + self, + path: str | None = None, + headers: dict[str, str] | None = None, + endpoint: str | None = None, + ) -> tuple[str, dict[str, str]]: headers = {} if headers is None else headers if endpoint is None: endpoint = self._endpoint + assert endpoint is not None # endpoint is always set in __init__ if path is None: - url = endpoint + url: str = endpoint else: p = urlparse.urlparse(endpoint) # should not use `os.path.join` since it returns path string like "/foo\\bar" - request_path = path if p.path == "/" else "/".join([p.path, path]) + # Type assertion: urlparse components are str not bytes for str input + p_path = str(p.path) + p_scheme = str(p.scheme) + p_netloc = str(p.netloc) + p_params = str(p.params) + p_query = str(p.query) + p_fragment = str(p.fragment) + request_path = path if p_path == "/" else "/".join([p_path, path]) url = urlparse.urlunparse( urlparse.ParseResult( - p.scheme, p.netloc, request_path, p.params, p.query, p.fragment + p_scheme, p_netloc, request_path, p_params, p_query, p_fragment ) ) # use default headers first @@ -494,7 +543,15 @@ def build_request(self, path=None, headers=None, endpoint=None): _headers.update({key.lower(): value for (key, value) in headers.items()}) return (url, _headers) - def send_request(self, method, url, fields=None, body=None, headers=None, **kwargs): + def send_request( + self, + method: str, + url: str, + fields: dict[str, Any] | None = None, + body: bytes | bytearray | memoryview | array[int] | IO[bytes] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> urllib3.BaseHTTPResponse: if body is None: return self.http.request( method, url, fields=fields, headers=headers, **kwargs @@ -508,7 +565,9 @@ def send_request(self, method, url, fields=None, body=None, headers=None, **kwar body = body.tobytes() return self.http.urlopen(method, url, body=body, headers=headers, **kwargs) - def raise_error(self, msg, res, body): + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes | str + ) -> None: status_code = res.status s = body if isinstance(body, str) else body.decode("utf-8") if status_code == 404: @@ -522,7 +581,7 @@ def raise_error(self, msg, res, body): else: raise errors.APIError("%d: %s: %s" % (status_code, msg, s)) - def checked_json(self, body, required): + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: js = None try: js = json.loads(body.decode("utf-8")) @@ -536,7 +595,7 @@ def checked_json(self, body, required): ) return js - def close(self): + def close(self) -> None: # urllib3 doesn't allow to close all connections immediately. # all connections in pool will be closed eventually during gc. self.http.clear() diff --git a/tdclient/client.py b/tdclient/client.py index e401508..4fcb0e1 100644 --- a/tdclient/client.py +++ b/tdclient/client.py @@ -5,70 +5,18 @@ import datetime import json from collections.abc import Iterator -from typing import IO, Any, Literal, TypedDict - -from typing_extensions import TypeAlias +from typing import Any from tdclient import api, models - - -# Type aliases for file-like objects -FileLike: TypeAlias = "str | bytes | IO[bytes]" - -# Common literal types -QueryEngineType: TypeAlias = 'Literal["presto", "hive"]' -EngineVersion: TypeAlias = 'Literal["stable", "experimental"]' -Priority: TypeAlias = "Literal[-2, -1, 0, 1, 2]" -ExportFileFormat: TypeAlias = 'Literal["jsonl.gz", "tsv.gz", "json.gz"]' -DataFormat: TypeAlias = 'Literal["msgpack", "msgpack.gz", "json", "json.gz", "csv", "csv.gz", "tsv", "tsv.gz"]' -ResultFormat: TypeAlias = 'Literal["msgpack", "json", "csv", "tsv"]' - - -class ScheduleParams(TypedDict, total=False): - """Parameters for create_schedule and update_schedule""" - - type: QueryEngineType # Query type - database: str # Target database name - timezone: str # Timezone e.g. "UTC" - cron: str # Schedule: "@daily", "@hourly", or cron expression - delay: int # Delay in seconds before running - query: str # SQL query to execute - priority: Priority # Priority: -2 (very low) to 2 (very high) - retry_limit: int # Automatic retry count - engine_version: EngineVersion # Engine version - pool_name: str # For Presto only: pool name - result: str # Result output location URL - - -class ExportParams(TypedDict, total=False): - """Parameters for export_data""" - - access_key_id: str # ID to access the export destination - secret_access_key: str # Password for access_key_id - file_prefix: str # Filename prefix for exported file - file_format: ExportFileFormat # File format - from_: ( - int # Start time in Unix epoch format (use 'from_' to avoid keyword conflict) - ) - to: int # End time in Unix epoch format - assume_role: str # Assume role ARN - bucket: str # Bucket name - domain_key: str # Job domain key - pool_name: str # For Presto only: pool name - - -class BulkImportParams(TypedDict, total=False): - """Parameters for create_bulk_import""" - - # Add any optional parameters for bulk import if needed - pass - - -class ResultParams(TypedDict, total=False): - """Parameters for create_result""" - - # Add any optional parameters for result creation if needed - pass +from tdclient.types import ( + BulkImportParams, + DataFormat, + ExportParams, + FileLike, + ResultFormat, + ResultParams, + ScheduleParams, +) class Client: diff --git a/tdclient/types.py b/tdclient/types.py new file mode 100644 index 0000000..af5a8a7 --- /dev/null +++ b/tdclient/types.py @@ -0,0 +1,88 @@ +"""Type definitions for td-client-python.""" + +from __future__ import annotations + +from array import array +from typing import IO + +from typing_extensions import Literal, TypeAlias, TypedDict + +# File-like types +FileLike: TypeAlias = "str | bytes | IO[bytes]" +"""Type for file inputs: file path, bytes, or file-like object.""" + +BytesOrStream: TypeAlias = "bytes | bytearray | IO[bytes]" +"""Type for byte data or streams (excluding file paths).""" + +StreamBody: TypeAlias = "bytes | bytearray | memoryview | array[int] | IO[bytes] | None" +"""Type for HTTP request body.""" + +# Query engine types +QueryEngineType: TypeAlias = 'Literal["presto", "hive"]' +"""Type for query engine selection.""" + +EngineVersion: TypeAlias = 'Literal["stable", "experimental"]' +"""Type for engine version selection.""" + +Priority: TypeAlias = "Literal[-2, -1, 0, 1, 2]" +"""Type for job priority levels.""" + +# Data format types +ExportFileFormat: TypeAlias = 'Literal["jsonl.gz", "tsv.gz", "json.gz"]' +"""Type for export file formats.""" + +DataFormat: TypeAlias = 'Literal["msgpack", "msgpack.gz", "json", "json.gz", "csv", "csv.gz", "tsv", "tsv.gz"]' +"""Type for data import/export formats.""" + +ResultFormat: TypeAlias = 'Literal["msgpack", "json", "csv", "tsv"]' +"""Type for query result formats.""" + + +# TypedDict classes for structured parameters +class ScheduleParams(TypedDict, total=False): + """Parameters for schedule operations.""" + + type: QueryEngineType # Query type + query: str # SQL query to execute + database: str # Target database name + result: str # Result output location URL + cron: str # Schedule: "@daily", "@hourly", or cron expression + timezone: str # Timezone e.g. "UTC" + delay: int # Delay in seconds before running + priority: Priority # Priority: -2 (very low) to 2 (very high) + retry_limit: int # Automatic retry count + engine_version: EngineVersion # Engine version + pool_name: str # For Presto only: pool name + + +class ExportParams(TypedDict, total=False): + """Parameters for export operations.""" + + storage_type: str # Storage type (e.g. "s3") + bucket: str # Bucket name + access_key_id: str # ID to access the export destination + secret_access_key: str # Password for access_key_id + file_prefix: str # Filename prefix for exported file + file_format: ExportFileFormat # File format + from_: int # Start time in Unix epoch format + to: int # End time in Unix epoch format + assume_role: str # Assume role ARN + domain_key: str # Job domain key + pool_name: str # For Presto only: pool name + + +class BulkImportParams(TypedDict, total=False): + """Parameters for bulk import operations.""" + + name: str + database: str + table: str + + +class ResultParams(TypedDict, total=False): + """Parameters for result operations.""" + + name: str + url: str + user: str + password: str From e3b3b63331ddb626eb332b725fbeaa31e09feb62 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 17:06:06 -0700 Subject: [PATCH 06/13] Add type hint for model related modules --- tdclient/bulk_import_model.py | 87 ++++++++++++-------- tdclient/client.py | 6 +- tdclient/database_model.py | 50 +++++++----- tdclient/job_model.py | 150 +++++++++++++++++++--------------- tdclient/result_model.py | 15 +++- tdclient/schedule_model.py | 80 ++++++++++-------- tdclient/table_model.py | 105 +++++++++++++++--------- tdclient/user_model.py | 25 ++++-- 8 files changed, 317 insertions(+), 201 deletions(-) diff --git a/tdclient/bulk_import_model.py b/tdclient/bulk_import_model.py index 235d7fd..be05e19 100644 --- a/tdclient/bulk_import_model.py +++ b/tdclient/bulk_import_model.py @@ -1,8 +1,17 @@ #!/usr/bin/env python +from __future__ import annotations + import time +from collections.abc import Callable, Iterator +from typing import TYPE_CHECKING, Any from tdclient.model import Model +from tdclient.types import FileLike + +if TYPE_CHECKING: + from tdclient.client import Client + from tdclient.job_model import Job class BulkImport(Model): @@ -14,94 +23,100 @@ class BulkImport(Model): STATUS_COMMITTING = "committing" STATUS_COMMITTED = "committed" - def __init__(self, client, **kwargs): + def __init__(self, client: Client, **kwargs: Any) -> None: super(BulkImport, self).__init__(client) self._feed(kwargs) - def _feed(self, data=None): + def _feed(self, data: dict[str, Any] | None = None) -> None: data = {} if data is None else data - self._name = data["name"] - self._database = data.get("database") - self._table = data.get("table") - self._status = data.get("status") - self._upload_frozen = data.get("upload_frozen") - self._job_id = data.get("job_id") - self._valid_records = data.get("valid_records") - self._error_records = data.get("error_records") - self._valid_parts = data.get("valid_parts") - self._error_parts = data.get("error_parts") - - def update(self): + self._name: str = data["name"] + self._database: str | None = data.get("database") + self._table: str | None = data.get("table") + self._status: str | None = data.get("status") + self._upload_frozen: bool | None = data.get("upload_frozen") + self._job_id: str | None = data.get("job_id") + self._valid_records: int | None = data.get("valid_records") + self._error_records: int | None = data.get("error_records") + self._valid_parts: int | None = data.get("valid_parts") + self._error_parts: int | None = data.get("error_parts") + + def update(self) -> None: data = self._client.api.show_bulk_import(self.name) self._feed(data) @property - def name(self): + def name(self) -> str: """A name of the bulk import session""" return self._name @property - def database(self): + def database(self) -> str | None: """A database name in a string which the bulk import session is working on""" return self._database @property - def table(self): + def table(self) -> str | None: """A table name in a string which the bulk import session is working on""" return self._table @property - def status(self): + def status(self) -> str | None: """The status of the bulk import session in a string""" return self._status @property - def job_id(self): + def job_id(self) -> str | None: """Job ID""" return self._job_id @property - def valid_records(self): + def valid_records(self) -> int | None: """The number of valid records.""" return self._valid_records @property - def error_records(self): + def error_records(self) -> int | None: """The number of error records.""" return self._error_records @property - def valid_parts(self): + def valid_parts(self) -> int | None: """The number of valid parts.""" return self._valid_parts @property - def error_parts(self): + def error_parts(self) -> int | None: """The number of error parts.""" return self._error_parts @property - def upload_frozen(self): + def upload_frozen(self) -> bool | None: """The number of upload frozen.""" return self._upload_frozen - def delete(self): + def delete(self) -> bool: """Delete bulk import""" return self._client.delete_bulk_import(self.name) - def freeze(self): + def freeze(self) -> bool: """Freeze bulk import""" response = self._client.freeze_bulk_import(self.name) self.update() return response - def unfreeze(self): + def unfreeze(self) -> bool: """Unfreeze bulk import""" response = self._client.unfreeze_bulk_import(self.name) self.update() return response - def perform(self, wait=False, wait_interval=5, wait_callback=None, timeout=None): + def perform( + self, + wait: bool = False, + wait_interval: int = 5, + wait_callback: Callable[[], None] | None = None, + timeout: float | None = None, + ) -> Job: """Perform bulk import Args: @@ -126,7 +141,9 @@ def perform(self, wait=False, wait_interval=5, wait_callback=None, timeout=None) self.update() return job - def commit(self, wait=False, wait_interval=5, timeout=None): + def commit( + self, wait: bool = False, wait_interval: int = 5, timeout: float | None = None + ) -> bool: """Commit bulk import""" response = self._client.commit_bulk_import(self.name) if wait: @@ -141,7 +158,7 @@ def commit(self, wait=False, wait_interval=5, timeout=None): self.update() return response - def error_record_items(self): + def error_record_items(self) -> Iterator[dict[str, Any]]: """Fetch error record rows. Yields: @@ -150,7 +167,7 @@ def error_record_items(self): for record in self._client.bulk_import_error_records(self.name): yield record - def upload_part(self, part_name, bytes_or_stream, size): + def upload_part(self, part_name: str, bytes_or_stream: FileLike, size: int) -> bool: """Upload a part to bulk import session Args: @@ -164,7 +181,9 @@ def upload_part(self, part_name, bytes_or_stream, size): self.update() return response - def upload_file(self, part_name, fmt, file_like, **kwargs): + def upload_file( + self, part_name: str, fmt: str, file_like: FileLike, **kwargs: Any + ) -> float: """Upload a part to Bulk Import session, from an existing file on filesystem. Args: @@ -205,7 +224,7 @@ def upload_file(self, part_name, fmt, file_like, **kwargs): self.update() return response - def delete_part(self, part_name): + def delete_part(self, part_name: str) -> bool: """Delete a part of a Bulk Import session Args: @@ -217,7 +236,7 @@ def delete_part(self, part_name): self.update() return response - def list_parts(self): + def list_parts(self) -> list[str]: """Return the list of available parts uploaded through :func:`~BulkImportAPI.bulk_import_upload_part`. diff --git a/tdclient/client.py b/tdclient/client.py index 4fcb0e1..f0d4aee 100644 --- a/tdclient/client.py +++ b/tdclient/client.py @@ -346,7 +346,7 @@ def job_result(self, job_id: str | int) -> list[Any]: """ return self.api.job_result(job_id) - def job_result_each(self, job_id: str | int) -> Iterator[Any]: + def job_result_each(self, job_id: str | int) -> Iterator[dict[str, Any]]: """ Args: job_id (str): job id @@ -377,7 +377,7 @@ def job_result_format_each( header: bool = False, store_tmpfile: bool = False, num_threads: int = 4, - ) -> Iterator[Any]: + ) -> Iterator[dict[str, Any]]: """ Args: job_id (str): job id @@ -548,7 +548,7 @@ def commit_bulk_import(self, name: str) -> bool: """ return self.api.commit_bulk_import(name) - def bulk_import_error_records(self, name: str) -> Iterator[Any]: + def bulk_import_error_records(self, name: str) -> Iterator[dict[str, Any]]: """ Args: name (str): name of a bulk import session diff --git a/tdclient/database_model.py b/tdclient/database_model.py index 5a00021..7652d63 100644 --- a/tdclient/database_model.py +++ b/tdclient/database_model.py @@ -1,7 +1,17 @@ #!/usr/bin/env python +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any + from tdclient.model import Model +if TYPE_CHECKING: + from tdclient.client import Client + from tdclient.job_model import Job + from tdclient.table_model import Table + class Database(Model): """Database on Treasure Data Service""" @@ -9,54 +19,55 @@ class Database(Model): PERMISSIONS = ["administrator", "full_access", "import_only", "query_only"] PERMISSION_LIST_TABLES = ["administrator", "full_access"] - def __init__(self, client, db_name, **kwargs): + def __init__(self, client: Client, db_name: str, **kwargs: Any) -> None: super(Database, self).__init__(client) self._db_name = db_name - self._tables = kwargs.get("tables") - self._count = kwargs.get("count") - self._created_at = kwargs.get("created_at") - self._updated_at = kwargs.get("updated_at") - self._org_name = kwargs.get("org_name") - self._permission = kwargs.get("permission") + self._tables: list[Table] | None = kwargs.get("tables") + self._count: int | None = kwargs.get("count") + self._created_at: datetime.datetime | None = kwargs.get("created_at") + self._updated_at: datetime.datetime | None = kwargs.get("updated_at") + self._org_name: str | None = kwargs.get("org_name") + self._permission: str | None = kwargs.get("permission") @property - def org_name(self): + def org_name(self) -> str | None: """ str: organization name """ return self._org_name @property - def permission(self): + def permission(self) -> str | None: """ str: permission for the database (e.g. "administrator", "full_access", etc.) """ return self._permission @property - def count(self): + def count(self) -> int | None: """ int: Total record counts in a database. """ return self._count @property - def name(self): + def name(self) -> str: """ str: a name of the database """ return self._db_name - def tables(self): + def tables(self) -> list[Table]: """ Returns: a list of :class:`tdclient.model.Table` """ if self._tables is None: self._update_tables() + assert self._tables is not None return self._tables - def create_log_table(self, name): + def create_log_table(self, name: str) -> Table: """ Args: name (str): name of new log table @@ -66,7 +77,7 @@ def create_log_table(self, name): """ return self._client.create_log_table(self._db_name, name) - def table(self, table_name): + def table(self, table_name: str) -> Table: """ Args: table_name (str): name of a table @@ -76,7 +87,7 @@ def table(self, table_name): """ return self._client.table(self._db_name, table_name) - def delete(self): + def delete(self) -> bool: """Delete the database Returns: @@ -84,7 +95,7 @@ def delete(self): """ return self._client.delete_database(self._db_name) - def query(self, q, **kwargs): + def query(self, q: str, **kwargs: Any) -> Job: """Run a query on the database Args: @@ -96,20 +107,21 @@ def query(self, q, **kwargs): return self._client.query(self._db_name, q, **kwargs) @property - def created_at(self): + def created_at(self) -> datetime.datetime | None: """ :class:`datetime.datetime` """ return self._created_at @property - def updated_at(self): + def updated_at(self) -> datetime.datetime | None: """ :class:`datetime.datetime` """ return self._updated_at - def _update_tables(self): + def _update_tables(self) -> None: self._tables = self._client.tables(self._db_name) + assert self._tables is not None for table in self._tables: table.database = self diff --git a/tdclient/job_model.py b/tdclient/job_model.py index 9640e0d..71ab46f 100644 --- a/tdclient/job_model.py +++ b/tdclient/job_model.py @@ -1,45 +1,52 @@ #!/usr/bin/env python +from __future__ import annotations + import time import warnings +from collections.abc import Callable, Iterator +from typing import TYPE_CHECKING, Any from tdclient.model import Model +if TYPE_CHECKING: + from tdclient.client import Client + class Schema: """Schema of a database table on Treasure Data Service""" class Field: - def __init__(self, name, type): + def __init__(self, name: str, type: str) -> None: self._name = name self._type = type @property - def name(self): + def name(self) -> str: """ TODO: add docstring """ return self._name @property - def type(self): + def type(self) -> str: """ TODO: add docstring """ return self._type - def __init__(self, fields=None): + def __init__(self, fields: list[Schema.Field] | None = None) -> None: fields = [] if fields is None else fields self._fields = fields @property - def fields(self): + def fields(self) -> list[Schema.Field]: """ TODO: add docstring """ return self._fields - def add_field(self, name, type): + def add_field(self, name: str, type: str) -> None: """ TODO: add docstring """ @@ -59,90 +66,98 @@ class Job(Model): JOB_PRIORITY = {-2: "VERY LOW", -1: "LOW", 0: "NORMAL", 1: "HIGH", 2: "VERY HIGH"} - def __init__(self, client, job_id, type, query, **kwargs): + def __init__( + self, client: Client, job_id: str, type: str, query: str, **kwargs: Any + ) -> None: super(Job, self).__init__(client) self._job_id = job_id self._type = type self._query = query self._feed(kwargs) - def _feed(self, data=None): + def _feed(self, data: dict[str, Any] | None = None) -> None: data = {} if data is None else data - self._url = data.get("url") - self._status = data.get("status") - self._debug = data.get("debug") - self._start_at = data.get("start_at") - self._end_at = data.get("end_at") - self._created_at = data.get("created_at") - self._updated_at = data.get("updated_at") - self._cpu_time = data.get("cpu_time") - self._result = data.get("result") - self._result_size = data.get("result_size") - self._result_url = data.get("result_url") - self._hive_result_schema = data.get("hive_result_schema") - self._priority = data.get("priority") - self._retry_limit = data.get("retry_limit") - self._org_name = data.get("org_name") - self._database = data.get("database") - self._num_records = data.get("num_records") - self._user_name = data.get("user_name") - self._linked_result_export_job_id = data.get("linked_result_export_job_id") - self._result_export_target_job_id = data.get("result_export_target_job_id") - - def update(self): + self._url: str | None = data.get("url") + self._status: str | None = data.get("status") + self._debug: dict[str, Any] | None = data.get("debug") + self._start_at: str | None = data.get("start_at") + self._end_at: str | None = data.get("end_at") + self._created_at: str | None = data.get("created_at") + self._updated_at: str | None = data.get("updated_at") + self._cpu_time: float | None = data.get("cpu_time") + self._result: str | None = data.get("result") + self._result_size: int | None = data.get("result_size") + self._result_url: str | None = data.get("result_url") + self._hive_result_schema: list[list[str]] | None = data.get( + "hive_result_schema" + ) + self._priority: int | None = data.get("priority") + self._retry_limit: int | None = data.get("retry_limit") + self._org_name: str | None = data.get("org_name") + self._database: str | None = data.get("database") + self._num_records: int | None = data.get("num_records") + self._user_name: str | None = data.get("user_name") + self._linked_result_export_job_id: str | None = data.get( + "linked_result_export_job_id" + ) + self._result_export_target_job_id: str | None = data.get( + "result_export_target_job_id" + ) + + def update(self) -> None: """Update all fields of the job""" data = self._client.api.show_job(self._job_id) self._feed(data) - def _update_status(self): + def _update_status(self) -> None: warnings.warn( "_update_status() will be removed from future release. Please use update() instaed.", category=DeprecationWarning, ) self.update() - def _update_progress(self): + def _update_progress(self) -> None: """Update `_status` field of the job if it's not finished""" if self._status not in self.FINISHED_STATUS: self._status = self._client.job_status(self._job_id) @property - def id(self): + def id(self) -> str: """a string represents the identifier of the job""" return self._job_id @property - def job_id(self): + def job_id(self) -> str: """a string represents the identifier of the job""" return self._job_id @property - def type(self): + def type(self) -> str: """a string represents the engine type of the job (e.g. "hive", "presto", etc.)""" return self._type @property - def result_size(self): + def result_size(self) -> int | None: """the length of job result""" return self._result_size @property - def num_records(self): + def num_records(self) -> int | None: """the number of records of job result""" return self._num_records @property - def result_url(self): + def result_url(self) -> str | None: """a string of URL of the result on Treasure Data Service""" return self._result_url @property - def result_schema(self): + def result_schema(self) -> list[list[str]] | None: """an array of array represents the type of result columns (Hive specific) (e.g. [["_c1", "string"], ["_c2", "bigint"]])""" return self._hive_result_schema @property - def priority(self): + def priority(self) -> str: """a string represents the priority of the job (e.g. "NORMAL", "HIGH", etc.)""" if self._priority in self.JOB_PRIORITY: return self.JOB_PRIORITY[self._priority] @@ -151,41 +166,46 @@ def priority(self): return str(self._priority) @property - def retry_limit(self): + def retry_limit(self) -> int | None: """a number for automatic retry count""" return self._retry_limit @property - def org_name(self): + def org_name(self) -> str | None: """organization name""" return self._org_name @property - def user_name(self): + def user_name(self) -> str | None: """executing user name""" return self._user_name @property - def database(self): + def database(self) -> str | None: """a string represents the name of a database that job is running on""" return self._database @property - def linked_result_export_job_id(self): + def linked_result_export_job_id(self) -> str | None: """Linked result export job ID from query job""" return self._linked_result_export_job_id @property - def result_export_target_job_id(self): + def result_export_target_job_id(self) -> str | None: """Associated query job ID from result export job ID""" return self._result_export_target_job_id @property - def debug(self): + def debug(self) -> dict[str, Any] | None: """a :class:`dict` of debug output (e.g. "cmdout", "stderr")""" return self._debug - def wait(self, timeout=None, wait_interval=5, wait_callback=None): + def wait( + self, + timeout: float | None = None, + wait_interval: int = 5, + wait_callback: Callable[[Job], None] | None = None, + ) -> None: """Sleep until the job has been finished Args: @@ -204,7 +224,7 @@ def wait(self, timeout=None, wait_interval=5, wait_callback=None): raise RuntimeError("timeout") # TODO: throw proper error self.update() - def kill(self): + def kill(self) -> str: """Kill the job Returns: @@ -215,11 +235,11 @@ def kill(self): return response @property - def query(self): + def query(self) -> str: """a string represents the query string of the job""" return self._query - def status(self): + def status(self) -> str | None: """ Returns: str: a string represents the status of the job ("success", "error", "killed", "queued", "running") @@ -229,11 +249,11 @@ def status(self): return self._status @property - def url(self): + def url(self) -> str | None: """a string of URL of the job on Treasure Data Service""" return self._url - def result(self): + def result(self) -> Iterator[dict[str, Any]]: """ Yields: an iterator of rows in result set @@ -246,10 +266,12 @@ def result(self): for row in self._client.job_result_each(self._job_id): yield row else: - for row in self._result: - yield row + for row in self._result: # type: ignore[union-attr] + yield row # type: ignore[misc] - def result_format(self, fmt, store_tmpfile=False, num_threads=4): + def result_format( + self, fmt: str, store_tmpfile: bool = False, num_threads: int = 4 + ) -> Iterator[dict[str, Any]]: """ Args: fmt (str): output format of result set @@ -274,10 +296,10 @@ def result_format(self, fmt, store_tmpfile=False, num_threads=4): ): yield row else: - for row in self._result: - yield row + for row in self._result: # type: ignore[union-attr] + yield row # type: ignore[misc] - def finished(self): + def finished(self) -> bool: """ Returns: `True` if the job has been finished in success, error or killed @@ -285,7 +307,7 @@ def finished(self): self._update_progress() return self._status in self.FINISHED_STATUS - def success(self): + def success(self) -> bool: """ Returns: `True` if the job has been finished in success @@ -293,7 +315,7 @@ def success(self): self._update_progress() return self._status == self.STATUS_SUCCESS - def error(self): + def error(self) -> bool: """ Returns: `True` if the job has been finished in error @@ -301,7 +323,7 @@ def error(self): self._update_progress() return self._status == self.STATUS_ERROR - def killed(self): + def killed(self) -> bool: """ Returns: `True` if the job has been finished in killed @@ -309,7 +331,7 @@ def killed(self): self._update_progress() return self._status == self.STATUS_KILLED - def queued(self): + def queued(self) -> bool: """ Returns: `True` if the job is queued @@ -317,7 +339,7 @@ def queued(self): self._update_progress() return self._status == self.STATUS_QUEUED - def running(self): + def running(self) -> bool: """ Returns: `True` if the job is running diff --git a/tdclient/result_model.py b/tdclient/result_model.py index fcea1d2..78f6e3c 100644 --- a/tdclient/result_model.py +++ b/tdclient/result_model.py @@ -1,28 +1,35 @@ #!/usr/bin/env python +from __future__ import annotations + +from typing import TYPE_CHECKING + from tdclient.model import Model +if TYPE_CHECKING: + from tdclient.client import Client + class Result(Model): """Result on Treasure Data Service""" - def __init__(self, client, name, url, org_name): + def __init__(self, client: Client, name: str, url: str, org_name: str) -> None: super(Result, self).__init__(client) self._name = name self._url = url self._org_name = org_name @property - def name(self): + def name(self) -> str: """str: a name for a authentication""" return self._name @property - def url(self): + def url(self) -> str: """str: a result output URL""" return self._url @property - def org_name(self): + def org_name(self) -> str: """str: organization name""" return self._org_name diff --git a/tdclient/schedule_model.py b/tdclient/schedule_model.py index 5322a77..21679fb 100644 --- a/tdclient/schedule_model.py +++ b/tdclient/schedule_model.py @@ -1,18 +1,34 @@ #!/usr/bin/env python +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any + from tdclient.job_model import Job from tdclient.model import Model +if TYPE_CHECKING: + from tdclient.client import Client + class ScheduledJob(Job): """Scheduled job on Treasure Data Service""" - def __init__(self, client, scheduled_at, job_id, type, query, **kwargs): + def __init__( + self, + client: Client, + scheduled_at: datetime.datetime, + job_id: str, + type: str, + query: str, + **kwargs: Any, + ) -> None: super(ScheduledJob, self).__init__(client, job_id, type, query, **kwargs) self._scheduled_at = scheduled_at @property - def scheduled_at(self): + def scheduled_at(self) -> datetime.datetime: """a :class:`datetime.datetime` represents the schedule of next invocation of the job""" return self._scheduled_at @@ -20,40 +36,40 @@ def scheduled_at(self): class Schedule(Model): """Schedule on Treasure Data Service""" - def __init__(self, client, *args, **kwargs): + def __init__(self, client: Client, *args: Any, **kwargs: Any) -> None: super(Schedule, self).__init__(client) if 0 < len(args): - self._name = args[0] - self._cron = args[1] - self._query = args[2] + self._name: str | None = args[0] + self._cron: str | None = args[1] + self._query: str | None = args[2] else: self._name = kwargs.get("name") self._cron = kwargs.get("cron") self._query = kwargs.get("query") - self._timezone = kwargs.get("timezone") - self._delay = kwargs.get("delay") - self._created_at = kwargs.get("created_at") - self._type = kwargs.get("type") - self._database = kwargs.get("database") - self._user_name = kwargs.get("user_name") - self._priority = kwargs.get("priority") - self._retry_limit = kwargs.get("retry_limit") + self._timezone: str | None = kwargs.get("timezone") + self._delay: int | None = kwargs.get("delay") + self._created_at: datetime.datetime | None = kwargs.get("created_at") + self._type: str | None = kwargs.get("type") + self._database: str | None = kwargs.get("database") + self._user_name: str | None = kwargs.get("user_name") + self._priority: int | str | None = kwargs.get("priority") + self._retry_limit: int | None = kwargs.get("retry_limit") if "result_url" in kwargs: # backward compatibility for td-client-python < 0.6.0 # TODO: remove this code if not necessary with fixing test - self._result = kwargs.get("result_url") + self._result: str | None = kwargs.get("result_url") else: self._result = kwargs.get("result") - self._next_time = kwargs.get("next_time") - self._org_name = kwargs.get("org_name") + self._next_time: datetime.datetime | None = kwargs.get("next_time") + self._org_name: str | None = kwargs.get("org_name") @property - def name(self): + def name(self) -> str | None: """The name of a scheduled job""" return self._name @property - def cron(self): + def cron(self) -> str | None: """The configured schedule of a scheduled job. Returns a string represents the schedule in cron form, or `None` if the @@ -62,32 +78,32 @@ def cron(self): return self._cron @property - def query(self): + def query(self) -> str | None: """The query string of a scheduled job""" return self._query @property - def database(self): + def database(self) -> str | None: """The target database of a scheduled job""" return self._database @property - def result_url(self): + def result_url(self) -> str | None: """The result output configuration in URL form of a scheduled job""" return self._result @property - def timezone(self): + def timezone(self) -> str | None: """The time zone of a scheduled job""" return self._timezone @property - def delay(self): + def delay(self) -> int | None: """A delay ensures all buffered events are imported before running the query.""" return self._delay @property - def priority(self): + def priority(self) -> str: """The priority of a scheduled job""" if self._priority in Job.JOB_PRIORITY: return Job.JOB_PRIORITY[self._priority] @@ -95,42 +111,42 @@ def priority(self): return str(self._priority) @property - def retry_limit(self): + def retry_limit(self) -> int | None: """Automatic retry count.""" return self._retry_limit @property - def org_name(self): + def org_name(self) -> str | None: """ TODO: add docstring """ return self._org_name @property - def next_time(self): + def next_time(self) -> datetime.datetime | None: """ :obj:`datetime.datetime`: Schedule for next run """ return self._next_time @property - def created_at(self): + def created_at(self) -> datetime.datetime | None: """ :obj:`datetime.datetime`: Create date """ return self._created_at @property - def type(self): + def type(self) -> str | None: """Query type. {"presto", "hive"}.""" return self._type @property - def user_name(self): + def user_name(self) -> str | None: """User name of a scheduled job""" return self._user_name - def run(self, time, num=None): + def run(self, time: int, num: int | None = None) -> list[ScheduledJob]: """Run a scheduled job Args: diff --git a/tdclient/table_model.py b/tdclient/table_model.py index 36d5cc5..e632937 100644 --- a/tdclient/table_model.py +++ b/tdclient/table_model.py @@ -1,140 +1,155 @@ #!/usr/bin/env python +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any + from tdclient.model import Model +from tdclient.types import DataFormat, FileLike + +if TYPE_CHECKING: + from tdclient.database_model import Database + from tdclient.job_model import Job class Table(Model): """Database table on Treasure Data Service""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(Table, self).__init__(args[0]) - self.database = None - self._db_name = args[1] - self._table_name = args[2] + self.database: Database | None = None + self._db_name: str = args[1] + self._table_name: str = args[2] if 3 < len(args): - self._type = args[3] - self._schema = args[4] - self._count = args[5] + self._type: str | None = args[3] + self._schema: list[tuple[str, str, str]] | None = args[4] + self._count: int | None = args[5] else: self._type = kwargs.get("type") self._schema = kwargs.get("schema") self._count = kwargs.get("count") - self._created_at = kwargs.get("created_at") - self._updated_at = kwargs.get("updated_at") - self._estimated_storage_size = kwargs.get("estimated_storage_size") - self._last_import = kwargs.get("last_import") - self._last_log_timestamp = kwargs.get("last_log_timestamp") - self._expire_days = kwargs.get("expire_days") - self._primary_key = kwargs.get("primary_key") - self._primary_key_type = kwargs.get("primary_key_type") + self._created_at: datetime.datetime | None = kwargs.get("created_at") + self._updated_at: datetime.datetime | None = kwargs.get("updated_at") + self._estimated_storage_size: int | None = kwargs.get("estimated_storage_size") + self._last_import: datetime.datetime | None = kwargs.get("last_import") + self._last_log_timestamp: datetime.datetime | None = kwargs.get( + "last_log_timestamp" + ) + self._expire_days: int | None = kwargs.get("expire_days") + self._primary_key: str | None = kwargs.get("primary_key") + self._primary_key_type: str | None = kwargs.get("primary_key_type") @property - def type(self): + def type(self) -> str | None: """a string represents the type of the table""" return self._type @property - def db_name(self): + def db_name(self) -> str: """a string represents the name of the database""" return self._db_name @property - def table_name(self): + def table_name(self) -> str: """a string represents the name of the table""" return self._table_name @property - def schema(self): + def schema(self) -> list[tuple[str, str, str]] | None: """ [[column_name:str, column_type:str, alias:str]]: The :obj:`list` of a schema """ return self._schema @property - def count(self): + def count(self) -> int | None: """int: total number of the table""" return self._count @property - def estimated_storage_size(self): + def estimated_storage_size(self) -> int | None: """estimated storage size""" return self._estimated_storage_size @property - def primary_key(self): + def primary_key(self) -> str | None: """ TODO: add docstring """ return self._primary_key @property - def primary_key_type(self): + def primary_key_type(self) -> str | None: """ TODO: add docstring """ return self._primary_key_type @property - def database_name(self): + def database_name(self) -> str: """a string represents the name of the database""" return self._db_name @property - def name(self): + def name(self) -> str: """a string represents the name of the table""" return self._table_name @property - def created_at(self): + def created_at(self) -> datetime.datetime | None: """ :class:`datetime.datetime`: Created datetime """ return self._created_at @property - def updated_at(self): + def updated_at(self) -> datetime.datetime | None: """ :class:`datetime.datetime`: Updated datetime """ return self._updated_at @property - def last_import(self): + def last_import(self) -> datetime.datetime | None: """:class:`datetime.datetime`""" return self._last_import @property - def last_log_timestamp(self): + def last_log_timestamp(self) -> datetime.datetime | None: """:class:`datetime.datetime`""" return self._last_log_timestamp @property - def expire_days(self): + def expire_days(self) -> int | None: """an int represents the days until expiration""" return self._expire_days @property - def permission(self): + def permission(self) -> str | None: """ str: permission for the database (e.g. "administrator", "full_access", etc.) """ if self.database is None: self._update_database() + assert self.database is not None return self.database.permission @property - def identifier(self): + def identifier(self) -> str: """a string identifier of the table""" return "%s.%s" % (self._db_name, self._table_name) - def delete(self): + def delete(self) -> str: """a string represents the type of deleted table""" return self._client.delete_table(self._db_name, self._table_name) - def tail(self, count, to=None, _from=None): + def tail( + self, count: int, to: int | None = None, _from: int | None = None + ) -> list[dict[str, Any]]: """ Args: count (int): Number for record to show up from the end. @@ -147,7 +162,13 @@ def tail(self, count, to=None, _from=None): """ return self._client.tail(self._db_name, self._table_name, count, to, _from) - def import_data(self, format, bytes_or_stream, size, unique_id=None): + def import_data( + self, + format: DataFormat, + bytes_or_stream: FileLike, + size: int, + unique_id: str | None = None, + ) -> float: """Import data into Treasure Data Service Args: @@ -168,7 +189,9 @@ def import_data(self, format, bytes_or_stream, size, unique_id=None): unique_id=unique_id, ) - def import_file(self, format, file, unique_id=None): + def import_file( + self, format: DataFormat, file: FileLike, unique_id: str | None = None + ) -> float: """Import data into Treasure Data Service, from an existing file on filesystem. This method will decompress/deserialize records from given file, and then @@ -185,7 +208,7 @@ def import_file(self, format, file, unique_id=None): self._db_name, self._table_name, format, file, unique_id=unique_id ) - def export_data(self, storage_type, **kwargs): + def export_data(self, storage_type: str, **kwargs: Any) -> Job: """Export data from Treasure Data Service Args: @@ -224,9 +247,11 @@ def export_data(self, storage_type, **kwargs): ) @property - def estimated_storage_size_string(self): + def estimated_storage_size_string(self) -> str: """a string represents estimated size of the table in human-readable format""" - if self._estimated_storage_size <= 1024 * 1024: + if self._estimated_storage_size is None: + return "0.0 GB" + elif self._estimated_storage_size <= 1024 * 1024: return "0.0 GB" elif self._estimated_storage_size <= 60 * 1024 * 1024: return "0.01 GB" @@ -239,5 +264,5 @@ def estimated_storage_size_string(self): float(self._estimated_storage_size) / (1024 * 1024 * 1024) ) - def _update_database(self): + def _update_database(self) -> None: self.database = self._client.database(self._db_name) diff --git a/tdclient/user_model.py b/tdclient/user_model.py index aee7556..4e1045a 100644 --- a/tdclient/user_model.py +++ b/tdclient/user_model.py @@ -1,12 +1,27 @@ #!/usr/bin/env python +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + from tdclient.model import Model +if TYPE_CHECKING: + from tdclient.client import Client + class User(Model): """User on Treasure Data Service""" - def __init__(self, client, name, org_name, role_names, email, **kwargs): + def __init__( + self, + client: Client, + name: str, + org_name: str, + role_names: list[str], + email: str, + **kwargs: Any, + ) -> None: super(User, self).__init__(client) self._name = name self._org_name = org_name @@ -14,28 +29,28 @@ def __init__(self, client, name, org_name, role_names, email, **kwargs): self._email = email @property - def name(self): + def name(self) -> str: """ Returns: name of the user """ return self._name @property - def org_name(self): + def org_name(self) -> str: """ Returns: organization name """ return self._org_name @property - def role_names(self): + def role_names(self) -> list[str]: """ TODO: add docstring """ return self._role_names @property - def email(self): + def email(self) -> str: """ Returns: e-mail address """ From ca3f12d9c288a20a607415655eac3446918bdcf5 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 17:16:18 -0700 Subject: [PATCH 07/13] Add type hint for connection and cursor --- tdclient/connection.py | 47 +++++++++++++++++++----------- tdclient/cursor.py | 65 +++++++++++++++++++++++++++--------------- tdclient/types.py | 6 ++-- 3 files changed, 77 insertions(+), 41 deletions(-) diff --git a/tdclient/connection.py b/tdclient/connection.py index 6f9f97f..7639668 100644 --- a/tdclient/connection.py +++ b/tdclient/connection.py @@ -1,20 +1,30 @@ #!/usr/bin/env python +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + from tdclient import api, cursor, errors +from tdclient.types import Priority + +if TYPE_CHECKING: + from types import TracebackType + + from tdclient.cursor import Cursor class Connection: def __init__( self, - type=None, - db=None, - result_url=None, - priority=None, - retry_limit=None, - wait_interval=None, - wait_callback=None, - **kwargs, - ): + type: str | None = None, + db: str | None = None, + result_url: str | None = None, + priority: Priority | None = None, + retry_limit: int | None = None, + wait_interval: int | None = None, + wait_callback: Callable[[Cursor], None] | None = None, + **kwargs: Any, + ) -> None: cursor_kwargs = dict() if type is not None: cursor_kwargs["type"] = type @@ -33,24 +43,29 @@ def __init__( self._api = api.API(**kwargs) self._cursor_kwargs = cursor_kwargs - def __enter__(self): + def __enter__(self) -> Connection: return self - def __exit__(self, type, value, traceback): + def __exit__( + self, + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self._api.close() @property - def api(self): + def api(self) -> api.API: return self._api - def close(self): + def close(self) -> None: self._api.close() - def commit(self): + def commit(self) -> None: raise errors.NotSupportedError - def rollback(self): + def rollback(self) -> None: raise errors.NotSupportedError - def cursor(self): + def cursor(self) -> Cursor: return cursor.Cursor(self._api, **self._cursor_kwargs) diff --git a/tdclient/cursor.py b/tdclient/cursor.py index 8438601..38a98ef 100644 --- a/tdclient/cursor.py +++ b/tdclient/cursor.py @@ -1,41 +1,53 @@ #!/usr/bin/env python +from __future__ import annotations + import time +from typing import TYPE_CHECKING, Any, Callable from tdclient import errors +if TYPE_CHECKING: + from tdclient.api import API + class Cursor: - def __init__(self, api, wait_interval=5, wait_callback=None, **kwargs): + def __init__( + self, + api: API, + wait_interval: int = 5, + wait_callback: Callable[[Cursor], None] | None = None, + **kwargs: Any, + ) -> None: self._api = api self._query_kwargs = kwargs - self._executed = None - self._rows = None + self._executed: str | None = None # Job ID + self._rows: list[Any] | None = None self._rownumber = 0 self._rowcount = -1 - self._description = [] + self._description: list[Any] = [] self.wait_interval = wait_interval self.wait_callback = wait_callback @property - def api(self): + def api(self) -> API: return self._api @property - def description(self): + def description(self) -> list[Any]: return self._description @property - def rowcount(self): + def rowcount(self) -> int: return self._rowcount - def callproc(self, procname, *parameters): + def callproc(self, procname: str, *parameters: Any) -> None: raise errors.NotSupportedError - def close(self): + def close(self) -> None: self._api.close() - def execute(self, query, args=None): + def execute(self, query: str, args: dict[str, Any] | None = None) -> str | None: if args is not None: if isinstance(args, dict): query = query.format(**args) @@ -49,16 +61,18 @@ def execute(self, query, args=None): self._do_execute() return self._executed - def executemany(self, operation, seq_of_parameters): + def executemany( + self, operation: str, seq_of_parameters: list[dict[str, Any]] + ) -> list[str | None]: return [ self.execute(operation, args=parameter) for parameter in seq_of_parameters ] - def _check_executed(self): + def _check_executed(self) -> None: if self._executed is None: raise errors.ProgrammingError("execute() first") - def _do_execute(self): + def _do_execute(self) -> None: self._check_executed() if self._rows is None: status = self._api.job_status(self._executed) @@ -81,18 +95,21 @@ def _do_execute(self): self.wait_callback(self) return self._do_execute() - def _result_description(self, result_schema): + def _result_description( + self, result_schema: list[Any] | None + ) -> list[tuple[Any, ...]]: if result_schema is None: result_schema = [] return [ (column[0], None, None, None, None, None, None) for column in result_schema ] - def fetchone(self): + def fetchone(self) -> Any | None: """ Fetch the next row of a query result set, returning a single sequence, or `None` when no more data is available. """ self._check_executed() + assert self._rows is not None if self._rownumber < self._rowcount: row = self._rows[self._rownumber] self._rownumber += 1 @@ -100,7 +117,7 @@ def fetchone(self): else: return None - def fetchmany(self, size=None): + def fetchmany(self, size: int | None = None) -> list[Any]: """ Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a list of tuples). An empty sequence is returned when no more rows are available. @@ -109,6 +126,7 @@ def fetchmany(self, size=None): return self.fetchall() else: self._check_executed() + assert self._rows is not None if self._rownumber + size - 1 < self._rowcount: rows = self._rows[self._rownumber : self._rownumber + size] self._rownumber += size @@ -119,12 +137,13 @@ def fetchmany(self, size=None): % (self._rownumber, self._rowcount) ) - def fetchall(self): + def fetchall(self) -> list[Any]: """ Fetch all (remaining) rows of a query result, returning them as a sequence of sequences (e.g. a list of tuples). Note that the cursor's arraysize attribute can affect the performance of this operation. """ self._check_executed() + assert self._rows is not None if self._rownumber < self._rowcount: rows = self._rows[self._rownumber :] self._rownumber = self._rowcount @@ -132,16 +151,16 @@ def fetchall(self): else: return [] - def nextset(self): + def nextset(self) -> None: raise errors.NotSupportedError - def setinputsizes(self, sizes): + def setinputsizes(self, sizes: Any) -> None: raise errors.NotSupportedError - def setoutputsize(self, size, column=None): + def setoutputsize(self, size: Any, column: Any = None) -> None: raise errors.NotSupportedError - def show_job(self): + def show_job(self) -> dict[str, Any]: """Returns detailed information of a Job Returns: @@ -150,7 +169,7 @@ def show_job(self): self._check_executed() return self._api.show_job(self._executed) - def job_status(self): + def job_status(self) -> str: """Show job status Returns: @@ -159,7 +178,7 @@ def job_status(self): self._check_executed() return self._api.job_status(self._executed) - def job_result(self): + def job_result(self) -> list[dict[str, Any]]: """Fetch job results Returns: diff --git a/tdclient/types.py b/tdclient/types.py index af5a8a7..445e9da 100644 --- a/tdclient/types.py +++ b/tdclient/types.py @@ -24,8 +24,10 @@ EngineVersion: TypeAlias = 'Literal["stable", "experimental"]' """Type for engine version selection.""" -Priority: TypeAlias = "Literal[-2, -1, 0, 1, 2]" -"""Type for job priority levels.""" +Priority: TypeAlias = ( + 'Literal[-2, -1, 0, 1, 2, "VERY LOW", "LOW", "NORMAL", "HIGH", "VERY HIGH"]' +) +"""Type for job priority levels (numeric or string).""" # Data format types ExportFileFormat: TypeAlias = 'Literal["jsonl.gz", "tsv.gz", "json.gz"]' From f311c894bfb204718d736b12bab611ca6e245584 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 18:10:39 -0700 Subject: [PATCH 08/13] Add type hint for API classes --- tdclient/bulk_import_api.py | 92 ++++++++++++++++++++++++++++------- tdclient/client.py | 35 +++++++------ tdclient/connector_api.py | 28 ++++++++++- tdclient/cursor.py | 4 ++ tdclient/database_api.py | 29 +++++++++-- tdclient/export_api.py | 31 ++++++++++-- tdclient/import_api.py | 44 +++++++++++++++-- tdclient/job_api.py | 91 ++++++++++++++++++++++++---------- tdclient/job_model.py | 4 +- tdclient/result_api.py | 38 ++++++++++++--- tdclient/schedule_api.py | 58 +++++++++++++++++----- tdclient/schedule_model.py | 2 +- tdclient/server_status_api.py | 17 ++++++- tdclient/table_api.py | 50 +++++++++++++++---- tdclient/types.py | 1 - tdclient/user_api.py | 39 +++++++++++---- 16 files changed, 448 insertions(+), 115 deletions(-) diff --git a/tdclient/bulk_import_api.py b/tdclient/bulk_import_api.py index 07cb10b..9b257c6 100644 --- a/tdclient/bulk_import_api.py +++ b/tdclient/bulk_import_api.py @@ -1,14 +1,25 @@ #!/usr/bin/env python +from __future__ import annotations + import collections import contextlib import gzip import io import os +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any import msgpack -from .util import create_url +if TYPE_CHECKING: + from contextlib import AbstractContextManager + from typing import IO + + import urllib3 + +from tdclient.types import BulkImportParams, BytesOrStream, DataFormat, FileLike +from tdclient.util import create_url class BulkImportAPI: @@ -17,7 +28,27 @@ class BulkImportAPI: This class is inherited by :class:`tdclient.api.API`. """ - def create_bulk_import(self, name, db, table, params=None): + # Methods from API class + def get( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def post( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def put( + self, url: str, stream: BytesOrStream, size: int + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes | str + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + def _prepare_file( + self, file: FileLike, format: str, **kwargs: Any + ) -> IO[bytes]: ... + + def create_bulk_import( + self, name: str, db: str, table: str, params: BulkImportParams | None = None + ) -> bool: """Enable bulk importing of data to the targeted database and table and stores it in the default resource pool. Default expiration for bulk import is 30days. @@ -30,7 +61,7 @@ def create_bulk_import(self, name, db, table, params=None): Returns: True if succeeded """ - params = {} if params is None else params + post_params = {} if params is None else dict(params) with self.post( create_url( "/v3/bulk_import/create/{name}/{db}/{table}", @@ -38,14 +69,16 @@ def create_bulk_import(self, name, db, table, params=None): db=db, table=table, ), - params, + post_params, ) as res: code, body = res.status, res.read() if code != 200: self.raise_error("Create bulk import failed", res, body) return True - def delete_bulk_import(self, name, params=None): + def delete_bulk_import( + self, name: str, params: dict[str, Any] | None = None + ) -> bool: """Delete the imported information with the specified name Args: @@ -63,7 +96,7 @@ def delete_bulk_import(self, name, params=None): self.raise_error("Delete bulk import failed", res, body) return True - def show_bulk_import(self, name): + def show_bulk_import(self, name: str) -> dict[str, Any]: """Show the details of the bulk import with the specified name Args: @@ -78,7 +111,9 @@ def show_bulk_import(self, name): js = self.checked_json(body, ["status"]) return js - def list_bulk_imports(self, params=None): + def list_bulk_imports( + self, params: dict[str, Any] | None = None + ) -> list[dict[str, Any]]: """Return the list of available bulk imports Args: params (dict, optional): Extra parameters. @@ -93,7 +128,9 @@ def list_bulk_imports(self, params=None): js = self.checked_json(body, ["bulk_imports"]) return js["bulk_imports"] - def list_bulk_import_parts(self, name, params=None): + def list_bulk_import_parts( + self, name: str, params: dict[str, Any] | None = None + ) -> list[str]: """Return the list of available parts uploaded through :func:`~BulkImportAPI.bulk_import_upload_part`. @@ -114,7 +151,7 @@ def list_bulk_import_parts(self, name, params=None): return js["parts"] @staticmethod - def validate_part_name(part_name): + def validate_part_name(part_name: str) -> None: """Make sure the part_name is valid Args: @@ -133,7 +170,9 @@ def validate_part_name(part_name): if 0 < part_name.find("/"): raise ValueError("part name must not contain '/': %s" % (repr(part_name))) - def bulk_import_upload_part(self, name, part_name, stream, size): + def bulk_import_upload_part( + self, name: str, part_name: str, stream: BytesOrStream, size: int + ) -> None: """Upload bulk import having the specified name and part in the path. Args: @@ -156,7 +195,14 @@ def bulk_import_upload_part(self, name, part_name, stream, size): if code / 100 != 2: self.raise_error("Upload a part failed", res, body) - def bulk_import_upload_file(self, name, part_name, format, file, **kwargs): + def bulk_import_upload_file( + self, + name: str, + part_name: str, + format: DataFormat, + file: FileLike, + **kwargs: Any, + ) -> None: """Upload a file with bulk import having the specified name. Args: @@ -193,7 +239,9 @@ def bulk_import_upload_file(self, name, part_name, format, file, **kwargs): size = os.fstat(fp.fileno()).st_size return self.bulk_import_upload_part(name, part_name, fp, size) - def bulk_import_delete_part(self, name, part_name, params=None): + def bulk_import_delete_part( + self, name: str, part_name: str, params: dict[str, Any] | None = None + ) -> bool: """Delete the imported information with the specified name. Args: @@ -218,7 +266,9 @@ def bulk_import_delete_part(self, name, part_name, params=None): self.raise_error("Delete a part failed", res, body) return True - def freeze_bulk_import(self, name, params=None): + def freeze_bulk_import( + self, name: str, params: dict[str, Any] | None = None + ) -> bool: """Freeze the bulk import with the specified name. Args: @@ -236,7 +286,9 @@ def freeze_bulk_import(self, name, params=None): self.raise_error("Freeze bulk import failed", res, body) return True - def unfreeze_bulk_import(self, name, params=None): + def unfreeze_bulk_import( + self, name: str, params: dict[str, Any] | None = None + ) -> bool: """Unfreeze bulk_import with the specified name. Args: @@ -254,7 +306,9 @@ def unfreeze_bulk_import(self, name, params=None): self.raise_error("Unfreeze bulk import failed", res, body) return True - def perform_bulk_import(self, name, params=None): + def perform_bulk_import( + self, name: str, params: dict[str, Any] | None = None + ) -> str: """Execute a job to perform bulk import with the indicated priority using the resource pool if indicated, else it will use the account's default. @@ -274,7 +328,9 @@ def perform_bulk_import(self, name, params=None): js = self.checked_json(body, ["job_id"]) return str(js["job_id"]) - def commit_bulk_import(self, name, params=None): + def commit_bulk_import( + self, name: str, params: dict[str, Any] | None = None + ) -> bool: """Commit the bulk import information having the specified name. Args: @@ -292,7 +348,9 @@ def commit_bulk_import(self, name, params=None): self.raise_error("Commit bulk import failed", res, body) return True - def bulk_import_error_records(self, name, params=None): + def bulk_import_error_records( + self, name: str, params: dict[str, Any] | None = None + ) -> Iterator[dict[str, Any]]: """List the records that have errors under the specified bulk import name. Args: diff --git a/tdclient/client.py b/tdclient/client.py index f0d4aee..5ca8e1e 100644 --- a/tdclient/client.py +++ b/tdclient/client.py @@ -5,14 +5,17 @@ import datetime import json from collections.abc import Iterator -from typing import Any +from typing import Any, cast, Literal + from tdclient import api, models from tdclient.types import ( BulkImportParams, + BytesOrStream, DataFormat, ExportParams, FileLike, + Priority, ResultFormat, ResultParams, ScheduleParams, @@ -254,7 +257,7 @@ def query( db_name: str, q: str, result_url: str | None = None, - priority: int | str | None = None, + priority: Priority | None = None, retry_limit: int | None = None, type: str = "hive", **kwargs: Any, @@ -279,9 +282,11 @@ def query( # for compatibility, assume type is hive unless specifically specified if type not in ["hive", "pig", "impala", "presto", "trino"]: raise ValueError("The specified query type is not supported: %s" % (type)) + # Cast type to expected literal since we've validated it + query_type = cast(Literal["hive", "presto", "trino", "bulkload"], type) job_id = self.api.query( q, - type=type, + type=query_type, db=db_name, result_url=result_url, priority=priority, @@ -295,7 +300,7 @@ def jobs( _from: int | None = None, to: int | None = None, status: str | None = None, - conditions: str | None = None, + conditions: dict[str, Any] | None = None, ) -> list[models.Job]: """List jobs @@ -304,7 +309,7 @@ def jobs( to (int, optional): Gets the Job up to the nth index in the list. By default, the first 20 jobs in the list are displayed status (str, optional): Filter by given status. {"queued", "running", "success", "error"} - conditions (str, optional): Condition for ``TIMESTAMPDIFF()`` to search for slow queries. + conditions (dict[str, Any], optional): Condition for ``TIMESTAMPDIFF()`` to search for slow queries. Avoid using this parameter as it can be dangerous. Returns: @@ -334,7 +339,7 @@ def job_status(self, job_id: str | int) -> str: Returns: a string represents the status of the job ("success", "error", "killed", "queued", "running") """ - return self.api.job_status(job_id) + return self.api.job_status(str(job_id)) def job_result(self, job_id: str | int) -> list[Any]: """ @@ -344,7 +349,7 @@ def job_result(self, job_id: str | int) -> list[Any]: Returns: a list of each rows in result set """ - return self.api.job_result(job_id) + return self.api.job_result(str(job_id)) def job_result_each(self, job_id: str | int) -> Iterator[dict[str, Any]]: """ @@ -354,7 +359,7 @@ def job_result_each(self, job_id: str | int) -> Iterator[dict[str, Any]]: Returns: an iterator of result set """ - for row in self.api.job_result_each(job_id): + for row in self.api.job_result_each(str(job_id)): yield row def job_result_format( @@ -368,7 +373,7 @@ def job_result_format( Returns: a list of each rows in result set """ - return self.api.job_result_format(job_id, format, header=header) + return self.api.job_result_format(str(job_id), format, header=header) def job_result_format_each( self, @@ -393,7 +398,7 @@ def job_result_format_each( an iterator of rows in result set """ for row in self.api.job_result_format_each( - job_id, + str(job_id), format, header=header, store_tmpfile=store_tmpfile, @@ -414,9 +419,9 @@ def download_job_result( Returns: `True` if success """ - return self.api.download_job_result(job_id, path, num_threads=num_threads) + return self.api.download_job_result(str(job_id), path, num_threads=num_threads) - def kill(self, job_id: str | int) -> str: + def kill(self, job_id: str | int) -> str | None: """ Args: job_id (str): job id @@ -424,7 +429,7 @@ def kill(self, job_id: str | int) -> str: Returns: a string represents the status of killed job ("queued", "running") """ - return self.api.kill(job_id) + return self.api.kill(str(job_id)) def export_data( self, @@ -582,7 +587,7 @@ def bulk_imports(self) -> list[models.BulkImport]: ] def bulk_import_upload_part( - self, name: str, part_name: str, bytes_or_stream: FileLike, size: int + self, name: str, part_name: str, bytes_or_stream: BytesOrStream, size: int ) -> None: """Upload a part to a bulk import session @@ -849,7 +854,7 @@ def import_data( db_name: str, table_name: str, format: DataFormat, - bytes_or_stream: FileLike, + bytes_or_stream: BytesOrStream, size: int, unique_id: str | None = None, ) -> float: diff --git a/tdclient/connector_api.py b/tdclient/connector_api.py index bf5b5d4..75549d7 100644 --- a/tdclient/connector_api.py +++ b/tdclient/connector_api.py @@ -1,8 +1,16 @@ #!/usr/bin/env python +from __future__ import annotations + import json +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import urllib3 -from .util import create_url, normalize_connector_config +from tdclient.util import create_url, normalize_connector_config class ConnectorAPI: @@ -11,7 +19,23 @@ class ConnectorAPI: This class is inherited by :class:`tdclient.api.API`. """ - def connector_guess(self, job): + # Methods from API class + def get( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def post( + self, url: str, params: Any, headers: dict[str, str] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def put( + self, url: str, params: Any, size: int, headers: dict[str, str] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def delete(self, url: str) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + + def connector_guess(self, job: dict[str, Any] | bytes) -> dict[str, Any]: """Guess the Data Connector configuration Args: diff --git a/tdclient/cursor.py b/tdclient/cursor.py index 38a98ef..72b2030 100644 --- a/tdclient/cursor.py +++ b/tdclient/cursor.py @@ -74,6 +74,7 @@ def _check_executed(self) -> None: def _do_execute(self) -> None: self._check_executed() + assert self._executed is not None if self._rows is None: status = self._api.job_status(self._executed) if status == "success": @@ -167,6 +168,7 @@ def show_job(self) -> dict[str, Any]: :class:`dict`: Detailed information of a job """ self._check_executed() + assert self._executed is not None return self._api.show_job(self._executed) def job_status(self) -> str: @@ -176,6 +178,7 @@ def job_status(self) -> str: The status information of the given job id at last execution. """ self._check_executed() + assert self._executed is not None return self._api.job_status(self._executed) def job_result(self) -> list[dict[str, Any]]: @@ -185,4 +188,5 @@ def job_result(self) -> list[dict[str, Any]]: Job result in :class:`list` """ self._check_executed() + assert self._executed is not None return self._api.job_result(self._executed) diff --git a/tdclient/database_api.py b/tdclient/database_api.py index 1a78009..3ebd9be 100644 --- a/tdclient/database_api.py +++ b/tdclient/database_api.py @@ -1,6 +1,15 @@ #!/usr/bin/env python -from .util import create_url, get_or_else, parse_date +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import urllib3 + +from tdclient.util import create_url, get_or_else, parse_date class DatabaseAPI: @@ -9,7 +18,19 @@ class DatabaseAPI: This class is inherited by :class:`tdclient.api.API`. """ - def list_databases(self): + # Methods from API class + def get( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def post( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + + def list_databases(self) -> dict[str, Any]: """Get the list of all the databases of the account. Returns: @@ -34,7 +55,7 @@ def list_databases(self): result[name] = m return result - def delete_database(self, db): + def delete_database(self, db: str) -> bool: """Delete a database. Args: @@ -48,7 +69,7 @@ def delete_database(self, db): self.raise_error("Delete database failed", res, body) return True - def create_database(self, db, params=None): + def create_database(self, db: str, params: dict[str, Any] | None = None) -> bool: """Create a new database with the given name. Args: diff --git a/tdclient/export_api.py b/tdclient/export_api.py index d5ce473..73ba1b7 100644 --- a/tdclient/export_api.py +++ b/tdclient/export_api.py @@ -1,6 +1,16 @@ #!/usr/bin/env python -from .util import create_url +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import urllib3 + +from tdclient.util import create_url +from tdclient.types import ExportParams class ExportAPI: @@ -9,7 +19,18 @@ class ExportAPI: This class is inherited by :class:`tdclient.api.API`. """ - def export_data(self, db, table, storage_type, params=None): + # Methods from API class + def post( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + + def export_data( + self, db: str, table: str, storage_type: str, params: ExportParams | None = None + ) -> str: """Creates a job to export the contents from the specified database and table names. @@ -48,10 +69,10 @@ def export_data(self, db, table, storage_type, params=None): Returns: str: Job ID. """ - params = {} if params is None else params - params["storage_type"] = storage_type + post_params = {} if params is None else dict(params) + post_params["storage_type"] = storage_type with self.post( - create_url("/v3/export/run/{db}/{table}", db=db, table=table), params + create_url("/v3/export/run/{db}/{table}", db=db, table=table), post_params ) as res: code, body = res.status, res.read() if code != 200: diff --git a/tdclient/import_api.py b/tdclient/import_api.py index 2d17341..146b959 100644 --- a/tdclient/import_api.py +++ b/tdclient/import_api.py @@ -1,9 +1,19 @@ #!/usr/bin/env python +from __future__ import annotations + import contextlib import os +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + from typing import IO -from .util import create_url + import urllib3 + +from tdclient.types import BytesOrStream, DataFormat, FileLike +from tdclient.util import create_url class ImportAPI: @@ -12,7 +22,27 @@ class ImportAPI: This class is inherited by :class:`tdclient.api.API`. """ - def import_data(self, db, table, format, bytes_or_stream, size, unique_id=None): + # Methods from API class + def put( + self, url: str, stream: BytesOrStream, size: int, **kwargs: Any + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + def _prepare_file( + self, file: FileLike, format: str, **kwargs: Any + ) -> IO[bytes]: ... + + def import_data( + self, + db: str, + table: str, + format: DataFormat, + bytes_or_stream: BytesOrStream, + size: int, + unique_id: str | None = None, + ) -> float: """Import data into Treasure Data Service This method expects data from a file-like object formatted with "msgpack.gz". @@ -53,7 +83,15 @@ def import_data(self, db, table, format, bytes_or_stream, size, unique_id=None): time = float(js["elapsed_time"]) return time - def import_file(self, db, table, format, file, unique_id=None, **kwargs): + def import_file( + self, + db: str, + table: str, + format: DataFormat, + file: FileLike, + unique_id: str | None = None, + **kwargs: Any, + ) -> float: """Import data into Treasure Data Service, from an existing file on filesystem. This method will decompress/deserialize records from given file, and then diff --git a/tdclient/job_api.py b/tdclient/job_api.py index e5ea9a0..ca1bf49 100644 --- a/tdclient/job_api.py +++ b/tdclient/job_api.py @@ -1,16 +1,26 @@ #!/usr/bin/env python +from __future__ import annotations + import codecs import gzip import json import logging import os import tempfile +from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Any, Literal import msgpack -from .util import create_url, get_or_else, parse_date +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import urllib3 + +from tdclient.types import Priority +from tdclient.util import create_url, get_or_else, parse_date log = logging.getLogger(__name__) @@ -22,7 +32,22 @@ class JobAPI: This class is inherited by :class:`tdclient.api.API`. """ - JOB_PRIORITY = { + # Methods from API class + def get( + self, + url: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def post( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes | str + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + + JOB_PRIORITY: dict[str, int] = { "VERY LOW": -2, "VERY-LOW": -2, "VERY_LOW": -2, @@ -35,7 +60,13 @@ class JobAPI: "VERY_HIGH": 2, } - def list_jobs(self, _from=0, to=None, status=None, conditions=None): + def list_jobs( + self, + _from: int = 0, + to: int | None = None, + status: str | None = None, + conditions: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: """Show the list of Jobs. Args: @@ -43,7 +74,7 @@ def list_jobs(self, _from=0, to=None, status=None, conditions=None): to (int, optional): Gets the Job up to the nth index in the list. By default, the first 20 jobs in the list are displayed status (str, optional): Filter by given status. {"queued", "running", "success", "error"} - conditions (str, optional): Condition for ``TIMESTAMPDIFF()`` to search for slow queries. + conditions (dict[str, Any], optional): Condition for ``TIMESTAMPDIFF()`` to search for slow queries. Avoid using this parameter as it can be dangerous. Returns: @@ -109,7 +140,7 @@ def list_jobs(self, _from=0, to=None, status=None, conditions=None): jobs.append(job) return jobs - def show_job(self, job_id): + def show_job(self, job_id: str) -> dict[str, Any]: """Return detailed information of a Job. Args: @@ -167,7 +198,7 @@ def show_job(self, job_id): } return job - def job_status(self, job_id): + def job_status(self, job_id: str) -> str: """Show job status Args: job_id (str): job ID @@ -183,7 +214,7 @@ def job_status(self, job_id): js = self.checked_json(body, ["status"]) return js["status"] - def job_result(self, job_id): + def job_result(self, job_id: str) -> list[dict[str, Any]]: """Return the job result. Args: @@ -197,7 +228,7 @@ def job_result(self, job_id): result.append(row) return result - def job_result_each(self, job_id): + def job_result_each(self, job_id: str) -> Iterator[dict[str, Any]]: """Yield a row of the job result. Args: @@ -209,7 +240,9 @@ def job_result_each(self, job_id): for row in self.job_result_format_each(job_id, "msgpack"): yield row - def job_result_format(self, job_id, format, header=False): + def job_result_format( + self, job_id: str, format: str, header: bool = False + ) -> list[dict[str, Any]]: """Return the job result with specified format. Args: @@ -228,8 +261,13 @@ def job_result_format(self, job_id, format, header=False): return result def job_result_format_each( - self, job_id, format, header=False, store_tmpfile=False, num_threads=4 - ): + self, + job_id: str, + format: str, + header: bool = False, + store_tmpfile: bool = False, + num_threads: int = 4, + ) -> Iterator[dict[str, Any]]: """Yield a row of the job result with specified format. Args: @@ -293,7 +331,7 @@ def job_result_format_each( else: yield res.read() - def download_job_result(self, job_id, path, num_threads=4): + def download_job_result(self, job_id: str, path: str, num_threads: int = 4) -> bool: """Download the job result to the specified path. Args: @@ -357,7 +395,7 @@ def download_file_multithreaded( download_file_multithreaded(url, path, file_size, num_threads=num_threads) return True - def kill(self, job_id): + def kill(self, job_id: str) -> str | None: """Stop the specific job if it is running. Args: @@ -376,14 +414,14 @@ def kill(self, job_id): def query( self, - q, - type="hive", - db=None, - result_url=None, - priority=None, - retry_limit=None, - **kwargs, - ): + q: str, + type: Literal["hive", "presto", "trino", "bulkload"] = "hive", + db: str | None = None, + result_url: str | None = None, + priority: Priority | None = None, + retry_limit: int | None = None, + **kwargs: Any, + ) -> str: """Create a job for given query. Args: @@ -401,18 +439,21 @@ def query( Returns: str: Job ID issued for the query """ - params = {"query": q} + params: dict[str, Any] = {"query": q} params.update(kwargs) if result_url is not None: params["result"] = result_url if priority is not None: + priority_value: int if not isinstance(priority, int): priority_name = str(priority).upper() if priority_name in self.JOB_PRIORITY: - priority = self.JOB_PRIORITY[priority_name] + priority_value = self.JOB_PRIORITY[priority_name] else: - raise (ValueError("unknown job priority: %s" % (priority_name,))) - params["priority"] = priority + raise ValueError("unknown job priority: %s" % (priority_name,)) + else: + priority_value = priority + params["priority"] = priority_value if retry_limit is not None: params["retry_limit"] = retry_limit with self.post( diff --git a/tdclient/job_model.py b/tdclient/job_model.py index 71ab46f..8e2e35c 100644 --- a/tdclient/job_model.py +++ b/tdclient/job_model.py @@ -67,7 +67,7 @@ class Job(Model): JOB_PRIORITY = {-2: "VERY LOW", -1: "LOW", 0: "NORMAL", 1: "HIGH", 2: "VERY HIGH"} def __init__( - self, client: Client, job_id: str, type: str, query: str, **kwargs: Any + self, client: Client, job_id: str, type: str, query: str | None, **kwargs: Any ) -> None: super(Job, self).__init__(client) self._job_id = job_id @@ -235,7 +235,7 @@ def kill(self) -> str: return response @property - def query(self) -> str: + def query(self) -> str | None: """a string represents the query string of the job""" return self._query diff --git a/tdclient/result_api.py b/tdclient/result_api.py index 72ff76b..2eb0e18 100644 --- a/tdclient/result_api.py +++ b/tdclient/result_api.py @@ -1,6 +1,16 @@ #!/usr/bin/env python -from .util import create_url +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import urllib3 + +from tdclient.util import create_url +from tdclient.types import ResultParams class ResultAPI: @@ -9,7 +19,19 @@ class ResultAPI: This class is inherited by :class:`tdclient.api.API`. """ - def list_result(self): + # Methods from API class + def get( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def post( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + + def list_result(self) -> list[tuple[str, str, None]]: """Get the list of all the available authentications. Returns: @@ -25,7 +47,9 @@ def list_result(self): (m["name"], m["url"], None) for m in js["results"] ] # same as database - def create_result(self, name, url, params=None): + def create_result( + self, name: str, url: str, params: ResultParams | None = None + ) -> bool: """Create a new authentication with the specified name. Args: @@ -35,17 +59,17 @@ def create_result(self, name, url, params=None): Returns: bool: True if succeeded. """ - params = {} if params is None else params - params.update({"url": url}) + post_params = {} if params is None else dict(params) + post_params.update({"url": url}) with self.post( - create_url("/v3/result/create/{name}", name=name), params + create_url("/v3/result/create/{name}", name=name), post_params ) as res: code, body = res.status, res.read() if code != 200: self.raise_error("Create result table failed", res, body) return True - def delete_result(self, name): + def delete_result(self, name: str) -> bool: """Delete the authentication having the specified name. Args: diff --git a/tdclient/schedule_api.py b/tdclient/schedule_api.py index eb6e43d..0764e1a 100644 --- a/tdclient/schedule_api.py +++ b/tdclient/schedule_api.py @@ -1,5 +1,17 @@ #!/usr/bin/env python -from .util import create_url, get_or_else, parse_date + +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any + +from tdclient.types import ScheduleParams +from tdclient.util import create_url, get_or_else, parse_date + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import urllib3 class ScheduleAPI: @@ -8,7 +20,21 @@ class ScheduleAPI: This class is inherited by :class:`tdclient.api.API`. """ - def create_schedule(self, name, params=None): + # Methods from API class + def get( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def post( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + + def create_schedule( + self, name: str, params: ScheduleParams | None = None + ) -> datetime.datetime | None: """Create a new scheduled query with the specified name. Args: @@ -50,10 +76,10 @@ def create_schedule(self, name, params=None): Returns: datetime.datetime: Start date time. """ - params = {} if params is None else params - params.update({"type": params.get("type", "hive")}) + post_params = {} if params is None else dict(params) + post_params.update({"type": post_params.get("type", "hive")}) with self.post( - create_url("/v3/schedule/create/{name}", name=name), params + create_url("/v3/schedule/create/{name}", name=name), post_params ) as res: code, body = res.status, res.read() if code != 200: @@ -61,7 +87,7 @@ def create_schedule(self, name, params=None): js = self.checked_json(body, ["start"]) return parse_date(get_or_else(js, "start", "1970-01-01T00:00:00Z")) - def delete_schedule(self, name): + def delete_schedule(self, name: str) -> tuple[str, str]: """Delete the scheduled query with the specified name. Args: @@ -76,7 +102,7 @@ def delete_schedule(self, name): js = self.checked_json(body, ["cron", "query"]) return js["cron"], js["query"] - def list_schedules(self): + def list_schedules(self) -> list[dict[str, Any]]: """Get the list of all the scheduled queries. Returns: @@ -90,12 +116,14 @@ def list_schedules(self): return [schedule_to_tuple(m) for m in js["schedules"]] - def update_schedule(self, name, params=None): + def update_schedule( + self, name: str, params: ScheduleParams | None = None + ) -> datetime.datetime | None: """Update the scheduled query. Args: name (str): Target scheduled query name. - params (dict): Extra parameters. + params (ScheduleParams | None): Extra parameters. - type (str): Query type. {"presto", "hive"}. Default: "hive" @@ -130,15 +158,17 @@ def update_schedule(self, name, params=None): Location where to store the result of the query. e.g. 'tableau://user:password@host.com:1234/datasource' """ - params = {} if params is None else params + post_params = {} if params is None else dict(params) with self.post( - create_url("/v3/schedule/update/{name}", name=name), params + create_url("/v3/schedule/update/{name}", name=name), post_params ) as res: code, body = res.status, res.read() if code != 200: self.raise_error("Update schedule failed", res, body) - def history(self, name, _from=0, to=None): + def history( + self, name: str, _from: int = 0, to: int | None = None + ) -> list[tuple[Any, ...]]: """Get the history details of the saved query for the past 90days. Args: @@ -169,7 +199,9 @@ def history(self, name, _from=0, to=None): return [history_to_tuple(m) for m in js["history"]] - def run_schedule(self, name, time, num=None): + def run_schedule( + self, name: str, time: int, num: int | None = None + ) -> list[tuple[Any, Any, datetime.datetime | None]]: """Execute the specified query. Args: diff --git a/tdclient/schedule_model.py b/tdclient/schedule_model.py index 21679fb..161a661 100644 --- a/tdclient/schedule_model.py +++ b/tdclient/schedule_model.py @@ -21,7 +21,7 @@ def __init__( scheduled_at: datetime.datetime, job_id: str, type: str, - query: str, + query: str | None, **kwargs: Any, ) -> None: super(ScheduledJob, self).__init__(client, job_id, type, query, **kwargs) diff --git a/tdclient/server_status_api.py b/tdclient/server_status_api.py index e3173cf..2cae65a 100644 --- a/tdclient/server_status_api.py +++ b/tdclient/server_status_api.py @@ -1,5 +1,14 @@ #!/usr/bin/env python +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import urllib3 + class ServerStatusAPI: """Access to Server Status API @@ -7,7 +16,13 @@ class ServerStatusAPI: This class is inherited by :class:`tdclient.api.API`. """ - def server_status(self): + # Methods from API class + def get( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + + def server_status(self) -> str: """Show the status of Treasure Data Returns: diff --git a/tdclient/table_api.py b/tdclient/table_api.py index 29fc65f..0f17b44 100644 --- a/tdclient/table_api.py +++ b/tdclient/table_api.py @@ -1,10 +1,18 @@ #!/usr/bin/env python +from __future__ import annotations + import json +from typing import TYPE_CHECKING, Any import msgpack -from .util import create_url, get_or_else, parse_date +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import urllib3 + +from tdclient.util import create_url, get_or_else, parse_date class TableAPI: @@ -13,7 +21,19 @@ class TableAPI: This class is inherited by :class:`tdclient.api.API`. """ - def list_tables(self, db): + # Methods from API class + def get( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def post( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes | str + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + + def list_tables(self, db: str) -> dict[str, Any]: """Gets the list of table in the database. Args: @@ -70,7 +90,7 @@ def list_tables(self, db): result[m["name"]] = m return result - def create_log_table(self, db, table): + def create_log_table(self, db: str, table: str) -> bool: """Create a new table in the database and registers it in PlazmaDB. Args: @@ -82,7 +102,9 @@ def create_log_table(self, db, table): """ return self._create_table(db, table, "log") - def _create_table(self, db, table, type, params=None): + def _create_table( + self, db: str, table: str, type: str, params: dict[str, Any] | None = None + ) -> bool: params = {} if params is None else params with self.post( create_url( @@ -95,7 +117,7 @@ def _create_table(self, db, table, type, params=None): self.raise_error("Create %s table failed" % (type), res, body) return True - def swap_table(self, db, table1, table2): + def swap_table(self, db: str, table1: str, table2: str) -> bool: """Swap the two specified tables with each other belonging to the same database and basically exchanges their names. @@ -119,7 +141,7 @@ def swap_table(self, db, table1, table2): self.raise_error("Swap tables failed", res, body) return True - def update_schema(self, db, table, schema_json): + def update_schema(self, db: str, table: str, schema_json: str) -> bool: """Update the table schema. Args: @@ -140,7 +162,7 @@ def update_schema(self, db, table, schema_json): self.raise_error("Create schema table failed", res, body) return True - def update_expire(self, db, table, expire_days): + def update_expire(self, db: str, table: str, expire_days: int) -> bool: """Update the expire days for the specified table Args: @@ -160,7 +182,7 @@ def update_expire(self, db, table, expire_days): self.raise_error("Update table expiration failed", res, body) return True - def delete_table(self, db, table): + def delete_table(self, db: str, table: str) -> str: """Delete the specified table. Args: @@ -180,7 +202,15 @@ def delete_table(self, db, table): t = js.get("type", "?") return t - def tail(self, db, table, count, to=None, _from=None, block=None): + def tail( + self, + db: str, + table: str, + count: int, + to: Any = None, + _from: Any = None, + block: Any = None, + ) -> list[dict[str, Any]]: """Get the contents of the table in reverse order based on the registered time (last data first). @@ -210,7 +240,7 @@ def tail(self, db, table, count, to=None, _from=None, block=None): return result - def change_database(self, db, table, dest_db): + def change_database(self, db: str, table: str, dest_db: str) -> bool: """Move a target table from it's original database to new destination database. Args: diff --git a/tdclient/types.py b/tdclient/types.py index 445e9da..adcd3b6 100644 --- a/tdclient/types.py +++ b/tdclient/types.py @@ -60,7 +60,6 @@ class ScheduleParams(TypedDict, total=False): class ExportParams(TypedDict, total=False): """Parameters for export operations.""" - storage_type: str # Storage type (e.g. "s3") bucket: str # Bucket name access_key_id: str # ID to access the export destination secret_access_key: str # Password for access_key_id diff --git a/tdclient/user_api.py b/tdclient/user_api.py index af8c2a2..2a6a892 100644 --- a/tdclient/user_api.py +++ b/tdclient/user_api.py @@ -1,10 +1,31 @@ #!/usr/bin/env python -from .util import create_url +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import urllib3 + +from tdclient.util import create_url class UserAPI: - def authenticate(self, user, password): + # Methods from API class + def get( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def post( + self, url: str, params: dict[str, Any] | None = None + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def raise_error( + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes + ) -> None: ... + def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... + + def authenticate(self, user: str, password: str) -> str: with self.post( "/v3/user/authenticate", {"user": user, "password": password} ) as res: @@ -15,7 +36,7 @@ def authenticate(self, user, password): apikey = js["apikey"] return apikey - def list_users(self): + def list_users(self) -> list[tuple[str, None, None, str]]: with self.get("/v3/user/list") as res: code, body = res.status, res.read() if code != 200: @@ -24,7 +45,7 @@ def list_users(self): return [user_to_tuple(roleinfo) for roleinfo in js["users"]] - def add_user(self, name, org, email, password): + def add_user(self, name: str, org: str, email: str, password: str) -> bool: params = {"organization": org, "email": email, "password": password} with self.post(create_url("/v3/user/add/{name}", name=name), params) as res: code, body = res.status, res.read() @@ -32,14 +53,14 @@ def add_user(self, name, org, email, password): self.raise_error("Adding user failed", res, body) return True - def remove_user(self, name): + def remove_user(self, name: str) -> bool: with self.post(create_url("/v3/user/remove/{name}", name=name)) as res: code, body = res.status, res.read() if code != 200: self.raise_error("Removing user failed", res, body) return True - def list_apikeys(self, name): + def list_apikeys(self, name: str) -> list[str]: with self.get(create_url("/v3/user/apikey/list/{name}", name=name)) as res: code, body = res.status, res.read() if code != 200: @@ -47,14 +68,14 @@ def list_apikeys(self, name): js = self.checked_json(body, ["apikeys"]) return js["apikeys"] - def add_apikey(self, name): + def add_apikey(self, name: str) -> bool: with self.post(create_url("/v3/user/apikey/add/{name}", name=name)) as res: code, body = res.status, res.read() if code != 200: self.raise_error("Adding API key failed", res, body) return True - def remove_apikey(self, name, apikey): + def remove_apikey(self, name: str, apikey: str) -> bool: params = {"apikey": apikey} with self.post( create_url("/v3/user/apikey/remove/{name}", name=name), params @@ -65,7 +86,7 @@ def remove_apikey(self, name, apikey): return True -def user_to_tuple(roleinfo): +def user_to_tuple(roleinfo: dict[str, Any]) -> tuple[str, None, None, str]: name = roleinfo["name"] email = roleinfo["email"] return (name, None, None, email) # set None to org and role for API compatibility From f5bc498ad69853ee5ff92516d085a1b75e5318f2 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 18:22:12 -0700 Subject: [PATCH 09/13] Add type hint for utilities --- tdclient/errors.py | 30 ++++++++++++++++ tdclient/types.py | 16 ++++++++- tdclient/util.py | 85 ++++++++++++++++++++++++++++++---------------- 3 files changed, 101 insertions(+), 30 deletions(-) diff --git a/tdclient/errors.py b/tdclient/errors.py index 6bd1fff..b7e1760 100644 --- a/tdclient/errors.py +++ b/tdclient/errors.py @@ -2,66 +2,96 @@ class ParameterValidationError(Exception): + """Exception raised when parameter validation fails.""" + pass # Generic API error class APIError(Exception): + """Base exception for API-related errors.""" + pass # 401 API errors class AuthError(APIError): + """Exception raised for authentication errors (HTTP 401).""" + pass # 403 API errors, used for database permissions class ForbiddenError(APIError): + """Exception raised for forbidden access errors (HTTP 403).""" + pass # 409 API errors class AlreadyExistsError(APIError): + """Exception raised when a resource already exists (HTTP 409).""" + pass # 404 API errors class NotFoundError(APIError): + """Exception raised when a resource is not found (HTTP 404).""" + pass # PEP 0249 errors class Error(Exception): + """Base class for database-related errors (PEP 249).""" + pass class InterfaceError(Error): + """Exception for errors related to the database interface (PEP 249).""" + pass class DatabaseError(Error): + """Exception for errors related to the database (PEP 249).""" + pass class DataError(DatabaseError): + """Exception for errors due to problems with the processed data (PEP 249).""" + pass class OperationalError(DatabaseError): + """Exception for errors related to database operation (PEP 249).""" + pass class IntegrityError(DatabaseError): + """Exception for errors related to relational integrity (PEP 249).""" + pass class InternalError(DatabaseError): + """Exception for internal database errors (PEP 249).""" + pass class ProgrammingError(DatabaseError): + """Exception for programming errors (PEP 249).""" + pass class NotSupportedError(DatabaseError): + """Exception for unsupported operations (PEP 249).""" + pass diff --git a/tdclient/types.py b/tdclient/types.py index adcd3b6..1ee0d49 100644 --- a/tdclient/types.py +++ b/tdclient/types.py @@ -3,10 +3,14 @@ from __future__ import annotations from array import array -from typing import IO +from typing import IO, TYPE_CHECKING from typing_extensions import Literal, TypeAlias, TypedDict +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any + # File-like types FileLike: TypeAlias = "str | bytes | IO[bytes]" """Type for file inputs: file path, bytes, or file-like object.""" @@ -39,6 +43,16 @@ ResultFormat: TypeAlias = 'Literal["msgpack", "json", "csv", "tsv"]' """Type for query result formats.""" +# Utility types for CSV parsing and data processing +CSVValue: TypeAlias = "int | float | str | bool | None" +"""Type for values parsed from CSV files.""" + +Converter: TypeAlias = "Callable[[str], Any]" +"""Type for converter functions that parse string values.""" + +Record: TypeAlias = "dict[str, Any]" +"""Type for data records (dictionaries with string keys and any values).""" + # TypedDict classes for structured parameters class ScheduleParams(TypedDict, total=False): diff --git a/tdclient/util.py b/tdclient/util.py index f3f6d69..55bd63e 100644 --- a/tdclient/util.py +++ b/tdclient/util.py @@ -1,16 +1,23 @@ +from __future__ import annotations + import csv import io import logging import warnings +from collections.abc import Iterator +from datetime import datetime +from typing import Any, BinaryIO from urllib.parse import quote as urlquote import dateutil.parser import msgpack +from tdclient.types import CSVValue, Converter, Record + log = logging.getLogger(__name__) -def create_url(tmpl, **values): +def create_url(tmpl: str, **values: Any) -> str: """Create url with values Args: @@ -21,7 +28,7 @@ def create_url(tmpl, **values): return tmpl.format(**quoted_values) -def validate_record(record): +def validate_record(record: Record) -> bool: """Check that `record` contains a key called "time". Args: @@ -41,7 +48,7 @@ def validate_record(record): return True -def guess_csv_value(s): +def guess_csv_value(s: str) -> CSVValue: """Determine the most appropriate type for `s` and return it. Tries to interpret `s` as a more specific datatype, in the following @@ -75,7 +82,7 @@ def guess_csv_value(s): # Convert our dtype names to callables that parse a string into that type -DTYPE_TO_CALLABLE = { +DTYPE_TO_CALLABLE: dict[str, Converter] = { "bool": bool, "float": float, "int": int, @@ -84,7 +91,9 @@ def guess_csv_value(s): } -def merge_dtypes_and_converters(dtypes=None, converters=None): +def merge_dtypes_and_converters( + dtypes: dict[str, str] | None = None, converters: dict[str, Converter] | None = None +) -> dict[str, Converter]: """Generate a merged dictionary from those given. Args: @@ -113,23 +122,25 @@ def merge_dtypes_and_converters(dtypes=None, converters=None): If a column name occurs in both input dictionaries, the callable specified in `converters` is used. """ - our_converters = {} + our_converters: dict[str, Converter] = {} if dtypes is not None: - try: - for column_name, dtype in dtypes.items(): + for column_name, dtype in dtypes.items(): + try: our_converters[column_name] = DTYPE_TO_CALLABLE[dtype] - except KeyError: - raise ValueError( - "Unrecognized dtype %r, must be one of %s" - % (dtype, ", ".join(repr(k) for k in sorted(DTYPE_TO_CALLABLE))) - ) + except KeyError: + raise ValueError( + "Unrecognized dtype %r, must be one of %s" + % (dtype, ", ".join(repr(k) for k in sorted(DTYPE_TO_CALLABLE))) + ) if converters is not None: for column_name, parse_fn in converters.items(): our_converters[column_name] = parse_fn return our_converters -def parse_csv_value(k, s, converters=None): +def parse_csv_value( + k: str, s: str, converters: dict[str, Converter] | None = None +) -> Any: """Given a CSV (string) value, work out an actual value. Args: @@ -167,7 +178,9 @@ def parse_csv_value(k, s, converters=None): return parse_fn(s) -def csv_dict_record_reader(file_like, encoding, dialect): +def csv_dict_record_reader( + file_like: BinaryIO, encoding: str, dialect: str | type[csv.Dialect] +) -> Iterator[dict[str, str]]: """Yield records from a CSV input using csv.DictReader. This is a reader suitable for use by `tdclient.util.read_csv_records`_. @@ -180,7 +193,7 @@ def csv_dict_record_reader(file_like, encoding, dialect): returns bytes. encoding (str): the name of the encoding to use when turning those bytes into strings. - dialect (str): the name of the CSV dialect to use. + dialect (str | type[csv.Dialect]): the name of the CSV dialect to use, or a Dialect class. Yields: For each row of CSV data read from `file_like`, yields a dictionary @@ -192,7 +205,12 @@ def csv_dict_record_reader(file_like, encoding, dialect): yield row -def csv_text_record_reader(file_like, encoding, dialect, columns): +def csv_text_record_reader( + file_like: BinaryIO, + encoding: str, + dialect: str | type[csv.Dialect], + columns: list[str], +) -> Iterator[dict[str, str]]: """Yield records from a CSV input using csv.reader and explicit column names. This is a reader suitable for use by `tdclient.util.read_csv_records`_. @@ -205,7 +223,7 @@ def csv_text_record_reader(file_like, encoding, dialect, columns): returns bytes. encoding (str): the name of the encoding to use when turning those bytes into strings. - dialect (str): the name of the CSV dialect to use. + dialect (str | type[csv.Dialect]): the name of the CSV dialect to use, or a Dialect class. Yields: For each row of CSV data read from `file_like`, yields a dictionary @@ -217,7 +235,12 @@ def csv_text_record_reader(file_like, encoding, dialect, columns): yield dict(zip(columns, row)) -def read_csv_records(csv_reader, dtypes=None, converters=None, **kwargs): +def read_csv_records( + csv_reader: Iterator[dict[str, str]], + dtypes: dict[str, str] | None = None, + converters: dict[str, Converter] | None = None, + **kwargs: Any, +) -> Iterator[Record]: """Read records using csv_reader and yield the results.""" our_converters = merge_dtypes_and_converters(dtypes, converters) @@ -227,7 +250,7 @@ def read_csv_records(csv_reader, dtypes=None, converters=None, **kwargs): yield record -def create_msgpack(items): +def create_msgpack(items: list[dict[str, Any]]) -> bytes: """Create msgpack streaming bytes from list Args: @@ -256,7 +279,7 @@ def create_msgpack(items): return stream.getvalue() -def normalized_msgpack(value): +def normalized_msgpack(value: Any) -> Any: """Recursively convert int to str if the int "overflows". Args: @@ -292,7 +315,9 @@ def normalized_msgpack(value): return value -def get_or_else(hashmap, key, default_value=None): +def get_or_else( + hashmap: dict[str, str], key: str, default_value: str | None = None +) -> str | None: """Get value or default value It differs from the standard dict ``get`` method in its behaviour when @@ -300,9 +325,9 @@ def get_or_else(hashmap, key, default_value=None): only spaces. Args: - hashmap (dict): target - key (Any): key - default_value (Any): default value + hashmap (dict): target dictionary with string values + key (str): key to look up + default_value (str | None): default value to return if key is missing or value is empty/whitespace Example: @@ -326,7 +351,7 @@ def get_or_else(hashmap, key, default_value=None): return default_value -def parse_date(s): +def parse_date(s: str | None) -> datetime | None: """Parse date from str to datetime TODO: parse datetime using an optional format string @@ -334,11 +359,13 @@ def parse_date(s): For now, this does not use a format string since API may return date in ambiguous format :( Args: - s (str): target str + s (str | None): target str, or None Returns: - datetime + datetime or None """ + if s is None: + return None try: return dateutil.parser.parse(s) except ValueError: @@ -346,7 +373,7 @@ def parse_date(s): return None -def normalize_connector_config(config): +def normalize_connector_config(config: dict[str, Any]) -> dict[str, Any]: """Normalize connector config This is porting of TD CLI's ConnectorConfigNormalizer#normalized_config. From e29af526b10cd2b0f8c7f09fe258ce0201546b4f Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 18:23:07 -0700 Subject: [PATCH 10/13] Re-enable pyright on pre-commit hook --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0ab8edc..6d8a812 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,4 +20,3 @@ repos: rev: v1.1.407 hooks: - id: pyright - stages: [manual] From a103e5822a52c6189d2ce613a3ab77775e1093fd Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 18:29:54 -0700 Subject: [PATCH 11/13] Tweak pre-commit settings --- .pre-commit-config.yaml | 7 +++++++ pyproject.toml | 7 ++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6d8a812..67eec89 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,3 +20,10 @@ repos: rev: v1.1.407 hooks: - id: pyright + exclude: ^docs/ + additional_dependencies: + - msgpack>=0.6.2 + - urllib3 + - python-dateutil + - typing-extensions>=4.0.0 + - certifi diff --git a/pyproject.toml b/pyproject.toml index 17dceac..61ea5d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ known-third-party = ["dateutil","msgpack","pkg_resources","pytest","setuptools", [tool.pyright] include = ["tdclient"] -exclude = ["**/__pycache__", "tdclient/test"] +exclude = ["**/__pycache__", "tdclient/test", "docs"] typeCheckingMode = "basic" pythonVersion = "3.9" pythonPlatform = "All" @@ -68,3 +68,8 @@ reportMissingTypeStubs = false reportUnknownMemberType = false reportUnknownArgumentType = false reportUnknownVariableType = false +reportMissingImports = "warning" + +# Pre-commit venv configuration +venvPath = "." +venv = ".venv" From faa56c430f18c863bc67ffa76223112d1e8c0856 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 18:30:04 -0700 Subject: [PATCH 12/13] Update documents --- README.rst | 40 +++++++++++++++++++++++++++++++++ docs/api/client.rst | 2 -- docs/file_import_parameters.rst | 2 +- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 68f75bb..50a4389 100644 --- a/README.rst +++ b/README.rst @@ -255,6 +255,46 @@ which would produce: 1575454204, "a", "0001", ["a", "b", "c"] 1575454204, "b", "0002", ["d", "e", "f"] +Type Hints +---------- + +td-client-python includes comprehensive type hints (PEP 484) for improved development experience with static type checkers like mypy and pyright. Type hints are available for all public APIs. + +**Features:** + + +* Fully typed public API with precise type annotations +* ``py.typed`` marker file for PEP 561 compliance +* Type aliases in ``tdclient.types`` for common patterns +* Support for type checking with mypy, pyright, and other tools + +**Example with type checking:** + +.. code-block:: python + + import tdclient + + # Type checkers will understand the types + with tdclient.Client(apikey="your_api_key") as client: + # client is inferred as tdclient.Client + job = client.query("sample_db", "SELECT COUNT(1) FROM table", type="presto") + # job is inferred as tdclient.models.Job + job.wait() + for row in job.result(): + # row is inferred as dict[str, Any] + print(row) + +**Using type aliases:** + +.. code-block:: python + + from tdclient.types import QueryEngineType, Priority + + def run_query(engine: QueryEngineType, priority: Priority) -> None: + with tdclient.Client() as client: + job = client.query("mydb", "SELECT 1", type=engine, priority=priority) + job.wait() + Development ----------- diff --git a/docs/api/client.rst b/docs/api/client.rst index 756b5d3..5843532 100644 --- a/docs/api/client.rst +++ b/docs/api/client.rst @@ -11,5 +11,3 @@ tdclient.client :members: :undoc-members: :show-inheritance: - - diff --git a/docs/file_import_parameters.rst b/docs/file_import_parameters.rst index 213d9ee..aad9b61 100644 --- a/docs/file_import_parameters.rst +++ b/docs/file_import_parameters.rst @@ -74,7 +74,7 @@ contains ``"not-an-int"``, the resulting ``ValueError`` will not be caught. To summarise, the default for reading CSV files is: ``dialect=csv.excel, encoding="utf-8", columns=None, dtypes=None, converters=None`` - + TSV data -------- From f243f7f866a5dd6ad8d47517e3c3b590b8c33bef Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Thu, 30 Oct 2025 21:57:08 -0700 Subject: [PATCH 13/13] Fix override signatures --- tdclient/bulk_import_api.py | 19 ++++++++++++++++--- tdclient/connector_api.py | 22 +++++++++++++++++++--- tdclient/database_api.py | 12 ++++++++++-- tdclient/export_api.py | 8 ++++++-- tdclient/import_api.py | 7 ++++++- tdclient/job_api.py | 6 +++++- tdclient/result_api.py | 12 ++++++++++-- tdclient/schedule_api.py | 12 ++++++++++-- tdclient/server_status_api.py | 6 +++++- tdclient/table_api.py | 12 ++++++++++-- tdclient/user_api.py | 12 ++++++++++-- 11 files changed, 107 insertions(+), 21 deletions(-) diff --git a/tdclient/bulk_import_api.py b/tdclient/bulk_import_api.py index 9b257c6..32d4348 100644 --- a/tdclient/bulk_import_api.py +++ b/tdclient/bulk_import_api.py @@ -30,13 +30,26 @@ class BulkImportAPI: # Methods from API class def get( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def post( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def put( - self, url: str, stream: BytesOrStream, size: int + self, + path: str, + bytes_or_stream: BytesOrStream, + size: int, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes | str diff --git a/tdclient/connector_api.py b/tdclient/connector_api.py index 75549d7..77f7271 100644 --- a/tdclient/connector_api.py +++ b/tdclient/connector_api.py @@ -10,6 +10,7 @@ import urllib3 +from tdclient.types import BytesOrStream from tdclient.util import create_url, normalize_connector_config @@ -21,15 +22,30 @@ class ConnectorAPI: # Methods from API class def get( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def post( self, url: str, params: Any, headers: dict[str, str] | None = None ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def put( - self, url: str, params: Any, size: int, headers: dict[str, str] | None = None + self, + path: str, + bytes_or_stream: BytesOrStream, + size: int, + headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... + def delete( + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... - def delete(self, url: str) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes ) -> None: ... diff --git a/tdclient/database_api.py b/tdclient/database_api.py index 3ebd9be..e8208ac 100644 --- a/tdclient/database_api.py +++ b/tdclient/database_api.py @@ -20,10 +20,18 @@ class DatabaseAPI: # Methods from API class def get( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def post( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes diff --git a/tdclient/export_api.py b/tdclient/export_api.py index 73ba1b7..c7a57f7 100644 --- a/tdclient/export_api.py +++ b/tdclient/export_api.py @@ -21,10 +21,14 @@ class ExportAPI: # Methods from API class def post( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( - self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes + self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes | str ) -> None: ... def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... diff --git a/tdclient/import_api.py b/tdclient/import_api.py index 146b959..a7af59d 100644 --- a/tdclient/import_api.py +++ b/tdclient/import_api.py @@ -24,7 +24,12 @@ class ImportAPI: # Methods from API class def put( - self, url: str, stream: BytesOrStream, size: int, **kwargs: Any + self, + path: str, + bytes_or_stream: BytesOrStream, + size: int, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes diff --git a/tdclient/job_api.py b/tdclient/job_api.py index ca1bf49..475821d 100644 --- a/tdclient/job_api.py +++ b/tdclient/job_api.py @@ -40,7 +40,11 @@ def get( headers: dict[str, str] | None = None, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def post( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes | str diff --git a/tdclient/result_api.py b/tdclient/result_api.py index 2eb0e18..8a1f979 100644 --- a/tdclient/result_api.py +++ b/tdclient/result_api.py @@ -21,10 +21,18 @@ class ResultAPI: # Methods from API class def get( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def post( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes diff --git a/tdclient/schedule_api.py b/tdclient/schedule_api.py index 0764e1a..1410d77 100644 --- a/tdclient/schedule_api.py +++ b/tdclient/schedule_api.py @@ -22,10 +22,18 @@ class ScheduleAPI: # Methods from API class def get( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def post( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes diff --git a/tdclient/server_status_api.py b/tdclient/server_status_api.py index 2cae65a..499cbe5 100644 --- a/tdclient/server_status_api.py +++ b/tdclient/server_status_api.py @@ -18,7 +18,11 @@ class ServerStatusAPI: # Methods from API class def get( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ... diff --git a/tdclient/table_api.py b/tdclient/table_api.py index 0f17b44..62b11aa 100644 --- a/tdclient/table_api.py +++ b/tdclient/table_api.py @@ -23,10 +23,18 @@ class TableAPI: # Methods from API class def get( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def post( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes | str diff --git a/tdclient/user_api.py b/tdclient/user_api.py index 2a6a892..6212155 100644 --- a/tdclient/user_api.py +++ b/tdclient/user_api.py @@ -15,10 +15,18 @@ class UserAPI: # Methods from API class def get( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def post( - self, url: str, params: dict[str, Any] | None = None + self, + path: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + **kwargs: Any, ) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ... def raise_error( self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes