diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 24391e1..acfa2b9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,7 +27,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest numpy torch matplotlib + python -m pip install flake8 pytest numpy torch matplotlib torchjd if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | diff --git a/.gitignore b/.gitignore index dc584c5..6dd8349 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +launch.json .vscode # C extensions *.so diff --git a/pinn/pinn_1d.py b/pinn/pinn_1d.py index 2599d8d..fd9f2c9 100644 --- a/pinn/pinn_1d.py +++ b/pinn/pinn_1d.py @@ -45,10 +45,12 @@ import torch import torch.nn as nn import torch.optim as optim +import torchjd +from torchjd import aggregation as agg import numpy as np from enum import Enum from utils import parse_args, get_activation, print_args, save_frame, make_video_from_frames -from utils import is_notebook, cleanfiles, fourier_analysis, get_scheduler_generator, scheduler_step +from utils import is_notebook, cleanfiles, fourier_analysis, get_scheduler_generator, scheduler_step, monitor_aggregator, get_aggregator # from SOAP.soap import SOAP # torch.set_default_dtype(torch.float64) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -280,6 +282,8 @@ def __init__(self, loss_type, loss_func=nn.MSELoss(), bc_weight=1.0): self.name = "PINN Loss" elif self.type == 1: self.name = "DRM Loss" + elif self.type == 2: + self.name = "PINN+DRM Loss" else: raise ValueError(f"Unknown loss type: {self.type}") self.bc_weight = bc_weight @@ -289,7 +293,7 @@ def super_loss(self, model, mesh, loss_func): x = mesh.x_train u = model.get_solution(x) loss = loss_func(u, mesh.u_ex) - return loss + return loss, [loss,] # "PINN" loss def pinn_loss(self, model, mesh, loss_func): @@ -301,15 +305,15 @@ def pinn_loss(self, model, mesh, loss_func): # Internal loss pde = mesh.pde - loss = loss_func(d2u_dx2[1:-1] + mesh.f[1:-1], pde.r * u[1:-1]) + loss_pinn = loss_func(d2u_dx2[1:-1] + mesh.f[1:-1], pde.r * u[1:-1]) # Boundary loss if not model.enforce_bc: u_bc = u[[0, -1]] u_ex_bc = mesh.u_ex[[0, -1]] loss_b = loss_func(u_bc, u_ex_bc) - loss += self.bc_weight * loss_b - - return loss + loss = loss_pinn + self.bc_weight * loss_b + return loss, [loss_pinn, loss_b] + return loss_pinn, [loss_pinn,] def drm_loss(self, model, mesh: Mesh): """Deep Ritz Method loss""" @@ -327,17 +331,27 @@ def drm_loss(self, model, mesh: Mesh): fu_prod = f_val * u integrand_values = 0.5 * grad_u_pred_sq[1:-1] + 0.5 * mesh.pde.r * u_pred_sq[1:-1] - fu_prod[1:-1] - loss = torch.mean(integrand_values) - - # Boundary loss - u_bc = u[[0, -1]] - u_ex_bc = mesh.u_ex[[0, -1]] - loss_b = self.loss_func(u_bc, u_ex_bc) - loss += self.bc_weight * loss_b - xs.requires_grad_(False) # Disable gradient tracking for x - return loss + loss_drm = torch.mean(integrand_values) + if not model.enforce_bc: + # Boundary loss + u_bc = u[[0,-1]] + u_ex_bc = mesh.u_ex[[0,-1]] + loss_b = self.loss_func(u_bc, u_ex_bc) + loss = loss_drm + self.bc_weight * loss_b + return loss, [loss_drm, loss_b] + + #xs.requires_grad_(False) # Disable gradient tracking for x + return loss_drm, [loss_drm,] + def drmpinn_loss(self, model, mesh): + """Combined Deep Ritz Method and PINN loss""" + loss_p, loss_ps = self.pinn_loss(model=model, mesh=mesh, loss_func=self.loss_func) + loss_d, loss_ds = self.drm_loss(model=model, mesh=mesh) + # Combine losses + loss = loss_p + loss_d + multi_loss = [*loss_ps, loss_ds[0]] + return loss, multi_loss def loss(self, model, mesh): if self.type == -1: loss_value = self.super_loss(model=model, mesh=mesh, loss_func=self.loss_func) @@ -345,16 +359,20 @@ def loss(self, model, mesh): loss_value = self.pinn_loss(model=model, mesh=mesh, loss_func=self.loss_func) elif self.type == 1: loss_value = self.drm_loss(model=model, mesh=mesh) + elif self.type == 2: + loss_value = self.drmpinn_loss(model=model, mesh=mesh) else: raise ValueError(f"Unknown loss type: {self.type}") return loss_value - # %% # Define the training loop def train(model, mesh, criterion, iterations, adam_iterations, learning_rate, num_check, num_plots, sweep_idx, - level_idx, frame_dir, scheduler_gen): + level_idx, frame_dir, scheduler_gen, aggregator:str='None', do_monitor_aggregator:bool=False): optimizer = optim.Adam(model.parameters(), lr=learning_rate) + aggregator = None if aggregator == 'None' else get_aggregator(aggregator) + if (aggregator is not None) and (do_monitor_aggregator): + monitor_aggregator(aggregator) # optimizer = SOAP(model.parameters(), lr = 3e-3, betas=(.95, .95), weight_decay=.01, # precondition_frequency=10) scheduler = scheduler_gen(optimizer) @@ -375,7 +393,7 @@ def to_np(t): return t.detach().cpu().numpy() def closure(): optimizer.zero_grad() - loss = criterion.loss(model=model, mesh=mesh) + loss, _ = criterion.loss(model=model, mesh=mesh) loss.backward() return loss @@ -385,10 +403,13 @@ def closure(): # we need to set to zero the gradients of all model parameters (PyTorch accumulates grad by default) optimizer.zero_grad() # compute the loss value for the current batch of data - loss = criterion.loss(model=model, mesh=mesh) + loss, multiloss = criterion.loss(model=model, mesh=mesh) # backpropagation to compute gradients of model param respect to the loss. computes dloss/dx # for every parameter x which has requires_grad=True. - loss.backward() + if aggregator is None: + loss.backward() + else: + torchjd.backward(multiloss, aggregator=aggregator,) # update the model param doing an optim step using the computed gradients and learning rate optimizer.step() # @@ -478,7 +499,7 @@ def main(args=None): train(model=model, mesh=mesh, criterion=loss, iterations=args.epochs, adam_iterations=args.adam_epochs, learning_rate=args.lr, num_check=args.num_checks, num_plots=num_plots, - sweep_idx=i, level_idx=lev, frame_dir=frame_dir, scheduler_gen=scheduler_gen) + sweep_idx=i, level_idx=lev, frame_dir=frame_dir, scheduler_gen=scheduler_gen, aggregator=args.aggregator, do_monitor_aggregator=args.monitor_aggregator) # Turn PNGs into a video using OpenCV if args.plot: make_video_from_frames(frame_dir=frame_dir, name_prefix="Model_Outputs", diff --git a/pinn/utils.py b/pinn/utils.py index b7720b8..da4d0e7 100644 --- a/pinn/utils.py +++ b/pinn/utils.py @@ -22,7 +22,10 @@ import numpy as np import torch import ast +from torch.nn.functional import cosine_similarity +from torchjd import aggregation as agg +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # %% def cleanfiles(dir_name): @@ -78,8 +81,8 @@ def parse_args(args=None): help="Learning rate for the optimizer.") parser.add_argument('--levels', type=int, default=4, help="Number of levels in multilevel training.") - parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0, 1], - help="Loss type: -1 for supervised (true solution), 0 for PINN loss.") + parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0, 1, 2], + help="Loss type: -1 for supervised (true solution), 0 for PINN loss. 1 for DRM, 2 for DRN+PINN") parser.add_argument('--activation', type=str, default='tanh', choices=['tanh', 'silu', 'relu', 'gelu', 'softmax'], help="Activation function to use.") @@ -93,6 +96,9 @@ def parse_args(args=None): help="If set, enforce the BC in solution.") parser.add_argument('--bc_weight', type=float, default=1.0, help="Weight for the loss of BC.") + parser.add_argument('--aggregator', type=str, nargs='+', default='None', help="Aggregator for the loss function. See https://torchjd.org/stable/docs/aggregation/ for options") + + parser.add_argument('--monitor_aggregator', action='store_true', help="If set, monitor gradient. This need to set up aggregator") parser.add_argument("--scheduler", type=str, default="StepLR", help="Learning rate scheduler to use. " "See https://docs.pytorch.org/docs/stable/optim.html for full list of schedulers") @@ -230,8 +236,28 @@ def make_video_from_frames(frame_dir, name_prefix, output_file, fps=10): video.release() print(f" Video saved as {output_file_path}") - -# %% +def get_aggregator(name: str): + if isinstance(name, str): + return getattr(agg, name)() + elif name[0]=="Constant": + return getattr(agg, name[0])(torch.tensor([ast.literal_eval(i) for i in name[1:]]).to(device)) + else: + return getattr(agg, name[0])(*[ast.literal_eval(i) for i in name[1:]]) + +def monitor_aggregator(aggregator): + def print_weights(_, __, weights: torch.Tensor) -> None: + """Prints the extracted weights.""" + print(f"Weights: {weights}") + + def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None: + """Prints the cosine similarity between the aggregation and the average gradient.""" + matrix = inputs[0] + gd_output = matrix.mean(dim=0) + similarity = cosine_similarity(aggregation, gd_output, dim=0) + print(f"Cosine similarity: {similarity.item():.4f}") + aggregator.weighting.register_forward_hook(print_weights) + aggregator.register_forward_hook(print_gd_similarity) + def fourier_analysis(x, y): """ Compute the magnitude spectrum using the Fast Fourier Transform (FFT).