Skip to content

Commit 7ca847b

Browse files
authored
Merge pull request #247 from lucasimi/add-docstrings-common
Add docstrings common
2 parents 32563e0 + 0c3a420 commit 7ca847b

File tree

4 files changed

+259
-29
lines changed

4 files changed

+259
-29
lines changed

src/tdamapper/_common.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919

2020

2121
def deprecated(msg: str) -> Callable[..., Any]:
22+
"""
23+
Decorator to mark functions as deprecated.
24+
25+
:param msg: The deprecation message to be shown in the warning.
26+
:return: A decorator that wraps the function to issue a deprecation warning.
27+
"""
28+
2229
def deprecated_func(func: Callable[..., Any]) -> Callable[..., Any]:
2330
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
2431
warnings.warn(msg, DeprecationWarning, stacklevel=2)
@@ -30,18 +37,45 @@ def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
3037

3138

3239
def warn_user(msg: str) -> None:
40+
"""
41+
Issue a warning to the user.
42+
"""
3343
warnings.warn(msg, UserWarning, stacklevel=2)
3444

3545

3646
class EstimatorMixin:
47+
"""
48+
Mixin to add common functionalities to estimators, such as validation of
49+
input data and setting the number of features.
50+
51+
This mixin is intended to be used with estimators that follow the scikit-learn
52+
interface, particularly those that implement the `fit` method.
53+
It provides methods to validate input data, check for sparsity, and set the
54+
number of features in the input data.
55+
"""
3756

3857
def _is_sparse(self, X: ArrayRead[Any]) -> bool:
58+
"""
59+
Check if the input data `X` is sparse.
60+
61+
:param X: Input data, can be a list, numpy array, or similar.
62+
:return: True if `X` is sparse, False otherwise.
63+
"""
3964
# simple alternative use scipy.sparse.issparse
4065
return hasattr(X, "toarray")
4166

4267
def _validate_X_y(
4368
self, X: ArrayRead[Any], y: ArrayRead[Any]
4469
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
70+
"""
71+
Validate the input data `X` and target `y`.
72+
73+
:param X: Input data, can be a list, numpy array, or similar.
74+
:param y: Target values, can be a list, numpy array, or similar.
75+
:return: Tuple of validated numpy arrays for `X` and `y`.
76+
:raises ValueError: If the input data is invalid, such as being empty,
77+
having NaNs or infinite values, or being complex.
78+
"""
4579
if self._is_sparse(X):
4680
raise ValueError("Sparse data not supported.")
4781

@@ -83,6 +117,11 @@ def _validate_X_y(
83117
return X_, y_
84118

85119
def _set_n_features_in(self, X: Array[Any]) -> None:
120+
"""
121+
Set the number of features in the input data `X`.
122+
123+
:param X: Input data, can be a list, numpy array, or similar.
124+
"""
86125
if hasattr(X, "shape"):
87126
self.n_features_in_ = X.shape[1]
88127

@@ -163,6 +202,14 @@ def clone(obj: Any) -> Any:
163202

164203

165204
def profile(n_lines: int = 10) -> Callable[..., Any]:
205+
"""
206+
Decorator to profile a function using cProfile and print the top `n_lines`
207+
cumulative time statistics.
208+
209+
:param n_lines: The number of lines to print from the profiling statistics.
210+
:return: A decorator that wraps the function to profile its execution.
211+
"""
212+
166213
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
167214
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
168215
profiler = cProfile.Profile()

src/tdamapper/_run_app.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
"""
2+
This module is the entry point for running the application.
3+
"""
4+
15
from tdamapper.app import main
26

37

48
def run() -> None:
9+
"""
10+
Run the application.
11+
"""
512
main()
613

714

0 commit comments

Comments
 (0)