Skip to content

Commit 4e0f373

Browse files
committed
refactor: now mypy checks tests
1 parent b21312a commit 4e0f373

File tree

5 files changed

+21
-46
lines changed

5 files changed

+21
-46
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,3 @@ repos:
6666
- numpy>=2
6767
- scipy>=1.13
6868
- pytest>=8
69-
exclude: '^tests/.*'

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ exclude_lines = [
9393
]
9494

9595
[tool.mypy]
96+
files = [ "src", "tests" ]
9697
python_version = "3.12"
9798
strict = true
9899
warn_unused_configs = true
@@ -103,3 +104,10 @@ no_implicit_optional = true
103104
namespace_packages = true
104105
explicit_package_bases = true
105106
mypy_path = [ "src" ]
107+
108+
[[tool.mypy.overrides]]
109+
module = [ "tests.*" ]
110+
disallow_untyped_defs = false
111+
check_untyped_defs = true
112+
warn_return_any = false
113+
implicit_reexport = true

tests/unit/families/test_basic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ class TestBaseFamily:
1111
CDF: GenericCharacteristicName = "cdf"
1212
PPF: GenericCharacteristicName = "mean"
1313

14-
def make_default_family(self, distr_characteristics=None) -> ParametricFamily:
14+
def make_default_family(
15+
self,
16+
distr_characteristics: dict[GenericCharacteristicName, dict[str, object]] | None = None,
17+
) -> ParametricFamily:
1518
if distr_characteristics is None:
1619
distr_characteristics = {
1720
self.PDF: {"base": lambda p, x: x},
@@ -22,7 +25,7 @@ def make_default_family(self, distr_characteristics=None) -> ParametricFamily:
2225
name="Default",
2326
distr_type=UnivariateContinuous,
2427
distr_parametrizations=["base", "alt"],
25-
distr_characteristics=distr_characteristics,
28+
distr_characteristics=distr_characteristics, # type: ignore[arg-type]
2629
sampling_strategy=MockSamplingStrategy(),
2730
)
2831

@@ -35,6 +38,6 @@ class Alt(Parametrization):
3538
value: float
3639

3740
def transform_to_base_parametrization(self) -> Parametrization:
38-
return Base(value=self.value)
41+
return Base(value=self.value) # type: ignore[call-arg]
3942

4043
return fam

tests/unit/families/test_distribution_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ def test_cache_auto_invalidation(self) -> None:
3333
assert computations1 is computations1_again # cache hit
3434

3535
# Replacing with a *new* object of the same parametrization should rebuild the cache
36-
distribution.parameters = family.parametrizations["alt"](value=5.0)
36+
distribution.parameters = family.parametrizations["alt"](value=5.0) # type: ignore[call-arg]
3737
computations2 = distribution.analytical_computations
3838
assert computations2 is not computations1
3939

4040
# Switching to the base parametrization should also rebuild the cache
41-
distribution.parameters = family.parametrizations["base"](value=7.0)
41+
distribution.parameters = family.parametrizations["base"](value=7.0) # type: ignore[call-arg]
4242
computations3 = distribution.analytical_computations
4343
assert computations3 is not computations2
4444

tests/unit/families/test_parameters.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from __future__ import annotations
22

3-
import pytest
3+
from typing import Any
44

55
from pysatl_core.families import (
66
ParametricFamily,
77
Parametrization,
88
ParametrizationConstraint,
99
constraint,
10-
parametrization,
1110
)
1211
from pysatl_core.types import UnivariateContinuous
1312
from tests.unit.families.test_basic import TestBaseFamily
13+
from tests.utils.mocks import MockSamplingStrategy
1414

1515

1616
class TestParametrizationAPI(TestBaseFamily):
@@ -30,7 +30,7 @@ def test_constraint_decorator_marks_function(self) -> None:
3030
"""Decorator should tag a function so Parametrization.validate() can discover it."""
3131

3232
@constraint("Value must be positive")
33-
def check_positive(self) -> bool: # noqa: ANN001 (test signature)
33+
def check_positive(self: Any) -> bool: # noqa: ANN001 (test signature)
3434
return getattr(self, "value", 0) > 0
3535

3636
# Different code versions may use single or double underscore attributes.
@@ -52,42 +52,7 @@ def test_free_function_parametrization_decorator(self) -> None:
5252
distr_type=UnivariateContinuous,
5353
distr_parametrizations=["base"],
5454
distr_characteristics={},
55-
sampling_strategy=lambda n, d, **_: __import__("numpy").random.random((n, 1)), # type: ignore[assignment]
56-
computation_strategy=lambda: None, # type: ignore[assignment]
57-
)
58-
59-
@parametrization(family=family, name="base")
60-
class Base(Parametrization):
61-
value: float
62-
63-
@constraint("Value must be positive")
64-
def check_positive(self) -> bool:
65-
return self.value > 0
66-
67-
instance = Base(value=5.0) # type: ignore[call-arg]
68-
assert instance.name == "base"
69-
assert instance.parameters == {"value": 5.0}
70-
assert getattr(Base, "__family__", None) is family
71-
assert getattr(Base, "__param_name__", None) == "base"
72-
assert hasattr(Base, "__dataclass_fields__")
73-
74-
# Validation succeeds
75-
instance.validate()
76-
77-
# Validation fails
78-
invalid = Base(value=-1.0) # type: ignore[call-arg]
79-
with pytest.raises(ValueError, match="Constraint.*does not hold"):
80-
invalid.validate()
81-
82-
def test_method_style_parametrization_decorator(self) -> None:
83-
"""family.parametrization(name=...) should behave the same as the free decorator."""
84-
family = ParametricFamily(
85-
name="MethodDecoratorFamily",
86-
distr_type=UnivariateContinuous,
87-
distr_parametrizations=["kind"],
88-
distr_characteristics={},
89-
sampling_strategy=lambda n, d, **_: __import__("numpy").random.random((n, 1)), # type: ignore[assignment]
90-
computation_strategy=lambda: None, # type: ignore[assignment]
55+
sampling_strategy=MockSamplingStrategy(),
9156
)
9257

9358
@family.parametrization(name="kind")
@@ -118,4 +83,4 @@ def test_get_base_parameters_uses_family_logic(self) -> None:
11883
base_from_alt = family.get_base_parameters(alt_params)
11984
assert isinstance(base_from_alt, BaseCls)
12085
# Our default factory maps Alt(value=v) → Base(value=v)
121-
assert base_from_alt.value == 3.0
86+
assert base_from_alt.value == 3.0 # type: ignore[attr-defined]

0 commit comments

Comments
 (0)