From 3f793cd85658f908a63dad4272d24d4081bbe780 Mon Sep 17 00:00:00 2001 From: Vladimir Date: Fri, 5 Aug 2022 20:07:06 +0300 Subject: [PATCH 1/2] add_annotations_to_primitives --- sqlalchemy_utils/primitives/country.py | 21 ++++++++------ sqlalchemy_utils/primitives/currency.py | 24 +++++++++------- sqlalchemy_utils/primitives/ltree.py | 37 +++++++++++++------------ sqlalchemy_utils/primitives/weekday.py | 23 ++++++++------- sqlalchemy_utils/primitives/weekdays.py | 22 +++++++++------ 5 files changed, 73 insertions(+), 54 deletions(-) diff --git a/sqlalchemy_utils/primitives/country.py b/sqlalchemy_utils/primitives/country.py index 8481fe5c..a3311b30 100644 --- a/sqlalchemy_utils/primitives/country.py +++ b/sqlalchemy_utils/primitives/country.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from functools import total_ordering +from typing import Any, Union from .. import i18n from ..utils import str_coercible @@ -52,7 +55,7 @@ class Country: assert hash(Country('FI')) == hash('FI') """ - def __init__(self, code_or_country): + def __init__(self, code_or_country: Union[Country, str]) -> None: if isinstance(code_or_country, Country): self.code = code_or_country.code elif isinstance(code_or_country, str): @@ -67,11 +70,11 @@ def __init__(self, code_or_country): ) @property - def name(self): + def name(self) -> str: return i18n.get_locale().territories[self.code] @classmethod - def validate(self, code): + def validate(self, code: str) -> None: try: i18n.babel.Locale('en').territories[code] except KeyError: @@ -82,7 +85,7 @@ def validate(self, code): # As babel is optional, we may raise an AttributeError accessing it pass - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Country): return self.code == other.code elif isinstance(other, str): @@ -90,21 +93,21 @@ def __eq__(self, other): else: return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self.code) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - def __lt__(self, other): + def __lt__(self, other: Union[Country, str]) -> bool: if isinstance(other, Country): return self.code < other.code elif isinstance(other, str): return self.code < other return NotImplemented - def __repr__(self): + def __repr__(self) -> str: return '%s(%r)' % (self.__class__.__name__, self.code) - def __unicode__(self): + def __unicode__(self) -> str: return self.name diff --git a/sqlalchemy_utils/primitives/currency.py b/sqlalchemy_utils/primitives/currency.py index c7c71a6c..66fbcbec 100644 --- a/sqlalchemy_utils/primitives/currency.py +++ b/sqlalchemy_utils/primitives/currency.py @@ -1,4 +1,8 @@ -from .. import i18n, ImproperlyConfigured +from __future__ import annotations + +from typing import Any, Union + +from .. import ImproperlyConfigured, i18n from ..utils import str_coercible @@ -50,7 +54,7 @@ class Currency: """ - def __init__(self, code): + def __init__(self, code: Union[Currency, str]) -> None: if i18n.babel is None: raise ImproperlyConfigured( "'babel' package is required in order to use Currency class." @@ -68,7 +72,7 @@ def __init__(self, code): ) @classmethod - def validate(self, code): + def validate(self, code: str) -> None: try: i18n.babel.Locale('en').currencies[code] except KeyError: @@ -78,17 +82,17 @@ def validate(self, code): pass @property - def symbol(self): + def symbol(self) -> str: return i18n.babel.numbers.get_currency_symbol( self.code, i18n.get_locale() ) @property - def name(self): + def name(self) -> str: return i18n.get_locale().currencies[self.code] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Currency): return self.code == other.code elif isinstance(other, str): @@ -96,14 +100,14 @@ def __eq__(self, other): else: return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - def __hash__(self): + def __hash__(self) -> int: return hash(self.code) - def __repr__(self): + def __repr__(self) -> str: return '%s(%r)' % (self.__class__.__name__, self.code) - def __unicode__(self): + def __unicode__(self) -> str: return self.code diff --git a/sqlalchemy_utils/primitives/ltree.py b/sqlalchemy_utils/primitives/ltree.py index 1ac11717..c438313f 100644 --- a/sqlalchemy_utils/primitives/ltree.py +++ b/sqlalchemy_utils/primitives/ltree.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import re +from typing import Any, Iterable, Optional, Union from ..utils import str_coercible -path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$') +path_matcher: re.Pattern = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$') @str_coercible @@ -92,7 +95,7 @@ class Ltree: assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2') """ - def __init__(self, path_or_ltree): + def __init__(self, path_or_ltree: Union[Ltree, str]) -> None: if isinstance(path_or_ltree, Ltree): self.path = path_or_ltree.path elif isinstance(path_or_ltree, str): @@ -107,16 +110,16 @@ def __init__(self, path_or_ltree): ) @classmethod - def validate(cls, path): + def validate(cls, path: str) -> None: if path_matcher.match(path) is None: raise ValueError( "'{0}' is not a valid ltree path.".format(path) ) - def __len__(self): + def __len__(self) -> int: return len(self.path.split('.')) - def index(self, other): + def index(self, other: Union[Ltree, str]) -> int: subpath = Ltree(other).path.split('.') parts = self.path.split('.') for index, _ in enumerate(parts): @@ -124,7 +127,7 @@ def index(self, other): return index raise ValueError('subpath not found') - def descendant_of(self, other): + def descendant_of(self, other: Union[Ltree, str]) -> bool: """ is left argument a descendant of right (or equal)? @@ -135,7 +138,7 @@ def descendant_of(self, other): subpath = self[:len(Ltree(other))] return subpath == other - def ancestor_of(self, other): + def ancestor_of(self, other: Union[Ltree, str]) -> bool: """ is left argument an ancestor of right (or equal)? @@ -146,7 +149,7 @@ def ancestor_of(self, other): subpath = Ltree(other)[:len(self)] return subpath == self - def __getitem__(self, key): + def __getitem__(self, key: Union[int, slice]) -> Ltree: if isinstance(key, int): return Ltree(self.path.split('.')[key]) elif isinstance(key, slice): @@ -157,7 +160,7 @@ def __getitem__(self, key): ) ) - def lca(self, *others): + def lca(self, *others: Union[Ltree, str]) -> Optional[Ltree]: """ Lowest common ancestor, i.e., longest common prefix of paths @@ -178,13 +181,13 @@ def lca(self, *others): return None return Ltree('.'.join(parts[0:index])) - def __add__(self, other): + def __add__(self, other: Union[Ltree, str]) -> Ltree: return Ltree(self.path + '.' + Ltree(other).path) - def __radd__(self, other): + def __radd__(self, other: Union[Ltree, str]) -> Ltree: return Ltree(other) + self - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Ltree): return self.path == other.path elif isinstance(other, str): @@ -192,17 +195,17 @@ def __eq__(self, other): else: return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self.path) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - def __repr__(self): + def __repr__(self) -> str: return '%s(%r)' % (self.__class__.__name__, self.path) - def __unicode__(self): + def __unicode__(self) -> str: return self.path - def __contains__(self, label): + def __contains__(self, label: Iterable) -> bool: return label in self.path.split('.') diff --git a/sqlalchemy_utils/primitives/weekday.py b/sqlalchemy_utils/primitives/weekday.py index 5b01efa7..38cef4c4 100644 --- a/sqlalchemy_utils/primitives/weekday.py +++ b/sqlalchemy_utils/primitives/weekday.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from functools import total_ordering +from typing import Any from .. import i18n from ..utils import str_coercible @@ -7,34 +10,34 @@ @str_coercible @total_ordering class WeekDay: - NUM_WEEK_DAYS = 7 + NUM_WEEK_DAYS: int = 7 - def __init__(self, index): + def __init__(self, index: int) -> None: if not (0 <= index < self.NUM_WEEK_DAYS): raise ValueError( "index must be between 0 and %d" % self.NUM_WEEK_DAYS ) self.index = index - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, WeekDay): return self.index == other.index else: return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self.index) - def __lt__(self, other): + def __lt__(self, other: WeekDay) -> bool: return self.position < other.position - def __repr__(self): + def __repr__(self) -> str: return '%s(%r)' % (self.__class__.__name__, self.index) - def __unicode__(self): + def __unicode__(self) -> str: return self.name - def get_name(self, width='wide', context='format'): + def get_name(self, width: str = 'wide', context: str = 'format') -> str: names = i18n.babel.dates.get_day_names( width, context, @@ -43,11 +46,11 @@ def get_name(self, width='wide', context='format'): return names[self.index] @property - def name(self): + def name(self) -> str: return self.get_name() @property - def position(self): + def position(self) -> int: return ( self.index - i18n.get_locale().first_week_day diff --git a/sqlalchemy_utils/primitives/weekdays.py b/sqlalchemy_utils/primitives/weekdays.py index 3f68683f..855e8861 100644 --- a/sqlalchemy_utils/primitives/weekdays.py +++ b/sqlalchemy_utils/primitives/weekdays.py @@ -1,10 +1,16 @@ -from ..utils import str_coercible +from __future__ import annotations + +from typing import Any, Generator, Hashable, Union + from .weekday import WeekDay +from ..utils import str_coercible @str_coercible class WeekDays: - def __init__(self, bit_string_or_week_days): + _days: set + + def __init__(self, bit_string_or_week_days: Union[str, WeekDays, Hashable]) -> None: if isinstance(bit_string_or_week_days, str): self._days = set() @@ -27,7 +33,7 @@ def __init__(self, bit_string_or_week_days): else: self._days = set(bit_string_or_week_days) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, WeekDays): return self._days == other._days elif isinstance(other, str): @@ -35,23 +41,23 @@ def __eq__(self, other): else: return NotImplemented - def __iter__(self): + def __iter__(self) -> Generator[WeekDay, None, None]: for day in sorted(self._days): yield day - def __contains__(self, value): + def __contains__(self, value: Any) -> bool: return value in self._days - def __repr__(self): + def __repr__(self) -> str: return '%s(%r)' % ( self.__class__.__name__, self.as_bit_string() ) - def __unicode__(self): + def __unicode__(self) -> str: return ', '.join(str(day) for day in self) - def as_bit_string(self): + def as_bit_string(self) -> str: return ''.join( '1' if WeekDay(index) in self._days else '0' for index in range(WeekDay.NUM_WEEK_DAYS) From 73390da14712477d41445da5ac99eff5d2255c72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=92=D0=BB=D0=B0=D0=B4=D0=B8=D0=BC=D0=B8=D1=80=20=D0=92?= =?UTF-8?q?=D0=BE=D0=B9=D1=82=D0=B5=D0=BD=D0=BA=D0=BE?= Date: Fri, 3 Mar 2023 19:32:14 +0500 Subject: [PATCH 2/2] fix-formatting --- sqlalchemy_utils/primitives/currency.py | 2 +- sqlalchemy_utils/primitives/weekdays.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlalchemy_utils/primitives/currency.py b/sqlalchemy_utils/primitives/currency.py index 5b88b0b9..76fe9380 100644 --- a/sqlalchemy_utils/primitives/currency.py +++ b/sqlalchemy_utils/primitives/currency.py @@ -2,7 +2,7 @@ from typing import Any, Union -from .. import ImproperlyConfigured, i18n +from .. import i18n, ImproperlyConfigured from ..utils import str_coercible diff --git a/sqlalchemy_utils/primitives/weekdays.py b/sqlalchemy_utils/primitives/weekdays.py index 69bbd469..88aca36d 100644 --- a/sqlalchemy_utils/primitives/weekdays.py +++ b/sqlalchemy_utils/primitives/weekdays.py @@ -2,8 +2,8 @@ from typing import Any, Generator, Hashable, Union -from .weekday import WeekDay from ..utils import str_coercible +from .weekday import WeekDay @str_coercible