Skip to content

Commit 62add3d

Browse files
authored
Merge pull request #139 from treasure-data/more-ruff
Introduce more ruff rules
2 parents f243f7f + 05f56ad commit 62add3d

21 files changed

+97
-105
lines changed

.github/workflows/pythontest.yml

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,26 @@ name: Python testing
33
on: [push, pull_request]
44

55
jobs:
6+
lint:
7+
runs-on: ubuntu-latest
8+
steps:
9+
- uses: actions/checkout@v4
10+
- name: Set up Python
11+
uses: actions/setup-python@v6
12+
with:
13+
python-version: "3.12"
14+
- name: Install dependencies
15+
run: |
16+
python -m pip install --upgrade pip
17+
pip install ".[dev]"
18+
- name: lint with ruff
19+
run: |
20+
ruff format tdclient --diff --exit-non-zero-on-fix
21+
ruff check tdclient
22+
- name: Run pyright
23+
run: |
24+
pyright tdclient
25+
626
test:
727
runs-on: ${{ matrix.os }}
828
strategy:
@@ -23,9 +43,6 @@ jobs:
2343
pip install ".[dev]"
2444
pip install -r requirements.txt -r test-requirements.txt
2545
pip install -U coveralls pyyaml
26-
- name: Run pyright
27-
run: |
28-
pyright tdclient
2946
- name: Run test
3047
run: |
3148
coverage run --source=tdclient -m pytest tdclient/test

docs/conf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def linkcode_resolve(domain, info):
7171
except Exception:
7272
linenum = ""
7373

74-
return "https://github.com/{}/{}/blob/{}/{}/{}#L{}".format(
75-
GH_ORGANIZATION, GH_PROJECT, revision, MODULE, relpath, linenum
76-
)
74+
return f"https://github.com/{GH_ORGANIZATION}/{GH_PROJECT}/blob/{revision}/{MODULE}/{relpath}#L{linenum}"
7775

7876

7977
# -- Project information -----------------------------------------------------

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ line-length = 88
5050
[tool.ruff.lint]
5151
select = [
5252
"E",
53+
"W",
5354
"F",
55+
"I",
56+
"UP",
57+
"B",
5458
]
5559
exclude = ["tdclient/test/*"]
5660
ignore = ["E203", "E501"]

tdclient/api.py

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import json
1212
import logging
1313
import os
14-
import socket
1514
import ssl
1615
import tempfile
1716
import time
@@ -108,11 +107,11 @@ def __init__(
108107
if user_agent is not None:
109108
self._user_agent = user_agent
110109
else:
111-
self._user_agent = "TD-Client-Python/%s" % (version.__version__)
110+
self._user_agent = f"TD-Client-Python/{version.__version__}"
112111

113112
if endpoint is not None:
114113
if not urlparse.urlparse(endpoint).scheme:
115-
endpoint = "https://{}".format(endpoint)
114+
endpoint = f"https://{endpoint}"
116115
self._endpoint = endpoint
117116
elif os.getenv("TD_API_SERVER"):
118117
self._endpoint = os.getenv("TD_API_SERVER")
@@ -154,7 +153,7 @@ def _init_http(
154153
if http_proxy.startswith("http://"):
155154
return self._init_http_proxy(http_proxy, **kwargs)
156155
else:
157-
return self._init_http_proxy("http://%s" % (http_proxy,), **kwargs)
156+
return self._init_http_proxy(f"http://{http_proxy}", **kwargs)
158157

159158
def _init_http_proxy(self, http_proxy: str, **kwargs: Any) -> urllib3.ProxyManager:
160159
pool_options = dict(kwargs)
@@ -164,7 +163,7 @@ def _init_http_proxy(self, http_proxy: str, **kwargs: Any) -> urllib3.ProxyManag
164163
if "@" in netloc:
165164
auth, netloc = netloc.split("@", 2)
166165
pool_options["proxy_headers"] = urllib3.make_headers(proxy_basic_auth=auth)
167-
return urllib3.ProxyManager("%s://%s" % (scheme, netloc), **pool_options)
166+
return urllib3.ProxyManager(f"{scheme}://{netloc}", **pool_options)
168167

169168
def get(
170169
self,
@@ -214,12 +213,12 @@ def get(
214213
self._max_cumul_retry_delay,
215214
)
216215
except (
216+
OSError,
217217
urllib3.exceptions.TimeoutStateError,
218218
urllib3.exceptions.TimeoutError,
219219
urllib3.exceptions.PoolError,
220220
http.client.IncompleteRead,
221221
TimeoutError,
222-
socket.error,
223222
):
224223
pass
225224

@@ -235,12 +234,7 @@ def get(
235234
retry_delay *= 2
236235
else:
237236
raise APIError(
238-
"Retrying stopped after %d seconds. (cumulative: %d/%d)"
239-
% (
240-
self._max_cumul_retry_delay,
241-
cumul_retry_delay,
242-
self._max_cumul_retry_delay,
243-
)
237+
f"Retrying stopped after {self._max_cumul_retry_delay} seconds. (cumulative: {cumul_retry_delay}/{self._max_cumul_retry_delay})"
244238
)
245239

246240
log.debug(
@@ -314,13 +308,15 @@ def post(
314308
self._max_cumul_retry_delay,
315309
)
316310
except (
311+
OSError,
317312
urllib3.exceptions.TimeoutStateError,
318313
urllib3.exceptions.TimeoutError,
319314
urllib3.exceptions.PoolError,
320-
socket.error,
321315
):
322316
if not self._retry_post_requests:
323-
raise APIError("Retrying stopped by retry_post_requests == False")
317+
raise APIError(
318+
"Retrying stopped by retry_post_requests == False"
319+
) from None
324320

325321
if cumul_retry_delay <= self._max_cumul_retry_delay:
326322
log.warning(
@@ -334,12 +330,7 @@ def post(
334330
retry_delay *= 2
335331
else:
336332
raise APIError(
337-
"Retrying stopped after %d seconds. (cumulative: %d/%d)"
338-
% (
339-
self._max_cumul_retry_delay,
340-
cumul_retry_delay,
341-
self._max_cumul_retry_delay,
342-
)
333+
f"Retrying stopped after {self._max_cumul_retry_delay} seconds. (cumulative: {cumul_retry_delay}/{self._max_cumul_retry_delay})"
343334
)
344335

345336
log.debug(
@@ -408,12 +399,12 @@ def put(
408399
else:
409400
raise APIError("Error %d: %s", response.status, response.data)
410401
except (
402+
OSError,
411403
urllib3.exceptions.TimeoutStateError,
412404
urllib3.exceptions.TimeoutError,
413405
urllib3.exceptions.PoolError,
414-
socket.error,
415406
):
416-
raise APIError("Error: %s" % (repr(response)))
407+
raise APIError(f"Error: {repr(response)}") from None
417408

418409
log.debug(
419410
"REST PUT response:\n headers: %s\n status: %d\n body: <omitted>",
@@ -470,10 +461,10 @@ def delete(
470461
self._max_cumul_retry_delay,
471462
)
472463
except (
464+
OSError,
473465
urllib3.exceptions.TimeoutStateError,
474466
urllib3.exceptions.TimeoutError,
475467
urllib3.exceptions.PoolError,
476-
socket.error,
477468
):
478469
pass
479470

@@ -489,12 +480,7 @@ def delete(
489480
retry_delay *= 2
490481
else:
491482
raise APIError(
492-
"Retrying stopped after %d seconds. (cumulative: %d/%d)"
493-
% (
494-
self._max_cumul_retry_delay,
495-
cumul_retry_delay,
496-
self._max_cumul_retry_delay,
497-
)
483+
f"Retrying stopped after {self._max_cumul_retry_delay} seconds. (cumulative: {cumul_retry_delay}/{self._max_cumul_retry_delay})"
498484
)
499485

500486
log.debug(
@@ -536,7 +522,7 @@ def build_request(
536522
# use default headers first
537523
_headers = dict(self._headers)
538524
# add default headers
539-
_headers["authorization"] = "TD1 %s" % (self._apikey,)
525+
_headers["authorization"] = f"TD1 {self._apikey}"
540526
_headers["date"] = email.utils.formatdate(time.time())
541527
_headers["user-agent"] = self._user_agent
542528
# override given headers
@@ -571,28 +557,26 @@ def raise_error(
571557
status_code = res.status
572558
s = body if isinstance(body, str) else body.decode("utf-8")
573559
if status_code == 404:
574-
raise errors.NotFoundError("%s: %s" % (msg, s))
560+
raise errors.NotFoundError(f"{msg}: {s}")
575561
elif status_code == 409:
576-
raise errors.AlreadyExistsError("%s: %s" % (msg, s))
562+
raise errors.AlreadyExistsError(f"{msg}: {s}")
577563
elif status_code == 401:
578-
raise errors.AuthError("%s: %s" % (msg, s))
564+
raise errors.AuthError(f"{msg}: {s}")
579565
elif status_code == 403:
580-
raise errors.ForbiddenError("%s: %s" % (msg, s))
566+
raise errors.ForbiddenError(f"{msg}: {s}")
581567
else:
582-
raise errors.APIError("%d: %s: %s" % (status_code, msg, s))
568+
raise errors.APIError(f"{status_code}: {msg}: {s}")
583569

584570
def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]:
585571
js = None
586572
try:
587573
js = json.loads(body.decode("utf-8"))
588574
except ValueError as error:
589-
raise APIError("Unexpected API response: %s: %s" % (error, repr(body)))
575+
raise APIError(f"Unexpected API response: {error}: {repr(body)}") from error
590576
js = dict(js)
591577
if 0 < [k in js for k in required].count(False):
592578
missing = [k for k in required if k not in js]
593-
raise APIError(
594-
"Unexpected API response: %s: %s" % (repr(missing), repr(body))
595-
)
579+
raise APIError(f"Unexpected API response: {repr(missing)}: {repr(body)}")
596580
return js
597581

598582
def close(self) -> None:
@@ -619,11 +603,11 @@ def _read_file(self, file_like, fmt, **kwargs):
619603
compressed = fmt.endswith(".gz")
620604
if compressed:
621605
fmt = fmt[0 : len(fmt) - len(".gz")]
622-
reader_name = "_read_%s_file" % (fmt,)
606+
reader_name = f"_read_{fmt}_file"
623607
if hasattr(self, reader_name):
624608
reader = getattr(self, reader_name)
625609
else:
626-
raise TypeError("unknown format: %s" % (fmt,))
610+
raise TypeError(f"unknown format: {fmt}")
627611
if hasattr(file_like, "read"):
628612
if compressed:
629613
file_like = gzip.GzipFile(fileobj=file_like)

tdclient/bulk_import_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,11 @@ def validate_part_name(part_name: str) -> None:
177177

178178
if 1 < d["."]:
179179
raise ValueError(
180-
"part names cannot contain multiple periods: %s" % (repr(part_name))
180+
f"part names cannot contain multiple periods: {repr(part_name)}"
181181
)
182182

183183
if 0 < part_name.find("/"):
184-
raise ValueError("part name must not contain '/': %s" % (repr(part_name)))
184+
raise ValueError(f"part name must not contain '/': {repr(part_name)}")
185185

186186
def bulk_import_upload_part(
187187
self, name: str, part_name: str, stream: BytesOrStream, size: int
@@ -385,5 +385,4 @@ def bulk_import_error_records(
385385
decompressor = gzip.GzipFile(fileobj=body)
386386

387387
unpacker = msgpack.Unpacker(decompressor, raw=False)
388-
for row in unpacker:
389-
yield row
388+
yield from unpacker

tdclient/bulk_import_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class BulkImport(Model):
2424
STATUS_COMMITTED = "committed"
2525

2626
def __init__(self, client: Client, **kwargs: Any) -> None:
27-
super(BulkImport, self).__init__(client)
27+
super().__init__(client)
2828
self._feed(kwargs)
2929

3030
def _feed(self, data: dict[str, Any] | None = None) -> None:
@@ -128,9 +128,7 @@ def perform(
128128
"""
129129
self.update()
130130
if not self.upload_frozen:
131-
raise (
132-
RuntimeError('bulk import session "%s" is not frozen' % (self.name,))
133-
)
131+
raise (RuntimeError(f'bulk import session "{self.name}" is not frozen'))
134132
job = self._client.perform_bulk_import(self.name)
135133
if wait:
136134
job.wait(
@@ -164,8 +162,7 @@ def error_record_items(self) -> Iterator[dict[str, Any]]:
164162
Yields:
165163
Error record
166164
"""
167-
for record in self._client.bulk_import_error_records(self.name):
168-
yield record
165+
yield from self._client.bulk_import_error_records(self.name)
169166

170167
def upload_part(self, part_name: str, bytes_or_stream: FileLike, size: int) -> bool:
171168
"""Upload a part to bulk import session

tdclient/client.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import datetime
66
import json
77
from collections.abc import Iterator
8-
from typing import Any, cast, Literal
9-
8+
from typing import Any, Literal, cast
109

1110
from tdclient import api, models
1211
from tdclient.types import (
@@ -103,7 +102,7 @@ def database(self, db_name: str) -> models.Database:
103102
for name, kwargs in databases.items():
104103
if name == db_name:
105104
return models.Database(self, name, **kwargs)
106-
raise api.NotFoundError("Database '%s' does not exist" % (db_name))
105+
raise api.NotFoundError(f"Database '{db_name}' does not exist")
107106

108107
def create_log_table(self, db_name: str, table_name: str) -> bool:
109108
"""
@@ -212,7 +211,7 @@ def table(self, db_name: str, table_name: str) -> models.Table:
212211
for table in tables:
213212
if table.table_name == table_name:
214213
return table
215-
raise api.NotFoundError("Table '%s.%s' does not exist" % (db_name, table_name))
214+
raise api.NotFoundError(f"Table '{db_name}.{table_name}' does not exist")
216215

217216
def tail(
218217
self,
@@ -281,7 +280,7 @@ def query(
281280
"""
282281
# for compatibility, assume type is hive unless specifically specified
283282
if type not in ["hive", "pig", "impala", "presto", "trino"]:
284-
raise ValueError("The specified query type is not supported: %s" % (type))
283+
raise ValueError(f"The specified query type is not supported: {type}")
285284
# Cast type to expected literal since we've validated it
286285
query_type = cast(Literal["hive", "presto", "trino", "bulkload"], type)
287286
job_id = self.api.query(
@@ -359,8 +358,7 @@ def job_result_each(self, job_id: str | int) -> Iterator[dict[str, Any]]:
359358
Returns:
360359
an iterator of result set
361360
"""
362-
for row in self.api.job_result_each(str(job_id)):
363-
yield row
361+
yield from self.api.job_result_each(str(job_id))
364362

365363
def job_result_format(
366364
self, job_id: str | int, format: ResultFormat, header: bool = False
@@ -397,14 +395,13 @@ def job_result_format_each(
397395
Returns:
398396
an iterator of rows in result set
399397
"""
400-
for row in self.api.job_result_format_each(
398+
yield from self.api.job_result_format_each(
401399
str(job_id),
402400
format,
403401
header=header,
404402
store_tmpfile=store_tmpfile,
405403
num_threads=num_threads,
406-
):
407-
yield row
404+
)
408405

409406
def download_job_result(
410407
self, job_id: str | int, path: str, num_threads: int = 4
@@ -561,8 +558,7 @@ def bulk_import_error_records(self, name: str) -> Iterator[dict[str, Any]]:
561558
Returns:
562559
an iterator of error records
563560
"""
564-
for record in self.api.bulk_import_error_records(name):
565-
yield record
561+
yield from self.api.bulk_import_error_records(name)
566562

567563
def bulk_import(self, name: str) -> models.BulkImport:
568564
"""Get a bulk import session

0 commit comments

Comments
 (0)