Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest numpy torch matplotlib
python -m pip install flake8 pytest numpy torch matplotlib torchjd
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
launch.json
.vscode
# C extensions
*.so
Expand Down
63 changes: 42 additions & 21 deletions pinn/pinn_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchjd
from torchjd import aggregation as agg
import numpy as np
from enum import Enum
from utils import parse_args, get_activation, print_args, save_frame, make_video_from_frames
from utils import is_notebook, cleanfiles, fourier_analysis, get_scheduler_generator, scheduler_step
from utils import is_notebook, cleanfiles, fourier_analysis, get_scheduler_generator, scheduler_step, monitor_aggregator, get_aggregator
# from SOAP.soap import SOAP
# torch.set_default_dtype(torch.float64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -280,6 +282,8 @@ def __init__(self, loss_type, loss_func=nn.MSELoss(), bc_weight=1.0):
self.name = "PINN Loss"
elif self.type == 1:
self.name = "DRM Loss"
elif self.type == 2:
self.name = "PINN+DRM Loss"
else:
raise ValueError(f"Unknown loss type: {self.type}")
self.bc_weight = bc_weight
Expand All @@ -289,7 +293,7 @@ def super_loss(self, model, mesh, loss_func):
x = mesh.x_train
u = model.get_solution(x)
loss = loss_func(u, mesh.u_ex)
return loss
return loss, [loss,]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

major change


# "PINN" loss
def pinn_loss(self, model, mesh, loss_func):
Expand All @@ -301,15 +305,15 @@ def pinn_loss(self, model, mesh, loss_func):

# Internal loss
pde = mesh.pde
loss = loss_func(d2u_dx2[1:-1] + mesh.f[1:-1], pde.r * u[1:-1])
loss_pinn = loss_func(d2u_dx2[1:-1] + mesh.f[1:-1], pde.r * u[1:-1])
# Boundary loss
if not model.enforce_bc:
u_bc = u[[0, -1]]
u_ex_bc = mesh.u_ex[[0, -1]]
loss_b = loss_func(u_bc, u_ex_bc)
loss += self.bc_weight * loss_b

return loss
loss = loss_pinn + self.bc_weight * loss_b
return loss, [loss_pinn, loss_b]
return loss_pinn, [loss_pinn,]

def drm_loss(self, model, mesh: Mesh):
"""Deep Ritz Method loss"""
Expand All @@ -327,34 +331,48 @@ def drm_loss(self, model, mesh: Mesh):
fu_prod = f_val * u

integrand_values = 0.5 * grad_u_pred_sq[1:-1] + 0.5 * mesh.pde.r * u_pred_sq[1:-1] - fu_prod[1:-1]
loss = torch.mean(integrand_values)

# Boundary loss
u_bc = u[[0, -1]]
u_ex_bc = mesh.u_ex[[0, -1]]
loss_b = self.loss_func(u_bc, u_ex_bc)
loss += self.bc_weight * loss_b

xs.requires_grad_(False) # Disable gradient tracking for x
return loss
loss_drm = torch.mean(integrand_values)

if not model.enforce_bc:
# Boundary loss
u_bc = u[[0,-1]]
u_ex_bc = mesh.u_ex[[0,-1]]
loss_b = self.loss_func(u_bc, u_ex_bc)
loss = loss_drm + self.bc_weight * loss_b
return loss, [loss_drm, loss_b]

#xs.requires_grad_(False) # Disable gradient tracking for x
return loss_drm, [loss_drm,]
def drmpinn_loss(self, model, mesh):
"""Combined Deep Ritz Method and PINN loss"""
loss_p, loss_ps = self.pinn_loss(model=model, mesh=mesh, loss_func=self.loss_func)
loss_d, loss_ds = self.drm_loss(model=model, mesh=mesh)
# Combine losses
loss = loss_p + loss_d
multi_loss = [*loss_ps, loss_ds[0]]
return loss, multi_loss
def loss(self, model, mesh):
if self.type == -1:
loss_value = self.super_loss(model=model, mesh=mesh, loss_func=self.loss_func)
elif self.type == 0:
loss_value = self.pinn_loss(model=model, mesh=mesh, loss_func=self.loss_func)
elif self.type == 1:
loss_value = self.drm_loss(model=model, mesh=mesh)
elif self.type == 2:
loss_value = self.drmpinn_loss(model=model, mesh=mesh)
else:
raise ValueError(f"Unknown loss type: {self.type}")
return loss_value


# %%
# Define the training loop
def train(model, mesh, criterion, iterations, adam_iterations, learning_rate, num_check, num_plots, sweep_idx,
level_idx, frame_dir, scheduler_gen):
level_idx, frame_dir, scheduler_gen, aggregator:str='None', do_monitor_aggregator:bool=False):
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
aggregator = None if aggregator == 'None' else get_aggregator(aggregator)
if (aggregator is not None) and (do_monitor_aggregator):
monitor_aggregator(aggregator)
# optimizer = SOAP(model.parameters(), lr = 3e-3, betas=(.95, .95), weight_decay=.01,
# precondition_frequency=10)
scheduler = scheduler_gen(optimizer)
Expand All @@ -375,7 +393,7 @@ def to_np(t): return t.detach().cpu().numpy()

def closure():
optimizer.zero_grad()
loss = criterion.loss(model=model, mesh=mesh)
loss, _ = criterion.loss(model=model, mesh=mesh)
loss.backward()
return loss

Expand All @@ -385,10 +403,13 @@ def closure():
# we need to set to zero the gradients of all model parameters (PyTorch accumulates grad by default)
optimizer.zero_grad()
# compute the loss value for the current batch of data
loss = criterion.loss(model=model, mesh=mesh)
loss, multiloss = criterion.loss(model=model, mesh=mesh)
# backpropagation to compute gradients of model param respect to the loss. computes dloss/dx
# for every parameter x which has requires_grad=True.
loss.backward()
if aggregator is None:
loss.backward()
else:
torchjd.backward(multiloss, aggregator=aggregator,)
# update the model param doing an optim step using the computed gradients and learning rate
optimizer.step()
#
Expand Down Expand Up @@ -478,7 +499,7 @@ def main(args=None):
train(model=model, mesh=mesh, criterion=loss, iterations=args.epochs,
adam_iterations=args.adam_epochs,
learning_rate=args.lr, num_check=args.num_checks, num_plots=num_plots,
sweep_idx=i, level_idx=lev, frame_dir=frame_dir, scheduler_gen=scheduler_gen)
sweep_idx=i, level_idx=lev, frame_dir=frame_dir, scheduler_gen=scheduler_gen, aggregator=args.aggregator, do_monitor_aggregator=args.monitor_aggregator)
# Turn PNGs into a video using OpenCV
if args.plot:
make_video_from_frames(frame_dir=frame_dir, name_prefix="Model_Outputs",
Expand Down
34 changes: 30 additions & 4 deletions pinn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import numpy as np
import torch
import ast
from torch.nn.functional import cosine_similarity
from torchjd import aggregation as agg

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# %%
def cleanfiles(dir_name):
Expand Down Expand Up @@ -78,8 +81,8 @@ def parse_args(args=None):
help="Learning rate for the optimizer.")
parser.add_argument('--levels', type=int, default=4,
help="Number of levels in multilevel training.")
parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0, 1],
help="Loss type: -1 for supervised (true solution), 0 for PINN loss.")
parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0, 1, 2],
help="Loss type: -1 for supervised (true solution), 0 for PINN loss. 1 for DRM, 2 for DRN+PINN")
parser.add_argument('--activation', type=str, default='tanh',
choices=['tanh', 'silu', 'relu', 'gelu', 'softmax'],
help="Activation function to use.")
Expand All @@ -93,6 +96,9 @@ def parse_args(args=None):
help="If set, enforce the BC in solution.")
parser.add_argument('--bc_weight', type=float, default=1.0,
help="Weight for the loss of BC.")
parser.add_argument('--aggregator', type=str, nargs='+', default='None', help="Aggregator for the loss function. See https://torchjd.org/stable/docs/aggregation/ for options")

parser.add_argument('--monitor_aggregator', action='store_true', help="If set, monitor gradient. This need to set up aggregator")
parser.add_argument("--scheduler", type=str, default="StepLR",
help="Learning rate scheduler to use. "
"See https://docs.pytorch.org/docs/stable/optim.html for full list of schedulers")
Expand Down Expand Up @@ -230,8 +236,28 @@ def make_video_from_frames(frame_dir, name_prefix, output_file, fps=10):
video.release()
print(f" Video saved as {output_file_path}")


# %%
def get_aggregator(name: str):
if isinstance(name, str):
return getattr(agg, name)()
elif name[0]=="Constant":
return getattr(agg, name[0])(torch.tensor([ast.literal_eval(i) for i in name[1:]]).to(device))
else:
return getattr(agg, name[0])(*[ast.literal_eval(i) for i in name[1:]])

def monitor_aggregator(aggregator):
def print_weights(_, __, weights: torch.Tensor) -> None:
"""Prints the extracted weights."""
print(f"Weights: {weights}")

def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
"""Prints the cosine similarity between the aggregation and the average gradient."""
matrix = inputs[0]
gd_output = matrix.mean(dim=0)
similarity = cosine_similarity(aggregation, gd_output, dim=0)
print(f"Cosine similarity: {similarity.item():.4f}")
aggregator.weighting.register_forward_hook(print_weights)
aggregator.register_forward_hook(print_gd_similarity)

def fourier_analysis(x, y):
"""
Compute the magnitude spectrum using the Fast Fourier Transform (FFT).
Expand Down