Skip to content

Commit 046ae5f

Browse files
committed
[refactor] Refactor __init__ of abstract evaluator
[refactor] Collect shared variables in NamedTuples [fix] Copy the budget passed to the evaluator params [refactor] Add cross validation result manager for separate management [refactor] Separate pipeline classes from abstract evaluator [refactor] Refactor tae.py [refactor] Increase the safety level of pipeline config [test] Fix test_evaluation.py [test] Fix test_abstract_evaluator.py 1 -- 3 [test] Add default pipeline config [test] Modify queue.empty in a safer way [test] Fix test_api.py [test] Fix test_train_evaluator.py [refactor] Refactor test_api.py before adding new tests [refactor] Refactor test_tabular_xxx [fix] Find the error in test_tabular_xxx Since pipeline is updated after the evaluations and the previous code updated self.pipeline in the predict method, dummy class only needs to override this method. However, the new code does it separately, so I override get_pipeline method so that we can reproduce the same results. [fix] Fix the shape issue in regression and add bug comment in a test [refactor] Use keyword args to avoid unexpected bugs [fix] Fix the ground truth of test_cv Since we changed the weighting strategy for the cross validation in the validation phase so that we weight performance from each model proportionally to the size of each VALIDATION split. I needed to change the answer. Note that the previous was weighting the performance proportionally to the TRAINING splits for both training and validation phases. [fix] Change qsize --> Empty since qsize might not be reliable [refactor] Add cost for crash in autoPyTorchMetrics [test] Remove self.pipeline since this is a duplication of self.pipelines [fix] Fix attribute errors caused by the last change in curve extraction [fix] Fix the issue when taking num_classes from regression task [fix] Deactivate the save of cv model in the case of holdout
1 parent 8048c15 commit 046ae5f

File tree

14 files changed

+1683
-2255
lines changed

14 files changed

+1683
-2255
lines changed

autoPyTorch/api/base_task.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@
4343
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
4444
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
4545
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
46-
from autoPyTorch.evaluation.abstract_evaluator import fit_and_suppress_warnings
47-
from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash
46+
from autoPyTorch.evaluation.abstract_evaluator import fit_pipeline
47+
from autoPyTorch.evaluation.pipeline_class_collection import get_default_pipeline_config
48+
from autoPyTorch.evaluation.tae import TargetAlgorithmQuery
4849
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
4950
from autoPyTorch.optimizer.smbo import AutoMLSMBO
5051
from autoPyTorch.pipeline.base_pipeline import BasePipeline
@@ -669,22 +670,23 @@ def _do_dummy_prediction(self) -> None:
669670
# already be generated here!
670671
stats = Stats(scenario_mock)
671672
stats.start_timing()
672-
ta = ExecuteTaFuncWithQueue(
673+
taq = TargetAlgorithmQuery(
673674
pynisher_context=self._multiprocessing_context,
674675
backend=self._backend,
675676
seed=self.seed,
676677
metric=self._metric,
677678
logger_port=self._logger_port,
678-
cost_for_crash=get_cost_of_crash(self._metric),
679+
cost_for_crash=self._metric._cost_of_crash,
679680
abort_on_first_run_crash=False,
680681
initial_num_run=num_run,
682+
pipeline_config=get_default_pipeline_config(choice='dummy'),
681683
stats=stats,
682684
memory_limit=memory_limit,
683685
disable_file_output=self._disable_file_output,
684686
all_supported_metrics=self._all_supported_metrics
685687
)
686688

687-
status, _, _, additional_info = ta.run(num_run, cutoff=self._time_for_task)
689+
status, _, _, additional_info = taq.run(num_run, cutoff=self._time_for_task)
688690
if status == StatusType.SUCCESS:
689691
self._logger.info("Finished creating dummy predictions.")
690692
else:
@@ -753,13 +755,13 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
753755
# already be generated here!
754756
stats = Stats(scenario_mock)
755757
stats.start_timing()
756-
ta = ExecuteTaFuncWithQueue(
758+
taq = TargetAlgorithmQuery(
757759
pynisher_context=self._multiprocessing_context,
758760
backend=self._backend,
759761
seed=self.seed,
760762
metric=self._metric,
761763
logger_port=self._logger_port,
762-
cost_for_crash=get_cost_of_crash(self._metric),
764+
cost_for_crash=self._metric._cost_of_crash,
763765
abort_on_first_run_crash=False,
764766
initial_num_run=self._backend.get_next_num_run(),
765767
stats=stats,
@@ -770,7 +772,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
770772
dask_futures.append([
771773
classifier,
772774
self._dask_client.submit(
773-
ta.run, config=classifier,
775+
taq.run, config=classifier,
774776
cutoff=func_eval_time_limit_secs,
775777
)
776778
])
@@ -1060,7 +1062,7 @@ def _search(
10601062

10611063
# Here the budget is set to max because the SMAC intensifier can be:
10621064
# Hyperband: in this case the budget is determined on the fly and overwritten
1063-
# by the ExecuteTaFuncWithQueue
1065+
# by the TargetAlgorithmQuery
10641066
# SimpleIntensifier (and others): in this case, we use max_budget as a target
10651067
# budget, and hece the below line is honored
10661068
self.pipeline_options[budget_type] = max_budget
@@ -1344,7 +1346,7 @@ def refit(
13441346
dataset_properties=dataset_properties,
13451347
dataset=dataset,
13461348
split_id=split_id)
1347-
fit_and_suppress_warnings(self._logger, model, X, y=None)
1349+
fit_pipeline(self._logger, model, X, y=None)
13481350

13491351
self._clean_logger()
13501352

@@ -1555,27 +1557,26 @@ def fit_pipeline(
15551557

15561558
stats.start_timing()
15571559

1558-
tae = ExecuteTaFuncWithQueue(
1560+
taq = TargetAlgorithmQuery(
15591561
backend=self._backend,
15601562
seed=self.seed,
15611563
metric=metric,
15621564
logger_port=self._logger_port,
1563-
cost_for_crash=get_cost_of_crash(metric),
1565+
cost_for_crash=metric._cost_of_crash,
15641566
abort_on_first_run_crash=False,
15651567
initial_num_run=self._backend.get_next_num_run(),
15661568
stats=stats,
15671569
memory_limit=memory_limit,
15681570
disable_file_output=disable_file_output,
15691571
all_supported_metrics=all_supported_metrics,
1570-
budget_type=budget_type,
15711572
include=include_components,
15721573
exclude=exclude_components,
15731574
search_space_updates=search_space_updates,
15741575
pipeline_config=pipeline_options,
15751576
pynisher_context=self._multiprocessing_context
15761577
)
15771578

1578-
run_info, run_value = tae.run_wrapper(
1579+
run_info, run_value = taq.run_wrapper(
15791580
RunInfo(config=configuration,
15801581
budget=budget,
15811582
seed=self.seed,
@@ -1587,7 +1588,7 @@ def fit_pipeline(
15871588

15881589
fitted_pipeline = self._get_fitted_pipeline(
15891590
dataset_name=dataset.dataset_name,
1590-
pipeline_idx=run_info.config.config_id + tae.initial_num_run,
1591+
pipeline_idx=run_info.config.config_id + taq.initial_num_run,
15911592
run_info=run_info,
15921593
run_value=run_value,
15931594
disable_file_output=disable_file_output
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"budget_type": "epochs",
3+
"epochs": 1,
4+
"runtime": 1
5+
}

0 commit comments

Comments
 (0)