Skip to content

Commit b8e0f4e

Browse files
committed
Fixed types
1 parent 8e3a51b commit b8e0f4e

File tree

4 files changed

+40
-31
lines changed

4 files changed

+40
-31
lines changed

src/tdamapper/_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def deprecated(msg: str) -> Callable[..., Any]:
2727
"""
2828

2929
def deprecated_func(func: Callable[..., Any]) -> Callable[..., Any]:
30-
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
30+
def wrapper(*args: list[Any], **kwargs: Any) -> Any:
3131
warnings.warn(msg, DeprecationWarning, stacklevel=2)
3232
return func(*args, **kwargs)
3333

@@ -213,7 +213,7 @@ def profile(n_lines: int = 10) -> Callable[..., Any]:
213213
"""
214214

215215
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
216-
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
216+
def wrapper(*args: list[Any], **kwargs: Any) -> Any:
217217
profiler = cProfile.Profile()
218218
profiler.enable()
219219
result = func(*args, **kwargs)

src/tdamapper/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _umap(X: NDArray[np.float_]) -> NDArray[np.float_]:
202202

203203

204204
def run_mapper(
205-
df: pd.DataFrame, **kwargs: dict[str, Any]
205+
df: pd.DataFrame, **kwargs: Any
206206
) -> Optional[tuple[nx.Graph, pd.DataFrame]]:
207207
"""
208208
Runs the Mapper algorithm on the provided DataFrame and returns the Mapper
@@ -301,7 +301,7 @@ def create_mapper_figure(
301301
df_y: pd.DataFrame,
302302
df_target: pd.DataFrame,
303303
mapper_graph: nx.Graph,
304-
**kwargs: dict[str, Any],
304+
**kwargs: Any,
305305
) -> go.Figure:
306306
"""
307307
Renders the Mapper graph as a Plotly figure.

src/tdamapper/utils/metrics.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_supported_metrics() -> list[MetricLiteral]:
5858
return list(get_args(MetricLiteral))
5959

6060

61-
def euclidean(**kwargs: dict[str, Any]) -> Metric[Any]:
61+
def euclidean() -> Metric[Any]:
6262
"""
6363
Return the Euclidean distance function for vectors.
6464
@@ -70,7 +70,7 @@ def euclidean(**kwargs: dict[str, Any]) -> Metric[Any]:
7070
return _metrics.euclidean
7171

7272

73-
def manhattan(**kwargs: dict[str, Any]) -> Metric[Any]:
73+
def manhattan() -> Metric[Any]:
7474
"""
7575
Return the Manhattan distance function for vectors.
7676
@@ -82,7 +82,7 @@ def manhattan(**kwargs: dict[str, Any]) -> Metric[Any]:
8282
return _metrics.manhattan
8383

8484

85-
def chebyshev(**kwargs: dict[str, Any]) -> Metric[Any]:
85+
def chebyshev() -> Metric[Any]:
8686
"""
8787
Return the Chebyshev distance function for vectors.
8888
@@ -94,7 +94,7 @@ def chebyshev(**kwargs: dict[str, Any]) -> Metric[Any]:
9494
return _metrics.chebyshev
9595

9696

97-
def minkowski(**kwargs: dict[str, Any]) -> Metric[Any]:
97+
def minkowski(p: Union[int, float]) -> Metric[Any]:
9898
"""
9999
Return the Minkowski distance function for order p on vectors.
100100
@@ -106,9 +106,6 @@ def minkowski(**kwargs: dict[str, Any]) -> Metric[Any]:
106106
:param p: The order of the Minkowski distance.
107107
:return: The Minkowski distance function.
108108
"""
109-
p = kwargs.get("p", 2)
110-
if not isinstance(p, (int, float)):
111-
raise TypeError("p must be an integer or a float")
112109
if p == 1:
113110
return manhattan()
114111
if p == 2:
@@ -122,7 +119,7 @@ def dist(x: Any, y: Any) -> float:
122119
return dist
123120

124121

125-
def cosine(**kwargs: dict[str, Any]) -> Metric[Any]:
122+
def cosine() -> Metric[Any]:
126123
"""
127124
Return the cosine distance function for vectors.
128125
@@ -145,9 +142,7 @@ def cosine(**kwargs: dict[str, Any]) -> Metric[Any]:
145142
return _metrics.cosine
146143

147144

148-
def get_metric(
149-
metric: Union[MetricLiteral, Metric[Any]], **kwargs: dict[str, Any]
150-
) -> Metric[Any]:
145+
def get_metric(metric: Union[MetricLiteral, Metric[Any]], **kwargs: Any) -> Metric[Any]:
151146
"""
152147
Return a distance function based on the specified string or callable.
153148

tests/test_unit_metrics.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,34 @@
3131
RANDOM = dataset_random(2, 100)
3232

3333

34+
def _check_values(m1, m2, a, b):
35+
m1_div_by_zero = False
36+
m1_is_nan = False
37+
38+
m2_div_by_zero = False
39+
m2_is_nan = False
40+
41+
try:
42+
m1_value = m1(a, b)
43+
if np.isnan(m1_value):
44+
m1_is_nan = True
45+
except ZeroDivisionError:
46+
m1_div_by_zero = True
47+
try:
48+
m2_value = m2(a, b)
49+
if np.isnan(m2_value):
50+
m2_is_nan = True
51+
except ZeroDivisionError:
52+
m2_div_by_zero = True
53+
assert m1_div_by_zero == m2_div_by_zero
54+
assert m1_is_nan == m2_is_nan
55+
if m1_div_by_zero or m2_div_by_zero:
56+
return True
57+
if m1_is_nan or m2_is_nan:
58+
return True
59+
return math.isclose(m1_value, m2_value)
60+
61+
3462
@pytest.mark.parametrize("data", [SIMPLE, TWO_LINES, GRID, RANDOM])
3563
@pytest.mark.parametrize(
3664
"m1, m2",
@@ -39,28 +67,14 @@
3967
(manhattan(), get_metric("manhattan")),
4068
(chebyshev(), get_metric("chebyshev")),
4169
(minkowski(p=3), get_metric("minkowski", p=3)),
70+
(minkowski(p=2.5), get_metric("minkowski", p=2.5)),
4271
(cosine(), get_metric("cosine")),
4372
],
4473
)
4574
def test_metrics(m1, m2, data):
4675
for a in data:
4776
for b in data:
48-
m1_fail = False
49-
m2_fail = False
50-
m1_value = 0.0
51-
m2_value = 0.0
52-
try:
53-
m1_value = m1(a, b)
54-
except Exception:
55-
m1_fail = True
56-
try:
57-
m2_value = m2(a, b)
58-
except Exception:
59-
m2_fail = True
60-
assert m1_fail == m2_fail
61-
if np.isnan(m1_value) and np.isnan(m2_value):
62-
return
63-
assert math.isclose(m1_value, m2_value)
77+
assert _check_values(m1, m2, a, b)
6478

6579

6680
def test_supported_metrics():

0 commit comments

Comments
 (0)