Skip to content

Commit 047648e

Browse files
authored
Merge pull request #57 from kaboroevich/typing_refactor
Typing refactor
2 parents 8950442 + 9887f95 commit 047648e

File tree

13 files changed

+37
-42
lines changed

13 files changed

+37
-42
lines changed
File renamed without changes.
File renamed without changes.

pyDeepInsight/image_transformer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import inspect
1313
import warnings
1414

15-
from .utils import sparse_assignment
15+
from .utils._assignment import sparse_assignment
1616

1717

1818
class ImageTransformer:
@@ -25,13 +25,13 @@ class ImageTransformer:
2525
Attributes:
2626
_fe (ManifoldLearner): The feature extraction method used for
2727
dimensionality reduction. It must have a `fit_transform` method.
28-
_dm (function): The discretization method used to assign data points to
28+
_dm (Callable): The discretization method used to assign data points to
2929
pixel coordinates.
30-
_pixels (tuple): The dimensions of the image matrix (height, width) to
30+
_pixels (Tuple[int, int]): The dimensions of the image matrix (height, width) to
3131
which the data will be mapped.
32-
_xrot (ndarray): The rotated coordinates of the data after
32+
_xrot (np.ndarray): The rotated coordinates of the data after
3333
dimensionality reduction.
34-
_coords (ndarray): The final pixel coordinates assigned to the data
34+
_coords (np.ndarray): The final pixel coordinates assigned to the data
3535
points after discretization.
3636
DISCRETIZATION_OPTIONS (dict): A dictionary mapping discretization
3737
method names to their corresponding class methods for pixel

stubs/pyDeepInsight/image_transformer.pyi renamed to pyDeepInsight/image_transformer.pyi

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from typing import Any, Optional, Callable
22
from typing_extensions import Protocol
33
from numpy.typing import ArrayLike
44
import numpy as np
5-
import torch
5+
from torch import Tensor
66

77

88
class ManifoldLearner(Protocol):
@@ -85,10 +85,10 @@ class ImageTransformer:
8585
def _calculate_coords(self) -> None: ...
8686

8787
def transform(self, X: np.ndarray, img_format: str = 'rgb',
88-
empty_value: int = 0) -> np.ndarray | torch.Tensor: ...
88+
empty_value: int = 0) -> np.ndarray | Tensor: ...
8989

9090
def fit_transform(self, X: np.ndarray, **kwargs: Any
91-
) -> np.ndarray | torch.Tensor: ...
91+
) -> np.ndarray | Tensor: ...
9292

9393
def inverse_transform(self, img: np.ndarray) -> np.ndarray: ...
9494

@@ -104,7 +104,7 @@ class ImageTransformer:
104104
def _mat_to_rgb(mat: np.ndarray) -> np.ndarray: ...
105105

106106
@staticmethod
107-
def _mat_to_pytorch(mat: np.ndarray) -> torch.Tensor: ...
107+
def _mat_to_pytorch(mat: np.ndarray) -> Tensor: ...
108108

109109

110110
class MRepImageTransformer:
@@ -134,15 +134,15 @@ class MRepImageTransformer:
134134
empty_value: int = 0, collate: str = 'sample',
135135
return_index: bool = True
136136
) -> (np.ndarray
137-
| torch.Tensor
137+
| Tensor
138138
| tuple[np.ndarray, np.ndarray, np.ndarray]
139-
| tuple[torch.Tensor, np.ndarray, np.ndarray]): ...
139+
| tuple[Tensor, np.ndarray, np.ndarray]): ...
140140

141141
def fit_transform(self, X: np.ndarray, **kwargs: Any
142142
) -> (np.ndarray
143-
| torch.Tensor
143+
| Tensor
144144
| tuple[np.ndarray, np.ndarray, np.ndarray]
145-
| tuple[torch.Tensor, np.ndarray, np.ndarray]): ...
145+
| tuple[Tensor, np.ndarray, np.ndarray]): ...
146146

147147
@staticmethod
148148
def prediction_reduction(input: np.ndarray, index: np.ndarray,

pyDeepInsight/py.typed

Whitespace-only changes.
File renamed without changes.

pyDeepInsight/utils/_assignment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def _sparsify_top_percentile(arr, p):
99
smallest top-percentile values in each row.
1010
1111
Args:
12-
arr (ndarray): 2D array of shape (n_rows, n_cols) representing the
12+
arr (np.ndarray): 2D array of shape (n_rows, n_cols) representing the
1313
input cost matrix.
1414
p (float): The fraction (0 < p <= 1) of the smallest values to retain
1515
per row.
@@ -34,7 +34,7 @@ def sparse_assignment(cost_matrix, p=0.1):
3434
smallest top-percentile values in each row.
3535
3636
Args:
37-
cost_matrix (ndarray): 2D array of shape (n_rows, n_cols)
37+
cost_matrix (np.ndarray): 2D array of shape (n_rows, n_cols)
3838
representing the cost matrix.
3939
p (float): The fraction (0 < p <= 1) of the smallest values to
4040
retain per row during sparsification.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from typing import Tuple
2+
import numpy as np
3+
4+
5+
def _sparsify_top_percentile(arr: np.ndarray, p: float) -> np.ndarray: ...
6+
7+
def sparse_assignment(cost_matrix: np.ndarray, p: float) -> Tuple[np.ndarray, np.ndarray]: ...
8+
File renamed without changes.
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
import torch
2+
from torch import nn, Tensor
33

44

55
def step_blur_kernel(kernel_size: int, amplification: float) -> np.ndarray: ...
@@ -9,14 +9,14 @@ def apply_blur_kernel(img: np.ndarray, kernel: np.ndarray) -> np.ndarray: ...
99
def step_blur(img: np.ndarray, kernel_size: int,
1010
amplification: float) -> np.ndarray: ...
1111

12-
class StepBlur2d(torch.nn.Module):
12+
class StepBlur2d(nn.Module):
1313

1414
def_amp: float
15-
kernel: torch.tensor
15+
kernel: Tensor
1616

17-
def __init__(self, kernel_size: torch.tensor, amplification: float): ...
17+
def __init__(self, kernel_size: Tensor, amplification: float): ...
1818

19-
def forward(self, input: torch.tensor) -> torch.tensor: ...
19+
def forward(self, input: Tensor) -> Tensor: ...
2020

2121
@staticmethod
2222
def step_kernel(kernel_size: int, amplification: float): ...
@@ -25,14 +25,14 @@ def imgaborfilt(image: np.ndarray, wavelength: float, orientation: float,
2525
SpatialFrequencyBandwidth: float,
2626
SpatialAspectRatio: float) -> np.ndarray: ...
2727

28-
class GaborFilter2d(torch.nn.Module):
28+
class GaborFilter2d(nn.Module):
2929

3030
def __init__(self, wavelength: float, orientation: float): ...
3131

32-
def forward(self, img: torch.tensor) -> torch.tensor: ...
32+
def forward(self, img: Tensor) -> Tensor: ...
3333

3434
@staticmethod
35-
def pil_to_tensor(img: np.ndarray) -> torch.tensor: ...
35+
def pil_to_tensor(img: np.ndarray) -> Tensor: ...
3636

3737
@staticmethod
38-
def tensor_to_pil(img: torch.tensor) -> np.ndarray: ...
38+
def tensor_to_pil(img: Tensor) -> np.ndarray: ...

0 commit comments

Comments
 (0)