Skip to content

Commit 8048c15

Browse files
committed
[fix] Add first draft of the PR for issue#349
1 parent f612f46 commit 8048c15

File tree

2 files changed

+177
-265
lines changed

2 files changed

+177
-265
lines changed

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import warnings
44
from multiprocessing.queues import Queue
5-
from typing import Any, Dict, List, Optional, Tuple, Union, no_type_check
5+
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union, no_type_check
66

77
from ConfigSpace import Configuration
88

@@ -54,6 +54,16 @@
5454
]
5555

5656

57+
class EvaluationResults(NamedTuple):
58+
opt_loss: Dict[str, float]
59+
train_loss: Dict[str, float]
60+
opt_pred: np.ndarray
61+
status: StatusType
62+
valid_pred: Optional[np.ndarray] = None
63+
test_pred: Optional[np.ndarray] = None
64+
additional_run_info: Optional[Dict] = None
65+
66+
5767
class MyTraditionalTabularClassificationPipeline(BaseEstimator):
5868
"""
5969
A wrapper class that holds a pipeline for traditional classification.
@@ -662,11 +672,7 @@ def _loss(self, y_true: np.ndarray, y_hat: np.ndarray) -> Dict[str, float]:
662672
return calculate_loss(
663673
y_true, y_hat, self.task_type, metrics)
664674

665-
def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
666-
opt_pred: np.ndarray, valid_pred: Optional[np.ndarray],
667-
test_pred: Optional[np.ndarray], additional_run_info: Optional[Dict],
668-
file_output: bool, status: StatusType
669-
) -> Optional[Tuple[float, float, int, Dict]]:
675+
def finish_up(self, results: EvaluationResults, file_output: bool) -> Optional[Tuple[float, float, int, Dict]]:
670676
"""This function does everything necessary after the fitting is done:
671677
672678
* predicting
@@ -711,37 +717,37 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
711717

712718
if file_output:
713719
loss_, additional_run_info_ = self.file_output(
714-
opt_pred, valid_pred, test_pred,
720+
results.opt_pred, results.valid_pred, results.test_pred,
715721
)
716722
else:
717723
loss_ = None
718724
additional_run_info_ = {}
719725

720726
validation_loss, test_loss = self.calculate_auxiliary_losses(
721-
valid_pred, test_pred
727+
results.valid_pred, results.test_pred
722728
)
723729

724730
if loss_ is not None:
725731
return self.duration, loss_, self.seed, additional_run_info_
726732

727-
cost = loss[self.metric.name]
733+
cost = results.opt_loss[self.metric.name]
728734

729735
additional_run_info = (
730-
{} if additional_run_info is None else additional_run_info
736+
{} if results.additional_run_info is None else results.additional_run_info
731737
)
732-
additional_run_info['opt_loss'] = loss
738+
additional_run_info['opt_loss'] = results.opt_loss
733739
additional_run_info['duration'] = self.duration
734740
additional_run_info['num_run'] = self.num_run
735-
if train_loss is not None:
736-
additional_run_info['train_loss'] = train_loss
741+
if results.train_loss is not None:
742+
additional_run_info['train_loss'] = results.train_loss
737743
if validation_loss is not None:
738744
additional_run_info['validation_loss'] = validation_loss
739745
if test_loss is not None:
740746
additional_run_info['test_loss'] = test_loss
741747

742748
rval_dict = {'loss': cost,
743749
'additional_run_info': additional_run_info,
744-
'status': status}
750+
'status': results.status}
745751

746752
self.queue.put(rval_dict)
747753
return None

0 commit comments

Comments
 (0)