Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7d06b13
First pass updates for transformer engine support
coreyjadams Jun 27, 2025
bb770a3
Update Transolver to optionally enable transformer engine, and make t…
coreyjadams Jun 30, 2025
3d392c3
Clean up transolver architecture and code to enable irregular and 3D …
coreyjadams Jul 1, 2025
b0897ec
Merge branch 'NVIDIA:main' into transformer_engine
coreyjadams Jul 1, 2025
c0517e5
Beginning work on transolver example. Planning ahead for domain para…
coreyjadams Jul 2, 2025
a070165
Updates for the transolver example using the consolidated and updated…
coreyjadams Jul 2, 2025
b1aa22f
Add small function to convert matlab matrices to npz
coreyjadams Jul 2, 2025
25a37ab
Updates to the transolver model architecture
coreyjadams Jul 2, 2025
016512e
Merge branch 'NVIDIA:main' into transformer_engine
coreyjadams Jul 7, 2025
4c383a3
Update transolver model. Add external aerodynamics example with tran…
coreyjadams Jul 8, 2025
c02482e
Merge branch 'NVIDIA:main' into transformer_engine
coreyjadams Jul 10, 2025
d74089d
Merge branch 'NVIDIA:main' into transformer_engine
coreyjadams Jul 14, 2025
554a3ab
Enable volume training in transolver. Still needs to be validated an…
coreyjadams Jul 17, 2025
a4c3620
Merge branch 'NVIDIA:main' into transformer_engine
coreyjadams Jul 17, 2025
8c95135
Updating transolver example further, and including more readme info
coreyjadams Jul 18, 2025
858c537
Further integrate transformer engine into Transolver.
coreyjadams Jul 18, 2025
221cab1
Update readme to point out matlab to npz conversion of fixed dataset.
coreyjadams Jul 18, 2025
5ea7789
Update README and improve scripts
coreyjadams Jul 18, 2025
79df865
Merge branch 'main' into transformer_engine
coreyjadams Jul 18, 2025
ab33308
Merge branch 'NVIDIA:main' into transformer_engine
coreyjadams Jul 22, 2025
8b430db
Merge branch 'NVIDIA:main' into transformer_engine
coreyjadams Jul 28, 2025
2332a71
Update model to not block multipl layers per mlp
coreyjadams Jul 28, 2025
93d33c2
Enabling domain parallelism.
coreyjadams Jul 28, 2025
0a1082c
Isolate preprocess steps.
coreyjadams Jul 28, 2025
37d66e8
Fix bug in transolver physics attention base: the wrong normalization
coreyjadams Jul 30, 2025
8b61a37
Add script for computing normalization factors for transolver.
coreyjadams Jul 31, 2025
0253bf6
remove volume config for now.
coreyjadams Jul 31, 2025
0313c33
Bulk update of transolver model and example for CFD external aero.
coreyjadams Jul 31, 2025
afbd96c
Merge branch 'NVIDIA:main' into transformer_engine
coreyjadams Jul 31, 2025
4f4906a
Ensure embedding is not None before concat
coreyjadams Jul 31, 2025
ea6c2c9
Update transolver for minor details.
coreyjadams Jul 31, 2025
9fcd89c
Update transolver tests, clean up lingering details in the model.
coreyjadams Aug 1, 2025
ad8c8c8
Merge branch 'NVIDIA:main' into transformer_engine
coreyjadams Aug 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/cfd/darcy_transolver/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ and the data path should be added when `Darcy_2D_fix` dataset is constructed.
You can download the data
[here](https://huggingface.co/datasets/lkuang/example_data).

The `fix` dataset training (which uses a fixed dataset) requires you to convert
data from matlab to numpy format, for faster startup of the training. Just
use the `convert_mat_to_npz.py` script to port your data.

## Model overview and architecture

## Getting Started
Expand Down
39 changes: 23 additions & 16 deletions examples/cfd/darcy_transolver/config_fix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,32 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

output_dir: ./output/darcy_transolver_fix
run_id: bf16_dev_r85_b8_s64

data:
train_path: /user_data/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth1.npz
test_path: /user_data/datasets/darcy_fix/example_data/piececonst_r421_N1024_smooth2.npz
resolution: 85 #421, 211, 141, 106, 85 all viable
batch_size: 8 # This is the GLOBAL batch size

model:
space_dim: 2
n_layers: 8
functional_dim: 1
out_dim: 1
embedding_dim: 2
n_layers: 4
n_hidden: 128
dropout: 0.0
n_head: 8
Time_Input: False
n_head: 4
act: gelu
mlp_ratio: 1
fun_dim: 1
out_dim: 1
slice_dim: 32
mlp_ratio: 4
unified_pos: False
ref: 8
unified_pos: 1
slice_num: 64
use_te: True
Time_Input: False


precision: bf16

normaliser:
permeability:
Expand All @@ -48,17 +57,15 @@ normaliser:

scheduler:
initial_lr: 1.E-3
decay_rate: 1.E-5
decay_rate: 5.E-5
weight_decay: 1.E-5
decay_pseudo_epochs: 8

training:
resolution: 85
batch_size: 4
rec_results_freq : 100
max_pseudo_epochs: 500
pseudo_epoch_sample_size: 1000
max_pseudo_epochs: 1000
pseudo_epoch_sample_size: 1024

validation:
sample_size: 200
validation_pseudo_epochs: 2
validation_pseudo_epochs: 1
47 changes: 47 additions & 0 deletions examples/cfd/darcy_transolver/convert_mat_to_npz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import sys
import numpy as np
from scipy.io import loadmat


def main(mat_file, npz_file):
# Load the .mat file
data = loadmat(mat_file)

# Extract 'coeff' and 'sol'
coeff = data["coeff"]
sol = data["sol"]

# Save to .npz file
np.savez(npz_file, coeff=coeff, sol=sol)
print(f"Saved 'coeff' and 'sol' to {npz_file}")


if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python convert_mat_to_npz.py input.mat")
sys.exit(1)
mat_file = sys.argv[1]
npz_file = mat_file.replace(".mat", ".npz")
main(mat_file, npz_file)
137 changes: 75 additions & 62 deletions examples/cfd/darcy_transolver/darcy_datapipe_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from dataclasses import dataclass
from typing import Dict, Tuple, Union

import numpy as np
import torch
import warp as wp
import scipy.io as scio
import torch.distributed as dist

from physicsnemo.datapipes.datapipe import Datapipe
from physicsnemo.datapipes.meta import DatapipeMetaData
from physicsnemo.datapipes.benchmarks.kernels.finite_difference import (
darcy_mgrid_jacobi_iterative_batched_2d,
mgrid_inf_residual_batched_2d,
)
from physicsnemo.datapipes.benchmarks.kernels.initialization import (
init_uniform_random_4d,
)
from physicsnemo.datapipes.benchmarks.kernels.utils import (
bilinear_upsample_batched_2d,
fourier_to_array_batched_2d,
threshold_3d,
)

Tensor = torch.Tensor
# TODO unsure if better to remove this. Keeping this in for now
wp.init()

from physicsnemo.utils.profiling import profile


class UnitTransformer:
Expand Down Expand Up @@ -138,6 +125,7 @@ class Darcy2D_fix(Datapipe):
Incompatable multi-grid and resolution settings
"""

@profile
def __init__(
self,
resolution: int = 256,
Expand All @@ -148,13 +136,14 @@ def __init__(
max_iterations: int = 30000,
convergence_threshold: float = 1e-6,
iterations_per_convergence_check: int = 1000,
nr_multigrids: int = 4,
normaliser: Union[Dict[str, Tuple[float, float]], None] = None,
# nr_multigrids: int = 4,
# normaliser: Union[Dict[str, Tuple[float, float]], None] = None,
device: Union[str, torch.device] = "cuda",
train_path: str = None,
is_test: bool = False,
x_normalizer: UnitTransformer = None,
y_normalizer: UnitTransformer = None,
downsample: int = 5,
):
super().__init__(meta=MetaData())

Expand All @@ -167,15 +156,15 @@ def __init__(
self.max_iterations = max_iterations
self.convergence_threshold = convergence_threshold
self.iterations_per_convergence_check = iterations_per_convergence_check
self.nr_multigrids = nr_multigrids
self.normaliser = normaliser
# self.nr_multigrids = nr_multigrids
# self.normaliser = normaliser

# check normaliser keys
if self.normaliser is not None:
if not {"permeability", "darcy"}.issubset(set(self.normaliser.keys())):
raise ValueError(
"normaliser need to have keys permeability and darcy with mean and std"
)
# # check normaliser keys
# if self.normaliser is not None:
# if not {"permeability", "darcy"}.issubset(set(self.normaliser.keys())):
# raise ValueError(
# "normaliser need to have keys permeability and darcy with mean and std"
# )

# Set up device for warp, warp has same naming convention as torch.
if isinstance(device, torch.device):
Expand All @@ -185,42 +174,32 @@ def __init__(
# spatial dims
self.dx = 1.0 / (self.resolution + 1) # pad edges by 1 for multi-grid
self.dim = (self.batch_size, self.resolution + 1, self.resolution + 1)
self.fourier_dim = (
4,
self.batch_size,
self.nr_permeability_freq,
self.nr_permeability_freq,
)

# assert resolution is compatible with multi-grid method
# if (resolution % 2 ** (nr_multigrids - 1)) != 0:
# raise ValueError("Resolution is incompatible with number of sub grids.")

# allocate arrays for constructing dataset
self.darcy0 = wp.zeros(self.dim, dtype=float, device=self.device)
self.darcy1 = wp.zeros(self.dim, dtype=float, device=self.device)
self.permeability = wp.zeros(self.dim, dtype=float, device=self.device)
self.rand_fourier = wp.zeros(self.fourier_dim, dtype=float, device=self.device)
self.inf_residual = wp.zeros([1], dtype=float, device=self.device)
self.train_path = train_path
self.downsample = 5
self.r = self.downsample
self.h = int(((421 - 1) / self.r) + 1)
self.s = self.h
# print(f"=============={self.s}===============")
self.native_resolution = 421 # Native grid size

# Calculate downsampling factor
if (self.native_resolution - 1) % (self.resolution - 1) != 0:
raise ValueError(
f"Resolution {self.resolution} is not achievable by strided sampling from native resolution {self.native_resolution}."
)
self.r = (self.native_resolution - 1) // (self.resolution - 1)
self.s = self.resolution
self.dx = 1.0 / self.s

# Output tenors
self.output_k = None
self.output_p = None

self.is_test = is_test

if not self.is_test:
n_train = 1000
self.n_train = 1024
else:
n_train = 200
self.n_train = n_train
self.n_train = 200

if self.train_path is not None:
self.__get_data__()
Expand All @@ -233,28 +212,58 @@ def __init__(
self.y_train = self.y_normalizer.encode(self.y_train)
else:
self.x_train = x_normalizer.encode(self.x_train)
self.y_train = y_normalizer.encode(self.y_train)

@profile
def __get_normalizer__(self):
return self.x_normalizer, self.y_normalizer

@profile
def __get_data__(self):

if self.train_path.endswith(".mat"):
data_dict = scio.loadmat(self.train_path)
elif self.train_path.endswith(".npz"):
data_dict = np.load(self.train_path)

# Extract data from dicts:
self.x_train = data_dict["coeff"]
self.y_train = data_dict["sol"]

x = np.linspace(0, 1, self.s)
y = np.linspace(0, 1, self.s)
x, y = np.meshgrid(x, y)
pos = np.c_[x.ravel(), y.ravel()]
pos = torch.tensor(pos, dtype=torch.float).unsqueeze(0).cuda()
self.x_train = scio.loadmat(self.train_path)["coeff"][
: self.n_train, :: self.r, :: self.r
][:, : self.s, : self.s]
pos = torch.tensor(pos, dtype=torch.float).cuda()

# Downsampling logic
if self.r > 1:
# Downsample by slicing
self.x_train = self.x_train[: self.n_train, :: self.r, :: self.r][
:, : self.s, : self.s
]
self.y_train = self.y_train[: self.n_train, :: self.r, :: self.r][
:, : self.s, : self.s
]
else:
# No downsampling, use full resolution
self.x_train = self.x_train[: self.n_train, : self.s, : self.s]
self.y_train = self.y_train[: self.n_train, : self.s, : self.s]

# Flatten them:
self.x_train = self.x_train.reshape(self.n_train, -1)
self.x_train = torch.from_numpy(self.x_train).float().cuda()
self.y_train = scio.loadmat(self.train_path)["sol"][
: self.n_train, :: self.r, :: self.r
][:, : self.s, : self.s]
self.y_train = self.y_train.reshape(self.n_train, -1)
self.y_train = torch.from_numpy(self.y_train).float().cuda()
self.pos_train = pos.repeat(self.n_train, 1, 1)

self.x_train = torch.from_numpy(self.x_train).float().cuda()
self.y_train = torch.from_numpy(self.y_train).float().cuda()
# Why are we repeating the postion?
# print(f"pos shape: {pos.shape}")
self.pos_train = pos
self.pos_train_batched = pos.repeat(self.batch_size, 1, 1).cuda()
# print(f"pos shape post repeat: {self.pos_train.shape}")
# self.pos_train = pos

@profile
def __iter__(self):
"""
Yields
Expand All @@ -263,13 +272,17 @@ def __iter__(self):
Infinite iterator that returns a batch of (permeability, darcy pressure)
fields of size [batch, resolution, resolution]
"""
# infinite generator

while True:
idx = np.random.choice(200, self.batch_size)
# Sample batch_size indices from this rank's shard
idx = np.random.choice(self.n_train, self.batch_size)
# All tensors are already on GPU, so no .cuda() needed
x = self.x_train[idx]
y = self.y_train[idx]
pos = self.pos_train[idx]
yield pos, x, y
yield self.pos_train_batched, x, y

def __getitem__(self, idx):
return self.pos_train, self.x_train[idx], self.y_train[idx]

def __len__(self):
return self.n_train // self.batch_size
return self.n_train
14 changes: 8 additions & 6 deletions examples/cfd/darcy_transolver/train_transolver_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,24 @@ def darcy_trainer(cfg: DictConfig) -> None:

# define model, loss, optimiser, scheduler, data loader
model = Transolver(
space_dim=cfg.model.space_dim,
functional_dim=cfg.model.functional_dim,
out_dim=cfg.model.out_dim,
embedding_dim=cfg.model.embedding_dim,
n_layers=cfg.model.n_layers,
n_hidden=cfg.model.n_hidden,
dropout=cfg.model.dropout,
n_head=cfg.model.n_head,
Time_Input=cfg.model.Time_Input,
act=cfg.model.act,
mlp_ratio=cfg.model.mlp_ratio,
fun_dim=cfg.model.fun_dim,
out_dim=cfg.model.out_dim,
slice_num=cfg.model.slice_num,
ref=cfg.model.ref,
unified_pos=cfg.model.unified_pos,
H=cfg.training.resolution,
W=cfg.training.resolution,
ref=cfg.model.ref,
structured_shape=[cfg.data.resolution, cfg.data.resolution],
use_te=cfg.model.use_te,
Time_Input=cfg.model.Time_Input,
).to(dist.device)

loss_fun = TestLoss(size_average=False)
optimizer = Adam(model.parameters(), lr=cfg.scheduler.initial_lr)
scheduler = lr_scheduler.LambdaLR(
Expand Down
Loading