88import matplotlib
99
1010import numpy as np
11+ import numpy .typing as npt
1112from matplotlib import cm , colors , pyplot as plt
1213from matplotlib .axes import Axes
1314from matplotlib .collections import LineCollection
@@ -47,11 +48,11 @@ class VisualizeSign(Enum):
4748 all = 4
4849
4950
50- def _prepare_image (attr_visual : ndarray ) -> ndarray :
51+ def _prepare_image (attr_visual : npt . NDArray ) -> npt . NDArray :
5152 return np .clip (attr_visual .astype (int ), 0 , 255 )
5253
5354
54- def _normalize_scale (attr : ndarray , scale_factor : float ) -> ndarray :
55+ def _normalize_scale (attr : npt . NDArray , scale_factor : float ) -> npt . NDArray :
5556 assert scale_factor != 0 , "Cannot normalize by scale factor = 0"
5657 if abs (scale_factor ) < 1e-5 :
5758 warnings .warn (
@@ -64,23 +65,26 @@ def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
6465 return np .clip (attr_norm , - 1 , 1 )
6566
6667
67- def _cumulative_sum_threshold (values : ndarray , percentile : Union [int , float ]) -> float :
68+ def _cumulative_sum_threshold (
69+ values : npt .NDArray , percentile : Union [int , float ]
70+ ) -> float :
6871 # given values should be non-negative
6972 assert percentile >= 0 and percentile <= 100 , (
7073 "Percentile for thresholding must be " "between 0 and 100 inclusive."
7174 )
7275 sorted_vals = np .sort (values .flatten ())
7376 cum_sums = np .cumsum (sorted_vals )
7477 threshold_id = np .where (cum_sums >= cum_sums [- 1 ] * 0.01 * percentile )[0 ][0 ]
78+ # pyre-fixme[7]: Expected `float` but got `ndarray[typing.Any, dtype[typing.Any]]`.
7579 return sorted_vals [threshold_id ]
7680
7781
7882def _normalize_attr (
79- attr : ndarray ,
83+ attr : npt . NDArray ,
8084 sign : str ,
8185 outlier_perc : Union [int , float ] = 2 ,
8286 reduction_axis : Optional [int ] = None ,
83- ) -> ndarray :
87+ ) -> npt . NDArray :
8488 attr_combined = attr
8589 if reduction_axis is not None :
8690 attr_combined = np .sum (attr , axis = reduction_axis )
@@ -130,7 +134,7 @@ def _initialize_cmap_and_vmin_vmax(
130134
131135def _visualize_original_image (
132136 plt_axis : Axes ,
133- original_image : Optional [ndarray ],
137+ original_image : Optional [npt . NDArray ],
134138 ** kwargs : Any ,
135139) -> None :
136140 assert (
@@ -143,7 +147,7 @@ def _visualize_original_image(
143147
144148def _visualize_heat_map (
145149 plt_axis : Axes ,
146- norm_attr : ndarray ,
150+ norm_attr : npt . NDArray ,
147151 cmap : Union [str , Colormap ],
148152 vmin : float ,
149153 vmax : float ,
@@ -155,8 +159,8 @@ def _visualize_heat_map(
155159
156160def _visualize_blended_heat_map (
157161 plt_axis : Axes ,
158- original_image : ndarray ,
159- norm_attr : ndarray ,
162+ original_image : npt . NDArray ,
163+ norm_attr : npt . NDArray ,
160164 cmap : Union [str , Colormap ],
161165 vmin : float ,
162166 vmax : float ,
@@ -176,8 +180,8 @@ def _visualize_blended_heat_map(
176180def _visualize_masked_image (
177181 plt_axis : Axes ,
178182 sign : str ,
179- original_image : ndarray ,
180- norm_attr : ndarray ,
183+ original_image : npt . NDArray ,
184+ norm_attr : npt . NDArray ,
181185 ** kwargs : Any ,
182186) -> None :
183187 assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
@@ -190,8 +194,8 @@ def _visualize_masked_image(
190194def _visualize_alpha_scaling (
191195 plt_axis : Axes ,
192196 sign : str ,
193- original_image : ndarray ,
194- norm_attr : ndarray ,
197+ original_image : npt . NDArray ,
198+ norm_attr : npt . NDArray ,
195199 ** kwargs : Any ,
196200) -> None :
197201 assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
@@ -210,8 +214,8 @@ def _visualize_alpha_scaling(
210214
211215
212216def visualize_image_attr (
213- attr : ndarray ,
214- original_image : Optional [ndarray ] = None ,
217+ attr : npt . NDArray ,
218+ original_image : Optional [npt . NDArray ] = None ,
215219 method : str = "heat_map" ,
216220 sign : str = "absolute_value" ,
217221 plt_fig_axis : Optional [Tuple [Figure , Axes ]] = None ,
@@ -417,8 +421,8 @@ def visualize_image_attr(
417421
418422
419423def visualize_image_attr_multiple (
420- attr : ndarray ,
421- original_image : Union [None , ndarray ],
424+ attr : npt . NDArray ,
425+ original_image : Union [None , npt . NDArray ],
422426 methods : List [str ],
423427 signs : List [str ],
424428 titles : Optional [List [str ]] = None ,
@@ -526,9 +530,9 @@ def visualize_image_attr_multiple(
526530
527531
528532def visualize_timeseries_attr (
529- attr : ndarray ,
530- data : ndarray ,
531- x_values : Optional [ndarray ] = None ,
533+ attr : npt . NDArray ,
534+ data : npt . NDArray ,
535+ x_values : Optional [npt . NDArray ] = None ,
532536 method : str = "overlay_individual" ,
533537 sign : str = "absolute_value" ,
534538 channel_labels : Optional [List [str ]] = None ,
0 commit comments