diff --git a/sqlalchemy_utils/primitives/country.py b/sqlalchemy_utils/primitives/country.py index bb67dead..47ac63dd 100644 --- a/sqlalchemy_utils/primitives/country.py +++ b/sqlalchemy_utils/primitives/country.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from functools import total_ordering +from typing import Any, Union from .. import i18n from ..utils import str_coercible @@ -52,7 +55,7 @@ class Country: assert hash(Country('FI')) == hash('FI') """ - def __init__(self, code_or_country): + def __init__(self, code_or_country: Union[Country, str]) -> None: if isinstance(code_or_country, Country): self.code = code_or_country.code elif isinstance(code_or_country, str): @@ -67,11 +70,11 @@ def __init__(self, code_or_country): ) @property - def name(self): + def name(self) -> str: return i18n.get_locale().territories[self.code] @classmethod - def validate(self, code): + def validate(self, code: str) -> None: try: i18n.babel.Locale('en').territories[code] except KeyError: @@ -82,7 +85,7 @@ def validate(self, code): # As babel is optional, we may raise an AttributeError accessing it pass - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Country): return self.code == other.code elif isinstance(other, str): @@ -90,13 +93,13 @@ def __eq__(self, other): else: return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self.code) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - def __lt__(self, other): + def __lt__(self, other: Union[Country, str]) -> bool: if isinstance(other, Country): return self.code < other.code elif isinstance(other, str): @@ -106,5 +109,5 @@ def __lt__(self, other): def __repr__(self): return f'{self.__class__.__name__}({self.code!r})' - def __unicode__(self): + def __unicode__(self) -> str: return self.name diff --git a/sqlalchemy_utils/primitives/currency.py b/sqlalchemy_utils/primitives/currency.py index ef8d46ad..76fe9380 100644 --- a/sqlalchemy_utils/primitives/currency.py +++ b/sqlalchemy_utils/primitives/currency.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any, Union + from .. import i18n, ImproperlyConfigured from ..utils import str_coercible @@ -50,7 +54,7 @@ class Currency: """ - def __init__(self, code): + def __init__(self, code: Union[Currency, str]) -> None: if i18n.babel is None: raise ImproperlyConfigured( "'babel' package is required in order to use Currency class." @@ -68,7 +72,7 @@ def __init__(self, code): ) @classmethod - def validate(self, code): + def validate(self, code: str) -> None: try: i18n.babel.Locale('en').currencies[code] except KeyError: @@ -78,17 +82,17 @@ def validate(self, code): pass @property - def symbol(self): + def symbol(self) -> str: return i18n.babel.numbers.get_currency_symbol( self.code, i18n.get_locale() ) @property - def name(self): + def name(self) -> str: return i18n.get_locale().currencies[self.code] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Currency): return self.code == other.code elif isinstance(other, str): @@ -96,14 +100,14 @@ def __eq__(self, other): else: return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - def __hash__(self): + def __hash__(self) -> int: return hash(self.code) def __repr__(self): return f'{self.__class__.__name__}({self.code!r})' - def __unicode__(self): + def __unicode__(self) -> str: return self.code diff --git a/sqlalchemy_utils/primitives/ltree.py b/sqlalchemy_utils/primitives/ltree.py index db7e42fe..51204c56 100644 --- a/sqlalchemy_utils/primitives/ltree.py +++ b/sqlalchemy_utils/primitives/ltree.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import re +from typing import Any, Iterable, Optional, Union from ..utils import str_coercible -path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$') +path_matcher: re.Pattern = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$') @str_coercible @@ -92,7 +95,7 @@ class Ltree: assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2') """ - def __init__(self, path_or_ltree): + def __init__(self, path_or_ltree: Union[Ltree, str]) -> None: if isinstance(path_or_ltree, Ltree): self.path = path_or_ltree.path elif isinstance(path_or_ltree, str): @@ -107,16 +110,16 @@ def __init__(self, path_or_ltree): ) @classmethod - def validate(cls, path): + def validate(cls, path: str) -> None: if path_matcher.match(path) is None: raise ValueError( f"'{path}' is not a valid ltree path." ) - def __len__(self): + def __len__(self) -> int: return len(self.path.split('.')) - def index(self, other): + def index(self, other: Union[Ltree, str]) -> int: subpath = Ltree(other).path.split('.') parts = self.path.split('.') for index, _ in enumerate(parts): @@ -124,7 +127,7 @@ def index(self, other): return index raise ValueError('subpath not found') - def descendant_of(self, other): + def descendant_of(self, other: Union[Ltree, str]) -> bool: """ is left argument a descendant of right (or equal)? @@ -135,7 +138,7 @@ def descendant_of(self, other): subpath = self[:len(Ltree(other))] return subpath == other - def ancestor_of(self, other): + def ancestor_of(self, other: Union[Ltree, str]) -> bool: """ is left argument an ancestor of right (or equal)? @@ -146,7 +149,7 @@ def ancestor_of(self, other): subpath = Ltree(other)[:len(self)] return subpath == self - def __getitem__(self, key): + def __getitem__(self, key: Union[int, slice]) -> Ltree: if isinstance(key, int): return Ltree(self.path.split('.')[key]) elif isinstance(key, slice): @@ -157,7 +160,7 @@ def __getitem__(self, key): ) ) - def lca(self, *others): + def lca(self, *others: Union[Ltree, str]) -> Optional[Ltree]: """ Lowest common ancestor, i.e., longest common prefix of paths @@ -178,13 +181,13 @@ def lca(self, *others): return None return Ltree('.'.join(parts[0:index])) - def __add__(self, other): + def __add__(self, other: Union[Ltree, str]) -> Ltree: return Ltree(self.path + '.' + Ltree(other).path) - def __radd__(self, other): + def __radd__(self, other: Union[Ltree, str]) -> Ltree: return Ltree(other) + self - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Ltree): return self.path == other.path elif isinstance(other, str): @@ -192,19 +195,19 @@ def __eq__(self, other): else: return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self.path) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) def __repr__(self): return f'{self.__class__.__name__}({self.path!r})' - def __unicode__(self): + def __unicode__(self) -> str: return self.path - def __contains__(self, label): + def __contains__(self, label: Iterable) -> bool: return label in self.path.split('.') def __gt__(self, other): diff --git a/sqlalchemy_utils/primitives/weekday.py b/sqlalchemy_utils/primitives/weekday.py index 21038940..9ebed010 100644 --- a/sqlalchemy_utils/primitives/weekday.py +++ b/sqlalchemy_utils/primitives/weekday.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from functools import total_ordering +from typing import Any from .. import i18n from ..utils import str_coercible @@ -7,34 +10,34 @@ @str_coercible @total_ordering class WeekDay: - NUM_WEEK_DAYS = 7 + NUM_WEEK_DAYS: int = 7 - def __init__(self, index): + def __init__(self, index: int) -> None: if not (0 <= index < self.NUM_WEEK_DAYS): raise ValueError( "index must be between 0 and %d" % self.NUM_WEEK_DAYS ) self.index = index - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, WeekDay): return self.index == other.index else: return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self.index) - def __lt__(self, other): + def __lt__(self, other: WeekDay) -> bool: return self.position < other.position def __repr__(self): return f'{self.__class__.__name__}({self.index!r})' - def __unicode__(self): + def __unicode__(self) -> str: return self.name - def get_name(self, width='wide', context='format'): + def get_name(self, width: str = 'wide', context: str = 'format') -> str: names = i18n.babel.dates.get_day_names( width, context, @@ -43,11 +46,11 @@ def get_name(self, width='wide', context='format'): return names[self.index] @property - def name(self): + def name(self) -> str: return self.get_name() @property - def position(self): + def position(self) -> int: return ( self.index - i18n.get_locale().first_week_day diff --git a/sqlalchemy_utils/primitives/weekdays.py b/sqlalchemy_utils/primitives/weekdays.py index a21a4df7..88aca36d 100644 --- a/sqlalchemy_utils/primitives/weekdays.py +++ b/sqlalchemy_utils/primitives/weekdays.py @@ -1,10 +1,16 @@ +from __future__ import annotations + +from typing import Any, Generator, Hashable, Union + from ..utils import str_coercible from .weekday import WeekDay @str_coercible class WeekDays: - def __init__(self, bit_string_or_week_days): + _days: set + + def __init__(self, bit_string_or_week_days: Union[str, WeekDays, Hashable]) -> None: if isinstance(bit_string_or_week_days, str): self._days = set() @@ -27,7 +33,7 @@ def __init__(self, bit_string_or_week_days): else: self._days = set(bit_string_or_week_days) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, WeekDays): return self._days == other._days elif isinstance(other, str): @@ -35,22 +41,23 @@ def __eq__(self, other): else: return NotImplemented - def __iter__(self): - yield from sorted(self._days) + def __iter__(self) -> Generator[WeekDay, None, None]: + for day in sorted(self._days): + yield day - def __contains__(self, value): + def __contains__(self, value: Any) -> bool: return value in self._days - def __repr__(self): - return '{}({!r})'.format( + def __repr__(self) -> str: + return '%s(%r)' % ( self.__class__.__name__, self.as_bit_string() ) - def __unicode__(self): + def __unicode__(self) -> str: return ', '.join(str(day) for day in self) - def as_bit_string(self): + def as_bit_string(self) -> str: return ''.join( '1' if WeekDay(index) in self._days else '0' for index in range(WeekDay.NUM_WEEK_DAYS)