Skip to content

Commit 95a5969

Browse files
[ADD] Calculate memory of dataset after one hot encoding (pytorch embedding) (#437)
* add updates for apt1.0+reg_cocktails * debug loggers for checking data and network memory usage * add support for pandas, test for data passing, remove debug loggers * remove unwanted changes * : * Adjust formula to account for embedding columns * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> * remove unwanted additions * Update autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent 9573358 commit 95a5969

File tree

11 files changed

+98
-41
lines changed

11 files changed

+98
-41
lines changed

autoPyTorch/api/base_task.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,23 +111,6 @@ def send_warnings_to_log(
111111
return prediction
112112

113113

114-
def get_search_updates(categorical_indicator: List[bool]) -> HyperparameterSearchSpaceUpdates:
115-
"""
116-
These updates mimic the autopytorch tabular paper.
117-
Returns:
118-
________
119-
search_space_updates - HyperparameterSearchSpaceUpdates
120-
The search space updates like setting different hps to different values or ranges.
121-
"""
122-
123-
# has_cat_features = any(categorical_indicator)
124-
# has_numerical_features = not all(categorical_indicator)
125-
126-
search_space_updates = HyperparameterSearchSpaceUpdates()
127-
128-
return search_space_updates
129-
130-
131114
class BaseTask(ABC):
132115
"""
133116
Base class for the tasks that serve as API to the pipelines.
@@ -200,7 +183,6 @@ def __init__(
200183
resampling_strategy_args: Optional[Dict[str, Any]] = None,
201184
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
202185
task_type: Optional[str] = None,
203-
categorical_indicator: Optional[List[bool]] = None
204186
) -> None:
205187

206188
if isinstance(resampling_strategy, NoResamplingStrategyTypes) and ensemble_size != 0:

autoPyTorch/api/tabular_classification.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def __init__(
9898
resampling_strategy_args: Optional[Dict[str, Any]] = None,
9999
backend: Optional[Backend] = None,
100100
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
101-
categorical_indicator: Optional[List[bool]] = None
102101
):
103102
super().__init__(
104103
seed=seed,
@@ -119,7 +118,6 @@ def __init__(
119118
resampling_strategy_args=resampling_strategy_args,
120119
search_space_updates=search_space_updates,
121120
task_type=TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION],
122-
categorical_indicator=categorical_indicator
123121
)
124122

125123
def build_pipeline(

autoPyTorch/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,5 @@
5454

5555
CLASSIFICATION_OUTPUTS = [BINARY, MULTICLASS, MULTICLASSMULTIOUTPUT]
5656
REGRESSION_OUTPUTS = [CONTINUOUS, CONTINUOUSMULTIOUTPUT]
57+
58+
MIN_CATEGORIES_FOR_EMBEDDING_MAX = 7

autoPyTorch/data/tabular_validator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def _compress_dataset(
104104
y=y,
105105
is_classification=self.is_classification,
106106
random_state=self.seed,
107+
categorical_columns=self.feature_validator.categorical_columns,
108+
n_categories_per_cat_column=self.feature_validator.num_categories_per_col,
107109
**self.dataset_compression # type: ignore [arg-type]
108110
)
109111
self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype

autoPyTorch/data/utils.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sklearn.utils import _approximate_mode, check_random_state
2626
from sklearn.utils.validation import _num_samples, check_array
2727

28+
from autoPyTorch.constants import MIN_CATEGORIES_FOR_EMBEDDING_MAX
2829
from autoPyTorch.data.base_target_validator import SupportedTargetTypes
2930
from autoPyTorch.utils.common import ispandas
3031

@@ -459,8 +460,8 @@ def _subsample_by_indices(
459460
return X, y
460461

461462

462-
def megabytes(arr: DatasetCompressionInputType) -> float:
463-
463+
def get_raw_memory_usage(arr: DatasetCompressionInputType) -> float:
464+
memory_in_bytes: float
464465
if isinstance(arr, np.ndarray):
465466
memory_in_bytes = arr.nbytes
466467
elif issparse(arr):
@@ -470,19 +471,57 @@ def megabytes(arr: DatasetCompressionInputType) -> float:
470471
else:
471472
raise ValueError(f"Unrecognised data type of X, expected data type to "
472473
f"be in (np.ndarray, spmatrix, pd.DataFrame) but got :{type(arr)}")
474+
return memory_in_bytes
475+
476+
477+
def get_approximate_mem_usage_in_mb(
478+
arr: DatasetCompressionInputType,
479+
categorical_columns: List,
480+
n_categories_per_cat_column: Optional[List[int]] = None
481+
) -> float:
482+
483+
err_msg = "Value number of categories per categorical is required when the data has categorical columns"
484+
if ispandas(arr):
485+
arr_dtypes = arr.dtypes.to_dict()
486+
multipliers = [dtype.itemsize for col, dtype in arr_dtypes.items() if col not in categorical_columns]
487+
if len(categorical_columns) > 0:
488+
if n_categories_per_cat_column is None:
489+
raise ValueError(err_msg)
490+
for col, num_cat in zip(categorical_columns, n_categories_per_cat_column):
491+
if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX:
492+
multipliers.append(num_cat * arr_dtypes[col].itemsize)
493+
else:
494+
multipliers.append(arr_dtypes[col].itemsize)
495+
size_one_row = sum(multipliers)
496+
497+
elif isinstance(arr, (np.ndarray, spmatrix)):
498+
n_cols = arr.shape[-1] - len(categorical_columns)
499+
multiplier = arr.dtype.itemsize
500+
if len(categorical_columns) > 0:
501+
if n_categories_per_cat_column is None:
502+
raise ValueError(err_msg)
503+
# multiply num categories with the size of the column to capture memory after one hot encoding
504+
n_cols += sum(num_cat if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX else 1 for num_cat in n_categories_per_cat_column)
505+
size_one_row = n_cols * multiplier
506+
else:
507+
raise ValueError(f"Unrecognised data type of X, expected data type to "
508+
f"be in (np.ndarray, spmatrix, pd.DataFrame), but got :{type(arr)}")
473509

474-
return float(memory_in_bytes / (2**20))
510+
return float(arr.shape[0] * size_one_row / (2**20))
475511

476512

477513
def reduce_dataset_size_if_too_large(
478514
X: DatasetCompressionInputType,
479515
memory_allocation: Union[int, float],
480516
is_classification: bool,
481517
random_state: Union[int, np.random.RandomState],
518+
categorical_columns: List,
519+
n_categories_per_cat_column: Optional[List[int]] = None,
482520
y: Optional[SupportedTargetTypes] = None,
483521
methods: List[str] = ['precision', 'subsample'],
484522
) -> DatasetCompressionInputType:
485-
f""" Reduces the size of the dataset if it's too close to the memory limit.
523+
f"""
524+
Reduces the size of the dataset if it's too close to the memory limit.
486525
487526
Follows the order of the operations passed in and retains the type of its
488527
input.
@@ -513,7 +552,6 @@ def reduce_dataset_size_if_too_large(
513552
Reduce the amount of samples of the dataset such that it fits into the allocated
514553
memory. Ensures stratification and that unique labels are present
515554
516-
517555
memory_allocation (Union[int, float]):
518556
The amount of memory to allocate to the dataset. It should specify an
519557
absolute amount.
@@ -524,7 +562,7 @@ def reduce_dataset_size_if_too_large(
524562
"""
525563

526564
for method in methods:
527-
if megabytes(X) <= memory_allocation:
565+
if get_approximate_mem_usage_in_mb(X, categorical_columns, n_categories_per_cat_column) <= memory_allocation:
528566
break
529567

530568
if method == 'precision':
@@ -540,7 +578,8 @@ def reduce_dataset_size_if_too_large(
540578
# into the allocated memory, we subsample it so that it does
541579

542580
n_samples_before = X.shape[0]
543-
sample_percentage = memory_allocation / megabytes(X)
581+
sample_percentage = memory_allocation / get_approximate_mem_usage_in_mb(
582+
X, categorical_columns, n_categories_per_cat_column)
544583

545584
# NOTE: type ignore
546585
#

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N
2424
self.add_fit_requirements([
2525
FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True),
2626
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True)])
27+
2728

2829
def get_column_transformer(self) -> ColumnTransformer:
2930
"""

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/ColumnSplitter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99

10-
10+
from autoPyTorch.constants import MIN_CATEGORIES_FOR_EMBEDDING_MAX
1111
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1212
from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import \
1313
autoPyTorchTabularPreprocessingComponent
@@ -72,7 +72,7 @@ def get_hyperparameter_search_space(
7272
dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None,
7373
min_categories_for_embedding: HyperparameterSearchSpace = HyperparameterSearchSpace(
7474
hyperparameter="min_categories_for_embedding",
75-
value_range=(3, 7),
75+
value_range=(3, MIN_CATEGORIES_FOR_EMBEDDING_MAX),
7676
default_value=3,
7777
log=True),
7878
) -> ConfigurationSpace:

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OneHotEncoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEncoder:
2424
# It is safer to have the OHE produce a 0 array than to crash a good configuration
2525
categories='auto',
2626
sparse=False,
27-
handle_unknown='ignore')
27+
handle_unknown='ignore',
28+
dtype=np.float32)
2829
return self
2930

3031
@staticmethod

test/test_api/test_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,10 +500,10 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
500500
del estimator
501501

502502

503-
@pytest.skip("Fix with new portfolio PR")
504503
@unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function',
505504
new=dummy_eval_train_function)
506505
@pytest.mark.parametrize('openml_id', (40981, ))
506+
@pytest.mark.skip(reason="Fix with new portfolio PR")
507507
def test_portfolio_selection(openml_id, backend, n_samples):
508508

509509
# Get the data and check that contents of data-manager make sense
@@ -543,7 +543,7 @@ def test_portfolio_selection(openml_id, backend, n_samples):
543543
assert any(successful_config in portfolio_configs for successful_config in successful_configs)
544544

545545

546-
@pytest.skip("Fix with new portfolio PR")
546+
@pytest.mark.skip(reason="Fix with new portfolio PR")
547547
@unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function',
548548
new=dummy_eval_train_function)
549549
@pytest.mark.parametrize('openml_id', (40981, ))

test/test_data/test_utils.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from autoPyTorch.data.utils import (
2626
default_dataset_compression_arg,
2727
get_dataset_compression_mapping,
28-
megabytes,
28+
get_raw_memory_usage,
2929
reduce_dataset_size_if_too_large,
3030
reduce_precision,
3131
subsample,
@@ -45,13 +45,14 @@ def test_reduce_dataset_if_too_large(openmlid, as_frame, n_samples):
4545
X.copy(),
4646
y=y.copy(),
4747
is_classification=True,
48+
categorical_columns=[],
4849
random_state=1,
49-
memory_allocation=0.001)
50+
memory_allocation=0.01)
5051

5152
assert X_converted.shape[0] < X.shape[0]
5253
assert y_converted.shape[0] < y.shape[0]
5354

54-
assert megabytes(X_converted) < megabytes(X)
55+
assert get_raw_memory_usage(X_converted) < get_raw_memory_usage(X)
5556

5657

5758
@pytest.mark.parametrize("X", [np.asarray([[1, 1, 1]] * 30)])
@@ -211,8 +212,18 @@ def test_unsupported_errors():
211212
['a', 'b', 'c', 'a', 'b', 'c'],
212213
['a', 'b', 'd', 'r', 'b', 'c']])
213214
with pytest.raises(ValueError, match=r'X.dtype = .*'):
214-
reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0)
215+
reduce_dataset_size_if_too_large(
216+
X,
217+
is_classification=True,
218+
categorical_columns=[],
219+
random_state=1,
220+
memory_allocation=0)
215221

216222
X = [[1, 2], [2, 3]]
217223
with pytest.raises(ValueError, match=r'Unrecognised data type of X, expected data type to be in .*'):
218-
reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0)
224+
reduce_dataset_size_if_too_large(
225+
X,
226+
is_classification=True,
227+
categorical_columns=[],
228+
random_state=1,
229+
memory_allocation=0)

0 commit comments

Comments
 (0)