|
2 | 2 | import time |
3 | 3 | import warnings |
4 | 4 | from multiprocessing.queues import Queue |
5 | | -from typing import Any, Dict, List, Optional, Tuple, Union, no_type_check |
| 5 | +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union, no_type_check |
6 | 6 |
|
7 | 7 | from ConfigSpace import Configuration |
8 | 8 |
|
|
54 | 54 | ] |
55 | 55 |
|
56 | 56 |
|
| 57 | +class EvaluationResults(NamedTuple): |
| 58 | + opt_loss: Dict[str, float] |
| 59 | + train_loss: Dict[str, float] |
| 60 | + opt_pred: np.ndarray |
| 61 | + status: StatusType |
| 62 | + valid_pred: Optional[np.ndarray] = None |
| 63 | + test_pred: Optional[np.ndarray] = None |
| 64 | + additional_run_info: Optional[Dict] = None |
| 65 | + |
| 66 | + |
57 | 67 | class MyTraditionalTabularClassificationPipeline(BaseEstimator): |
58 | 68 | """ |
59 | 69 | A wrapper class that holds a pipeline for traditional classification. |
@@ -662,11 +672,7 @@ def _loss(self, y_true: np.ndarray, y_hat: np.ndarray) -> Dict[str, float]: |
662 | 672 | return calculate_loss( |
663 | 673 | y_true, y_hat, self.task_type, metrics) |
664 | 674 |
|
665 | | - def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float], |
666 | | - opt_pred: np.ndarray, valid_pred: Optional[np.ndarray], |
667 | | - test_pred: Optional[np.ndarray], additional_run_info: Optional[Dict], |
668 | | - file_output: bool, status: StatusType |
669 | | - ) -> Optional[Tuple[float, float, int, Dict]]: |
| 675 | + def finish_up(self, results: EvaluationResults, file_output: bool) -> Optional[Tuple[float, float, int, Dict]]: |
670 | 676 | """This function does everything necessary after the fitting is done: |
671 | 677 |
|
672 | 678 | * predicting |
@@ -711,37 +717,37 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float], |
711 | 717 |
|
712 | 718 | if file_output: |
713 | 719 | loss_, additional_run_info_ = self.file_output( |
714 | | - opt_pred, valid_pred, test_pred, |
| 720 | + results.opt_pred, results.valid_pred, results.test_pred, |
715 | 721 | ) |
716 | 722 | else: |
717 | 723 | loss_ = None |
718 | 724 | additional_run_info_ = {} |
719 | 725 |
|
720 | 726 | validation_loss, test_loss = self.calculate_auxiliary_losses( |
721 | | - valid_pred, test_pred |
| 727 | + results.valid_pred, results.test_pred |
722 | 728 | ) |
723 | 729 |
|
724 | 730 | if loss_ is not None: |
725 | 731 | return self.duration, loss_, self.seed, additional_run_info_ |
726 | 732 |
|
727 | | - cost = loss[self.metric.name] |
| 733 | + cost = results.opt_loss[self.metric.name] |
728 | 734 |
|
729 | 735 | additional_run_info = ( |
730 | | - {} if additional_run_info is None else additional_run_info |
| 736 | + {} if results.additional_run_info is None else results.additional_run_info |
731 | 737 | ) |
732 | | - additional_run_info['opt_loss'] = loss |
| 738 | + additional_run_info['opt_loss'] = results.opt_loss |
733 | 739 | additional_run_info['duration'] = self.duration |
734 | 740 | additional_run_info['num_run'] = self.num_run |
735 | | - if train_loss is not None: |
736 | | - additional_run_info['train_loss'] = train_loss |
| 741 | + if results.train_loss is not None: |
| 742 | + additional_run_info['train_loss'] = results.train_loss |
737 | 743 | if validation_loss is not None: |
738 | 744 | additional_run_info['validation_loss'] = validation_loss |
739 | 745 | if test_loss is not None: |
740 | 746 | additional_run_info['test_loss'] = test_loss |
741 | 747 |
|
742 | 748 | rval_dict = {'loss': cost, |
743 | 749 | 'additional_run_info': additional_run_info, |
744 | | - 'status': status} |
| 750 | + 'status': results.status} |
745 | 751 |
|
746 | 752 | self.queue.put(rval_dict) |
747 | 753 | return None |
|
0 commit comments