|
| 1 | +# -*- encoding: utf-8 -*- |
| 2 | +""" |
| 3 | +========================== |
| 4 | +Fit a single configuration |
| 5 | +========================== |
| 6 | +*Auto-PyTorch* searches for the best combination of machine learning algorithms |
| 7 | +and their hyper-parameter configuration for a given task. |
| 8 | +
|
| 9 | +This example shows how one can fit one of these pipelines, both, with a user defined |
| 10 | +configuration, and a randomly sampled one form the configuration space. |
| 11 | +The pipelines that Auto-PyTorch fits are compatible with Scikit-Learn API. You can |
| 12 | +get further documentation about Scikit-Learn models here: <https://scikit-learn.org/stable/getting_started.html`>_ |
| 13 | +""" |
| 14 | +import os |
| 15 | +import tempfile as tmp |
| 16 | +import warnings |
| 17 | + |
| 18 | +os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir() |
| 19 | +os.environ['OMP_NUM_THREADS'] = '1' |
| 20 | +os.environ['OPENBLAS_NUM_THREADS'] = '1' |
| 21 | +os.environ['MKL_NUM_THREADS'] = '1' |
| 22 | + |
| 23 | +warnings.simplefilter(action='ignore', category=UserWarning) |
| 24 | +warnings.simplefilter(action='ignore', category=FutureWarning) |
| 25 | + |
| 26 | +import sklearn.datasets |
| 27 | +import sklearn.metrics |
| 28 | + |
| 29 | +from autoPyTorch.api.tabular_classification import TabularClassificationTask |
| 30 | +from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes |
| 31 | + |
| 32 | + |
| 33 | +if __name__ == '__main__': |
| 34 | + ############################################################################ |
| 35 | + # Data Loading |
| 36 | + # ============ |
| 37 | + |
| 38 | + X, y = sklearn.datasets.fetch_openml(data_id=3, return_X_y=True, as_frame=True) |
| 39 | + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( |
| 40 | + X, y, test_size=0.5, random_state=3 |
| 41 | + ) |
| 42 | + |
| 43 | + ############################################################################ |
| 44 | + # Define an estimator |
| 45 | + # ============================ |
| 46 | + |
| 47 | + # Search for a good configuration |
| 48 | + estimator = TabularClassificationTask( |
| 49 | + resampling_strategy=HoldoutValTypes.holdout_validation, |
| 50 | + resampling_strategy_args={'val_share': 0.33} |
| 51 | + ) |
| 52 | + |
| 53 | + ############################################################################ |
| 54 | + # Get a random configuration of the pipeline for current dataset |
| 55 | + # =============================================================== |
| 56 | + |
| 57 | + dataset = estimator.get_dataset(X_train=X_train, |
| 58 | + y_train=y_train, |
| 59 | + X_test=X_test, |
| 60 | + y_test=y_test) |
| 61 | + configuration = estimator.get_search_space(dataset).get_default_configuration() |
| 62 | + |
| 63 | + ########################################################################### |
| 64 | + # Fit the configuration |
| 65 | + # ================================== |
| 66 | + |
| 67 | + pipeline, run_info, run_value, dataset = estimator.fit_pipeline(X_train=X_train, y_train=y_train, |
| 68 | + dataset_name='kr-vs-kp', |
| 69 | + X_test=X_test, y_test=y_test, |
| 70 | + disable_file_output=False, |
| 71 | + configuration=configuration |
| 72 | + ) |
| 73 | + |
| 74 | + # This object complies with Scikit-Learn Pipeline API. |
| 75 | + # https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html |
| 76 | + print(pipeline.named_steps) |
| 77 | + |
| 78 | + # The fit_pipeline command also returns a named tuple with the pipeline constraints |
| 79 | + print(run_info) |
| 80 | + |
| 81 | + # The fit_pipeline command also returns a named tuple with train/test performance |
| 82 | + print(run_value) |
| 83 | + |
| 84 | + print("Passed Configuration:", pipeline.config) |
| 85 | + print("Network:", pipeline.named_steps['network'].network) |
0 commit comments