Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions .github/workflows/pythontest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@ name: Python testing
on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ".[dev]"
- name: lint with ruff
run: |
ruff format tdclient --diff --exit-non-zero-on-fix
ruff check tdclient
- name: Run pyright
run: |
pyright tdclient

test:
runs-on: ${{ matrix.os }}
strategy:
Expand All @@ -23,9 +43,6 @@ jobs:
pip install ".[dev]"
pip install -r requirements.txt -r test-requirements.txt
pip install -U coveralls pyyaml
- name: Run pyright
run: |
pyright tdclient
- name: Run test
run: |
coverage run --source=tdclient -m pytest tdclient/test
Expand Down
4 changes: 1 addition & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def linkcode_resolve(domain, info):
except Exception:
linenum = ""

return "https://github.com/{}/{}/blob/{}/{}/{}#L{}".format(
GH_ORGANIZATION, GH_PROJECT, revision, MODULE, relpath, linenum
)
return f"https://github.com/{GH_ORGANIZATION}/{GH_PROJECT}/blob/{revision}/{MODULE}/{relpath}#L{linenum}"


# -- Project information -----------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ line-length = 88
[tool.ruff.lint]
select = [
"E",
"W",
"F",
"I",
"UP",
"B",
]
exclude = ["tdclient/test/*"]
ignore = ["E203", "E501"]
Expand Down
66 changes: 25 additions & 41 deletions tdclient/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import json
import logging
import os
import socket
import ssl
import tempfile
import time
Expand Down Expand Up @@ -108,11 +107,11 @@ def __init__(
if user_agent is not None:
self._user_agent = user_agent
else:
self._user_agent = "TD-Client-Python/%s" % (version.__version__)
self._user_agent = f"TD-Client-Python/{version.__version__}"

if endpoint is not None:
if not urlparse.urlparse(endpoint).scheme:
endpoint = "https://{}".format(endpoint)
endpoint = f"https://{endpoint}"
self._endpoint = endpoint
elif os.getenv("TD_API_SERVER"):
self._endpoint = os.getenv("TD_API_SERVER")
Expand Down Expand Up @@ -154,7 +153,7 @@ def _init_http(
if http_proxy.startswith("http://"):
return self._init_http_proxy(http_proxy, **kwargs)
else:
return self._init_http_proxy("http://%s" % (http_proxy,), **kwargs)
return self._init_http_proxy(f"http://{http_proxy}", **kwargs)

def _init_http_proxy(self, http_proxy: str, **kwargs: Any) -> urllib3.ProxyManager:
pool_options = dict(kwargs)
Expand All @@ -164,7 +163,7 @@ def _init_http_proxy(self, http_proxy: str, **kwargs: Any) -> urllib3.ProxyManag
if "@" in netloc:
auth, netloc = netloc.split("@", 2)
pool_options["proxy_headers"] = urllib3.make_headers(proxy_basic_auth=auth)
return urllib3.ProxyManager("%s://%s" % (scheme, netloc), **pool_options)
return urllib3.ProxyManager(f"{scheme}://{netloc}", **pool_options)

def get(
self,
Expand Down Expand Up @@ -214,12 +213,12 @@ def get(
self._max_cumul_retry_delay,
)
except (
OSError,
urllib3.exceptions.TimeoutStateError,
urllib3.exceptions.TimeoutError,
urllib3.exceptions.PoolError,
http.client.IncompleteRead,
TimeoutError,
socket.error,
):
pass

Expand All @@ -235,12 +234,7 @@ def get(
retry_delay *= 2
else:
raise APIError(
"Retrying stopped after %d seconds. (cumulative: %d/%d)"
% (
self._max_cumul_retry_delay,
cumul_retry_delay,
self._max_cumul_retry_delay,
)
f"Retrying stopped after {self._max_cumul_retry_delay} seconds. (cumulative: {cumul_retry_delay}/{self._max_cumul_retry_delay})"
)

log.debug(
Expand Down Expand Up @@ -314,13 +308,15 @@ def post(
self._max_cumul_retry_delay,
)
except (
OSError,
urllib3.exceptions.TimeoutStateError,
urllib3.exceptions.TimeoutError,
urllib3.exceptions.PoolError,
socket.error,
):
if not self._retry_post_requests:
raise APIError("Retrying stopped by retry_post_requests == False")
raise APIError(
"Retrying stopped by retry_post_requests == False"
) from None

if cumul_retry_delay <= self._max_cumul_retry_delay:
log.warning(
Expand All @@ -334,12 +330,7 @@ def post(
retry_delay *= 2
else:
raise APIError(
"Retrying stopped after %d seconds. (cumulative: %d/%d)"
% (
self._max_cumul_retry_delay,
cumul_retry_delay,
self._max_cumul_retry_delay,
)
f"Retrying stopped after {self._max_cumul_retry_delay} seconds. (cumulative: {cumul_retry_delay}/{self._max_cumul_retry_delay})"
)

log.debug(
Expand Down Expand Up @@ -408,12 +399,12 @@ def put(
else:
raise APIError("Error %d: %s", response.status, response.data)
except (
OSError,
urllib3.exceptions.TimeoutStateError,
urllib3.exceptions.TimeoutError,
urllib3.exceptions.PoolError,
socket.error,
):
raise APIError("Error: %s" % (repr(response)))
raise APIError(f"Error: {repr(response)}") from None

log.debug(
"REST PUT response:\n headers: %s\n status: %d\n body: <omitted>",
Expand Down Expand Up @@ -470,10 +461,10 @@ def delete(
self._max_cumul_retry_delay,
)
except (
OSError,
urllib3.exceptions.TimeoutStateError,
urllib3.exceptions.TimeoutError,
urllib3.exceptions.PoolError,
socket.error,
):
pass

Expand All @@ -489,12 +480,7 @@ def delete(
retry_delay *= 2
else:
raise APIError(
"Retrying stopped after %d seconds. (cumulative: %d/%d)"
% (
self._max_cumul_retry_delay,
cumul_retry_delay,
self._max_cumul_retry_delay,
)
f"Retrying stopped after {self._max_cumul_retry_delay} seconds. (cumulative: {cumul_retry_delay}/{self._max_cumul_retry_delay})"
)

log.debug(
Expand Down Expand Up @@ -536,7 +522,7 @@ def build_request(
# use default headers first
_headers = dict(self._headers)
# add default headers
_headers["authorization"] = "TD1 %s" % (self._apikey,)
_headers["authorization"] = f"TD1 {self._apikey}"
_headers["date"] = email.utils.formatdate(time.time())
_headers["user-agent"] = self._user_agent
# override given headers
Expand Down Expand Up @@ -571,28 +557,26 @@ def raise_error(
status_code = res.status
s = body if isinstance(body, str) else body.decode("utf-8")
if status_code == 404:
raise errors.NotFoundError("%s: %s" % (msg, s))
raise errors.NotFoundError(f"{msg}: {s}")
elif status_code == 409:
raise errors.AlreadyExistsError("%s: %s" % (msg, s))
raise errors.AlreadyExistsError(f"{msg}: {s}")
elif status_code == 401:
raise errors.AuthError("%s: %s" % (msg, s))
raise errors.AuthError(f"{msg}: {s}")
elif status_code == 403:
raise errors.ForbiddenError("%s: %s" % (msg, s))
raise errors.ForbiddenError(f"{msg}: {s}")
else:
raise errors.APIError("%d: %s: %s" % (status_code, msg, s))
raise errors.APIError(f"{status_code}: {msg}: {s}")

def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]:
js = None
try:
js = json.loads(body.decode("utf-8"))
except ValueError as error:
raise APIError("Unexpected API response: %s: %s" % (error, repr(body)))
raise APIError(f"Unexpected API response: {error}: {repr(body)}") from error
js = dict(js)
if 0 < [k in js for k in required].count(False):
missing = [k for k in required if k not in js]
raise APIError(
"Unexpected API response: %s: %s" % (repr(missing), repr(body))
)
raise APIError(f"Unexpected API response: {repr(missing)}: {repr(body)}")
return js

def close(self) -> None:
Expand All @@ -619,11 +603,11 @@ def _read_file(self, file_like, fmt, **kwargs):
compressed = fmt.endswith(".gz")
if compressed:
fmt = fmt[0 : len(fmt) - len(".gz")]
reader_name = "_read_%s_file" % (fmt,)
reader_name = f"_read_{fmt}_file"
if hasattr(self, reader_name):
reader = getattr(self, reader_name)
else:
raise TypeError("unknown format: %s" % (fmt,))
raise TypeError(f"unknown format: {fmt}")
if hasattr(file_like, "read"):
if compressed:
file_like = gzip.GzipFile(fileobj=file_like)
Expand Down
7 changes: 3 additions & 4 deletions tdclient/bulk_import_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ def validate_part_name(part_name: str) -> None:

if 1 < d["."]:
raise ValueError(
"part names cannot contain multiple periods: %s" % (repr(part_name))
f"part names cannot contain multiple periods: {repr(part_name)}"
)

if 0 < part_name.find("/"):
raise ValueError("part name must not contain '/': %s" % (repr(part_name)))
raise ValueError(f"part name must not contain '/': {repr(part_name)}")

def bulk_import_upload_part(
self, name: str, part_name: str, stream: BytesOrStream, size: int
Expand Down Expand Up @@ -372,5 +372,4 @@ def bulk_import_error_records(
decompressor = gzip.GzipFile(fileobj=body)

unpacker = msgpack.Unpacker(decompressor, raw=False)
for row in unpacker:
yield row
yield from unpacker
9 changes: 3 additions & 6 deletions tdclient/bulk_import_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BulkImport(Model):
STATUS_COMMITTED = "committed"

def __init__(self, client: Client, **kwargs: Any) -> None:
super(BulkImport, self).__init__(client)
super().__init__(client)
self._feed(kwargs)

def _feed(self, data: dict[str, Any] | None = None) -> None:
Expand Down Expand Up @@ -128,9 +128,7 @@ def perform(
"""
self.update()
if not self.upload_frozen:
raise (
RuntimeError('bulk import session "%s" is not frozen' % (self.name,))
)
raise (RuntimeError(f'bulk import session "{self.name}" is not frozen'))
job = self._client.perform_bulk_import(self.name)
if wait:
job.wait(
Expand Down Expand Up @@ -164,8 +162,7 @@ def error_record_items(self) -> Iterator[dict[str, Any]]:
Yields:
Error record
"""
for record in self._client.bulk_import_error_records(self.name):
yield record
yield from self._client.bulk_import_error_records(self.name)

def upload_part(self, part_name: str, bytes_or_stream: FileLike, size: int) -> bool:
"""Upload a part to bulk import session
Expand Down
20 changes: 8 additions & 12 deletions tdclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import datetime
import json
from collections.abc import Iterator
from typing import Any, cast, Literal

from typing import Any, Literal, cast

from tdclient import api, models
from tdclient.types import (
Expand Down Expand Up @@ -103,7 +102,7 @@ def database(self, db_name: str) -> models.Database:
for name, kwargs in databases.items():
if name == db_name:
return models.Database(self, name, **kwargs)
raise api.NotFoundError("Database '%s' does not exist" % (db_name))
raise api.NotFoundError(f"Database '{db_name}' does not exist")

def create_log_table(self, db_name: str, table_name: str) -> bool:
"""
Expand Down Expand Up @@ -212,7 +211,7 @@ def table(self, db_name: str, table_name: str) -> models.Table:
for table in tables:
if table.table_name == table_name:
return table
raise api.NotFoundError("Table '%s.%s' does not exist" % (db_name, table_name))
raise api.NotFoundError(f"Table '{db_name}.{table_name}' does not exist")

def tail(
self,
Expand Down Expand Up @@ -281,7 +280,7 @@ def query(
"""
# for compatibility, assume type is hive unless specifically specified
if type not in ["hive", "pig", "impala", "presto", "trino"]:
raise ValueError("The specified query type is not supported: %s" % (type))
raise ValueError(f"The specified query type is not supported: {type}")
# Cast type to expected literal since we've validated it
query_type = cast(Literal["hive", "presto", "trino", "bulkload"], type)
job_id = self.api.query(
Expand Down Expand Up @@ -359,8 +358,7 @@ def job_result_each(self, job_id: str | int) -> Iterator[dict[str, Any]]:
Returns:
an iterator of result set
"""
for row in self.api.job_result_each(str(job_id)):
yield row
yield from self.api.job_result_each(str(job_id))

def job_result_format(
self, job_id: str | int, format: ResultFormat, header: bool = False
Expand Down Expand Up @@ -397,14 +395,13 @@ def job_result_format_each(
Returns:
an iterator of rows in result set
"""
for row in self.api.job_result_format_each(
yield from self.api.job_result_format_each(
str(job_id),
format,
header=header,
store_tmpfile=store_tmpfile,
num_threads=num_threads,
):
yield row
)

def download_job_result(
self, job_id: str | int, path: str, num_threads: int = 4
Expand Down Expand Up @@ -561,8 +558,7 @@ def bulk_import_error_records(self, name: str) -> Iterator[dict[str, Any]]:
Returns:
an iterator of error records
"""
for record in self.api.bulk_import_error_records(name):
yield record
yield from self.api.bulk_import_error_records(name)

def bulk_import(self, name: str) -> models.BulkImport:
"""Get a bulk import session
Expand Down
Loading