Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions sqlalchemy_utils/primitives/country.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

from functools import total_ordering
from typing import Any, Union

from .. import i18n
from ..utils import str_coercible
Expand Down Expand Up @@ -52,7 +55,7 @@ class Country:
assert hash(Country('FI')) == hash('FI')

"""
def __init__(self, code_or_country):
def __init__(self, code_or_country: Union[Country, str]) -> None:
if isinstance(code_or_country, Country):
self.code = code_or_country.code
elif isinstance(code_or_country, str):
Expand All @@ -67,11 +70,11 @@ def __init__(self, code_or_country):
)

@property
def name(self):
def name(self) -> str:
return i18n.get_locale().territories[self.code]

@classmethod
def validate(self, code):
def validate(self, code: str) -> None:
try:
i18n.babel.Locale('en').territories[code]
except KeyError:
Expand All @@ -82,29 +85,29 @@ def validate(self, code):
# As babel is optional, we may raise an AttributeError accessing it
pass

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, Country):
return self.code == other.code
elif isinstance(other, str):
return self.code == other
else:
return NotImplemented

def __hash__(self):
def __hash__(self) -> int:
return hash(self.code)

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not (self == other)

def __lt__(self, other):
def __lt__(self, other: Union[Country, str]) -> bool:
if isinstance(other, Country):
return self.code < other.code
elif isinstance(other, str):
return self.code < other
return NotImplemented

def __repr__(self):
def __repr__(self) -> str:
return '%s(%r)' % (self.__class__.__name__, self.code)

def __unicode__(self):
def __unicode__(self) -> str:
return self.name
24 changes: 14 additions & 10 deletions sqlalchemy_utils/primitives/currency.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .. import i18n, ImproperlyConfigured
from __future__ import annotations

from typing import Any, Union

from .. import ImproperlyConfigured, i18n
from ..utils import str_coercible


Expand Down Expand Up @@ -50,7 +54,7 @@ class Currency:


"""
def __init__(self, code):
def __init__(self, code: Union[Currency, str]) -> None:
if i18n.babel is None:
raise ImproperlyConfigured(
"'babel' package is required in order to use Currency class."
Expand All @@ -68,7 +72,7 @@ def __init__(self, code):
)

@classmethod
def validate(self, code):
def validate(self, code: str) -> None:
try:
i18n.babel.Locale('en').currencies[code]
except KeyError:
Expand All @@ -78,32 +82,32 @@ def validate(self, code):
pass

@property
def symbol(self):
def symbol(self) -> str:
return i18n.babel.numbers.get_currency_symbol(
self.code,
i18n.get_locale()
)

@property
def name(self):
def name(self) -> str:
return i18n.get_locale().currencies[self.code]

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, Currency):
return self.code == other.code
elif isinstance(other, str):
return self.code == other
else:
return NotImplemented

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not (self == other)

def __hash__(self):
def __hash__(self) -> int:
return hash(self.code)

def __repr__(self):
def __repr__(self) -> str:
return '%s(%r)' % (self.__class__.__name__, self.code)

def __unicode__(self):
def __unicode__(self) -> str:
return self.code
37 changes: 20 additions & 17 deletions sqlalchemy_utils/primitives/ltree.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import re
from typing import Any, Iterable, Optional, Union

from ..utils import str_coercible

path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$')
path_matcher: re.Pattern = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$')


@str_coercible
Expand Down Expand Up @@ -92,7 +95,7 @@ class Ltree:
assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2')
"""

def __init__(self, path_or_ltree):
def __init__(self, path_or_ltree: Union[Ltree, str]) -> None:
if isinstance(path_or_ltree, Ltree):
self.path = path_or_ltree.path
elif isinstance(path_or_ltree, str):
Expand All @@ -107,24 +110,24 @@ def __init__(self, path_or_ltree):
)

@classmethod
def validate(cls, path):
def validate(cls, path: str) -> None:
if path_matcher.match(path) is None:
raise ValueError(
"'{0}' is not a valid ltree path.".format(path)
)

def __len__(self):
def __len__(self) -> int:
return len(self.path.split('.'))

def index(self, other):
def index(self, other: Union[Ltree, str]) -> int:
subpath = Ltree(other).path.split('.')
parts = self.path.split('.')
for index, _ in enumerate(parts):
if parts[index:len(subpath) + index] == subpath:
return index
raise ValueError('subpath not found')

def descendant_of(self, other):
def descendant_of(self, other: Union[Ltree, str]) -> bool:
"""
is left argument a descendant of right (or equal)?

Expand All @@ -135,7 +138,7 @@ def descendant_of(self, other):
subpath = self[:len(Ltree(other))]
return subpath == other

def ancestor_of(self, other):
def ancestor_of(self, other: Union[Ltree, str]) -> bool:
"""
is left argument an ancestor of right (or equal)?

Expand All @@ -146,7 +149,7 @@ def ancestor_of(self, other):
subpath = Ltree(other)[:len(self)]
return subpath == self

def __getitem__(self, key):
def __getitem__(self, key: Union[int, slice]) -> Ltree:
if isinstance(key, int):
return Ltree(self.path.split('.')[key])
elif isinstance(key, slice):
Expand All @@ -157,7 +160,7 @@ def __getitem__(self, key):
)
)

def lca(self, *others):
def lca(self, *others: Union[Ltree, str]) -> Optional[Ltree]:
"""
Lowest common ancestor, i.e., longest common prefix of paths

Expand All @@ -178,31 +181,31 @@ def lca(self, *others):
return None
return Ltree('.'.join(parts[0:index]))

def __add__(self, other):
def __add__(self, other: Union[Ltree, str]) -> Ltree:
return Ltree(self.path + '.' + Ltree(other).path)

def __radd__(self, other):
def __radd__(self, other: Union[Ltree, str]) -> Ltree:
return Ltree(other) + self

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, Ltree):
return self.path == other.path
elif isinstance(other, str):
return self.path == other
else:
return NotImplemented

def __hash__(self):
def __hash__(self) -> int:
return hash(self.path)

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not (self == other)

def __repr__(self):
def __repr__(self) -> str:
return '%s(%r)' % (self.__class__.__name__, self.path)

def __unicode__(self):
def __unicode__(self) -> str:
return self.path

def __contains__(self, label):
def __contains__(self, label: Iterable) -> bool:
return label in self.path.split('.')
23 changes: 13 additions & 10 deletions sqlalchemy_utils/primitives/weekday.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

from functools import total_ordering
from typing import Any

from .. import i18n
from ..utils import str_coercible
Expand All @@ -7,34 +10,34 @@
@str_coercible
@total_ordering
class WeekDay:
NUM_WEEK_DAYS = 7
NUM_WEEK_DAYS: int = 7

def __init__(self, index):
def __init__(self, index: int) -> None:
if not (0 <= index < self.NUM_WEEK_DAYS):
raise ValueError(
"index must be between 0 and %d" % self.NUM_WEEK_DAYS
)
self.index = index

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, WeekDay):
return self.index == other.index
else:
return NotImplemented

def __hash__(self):
def __hash__(self) -> int:
return hash(self.index)

def __lt__(self, other):
def __lt__(self, other: WeekDay) -> bool:
return self.position < other.position

def __repr__(self):
def __repr__(self) -> str:
return '%s(%r)' % (self.__class__.__name__, self.index)

def __unicode__(self):
def __unicode__(self) -> str:
return self.name

def get_name(self, width='wide', context='format'):
def get_name(self, width: str = 'wide', context: str = 'format') -> str:
names = i18n.babel.dates.get_day_names(
width,
context,
Expand All @@ -43,11 +46,11 @@ def get_name(self, width='wide', context='format'):
return names[self.index]

@property
def name(self):
def name(self) -> str:
return self.get_name()

@property
def position(self):
def position(self) -> int:
return (
self.index -
i18n.get_locale().first_week_day
Expand Down
22 changes: 14 additions & 8 deletions sqlalchemy_utils/primitives/weekdays.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from ..utils import str_coercible
from __future__ import annotations

from typing import Any, Generator, Hashable, Union

from .weekday import WeekDay
from ..utils import str_coercible


@str_coercible
class WeekDays:
def __init__(self, bit_string_or_week_days):
_days: set

def __init__(self, bit_string_or_week_days: Union[str, WeekDays, Hashable]) -> None:
if isinstance(bit_string_or_week_days, str):
self._days = set()

Expand All @@ -27,31 +33,31 @@ def __init__(self, bit_string_or_week_days):
else:
self._days = set(bit_string_or_week_days)

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, WeekDays):
return self._days == other._days
elif isinstance(other, str):
return self.as_bit_string() == other
else:
return NotImplemented

def __iter__(self):
def __iter__(self) -> Generator[WeekDay, None, None]:
for day in sorted(self._days):
yield day

def __contains__(self, value):
def __contains__(self, value: Any) -> bool:
return value in self._days

def __repr__(self):
def __repr__(self) -> str:
return '%s(%r)' % (
self.__class__.__name__,
self.as_bit_string()
)

def __unicode__(self):
def __unicode__(self) -> str:
return ', '.join(str(day) for day in self)

def as_bit_string(self):
def as_bit_string(self) -> str:
return ''.join(
'1' if WeekDay(index) in self._days else '0'
for index in range(WeekDay.NUM_WEEK_DAYS)
Expand Down