Skip to content

Commit 25691f5

Browse files
YonyBreslerYony Breslerpre-commit-ci[bot]manujosephvBorda
authored
Add multi target classification (#441)
* First pass at Multi-Target classifier. Core functionality works, but failing other tests * Updated base model to support custom metrics on multi-target * fix to init metrics param config in multi-target * updates pytests to include multi-target classification * preliminary fix for combine_prediction * Documentation updates * linter cleanup * Bugfix for metrics in multi-target classification * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added new tutorial for multi-target classification * Minor update to documentation for multi-target classification --------- Co-authored-by: Yony Bresler <yony@craterlabs.io> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Manu Joseph V <manujosephv@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
1 parent fc6060e commit 25691f5

22 files changed

+2160
-61
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ from pytorch_tabular.config import (
100100
data_config = DataConfig(
101101
target=[
102102
"target"
103-
], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
103+
], # target should always be a list.
104104
continuous_cols=num_col_names,
105105
categorical_cols=cat_col_names,
106106
)

docs/gs_usage.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ from pytorch_tabular.config import (
1414
data_config = DataConfig(
1515
target=[
1616
"target"
17-
], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
17+
], # target should always be a list.
1818
continuous_cols=num_col_names,
1919
categorical_cols=cat_col_names,
2020
)

docs/tutorials/01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@
532532
"data_config = DataConfig(\n",
533533
" target=[\n",
534534
" target_col\n",
535-
" ], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
535+
" ], # target should always be a list\n",
536536
" continuous_cols=num_col_names,\n",
537537
" categorical_cols=cat_col_names,\n",
538538
")\n",

docs/tutorials/15-Multi Target Classification.ipynb

Lines changed: 2008 additions & 0 deletions
Large diffs are not rendered by default.

examples/__only_for_dev__/adhoc_scaffold.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def print_metrics(y_true, y_pred, tag):
5353
from pytorch_tabular.models import GatedAdditiveTreeEnsembleConfig # noqa: E402
5454

5555
data_config = DataConfig(
56-
# target should always be a list. Multi-targets are only supported for regression.
57-
# Multi-Task Classification is not implemented
56+
# target should always be a list.
5857
target=["target"],
5958
continuous_cols=num_col_names,
6059
categorical_cols=cat_col_names,

examples/__only_for_dev__/to_test_dae.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ def print_metrics(y_true, y_pred, tag):
145145
lr = 1e-3
146146

147147
data_config = DataConfig(
148-
# target should always be a list. Multi-targets are only supported for regression.
149-
# Multi-Task Classification is not implemented
148+
# target should always be a list.
150149
target=[target_name],
151150
continuous_cols=num_col_names,
152151
categorical_cols=cat_col_names,

src/pytorch_tabular/config/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ class InferredConfig:
197197
198198
output_dim (Optional[int]): The number of output targets
199199
200+
output_cardinality (Optional[int]): The number of unique values in classification output
201+
200202
categorical_cardinality (Optional[List[int]]): The number of unique values in categorical features
201203
202204
embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a
@@ -216,6 +218,10 @@ class InferredConfig:
216218
default=None,
217219
metadata={"help": "The number of output targets"},
218220
)
221+
output_cardinality: Optional[List[int]] = field(
222+
default=None,
223+
metadata={"help": "The number of unique values in classification output"},
224+
)
219225
categorical_cardinality: Optional[List[int]] = field(
220226
default=None,
221227
metadata={"help": "The number of unique values in categorical features"},

src/pytorch_tabular/models/base_model.py

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -122,23 +122,43 @@ def __init__(
122122
config.metrics_params.append(vars(metric))
123123
if config.task == "classification":
124124
config.metrics_prob_input = self.custom_metrics_prob_inputs
125+
for i, mp in enumerate(config.metrics_params):
126+
mp.sub_params_list = []
127+
for j, num_classes in enumerate(inferred_config.output_cardinality):
128+
config.metrics_params[i].sub_params_list.append(
129+
OmegaConf.create(
130+
{
131+
"task": mp.get("task", "multiclass"),
132+
"num_classes": mp.get("num_classes", num_classes),
133+
}
134+
)
135+
)
136+
125137
# Updating default metrics in config
126138
elif config.task == "classification":
127139
# Adding metric_params to config for classification task
128140
for i, mp in enumerate(config.metrics_params):
129-
# For classification task, output_dim == number of classses
130-
config.metrics_params[i]["task"] = mp.get("task", "multiclass")
131-
config.metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim)
132-
if config.metrics[i] in (
133-
"accuracy",
134-
"precision",
135-
"recall",
136-
"precision_recall",
137-
"specificity",
138-
"f1_score",
139-
"fbeta_score",
140-
):
141-
config.metrics_params[i]["top_k"] = mp.get("top_k", 1)
141+
mp.sub_params_list = []
142+
for j, num_classes in enumerate(inferred_config.output_cardinality):
143+
# config.metrics_params[i][j]["task"] = mp.get("task", "multiclass")
144+
# config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes)
145+
146+
config.metrics_params[i].sub_params_list.append(
147+
OmegaConf.create(
148+
{"task": mp.get("task", "multiclass"), "num_classes": mp.get("num_classes", num_classes)}
149+
)
150+
)
151+
152+
if config.metrics[i] in (
153+
"accuracy",
154+
"precision",
155+
"recall",
156+
"precision_recall",
157+
"specificity",
158+
"f1_score",
159+
"fbeta_score",
160+
):
161+
config.metrics_params[i].sub_params_list[j]["top_k"] = mp.get("top_k", 1)
142162

143163
if self.custom_optimizer is not None:
144164
config.optimizer = str(self.custom_optimizer.__class__.__name__)
@@ -267,7 +287,22 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
267287
)
268288
else:
269289
# TODO loss fails with batch size of 1?
270-
computed_loss = self.loss(y_hat.squeeze(), y.squeeze()) + reg_loss
290+
computed_loss = reg_loss
291+
start_index = 0
292+
for i in range(len(self.hparams.output_cardinality)):
293+
end_index = start_index + self.hparams.output_cardinality[i]
294+
_loss = self.loss(y_hat[:, start_index:end_index], y[:, i])
295+
computed_loss += _loss
296+
if self.hparams.output_dim > 1:
297+
self.log(
298+
f"{tag}_loss_{i}",
299+
_loss,
300+
on_epoch=True,
301+
on_step=False,
302+
logger=True,
303+
prog_bar=False,
304+
)
305+
start_index = end_index
271306
self.log(
272307
f"{tag}_loss",
273308
computed_loss,
@@ -325,11 +360,29 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
325360
_metrics.append(_metric)
326361
avg_metric = torch.stack(_metrics, dim=0).sum()
327362
else:
328-
y_hat = nn.Softmax(dim=-1)(y_hat.squeeze())
329-
if prob_inp:
330-
avg_metric = metric(y_hat, y.squeeze(), **metric_params)
331-
else:
332-
avg_metric = metric(torch.argmax(y_hat, dim=-1), y.squeeze(), **metric_params)
363+
_metrics = []
364+
start_index = 0
365+
for i, cardinality in enumerate(self.hparams.output_cardinality):
366+
end_index = start_index + cardinality
367+
y_hat_i = nn.Softmax(dim=-1)(y_hat[:, start_index:end_index].squeeze())
368+
if prob_inp:
369+
_metric = metric(y_hat_i, y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i])
370+
else:
371+
_metric = metric(
372+
torch.argmax(y_hat_i, dim=-1), y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i]
373+
)
374+
if len(self.hparams.output_cardinality) > 1:
375+
self.log(
376+
f"{tag}_{metric_str}_{i}",
377+
_metric,
378+
on_epoch=True,
379+
on_step=False,
380+
logger=True,
381+
prog_bar=False,
382+
)
383+
_metrics.append(_metric)
384+
start_index = end_index
385+
avg_metric = torch.stack(_metrics, dim=0).sum()
333386
metrics.append(avg_metric)
334387
self.log(
335388
f"{tag}_{metric_str}",

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,21 @@ def _update_config(self, config) -> InferredConfig:
282282
if config.task == "regression":
283283
# self._output_dim_reg = len(config.target) if config.target else None if self.train is not None:
284284
output_dim = len(config.target) if config.target else None
285+
output_cardinality = None
285286
elif config.task == "classification":
286287
# self._output_dim_clf = len(np.unique(self.train_dataset.y)) if config.target else None
287288
if self.train is not None:
288-
output_dim = len(np.unique(self.train[config.target[0]])) if config.target else None
289+
output_cardinality = (
290+
self.train[config.target].fillna("NA").nunique().tolist() if config.target else None
291+
)
292+
output_dim = sum(output_cardinality)
289293
else:
290-
output_dim = len(np.unique(self.train_dataset.y)) if config.target else None
294+
output_cardinality = (
295+
self.train_dataset.data[config.target].fillna("NA").nunique().tolist() if config.target else None
296+
)
297+
output_dim = sum(output_cardinality)
291298
elif config.task == "ssl":
299+
output_cardinality = None
292300
output_dim = None
293301
else:
294302
raise ValueError(f"{config.task} is an unsupported task.")
@@ -308,6 +316,7 @@ def _update_config(self, config) -> InferredConfig:
308316
categorical_dim=categorical_dim,
309317
continuous_dim=continuous_dim,
310318
output_dim=output_dim,
319+
output_cardinality=output_cardinality,
311320
categorical_cardinality=categorical_cardinality,
312321
embedding_dims=embedding_dims,
313322
)
@@ -381,11 +390,14 @@ def _label_encode_target(self, data: DataFrame, stage: str) -> DataFrame:
381390
if self.config.task != "classification":
382391
return data
383392
if stage == "fit" or self.label_encoder is None:
384-
self.label_encoder = LabelEncoder()
385-
data[self.config.target[0]] = self.label_encoder.fit_transform(data[self.config.target[0]])
393+
self.label_encoder = [None] * len(self.config.target)
394+
for i in range(len(self.config.target)):
395+
self.label_encoder[i] = LabelEncoder()
396+
data[self.config.target[i]] = self.label_encoder[i].fit_transform(data[self.config.target[i]])
386397
else:
387-
if self.config.target[0] in data.columns:
388-
data[self.config.target[0]] = self.label_encoder.transform(data[self.config.target[0]])
398+
for i in range(len(self.config.target)):
399+
if self.config.target[i] in data.columns:
400+
data[self.config.target[i]] = self.label_encoder[i].transform(data[self.config.target[i]])
389401
return data
390402

391403
def _target_transform(self, data: DataFrame, stage: str) -> DataFrame:
@@ -818,7 +830,8 @@ def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
818830
# TODO Is the target encoding necessary?
819831
if len(set(self.target) - set(df.columns)) > 0:
820832
if self.config.task == "classification":
821-
df.loc[:, self.target] = np.array([self.label_encoder.classes_[0]] * len(df)).reshape(-1, 1)
833+
for i in range(len(self.target)):
834+
df.loc[:, self.target[i]] = np.array([self.label_encoder[i].classes_[0]] * len(df)).reshape(-1, 1)
822835
else:
823836
df.loc[:, self.target] = np.zeros((len(df), len(self.target)))
824837
df, _ = self.preprocess_data(df, stage="inference")

src/pytorch_tabular/tabular_model.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,6 @@ def num_params(self):
211211

212212
def _run_validation(self):
213213
"""Validates the Config params and throws errors if something is wrong."""
214-
if self.config.task == "classification":
215-
if len(self.config.target) > 1:
216-
raise NotImplementedError("Multi-Target Classification is not implemented.")
217214
if self.config.task == "regression":
218215
if self.config.target_range is not None:
219216
if (
@@ -1291,12 +1288,16 @@ def _format_predicitons(
12911288
pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1)
12921289

12931290
elif self.config.task == "classification":
1294-
point_predictions = nn.Softmax(dim=-1)(point_predictions).numpy()
1295-
for i, class_ in enumerate(self.datamodule.label_encoder.classes_):
1296-
pred_df[f"{class_}_probability"] = point_predictions[:, i]
1297-
pred_df["prediction"] = self.datamodule.label_encoder.inverse_transform(
1298-
np.argmax(point_predictions, axis=1)
1299-
)
1291+
start_index = 0
1292+
for i, target_col in enumerate(self.config.target):
1293+
end_index = start_index + self.datamodule._inferred_config.output_cardinality[i]
1294+
prob_prediction = nn.Softmax(dim=-1)(point_predictions[:, start_index:end_index]).numpy()
1295+
start_index = end_index
1296+
for j, class_ in enumerate(self.datamodule.label_encoder[i].classes_):
1297+
pred_df[f"{target_col}_{class_}_probability"] = prob_prediction[:, j]
1298+
pred_df[f"{target_col}_prediction"] = self.datamodule.label_encoder[i].inverse_transform(
1299+
np.argmax(prob_prediction, axis=1)
1300+
)
13001301
warnings.warn(
13011302
"Classification prediction column will be renamed to"
13021303
" `{target_col}_prediction` in the next release to maintain"
@@ -2046,23 +2047,21 @@ def _combine_predictions(
20462047
elif callable(aggregate):
20472048
bagged_pred = aggregate(pred_prob_l)
20482049
if self.config.task == "classification":
2049-
classes = self.datamodule.label_encoder.classes_
2050+
# FIXME need to iterate .label_encoder[x]
2051+
classes = self.datamodule.label_encoder[0].classes_
20502052
if aggregate == "hard_voting":
20512053
pred_df = pd.DataFrame(
20522054
np.concatenate(pred_prob_l, axis=1),
2053-
columns=[
2054-
f"{c}_probability_fold_{i}"
2055-
for i in range(len(pred_prob_l))
2056-
for c in self.datamodule.label_encoder.classes_
2057-
],
2055+
columns=[f"{c}_probability_fold_{i}" for i in range(len(pred_prob_l)) for c in classes],
20582056
index=pred_idx,
20592057
)
20602058
pred_df["prediction"] = classes[final_pred]
20612059
else:
20622060
final_pred = classes[np.argmax(bagged_pred, axis=1)]
20632061
pred_df = pd.DataFrame(
20642062
bagged_pred,
2065-
columns=[f"{c}_probability" for c in self.datamodule.label_encoder.classes_],
2063+
# FIXME
2064+
columns=[f"{c}_probability" for c in self.datamodule.label_encoder[0].classes_],
20662065
index=pred_idx,
20672066
)
20682067
pred_df["prediction"] = final_pred

0 commit comments

Comments
 (0)