11#!/usr/bin/env python3
22import warnings
33from collections import namedtuple
4- from typing import Any , Callable , Dict , List , NamedTuple , Optional , Tuple , Union
4+ from typing import (
5+ Any ,
6+ Callable ,
7+ Dict ,
8+ Generic ,
9+ List ,
10+ NamedTuple ,
11+ Optional ,
12+ Tuple ,
13+ TypeVar ,
14+ Union ,
15+ cast ,
16+ )
517
618from torch import Tensor
719
1527
1628ORIGINAL_KEY = "Original"
1729
30+ MetricResultType = TypeVar (
31+ "MetricResultType" , float , Tensor , Tuple [Union [float , Tensor ], ...]
32+ )
33+
1834
1935class AttackInfo (NamedTuple ):
2036 attack_fn : Union [Perturbation , Callable ]
@@ -33,7 +49,7 @@ def agg_metric(inp):
3349 return inp
3450
3551
36- class AttackComparator :
52+ class AttackComparator ( Generic [ MetricResultType ]) :
3753 r"""
3854 Allows measuring model robustness for a given attack or set of attacks. This class
3955 can be used with any metric(s) as well as any set of attacks, either based on
@@ -44,7 +60,7 @@ class AttackComparator:
4460 def __init__ (
4561 self ,
4662 forward_func : Callable ,
47- metric : Callable [..., Union [ float , Tensor , Tuple [ Union [ float , Tensor ], ...]] ],
63+ metric : Callable [..., MetricResultType ],
4864 preproc_fn : Callable = None ,
4965 ) -> None :
5066 r"""
@@ -74,10 +90,10 @@ def model_metric(model_out: Tensor, **kwargs: Any)
7490 additional_forward_args provided to evaluate.
7591 """
7692 self .forward_func = forward_func
77- self .metric = metric
93+ self .metric : Callable = metric
7894 self .preproc_fn = preproc_fn
79- self .attacks = {}
80- self .summary_results = {}
95+ self .attacks : Dict [ str , AttackInfo ] = {}
96+ self .summary_results : Dict [ str , Summarizer ] = {}
8197 self .metric_aggregator = agg_metric
8298 self .batch_stats = [Mean , Min , Max ]
8399 self .aggregate_stats = [Mean ]
@@ -148,7 +164,7 @@ def add_attack(
148164
149165 def _format_summary (
150166 self , summary : Union [Dict , List [Dict ]]
151- ) -> Dict [str , Union [ float , Tuple [ float , ...]] ]:
167+ ) -> Dict [str , MetricResultType ]:
152168 r"""
153169 This method reformats a given summary; particularly for tuples,
154170 the Summarizer's summary format is a list of dictionaries,
@@ -159,12 +175,12 @@ def _format_summary(
159175 if isinstance (summary , dict ):
160176 return summary
161177 else :
162- summary_dict = {}
178+ summary_dict : Dict [ str , Tuple ] = {}
163179 for key in summary [0 ]:
164180 summary_dict [key ] = tuple (s [key ] for s in summary )
165181 if self .out_format :
166182 summary_dict [key ] = self .out_format (* summary_dict [key ])
167- return summary_dict
183+ return summary_dict # type: ignore
168184
169185 def _update_out_format (
170186 self , out_metric : Union [float , Tensor , Tuple [Union [float , Tensor ], ...]]
@@ -174,7 +190,9 @@ def _update_out_format(
174190 and isinstance (out_metric , tuple )
175191 and hasattr (out_metric , "_fields" )
176192 ):
177- self .out_format = namedtuple (type (out_metric ).__name__ , out_metric ._fields )
193+ self .out_format = namedtuple ( # type: ignore
194+ type (out_metric ).__name__ , cast (NamedTuple , out_metric )._fields
195+ )
178196
179197 def _evaluate_batch (
180198 self ,
@@ -212,13 +230,10 @@ def _evaluate_batch(
212230 def evaluate (
213231 self ,
214232 inputs : Any ,
215- additional_forward_args : Optional [ Tuple ] = None ,
233+ additional_forward_args : Any = None ,
216234 perturbations_per_eval : int = 1 ,
217235 ** kwargs ,
218- ) -> Dict [
219- str ,
220- Union [Tensor , Tuple [Tensor , ...], Dict [str , Union [Tensor , Tuple [Tensor , ...]]]],
221- ]:
236+ ) -> Dict [str , Union [MetricResultType , Dict [str , MetricResultType ]]]:
222237 r"""
223238 Evaluate model and attack performance on provided inputs
224239
@@ -385,45 +400,44 @@ def _check_and_evaluate(input_list, key_list):
385400
386401 def _parse_and_update_results (
387402 self , batch_summarizers : Dict [str , Summarizer ]
388- ) -> Dict [
389- str , Union [float , Tuple [float , ...], Dict [str , Union [float , Tuple [float , ...]]]]
390- ]:
391- results = {
392- ORIGINAL_KEY : self ._format_summary (batch_summarizers [ORIGINAL_KEY ].summary )[
393- "mean"
394- ]
403+ ) -> Dict [str , Union [MetricResultType , Dict [str , MetricResultType ]]]:
404+ results : Dict [str , Union [MetricResultType , Dict [str , MetricResultType ]]] = {
405+ ORIGINAL_KEY : self ._format_summary (
406+ cast (Union [Dict , List ], batch_summarizers [ORIGINAL_KEY ].summary )
407+ )["mean" ]
395408 }
396409 self .summary_results [ORIGINAL_KEY ].update (
397410 self .metric_aggregator (results [ORIGINAL_KEY ])
398411 )
399412 for attack_key in self .attacks :
400413 attack = self .attacks [attack_key ]
401- results [ attack . name ] = self ._format_summary (
402- batch_summarizers [attack .name ].summary
414+ attack_results = self ._format_summary (
415+ cast ( Union [ Dict , List ], batch_summarizers [attack .name ].summary )
403416 )
417+ results [attack .name ] = attack_results
404418
405- if len (results [ attack . name ] ) == 1 :
406- key = next (iter (results [ attack . name ] ))
419+ if len (attack_results ) == 1 :
420+ key = next (iter (attack_results ))
407421 if attack .name not in self .summary_results :
408422 self .summary_results [attack .name ] = Summarizer (
409423 [stat () for stat in self .aggregate_stats ]
410424 )
411425 self .summary_results [attack .name ].update (
412- self .metric_aggregator (results [ attack . name ] [key ])
426+ self .metric_aggregator (attack_results [key ])
413427 )
414428 else :
415- for key in results [ attack . name ] :
429+ for key in attack_results :
416430 summary_key = f"{ attack .name } { key .title ()} Attempt"
417431 if summary_key not in self .summary_results :
418432 self .summary_results [summary_key ] = Summarizer (
419433 [stat () for stat in self .aggregate_stats ]
420434 )
421435 self .summary_results [summary_key ].update (
422- self .metric_aggregator (results [ attack . name ] [key ])
436+ self .metric_aggregator (attack_results [key ])
423437 )
424438 return results
425439
426- def summary (self ) -> Dict [str , Dict [str , Union [ Tensor , Tuple [ Tensor , ...]] ]]:
440+ def summary (self ) -> Dict [str , Dict [str , MetricResultType ]]:
427441 r"""
428442 Returns average results over all previous batches evaluated.
429443
@@ -440,7 +454,9 @@ def summary(self) -> Dict[str, Dict[str, Union[Tensor, Tuple[Tensor, ...]]]]:
440454 per batch.
441455 """
442456 return {
443- key : self ._format_summary (self .summary_results [key ].summary )
457+ key : self ._format_summary (
458+ cast (Union [Dict , List ], self .summary_results [key ].summary )
459+ )
444460 for key in self .summary_results
445461 }
446462
0 commit comments