Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions docs/source/tutorials/ptf_V2_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -412,20 +412,21 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"id": "xOsEucZnzCkN"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from pytorch_forecasting.metrics import MAE, SMAPE\n",
"from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand All @@ -446,7 +447,7 @@
"source": [
"# Initialise the Model\n",
"model = TFT(\n",
" loss=MAE(),\n",
" loss=nn.L1Loss(),\n",
" logging_metrics=[MAE(), SMAPE()],\n",
" optimizer=\"adam\",\n",
" optimizer_params={\"lr\": 1e-3},\n",
Expand Down Expand Up @@ -493,7 +494,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down Expand Up @@ -791,7 +792,7 @@
"provenance": []
},
"kernelspec": {
"display_name": ".venv (3.12.3)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand All @@ -805,7 +806,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.6"
}
},
"nbformat": 4,
Expand Down
124 changes: 124 additions & 0 deletions pytorch_forecasting/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,81 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:
else torch.zeros((features.shape[0], 0))
)

if self._scalers and self.continuous_indices:
for i, orig_idx in enumerate(self.continuous_indices):
col_name = self.time_series_metadata["cols"]["x"][orig_idx]
if col_name in self._scalers:
scaler = self._scalers[col_name]
feature_data = continuous[:, i]
try:
if isinstance(scaler, (TorchNormalizer, EncoderNormalizer)):
continuous[:, i] = scaler.transform(feature_data)
elif isinstance(scaler, (StandardScaler, RobustScaler)):
# if scaler is a sklearn scaler, we need to
# input numpy np.array
requires_grad = feature_data.requires_grad
device = feature_data.device
feature_data_np = (
feature_data.cpu().detach().numpy().reshape(-1, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have doubt: Wouldn't using detach again detach the tensor from the computation graph? That would again lead to the same issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as my knowledge of pytorch goes, I think it's a good practice to use .detach() before converting the pytorch tensor to a numpy array. Anyways, the numpy array will not track the gradients, so this won't matter.

) # noqa: E501
scaled_feature_np = scaler.transform(feature_data_np)
scaled_tensor = torch.from_numpy(
scaled_feature_np.flatten()
).to(device)
if requires_grad:
scaled_tensor = scaled_tensor.requires_grad_(True)
continuous[:, i] = scaled_tensor
except Exception as e:
import warnings

warnings.warn(
f"Failed to transform feature '{col_name}' with scaler: {e}. " # noqa: E501
f"Using unscaled values.",
UserWarning,
)
continue

if self._target_normalizer is not None:
try:
if isinstance(
self._target_normalizer, (TorchNormalizer, EncoderNormalizer)
):
# automatically handle multiple targets.
target = self._target_normalizer.transform(target)
elif isinstance(
self._target_normalizer, (StandardScaler, RobustScaler)
):
requires_grad = target.requires_grad
device = target.device
if target.ndim == 2: # (seq_len, n_targets)
target_scaled = []
for i in range(target.shape[1]):
target_np = (
target[:, i].detach().cpu().numpy().reshape(-1, 1)
) # noqa: E501
scaled = self._target_normalizer.transform(target_np)
scaled_tensor = torch.from_numpy(scaled.flatten()).to(
device
) # noqa: E501
if requires_grad:
scaled_tensor = scaled_tensor.requires_grad_(True)
target_scaled.append(scaled_tensor)
target = torch.stack(target_scaled, dim=1)
else:
target_np = target.detach().cpu().numpy().reshape(-1, 1)
target_scaled = self._target_normalizer.transform(target_np)
target = torch.from_numpy(target_scaled.flatten()).to(device)
if requires_grad:
target = target.requires_grad_(True)
except Exception as e:
import warnings

warnings.warn(
f"Failed to transform target with scaler: {e}. " # noqa: E501
f"Using unscaled values.",
UserWarning,
)

return {
"features": {"categorical": categorical, "continuous": continuous},
"target": target,
Expand Down Expand Up @@ -623,6 +698,52 @@ def _create_windows(self, indices: torch.Tensor) -> list[tuple[int, int, int, in

return windows

def _fit_scalers(self, train_indices: torch.Tensor):
"""Fit scaler on the training dataset.

Parameters
----------
train_indices : torch.Tensor
Indices of the training time series in `time_series_dataset`.
"""

train_targets = []
train_features = []

for series_idx in train_indices:
sample = self.time_series_dataset[series_idx.item()]
target = sample["y"]
features = sample["x"]

train_targets.append(target)
train_features.append(features)

train_targets = torch.cat(train_targets, dim=0)
train_features = torch.cat(train_features, dim=0)

if self._target_normalizer is not None:
if isinstance(
self._target_normalizer, (TorchNormalizer, EncoderNormalizer)
):
self._target_normalizer.fit(train_targets)
elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)):
target_np = train_targets.numpy()
if target_np.ndim == 1:
target_np = target_np.reshape(-1, 1)
self._target_normalizer.fit(target_np)

if self._scalers and self.continuous_indices:
for i, orig_idx in enumerate(self.continuous_indices):
col_name = self.time_series_metadata["cols"]["x"][orig_idx]
if col_name in self._scalers:
scaler = self._scalers[col_name]
feature_data = train_features[:, orig_idx]

if isinstance(scaler, (StandardScaler, RobustScaler)):
feature_data = feature_data.numpy().reshape(-1, 1)

scaler.fit(feature_data)

def setup(self, stage: Optional[str] = None):
"""Prepare the datasets for training, validation, testing, or prediction.

Expand All @@ -647,6 +768,9 @@ def setup(self, stage: Optional[str] = None):
]
self._test_indices = self._split_indices[self._train_size + self._val_size :]

if (stage is None or stage == "fit") and len(self._train_indices) > 0:
self._fit_scalers(self._train_indices)

if stage is None or stage == "fit":
if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"):
self.train_windows = self._create_windows(self._train_indices)
Expand Down
Loading