|
22 | 22 | from pandas import DataFrame |
23 | 23 | from pytorch_lightning import seed_everything |
24 | 24 | from pytorch_lightning.callbacks import RichProgressBar |
25 | | -from pytorch_lightning.callbacks.gradient_accumulation_scheduler import ( |
26 | | - GradientAccumulationScheduler, |
27 | | -) |
| 25 | +from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler |
28 | 26 | from pytorch_lightning.tuner.tuning import Tuner |
29 | 27 | from pytorch_lightning.utilities.model_summary import summarize |
30 | 28 | from rich import print as rich_print |
|
43 | 41 | ) |
44 | 42 | from pytorch_tabular.config.config import InferredConfig |
45 | 43 | from pytorch_tabular.models.base_model import BaseModel, _CaptumModel, _GenericModel |
46 | | -from pytorch_tabular.models.common.layers.embeddings import ( |
47 | | - Embedding1dLayer, |
48 | | - Embedding2dLayer, |
49 | | - PreEncoded1dLayer, |
50 | | -) |
| 44 | +from pytorch_tabular.models.common.layers.embeddings import Embedding1dLayer, Embedding2dLayer, PreEncoded1dLayer |
51 | 45 | from pytorch_tabular.tabular_datamodule import TabularDatamodule |
52 | 46 | from pytorch_tabular.utils import ( |
53 | 47 | OOMException, |
@@ -232,13 +226,6 @@ def _run_validation(self): |
232 | 226 | " two(min,max). The length of the list should be equal to hte" |
233 | 227 | " length of target columns" |
234 | 228 | ) |
235 | | - if self.config.task == "ssl": |
236 | | - assert not self.config.handle_unknown_categories, ( |
237 | | - "SSL only supports handle_unknown_categories=False. Please set this" " in your DataConfig" |
238 | | - ) |
239 | | - assert not self.config.handle_missing_values, ( |
240 | | - "SSL only supports handle_missing_values=False. Please set this in" " your DataConfig" |
241 | | - ) |
242 | 229 |
|
243 | 230 | def _read_parse_config(self, config, cls): |
244 | 231 | if isinstance(config, str): |
|
0 commit comments