Skip to content

Commit fc6060e

Browse files
sorenmacbethSoren Macbethmanujosephv
authored
remove restriction for using missing and unknown (#470)
category handling in SSL models Co-authored-by: Soren Macbeth <soren@abracadaniel.local> Co-authored-by: Manu Joseph V <manujosephv@gmail.com>
1 parent 2957537 commit fc6060e

File tree

1 file changed

+2
-15
lines changed

1 file changed

+2
-15
lines changed

src/pytorch_tabular/tabular_model.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
from pandas import DataFrame
2323
from pytorch_lightning import seed_everything
2424
from pytorch_lightning.callbacks import RichProgressBar
25-
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import (
26-
GradientAccumulationScheduler,
27-
)
25+
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
2826
from pytorch_lightning.tuner.tuning import Tuner
2927
from pytorch_lightning.utilities.model_summary import summarize
3028
from rich import print as rich_print
@@ -43,11 +41,7 @@
4341
)
4442
from pytorch_tabular.config.config import InferredConfig
4543
from pytorch_tabular.models.base_model import BaseModel, _CaptumModel, _GenericModel
46-
from pytorch_tabular.models.common.layers.embeddings import (
47-
Embedding1dLayer,
48-
Embedding2dLayer,
49-
PreEncoded1dLayer,
50-
)
44+
from pytorch_tabular.models.common.layers.embeddings import Embedding1dLayer, Embedding2dLayer, PreEncoded1dLayer
5145
from pytorch_tabular.tabular_datamodule import TabularDatamodule
5246
from pytorch_tabular.utils import (
5347
OOMException,
@@ -232,13 +226,6 @@ def _run_validation(self):
232226
" two(min,max). The length of the list should be equal to hte"
233227
" length of target columns"
234228
)
235-
if self.config.task == "ssl":
236-
assert not self.config.handle_unknown_categories, (
237-
"SSL only supports handle_unknown_categories=False. Please set this" " in your DataConfig"
238-
)
239-
assert not self.config.handle_missing_values, (
240-
"SSL only supports handle_missing_values=False. Please set this in" " your DataConfig"
241-
)
242229

243230
def _read_parse_config(self, config, cls):
244231
if isinstance(config, str):

0 commit comments

Comments
 (0)