diff --git a/pinn/bc.py b/pinn/bc.py new file mode 100644 index 0000000..95de5eb --- /dev/null +++ b/pinn/bc.py @@ -0,0 +1,157 @@ +import torch +from typing import Union, Tuple, Callable +import itertools + +# %% [markdown] +# Helper functions from BC implementation +# The approach to enforce BC is +# $$u(x) = g_0(x) + d(x) u_{NN}(x)$$ +# where $g_0$ satisfies BC at the boundary and $d$ is zero at the boundary. + +# %% +def _calculate_laplacian_1d(func: Callable[[torch.Tensor], torch.Tensor], x_val: float) -> torch.Tensor: + x_tensor = torch.tensor([[x_val]], dtype=torch.float32, requires_grad=True) + u = func(x_tensor) + grad_u = torch.autograd.grad(u, x_tensor, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0] + laplacian_u = torch.autograd.grad(grad_u, x_tensor, grad_outputs=torch.ones_like(grad_u), create_graph=False, retain_graph=False)[0] + return laplacian_u + + +def get_g0_func( + u_exact_func: Callable[[torch.Tensor], torch.Tensor], + domain_dim: int, + domain_bounds: Union[Tuple[float, float], Tuple[Tuple[float, float], ...]], + g0_type: str = "multilinear" +) -> Callable[[torch.Tensor], torch.Tensor]: + domain_bounds_tuple = domain_bounds + if domain_dim == 1 and not isinstance(domain_bounds[0], (tuple, list)): + domain_bounds_tuple = (domain_bounds,) + min_bounds = torch.tensor([b[0] for b in domain_bounds_tuple], dtype=torch.float32) + max_bounds = torch.tensor([b[1] for b in domain_bounds_tuple], dtype=torch.float32) + + if g0_type == "hermite_cubic_2nd_deriv": + if domain_dim != 1: + raise ValueError("Hermite cubic interpolation with 2nd derivatives is only supported for 1D problems.") + x0, x1 = min_bounds.item(), max_bounds.item() + h = x1 - x0 + u_x0 = u_exact_func(torch.tensor([[x0]], dtype=torch.float32)).item() + u_x1 = u_exact_func(torch.tensor([[x1]], dtype=torch.float32)).item() + u_prime_prime_x0 = _calculate_laplacian_1d(u_exact_func, x0).item() + u_prime_prime_x1 = _calculate_laplacian_1d(u_exact_func, x1).item() + a3 = (u_prime_prime_x1 - u_prime_prime_x0) / (6 * h) + a2 = u_prime_prime_x0 / 2 - 3 * a3 * x0 + a1 = (u_x1 - u_x0) / h - a2 * (x1 + x0) - a3 * (x1**2 + x1 * x0 + x0**2) + a0 = u_x0 - a1 * x0 - a2 * x0**2 - a3 * x0**3 + coeffs = torch.tensor([a0, a1, a2, a3], dtype=torch.float32) + + def g0_hermite_cubic_val(x: torch.Tensor) -> torch.Tensor: + x_flat = x[:, 0] + g0_vals = coeffs[0] + coeffs[1] * x_flat + coeffs[2] * (x_flat**2) + coeffs[3] * (x_flat**3) + return g0_vals.unsqueeze(1) + return g0_hermite_cubic_val + + if g0_type == "multilinear": + boundary_values = {} + dim_ranges = [[min_bounds[d].item(), max_bounds[d].item()] for d in range(domain_dim)] + for corner_coords in itertools.product(*dim_ranges): + corner_coords_tensor = torch.tensor(corner_coords, dtype=torch.float32).unsqueeze(0) + with torch.no_grad(): + boundary_values[corner_coords] = u_exact_func(corner_coords_tensor).item() + + def g0_multilinear_val(x: torch.Tensor) -> torch.Tensor: + num_points = x.shape[0] + xi = (x - min_bounds.to(x.device)) / (max_bounds.to(x.device) - min_bounds.to(x.device)) + xi = torch.clamp(xi, 0.0, 1.0) + g0_vals = torch.zeros((num_points, 1), device=x.device) + for corner_label in itertools.product([0, 1], repeat=domain_dim): + current_corner_key_list = [] + weight_factors = torch.ones((num_points, 1), device=x.device) + for d in range(domain_dim): + if corner_label[d] == 0: + current_corner_key_list.append(min_bounds[d].item()) + weight_factors *= (1 - xi[:, d]).unsqueeze(1) + else: + current_corner_key_list.append(max_bounds[d].item()) + weight_factors *= xi[:, d].unsqueeze(1) + corner_key_tuple = tuple(current_corner_key_list) + corner_value = boundary_values[corner_key_tuple] + g0_vals += corner_value * weight_factors + return g0_vals + return g0_multilinear_val + + raise ValueError(f"Unknown g0_type: {g0_type}. Choose 'multilinear' or 'hermite_cubic_2nd_deriv'.") + + +def _psi_tensor(t: torch.Tensor) -> torch.Tensor: + return torch.where(t <= 0, torch.tensor(0.0, dtype=t.dtype, device=t.device), torch.exp(-1.0 / t)) + + +def get_d_func(domain_dim: int, domain_bounds: Union[Tuple[float, float], Tuple[Tuple[float, float], ...]], + d_type: str = "sin_half_period") -> Callable[[torch.Tensor], torch.Tensor]: + domain_bounds_tuple = domain_bounds + if domain_dim == 1 and not isinstance(domain_bounds[0], (tuple, list)): + domain_bounds_tuple = (domain_bounds,) + min_bounds = torch.tensor([b[0] for b in domain_bounds_tuple], dtype=torch.float32) + max_bounds = torch.tensor([b[1] for b in domain_bounds_tuple], dtype=torch.float32) + domain_length = (max_bounds[0] - min_bounds[0]).item() if domain_dim == 1 else None + + if d_type == "quadratic_bubble": + def d_func_val(x: torch.Tensor) -> torch.Tensor: + d_vals = torch.ones_like(x[:, 0], dtype=torch.float32, device=x.device) + for i in range(domain_dim): + x_i = x[:, i] + min_val, max_val = domain_bounds_tuple[i] + d_vals *= (x_i - min_val) * (max_val - x_i) + return d_vals.unsqueeze(1) + return d_func_val + + if d_type == "inf_smooth_bump": + def d_inf_smooth_bump_val(x: torch.Tensor) -> torch.Tensor: + product_terms = torch.ones((x.shape[0],), dtype=x.dtype, device=x.device) + for i in range(domain_dim): + x_i = x[:, i] + min_val_i = min_bounds[i] + max_val_i = max_bounds[i] + x_c_i = (min_val_i + max_val_i) / 2.0 + R_i = (max_val_i - min_val_i) / 2.0 + R_i_squared = R_i**2 + arg_for_psi = R_i_squared - (x_i - x_c_i)**2 + product_terms *= _psi_tensor(arg_for_psi) + return product_terms.unsqueeze(1) + return d_inf_smooth_bump_val + + if d_type == "abs_dist_complement": + if domain_dim != 1: + raise ValueError(f"d_type '{d_type}' is only supported for 1D problems.") + def d_abs_dist_complement_val(x: torch.Tensor) -> torch.Tensor: + x_val = x[:, 0] + x_norm = (x_val - min_bounds[0]) / domain_length + sqrt_term = torch.sqrt(x_norm**2 + (1.0 - x_norm)**2) + return (1.0 - sqrt_term).unsqueeze(1) + return d_abs_dist_complement_val + + if d_type == "ratio_bubble_dist": + if domain_dim != 1: + raise ValueError(f"d_type '{d_type}' is only supported for 1D problems.") + def d_ratio_bubble_dist_val(x: torch.Tensor) -> torch.Tensor: + x_val = x[:, 0] + x_norm = (x_val - min_bounds[0]) / domain_length + numerator = x_norm * (1.0 - x_norm) + denominator = torch.sqrt(x_norm**2 + (1.0 - x_norm)**2) + return (numerator / denominator).unsqueeze(1) + return d_ratio_bubble_dist_val + + if d_type == "sin_half_period": + if domain_dim != 1: + raise ValueError(f"d_type '{d_type}' is only supported for 1D problems.") + if domain_length is None: + raise ValueError("Domain length must be defined for 'sin_half_period' d_type.") + def d_sin_half_period_val(x: torch.Tensor) -> torch.Tensor: + x_val = x[:, 0] + argument = (torch.pi / domain_length) * (x_val - min_bounds[0]) + return torch.sin(argument).unsqueeze(1) + return d_sin_half_period_val + + raise ValueError(f"Unknown d_type: {d_type}." + "Choose from 'quadratic_bubble', 'inf_smooth_bump', 'abs_dist_complement', " + "'ratio_bubble_dist', or 'sin_half_period'.") diff --git a/pinn/cheby.py b/pinn/cheby.py index 26c14e9..dc7c9af 100644 --- a/pinn/cheby.py +++ b/pinn/cheby.py @@ -15,48 +15,32 @@ # %% -def chebyshev_transformed_features(x, chebyshev_freq_min, chebyshev_freq_max): - chebyshev_features = [] +def generate_chebyshev_features(x, chebyshev_freq_min, chebyshev_freq_max): theta = torch.pi * x[:, 0] - sin_theta = torch.sin(theta) - cos_theta = torch.cos(theta) - left_end = torch.abs(theta) < 1e-8 - right_end = torch.abs(theta - torch.pi) < 1e-8 - - u_k_minus_2 = torch.sin((chebyshev_freq_min) * theta) / sin_theta - u_k_minus_2[left_end] = float(chebyshev_freq_min) - u_k_minus_2[right_end] = float((chebyshev_freq_min) * (-1)**(chebyshev_freq_min - 1)) - - u_k_minus_1 = torch.sin((chebyshev_freq_min + 1) * theta) / sin_theta - u_k_minus_1[left_end] = float(chebyshev_freq_min + 1) - u_k_minus_1[right_end] = float((chebyshev_freq_min + 1) * (-1)**(chebyshev_freq_min)) - - for k_current_degree in range(chebyshev_freq_min, chebyshev_freq_max + 1): - if k_current_degree == chebyshev_freq_min: - current_chebyshev_u = u_k_minus_2 - elif k_current_degree == chebyshev_freq_min + 1: - current_chebyshev_u = u_k_minus_1 - else: - current_chebyshev_u = 2 * cos_theta * u_k_minus_1 - u_k_minus_2 - u_k_minus_2 = u_k_minus_1 - u_k_minus_1 = current_chebyshev_u - chebyshev_features.append(current_chebyshev_u.unsqueeze(1)) - - return torch.cat(chebyshev_features, dim=1) + cos_theta = torch.cos(theta).unsqueeze(1) + if chebyshev_freq_min > chebyshev_freq_max: + return torch.empty((x.shape[0], 0), device=x.device) -# %% -def chebyshev_transformed_features2(x, chebyshev_freq_min, chebyshev_freq_max): + # cheby poly 2nd kind Uₖ(cos(πx)) + # U0 = 1 + U_k_minus_2 = torch.ones_like(cos_theta) + # U₁ = 2 * x + U_k_minus_1 = 2 * cos_theta + # all cheby features chebyshev_features = [] - theta = torch.pi * x[:, 0] - cos_theta = torch.cos(theta) - - chebyshev_features.append(torch.ones_like(cos_theta)) - chebyshev_features.append(2 * cos_theta) + # degree k loop + for k in range(chebyshev_freq_max + 1): + if k == 0: + U_k = U_k_minus_2 + elif k == 1: + U_k = U_k_minus_1 + else: + U_k = 2 * cos_theta * U_k_minus_1 - U_k_minus_2 + U_k_minus_2 = U_k_minus_1 + U_k_minus_1 = U_k - for degree in range(2, chebyshev_freq_max): - u = 2 * cos_theta * chebyshev_features[degree - 1] - chebyshev_features[degree - 2] - chebyshev_features.append(u) - x = torch.stack(chebyshev_features).T + if k >= chebyshev_freq_min: + chebyshev_features.append(U_k) - return x[:, chebyshev_freq_min-1:] + return torch.cat(chebyshev_features, dim=1) diff --git a/pinn/pinn_1d.py b/pinn/pinn_1d.py index 9929d49..a889101 100644 --- a/pinn/pinn_1d.py +++ b/pinn/pinn_1d.py @@ -49,7 +49,9 @@ from enum import Enum from utils import parse_args, get_activation, print_args, save_frame, make_video_from_frames from utils import is_notebook, cleanfiles, fourier_analysis, get_scheduler_generator, scheduler_step -from cheby import chebyshev_transformed_features, chebyshev_transformed_features2 # noqa F401 +from cheby import generate_chebyshev_features +from bc import get_d_func, get_g0_func +from datetime import datetime # from SOAP.soap import SOAP # torch.set_default_dtype(torch.float64) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -115,7 +117,9 @@ def __init__(self, ntrain, neval, ax, bx): self.x_train = torch.linspace(self.ax, self.bx, self.ntrain + 1, device=device)[:-1].unsqueeze(-1) self.x_eval = torch.linspace(self.ax, self.bx, self.neval + 1, device=device)[:-1].unsqueeze(-1) self.pde = None + # source term self.f = None + # analytical solution self.u_ex = None def set_pde(self, pde: PDE): @@ -125,7 +129,6 @@ def set_pde(self, pde: PDE): # analytical solution self.u_ex = pde.u_ex(self.x_train) - # %% # Define one level NN class Level(nn.Module): @@ -158,7 +161,7 @@ def __init__(self, dim_inputs, dim_outputs, dim_hidden: list, def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_chebyshev_basis: - x_features = chebyshev_transformed_features(x, self.chebyshev_freq_min, self.chebyshev_freq_max) + x_features = generate_chebyshev_features(x, self.chebyshev_freq_min, self.chebyshev_freq_max) else: x_features = x @@ -183,22 +186,42 @@ class LevelStatus(Enum): class MultiLevelNN(nn.Module): def __init__(self, mesh: Mesh, num_levels: int, dim_inputs, dim_outputs, dim_hidden: list, act: nn.Module = nn.ReLU(), enforce_bc: bool = False, + g0_type: str = "multilinear", d_type: str = "sin_half_period", use_chebyshev_basis: bool = False, chebyshev_freq_min: int = 0, chebyshev_freq_max: int = 0) -> None: super().__init__() self.mesh = mesh # currently the same model on each level + self.dim_inputs = dim_inputs + self.dim_outputs = dim_outputs + self.enforce_bc = enforce_bc + + self.g0_func = None + self.d_func = None + if self.enforce_bc: + self.g0_func = get_g0_func( + u_exact_func=self.mesh.pde.u_ex, + domain_dim=1, + domain_bounds=(self.mesh.ax, self.mesh.bx), + g0_type=g0_type + ) + self.d_func = get_d_func( + domain_dim=1, + domain_bounds=(self.mesh.ax, self.mesh.bx), + d_type=d_type + ) + print(f"BCs will be enforced using g0_type: {g0_type} and d_type: {d_type}") + + self.use_chebyshev_basis = use_chebyshev_basis + self.chebyshev_freqs = np.round(np.linspace(chebyshev_freq_min, chebyshev_freq_max, num_levels + 1)).astype(int) self.models = nn.ModuleList([ Level(dim_inputs=dim_inputs, dim_outputs=dim_outputs, dim_hidden=dim_hidden, act=act, use_chebyshev_basis=use_chebyshev_basis, - chebyshev_freq_min=chebyshev_freq_min, - chebyshev_freq_max=chebyshev_freq_max) - for _ in range(num_levels) + chebyshev_freq_min=self.chebyshev_freqs[i], + chebyshev_freq_max=self.chebyshev_freqs[i+1]) + for i in range(num_levels) ]) - self.dim_inputs = dim_inputs - self.dim_outputs = dim_outputs - self.enforce_bc = enforce_bc # All levels start as "off" self.level_status = [LevelStatus.OFF] * num_levels @@ -255,7 +278,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ys = [] for i, model in enumerate(self.models): if self.level_status[i] != LevelStatus.OFF: - x_scale = self.scales[i] * x + if self.use_chebyshev_basis: + x_scale = x + else: + x_scale = self.scales[i] * x y = model.forward(x=x_scale) ys.append(y) if not ys: @@ -268,6 +294,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def get_solution(self, x: torch.Tensor) -> torch.Tensor: y = self.forward(x) + n_active = self.num_active_levels() # reshape to [batch_size, num_levels, dim_outputs] # and sum over levels @@ -275,12 +302,16 @@ def get_solution(self, x: torch.Tensor) -> torch.Tensor: y = y.view(-1, n_active, self.dim_outputs) y = y.sum(dim=1) # shape: (n, dim_outputs) # + if self.enforce_bc: - g0 = self.mesh.u_ex[0].item() - g1 = self.mesh.u_ex[-1].item() - # in domain x in [0, 1] - y = g0 * (1 - x) + g1 * x + x * (1 - x) * y - # y = g0 + (x-0)/(1-0)*(g1 - g0) + (1 - torch.exp(0-x)) * (1 - torch.exp(x-1)) * y + g0_vals = self.g0_func(x) + d_vals = self.d_func(x) + mask = torch.abs(d_vals) < 1e-8 + signs = torch.sign(d_vals) + replacement = signs * 1e-8 + d_vals = torch.where(mask, replacement, d_vals) + y = g0_vals + d_vals * y + return y # def _init_weights(self, m): @@ -290,7 +321,6 @@ def get_solution(self, x: torch.Tensor) -> torch.Tensor: # if type(m) == nn.Linear: # torch.nn.init.xavier_uniform(m.weight) # - # %% # Define Loss class Loss: @@ -325,6 +355,7 @@ def pinn_loss(self, model, mesh, loss_func): # Internal loss pde = mesh.pde loss = loss_func(d2u_dx2[1:-1] + mesh.f[1:-1], pde.r * u[1:-1]) + # Boundary loss if not model.enforce_bc: u_bc = u[[0, -1]] @@ -336,29 +367,18 @@ def pinn_loss(self, model, mesh, loss_func): def drm_loss(self, model, mesh: Mesh): """Deep Ritz Method loss""" - xs = mesh.x_train.requires_grad_(True) - u = model(xs) - - grad_u_pred = torch.autograd.grad(u, xs, - grad_outputs=torch.ones_like(u), - create_graph=True)[0] - - u_pred_sq = torch.sum(u**2, dim=1, keepdim=True) - grad_u_pred_sq = torch.sum(grad_u_pred**2, dim=1, keepdim=True) + if not model.enforce_bc: + raise NotImplementedError("Deep Ritz loss only supports enforce_bc") - f_val = mesh.pde.f(xs) - fu_prod = f_val * u + x = mesh.x_train.requires_grad_(True) + u = model.get_solution(x) + du_dx, = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True) - integrand_values = 0.5 * grad_u_pred_sq[1:-1] + 0.5 * mesh.pde.r * u_pred_sq[1:-1] - fu_prod[1:-1] - loss = torch.mean(integrand_values) + du_dx_sq = torch.sum(du_dx**2, dim=1, keepdim=True) - # Boundary loss - u_bc = u[[0, -1]] - u_ex_bc = mesh.u_ex[[0, -1]] - loss_b = self.loss_func(u_bc, u_ex_bc) - loss += self.bc_weight * loss_b + #loss = torch.mean(0.5 * du_dx_sq[1:-1] + 0.5 * mesh.pde.r * u[1:-1]**2 - mesh.f[1:-1] * u[1:-1]) + loss = torch.mean(0.5 * du_dx_sq + 0.5 * mesh.pde.r * u**2 - mesh.f * u) - xs.requires_grad_(False) # Disable gradient tracking for x return loss def loss(self, model, mesh): @@ -372,7 +392,6 @@ def loss(self, model, mesh): raise ValueError(f"Unknown loss type: {self.type}") return loss_value - # %% # Define the training loop def train(model, mesh, criterion, iterations, adam_iterations, learning_rate, num_check, num_plots, sweep_idx, @@ -412,6 +431,7 @@ def closure(): # backpropagation to compute gradients of model param respect to the loss. computes dloss/dx # for every parameter x which has requires_grad=True. loss.backward() + # update the model param doing an optim step using the computed gradients and learning rate optimizer.step() # @@ -447,6 +467,10 @@ def closure(): # %% # Define the main function def main(args=None): + # Register run by timestamp + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + os.makedirs(f"results_pinn_1d_{ts}", exist_ok=True) + # For reproducibility torch.manual_seed(0) # Parse args @@ -454,7 +478,7 @@ def main(args=None): # Ensure chebyshev_freq_max is at least chebyshev_freq_min for range to be valid if args.use_chebyshev_basis and args.chebyshev_freq_max < args.chebyshev_freq_min: raise ValueError("chebyshev_freq_max must be >= chebyshev_freq_min when using Chebyshev basis.") - print_args(args=args) + print_args(args=args, output_file=f"results_pinn_1d_{ts}/args.txt") # PDE pde = PDE(high=args.high_freq, mu=args.mu, r=args.gamma, problem=args.problem_id) @@ -476,13 +500,15 @@ def main(args=None): dim_hidden=args.hidden_dims, act=get_activation(args.activation), enforce_bc=args.enforce_bc, + g0_type=args.bc_extension, + d_type=args.distance, use_chebyshev_basis=args.use_chebyshev_basis, chebyshev_freq_min=args.chebyshev_freq_min, chebyshev_freq_max=args.chebyshev_freq_max) print(model) model.to(device) # Plotting - frame_dir = "./frames" + frame_dir = f"results_pinn_1d_{ts}/frames" os.makedirs(frame_dir, exist_ok=True) if args.clear: cleanfiles(frame_dir) @@ -523,7 +549,7 @@ def main(args=None): # can run it like normal: python filename.py if __name__ == "__main__": if is_notebook(): - err = main(['--levels', '4', '--epochs', '10000', '--sweeps', '1', '--plot']) + err = main(['--levels', '4', '--epochs', '10000', '--sweeps', '1', '--plot', '--enforce_bc', '--g0_type', 'hermite_cubic_2nd_deriv', '--d_type', 'sin_half_period']) else: err = main() try: diff --git a/pinn/utils.py b/pinn/utils.py index 6375914..1828149 100644 --- a/pinn/utils.py +++ b/pinn/utils.py @@ -83,6 +83,12 @@ def parse_args(args=None): parser.add_argument('--activation', type=str, default='tanh', choices=['tanh', 'silu', 'relu', 'gelu', 'softmax'], help="Activation function to use.") + parser.add_argument('--bc_extension', type=str, default='hermite_cubic_2nd_deriv', + choices=['multilinear', 'hermite_cubic_2nd_deriv'], + help='Boundary value extension function.') + parser.add_argument('--distance', type=str, default='sin_half_period', + choices=['quadratic_bubble', 'inf_smooth_bump', 'abs_dist_complement', 'ratio_bubble_dist', 'sin_half_period'], + help='Distance function.') parser.add_argument('--chebyshev_freq_min', type=int, default=-1, help='Minimum frequency for Chebyshev polynomials.') parser.add_argument('--chebyshev_freq_max', type=int, default=-1, @@ -169,11 +175,16 @@ def scheduler_step(scheduler, loss, epoch=None): # %% -def print_args(args): +def print_args(args, output_file=None): + if output_file: + f = open(output_file, 'w') + else: + f = None print("Options used:") for key, value in vars(args).items(): print(f" --{key}: {value}") - + if f: + print(f" --{key}: {value}", file=f) # %% def get_activation(name: str): diff --git a/tony/config/poisson.py b/tony/config/elliptic.py similarity index 84% rename from tony/config/poisson.py rename to tony/config/elliptic.py index f2f4164..e16eeef 100644 --- a/tony/config/poisson.py +++ b/tony/config/elliptic.py @@ -1,7 +1,7 @@ import numpy as np import torch -class PoissonSolverConfig: +class EllipticSolverConfig: def __init__(self, **kwargs): device_arg = kwargs.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') if isinstance(device_arg, str): @@ -11,6 +11,8 @@ def __init__(self, **kwargs): else: raise TypeError(f"Device must be a string or torch.device object, got {type(device_arg)}") + self.problem = kwargs.get('problem', 'poisson') + # Domain parameters self.domain_dim = kwargs.get('domain_dim', 1) # 1, 2, or 3 self.domain_bounds = kwargs.get('domain_bounds', (0.0, 1.0)) @@ -25,6 +27,16 @@ def __init__(self, **kwargs): self.activation = kwargs.get('activation', 'tanh') # e.g., 'tanh', 'relu', 'sigmoid' self.bc_extension = kwargs.get('bc_extension', 'hermite_cubic_2nd_hermite') self.distance = kwargs.get('distance', 'sin_half_period') + self.use_positional_encoding = False + self.pe_freq_min = kwargs.get('pe_freq_min', -1) # Minimum positional encoding frequency power + self.pe_freq_max = kwargs.get('pe_freq_max', -1) # Maximum positional encoding frequency power + if (0 <= self.pe_freq_min <= self.pe_freq_max): + if self.domain_dim != 1: + print("Warning: Positional encoding is only implemented for 1D problems. Turning False") + else: + print(f"Positional encoding of frequency power {self.pe_freq_min} to {self.pe_freq_max} are used") + self.use_positional_encoding = True + self.use_chebyshev_basis = False self.chebyshev_freq_min = kwargs.get('chebyshev_freq_min', -1) # Minimum Chebyshev frequency self.chebyshev_freq_max = kwargs.get('chebyshev_freq_max', -1) # Maximum Chebyshev frequency @@ -35,6 +47,9 @@ def __init__(self, **kwargs): print(f"Chebyshev basis of frequency {self.chebyshev_freq_min} to {self.chebyshev_freq_max} are used") self.use_chebyshev_basis = True + if self.use_positional_encoding and self.use_chebyshev_basis: + raise RuntimeError("Positional encoding is not compatiable with chebyshev basis.") + # Training Parameters self.num_epochs = kwargs.get('num_epochs', 5000) self.batch_size = kwargs.get('batch_size', 256) diff --git a/tony/loss/reaction_diffusion.py b/tony/loss/reaction_diffusion.py new file mode 100644 index 0000000..205ef02 --- /dev/null +++ b/tony/loss/reaction_diffusion.py @@ -0,0 +1,79 @@ +import torch + +def calculate_drm_loss(u_nn_model, f_exact_func, a_exact_func, c_exact_func, domain_points): + """ + Calculates the DRM (Domain Regularization Method) loss for the Reaction-Diffusion equation. + The integral form is often derived from the weak form. + For -div(a grad u) + c*u = f, a common energy functional (or part of it) is: + Integral over domain [ 0.5 * a * |grad u|^2 + 0.5 * c * u^2 - f * u ] dX + We'll assume the DRM loss aims to minimize this energy, or a similar variational form. + """ + domain_points.requires_grad_(True) + u_pred = u_nn_model(domain_points) + + # Get a(x) and c(x) values at domain_points + a_val = a_exact_func(domain_points) # (batch_size, 1) or (batch_size, dim) + c_val = c_exact_func(domain_points) # (batch_size, 1) + + grad_u_pred = torch.autograd.grad(u_pred, domain_points, + grad_outputs=torch.ones_like(u_pred), + create_graph=True)[0] + + # Calculate |grad u|^2 + grad_u_pred_sq = torch.sum(grad_u_pred**2, dim=1, keepdim=True) + + f_val = f_exact_func(domain_points) + + # The integrand for DRM based on the variational formulation of the PDE + # For -div(a grad u) + c*u = f, the energy functional typically involves: + # 0.5 * a * |grad u|^2 + 0.5 * c * u^2 - f * u + # We will use this form. + integrand_values = 0.5 * a_val * grad_u_pred_sq + 0.5 * c_val * u_pred**2 - f_val * u_pred + loss = torch.mean(integrand_values) # Mean over the batch as an approximation of integral + + domain_points.requires_grad_(False) + return loss + +def calculate_pinn_loss(u_nn_model, f_exact_func, a_exact_func, c_exact_func, domain_points, domain_dim): + """ + Calculates the PINN (Physics-Informed Neural Network) loss for the Reaction-Diffusion equation: + -div(a(x) grad u) + c(x)u = f(x) + The residual is: R(x) = -div(a(x) grad u_nn) + c(x)u_nn - f(x) + """ + domain_points.requires_grad_(True) + u_pred = u_nn_model(domain_points) + + # Get a(x) and c(x) values at domain_points + a_val = a_exact_func(domain_points) # (batch_size, 1) + c_val = c_exact_func(domain_points) # (batch_size, 1) + + # First derivatives (gradients) + grad_u_pred = torch.autograd.grad(u_pred, domain_points, + grad_outputs=torch.ones_like(u_pred), + create_graph=True)[0] + + # Compute a * grad u_pred + # Assuming a_val is (batch_size, 1) or broadcastable (batch_size, domain_dim) + a_grad_u_pred = a_val * grad_u_pred # Element-wise product + + # Compute divergence of (a * grad u_pred) + # div(V) = dV0/dx0 + dV1/dx1 + ... + div_a_grad_u_pred = torch.zeros_like(u_pred) # (batch_size, 1) + + for i in range(domain_dim): + # Compute gradient of (a_grad_u_pred[:, i]) with respect to x_clone[:, i] + d_a_grad_u_pred_dxi = torch.autograd.grad(a_grad_u_pred[:, i], domain_points, + grad_outputs=torch.ones_like(a_grad_u_pred[:, i]), + create_graph=True)[0][:, i] + div_a_grad_u_pred += d_a_grad_u_pred_dxi.unsqueeze(1) + + # Get the exact source term f(x) + f_val = f_exact_func(domain_points) + + # Residual for the Reaction-Diffusion equation: R(x) = -div(a grad u_nn) + c*u_nn - f + residual = -div_a_grad_u_pred + c_val * u_pred - f_val + + loss = torch.mean(residual**2) + + domain_points.requires_grad_(False) + return loss diff --git a/tony/model/mlp.py b/tony/model/mlp.py new file mode 100644 index 0000000..3608bc9 --- /dev/null +++ b/tony/model/mlp.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import math + +class SinActivation(nn.Module): + def forward(self, x): + return torch.sin(x) + +class CosActivation(nn.Module): + def forward(self, x): + return torch.cos(x) + +class NNModel(nn.Module): + def __init__(self, input_dim, output_dim, hidden_neurons, activation='tanh', + g0_func=None, d_func=None, + use_positional_encoding=False, pe_freq_min=0, pe_freq_max=6, + use_chebyshev_basis=False, chebyshev_freq_min=-1, chebyshev_freq_max=-1): + super(NNModel, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.g0_func = g0_func + self.d_func = d_func + + self.use_positional_encoding = use_positional_encoding + self.pe_freq_min = pe_freq_min + self.pe_freq_max = pe_freq_max + + self.use_chebyshev_basis = use_chebyshev_basis + self.chebyshev_freq_min = chebyshev_freq_min + self.chebyshev_freq_max = chebyshev_freq_max + + self.activation_fns = [] + if isinstance(activation, str): + # All layers have the same activation function + self.activation_fns = [self._get_activation_fn(activation)] * len(hidden_neurons) + elif isinstance(activation, (list, tuple)): + # Each layer has its own activation function + if len(activation) == len(hidden_neurons): + self.activation_fns = [self._get_activation_fn(act_str) for act_str in activation] + elif len(activation) == len(hidden_neurons) + 1: + print("Warning: Activation list length is one longer than hidden layers. " + "Assuming the last activation is intended for an output layer, " + "but for regression problems, the output layer typically has no activation in this model.") + self.activation_fns = [self._get_activation_fn(act_str) for act_str in activation[:-1]] + else: + raise ValueError(f"Activation list length ({len(activation)}) must match " + f"hidden_neurons length ({len(hidden_neurons)}), or be one longer.") + else: + raise ValueError("Activation must be a string or a list/tuple of strings.") + + # Determine the effective input dimension for the trainable layers + if self.use_positional_encoding: + trainable_input_dim = input_dim * (self.pe_freq_max - self.pe_freq_min + 1) + elif self.use_chebyshev_basis: + trainable_input_dim = self.chebyshev_freq_max - self.chebyshev_freq_min + 1 + else: + trainable_input_dim = input_dim + + if self.use_chebyshev_basis: + trainable_output_dim = self.chebyshev_freq_max - self.chebyshev_freq_min + 1 + else: + trainable_output_dim = output_dim + + layers = [] + prev_neurons = trainable_input_dim + for i, num_neurons in enumerate(hidden_neurons): + layers.append(nn.Linear(prev_neurons, num_neurons)) + layers.append(self.activation_fns[i]) + prev_neurons = num_neurons + layers.append(nn.Linear(prev_neurons, trainable_output_dim)) # Output layer + self.layers = nn.Sequential(*layers) + + # Initialize weights + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # Initialize weights for the final combiner + if use_chebyshev_basis: + self.final_combiner = nn.Linear(trainable_output_dim, output_dim, bias=False) + nn.init.xavier_normal_(self.final_combiner.weight) + + def _get_activation_fn(self, activation_str): + if activation_str == 'relu': + return nn.ReLU() + elif activation_str == 'sin': + return SinActivation() + elif activation_str == 'tanh': + return nn.Tanh() + elif activation_str == 'sigmoid': + return nn.Sigmoid() + elif activation_str == 'elu': + return nn.ELU() + elif activation_str == 'leaky_relu': + return nn.LeakyReLU() + elif activation_str == 'silu': + return nn.SiLU() + elif activation_str == 'linear' or activation_str is None: + return nn.Identity() + else: + raise ValueError(f"Unknown activation function: {activation_str}") + + def _positional_encode(self, x): + """ + Applies positional encoding to the input tensor x. + Encodes each dimension of x by concatenating sin and cos features at + frequencies 2^0*pi, 2^1*pi, ..., 2^(pe_freq_max-1)*pi. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + torch.Tensor: Encoded tensor of shape (batch_size, input_dim * pe_freq_max * 2). + """ + pe_features = [] + for i in range(self.pe_freq_min, self.pe_freq_max + 1): + freq_scale = 2**i * math.pi + pe_features.append(torch.cos(freq_scale * x)) + + # Concatenate features along the last dimension (for each input_dim) + return torch.cat(pe_features, dim=-1) + + def _generate_chebyshev_features(self, x): + chebyshev_features = [] + theta = math.pi * x + dx = torch.sin(theta) + left_end = torch.abs(theta) < 1e-8 + right_end = torch.abs(theta - math.pi) < 1e-8 + chebyshev_arg = torch.cos(math.pi * x) + uk_minus_2 = torch.sin((self.chebyshev_freq_min + 1) * math.pi * x) / dx + uk_minus_2[left_end] = self.chebyshev_freq_min + 1 + uk_minus_2[right_end] = (self.chebyshev_freq_min + 1) * (-1)**self.chebyshev_freq_min + uk_minus_1 = torch.sin((self.chebyshev_freq_min + 2) * math.pi * x) / dx + uk_minus_1[left_end] = self.chebyshev_freq_min + 2 + uk_minus_1[right_end] = (self.chebyshev_freq_min + 2) * (-1)**(self.chebyshev_freq_min + 1) + for k in range(self.chebyshev_freq_min, self.chebyshev_freq_max + 1): + if k == self.chebyshev_freq_min: + current_uk = uk_minus_2 + elif k == self.chebyshev_freq_min + 1: + current_uk = uk_minus_1 + else: + current_uk = 2 * chebyshev_arg * uk_minus_1 - uk_minus_2 + uk_minus_2 = uk_minus_1 + uk_minus_1 = current_uk + chebyshev_features.append(current_uk.unsqueeze(1)) + return torch.cat(chebyshev_features, dim=1) + + def forward(self, x_raw): + x = x_raw.float() + if self.use_positional_encoding: + x = self._positional_encode(x) + elif self.use_chebyshev_basis: + chebyshev_features_s = self._generate_chebyshev_features(x[:, 0]) + y = self.layers(chebyshev_features_s) + if y.shape[1] != chebyshev_features_s.shape[1]: + raise RuntimeError("Mismatch between learned features and Chebyshev features dimensions.") + combined_features = chebyshev_features_s + y + raw_nn_output = self.final_combiner(combined_features) + else: + raw_nn_output = self.layers(x) + + # Incorporate boundary conditions if g0_func and d_func are provided + if self.g0_func is None or self.d_func is None: + return raw_nn_output + else: + return self.g0_func(x_raw) + self.d_func(x_raw) * raw_nn_output diff --git a/tony/model/nn.py b/tony/model/nn.py deleted file mode 100644 index 78175ca..0000000 --- a/tony/model/nn.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -import torch.nn as nn -import math - -class SinActivation(nn.Module): - def forward(self, x): - return torch.sin(x) - -class NNModel(nn.Module): - def __init__(self, input_dim, output_dim, hidden_neurons, activation='tanh', - g0_func=None, d_func=None, - use_chebyshev_basis=False, chebyshev_freq_min=-1, chebyshev_freq_max=-1): - super(NNModel, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.g0_func = g0_func - self.d_func = d_func - - self.use_chebyshev_basis = use_chebyshev_basis - self.chebyshev_freq_min = chebyshev_freq_min - self.chebyshev_freq_max = chebyshev_freq_max - - self.activation_fns = [] - if isinstance(activation, str): - # All layers have the same activation function - self.activation_fns = [self._get_activation_fn(activation)] * len(hidden_neurons) - elif isinstance(activation, (list, tuple)): - # Each layer has its own activation function - if len(activation) == len(hidden_neurons): - self.activation_fns = [self._get_activation_fn(act_str) for act_str in activation] - elif len(activation) == len(hidden_neurons) + 1: - print("Warning: Activation list length is one longer than hidden layers. " - "Assuming the last activation is intended for an output layer, " - "but for regression problems, the output layer typically has no activation in this model.") - self.activation_fns = [self._get_activation_fn(act_str) for act_str in activation[:-1]] - else: - raise ValueError(f"Activation list length ({len(activation)}) must match " - f"hidden_neurons length ({len(hidden_neurons)}), or be one longer.") - else: - raise ValueError("Activation must be a string or a list/tuple of strings.") - - # Determine the effective input dimension for the trainable layers - effective_input_dim = input_dim - if self.use_chebyshev_basis and self.input_dim >= 1: # Can be applied to the first dimension for multi-D - # Calculate total features: Chebyshev features for dim 0 + remaining raw dimensions - num_chebyshev_features = self.chebyshev_freq_max - self.chebyshev_freq_min + 1 - effective_input_dim = num_chebyshev_features + (self.input_dim - 1 if self.input_dim > 1 else 0) - - layers = [] - prev_neurons = effective_input_dim - for i, num_neurons in enumerate(hidden_neurons): - layers.append(nn.Linear(prev_neurons, num_neurons)) - layers.append(self.activation_fns[i]) - prev_neurons = num_neurons - - layers.append(nn.Linear(prev_neurons, output_dim)) # Output layer - self.layers = nn.Sequential(*layers) - - # Initialize weights - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_normal_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - - def _get_activation_fn(self, activation_str): - if activation_str == 'relu': - return nn.ReLU() - elif activation_str == 'sin': - return SinActivation() - elif activation_str == 'tanh': - return nn.Tanh() - elif activation_str == 'sigmoid': - return nn.Sigmoid() - elif activation_str == 'elu': - return nn.ELU() - elif activation_str == 'leaky_relu': - return nn.LeakyReLU() - elif activation_str == 'linear' or activation_str is None: - return nn.Identity() - else: - raise ValueError(f"Unknown activation function: {activation_str}") - - def forward(self, x_raw): - x = x_raw.float() - if self.use_chebyshev_basis: - chebyshev_arg = torch.cos(math.pi * x[:, 0]) # Argument for U_k(x) - chebyshev_features = [] - # Initialize for recurrence - uk_minus_2 = torch.ones_like(chebyshev_arg) # U_0(x_mapped) - uk_minus_1 = 2 * chebyshev_arg # U_1(x_mapped) - for k in range(self.chebyshev_freq_min, self.chebyshev_freq_max + 1): - if k == 0: - current_uk = uk_minus_2 # This is U_0 - elif k == 1: - current_uk = uk_minus_1 # This is U_1 - else: - # Compute U_k using the recurrence relation - current_uk = 2 * chebyshev_arg * uk_minus_1 - uk_minus_2 - # Update for next iteration - uk_minus_2 = uk_minus_1 - uk_minus_1 = current_uk - chebyshev_features.append(current_uk.unsqueeze(1)) - processed_input = torch.cat(chebyshev_features, dim=1) - if self.input_dim > 1: # Append other raw dimensions if multi-D - processed_input = torch.cat([processed_input, x[:, 1:]], dim=1) - else: - processed_input = x - raw_nn_output = self.layers(processed_input) - - # Incorporate boundary conditions if g0_func and d_func are provided - if self.g0_func is None or self.d_func is None: - return raw_nn_output - else: - return self.g0_func(x_raw) + self.d_func(x_raw) * raw_nn_output diff --git a/tony/pde/poisson.py b/tony/pde/poisson.py index 134af56..d588d81 100644 --- a/tony/pde/poisson.py +++ b/tony/pde/poisson.py @@ -20,6 +20,10 @@ def u_exact(x): # x : (batch_size, domain_dim) return (torch.sin(3 * torch.pi * x0) + torch.sin(7 * torch.pi * x0)).unsqueeze(1) elif case_number == 5: return (torch.sin(3 * torch.pi * x0) * torch.sin(7 * torch.pi * x0)).unsqueeze(1) + elif case_number == 6: + return (torch.sin(torch.pi * x0) + torch.sin(2 * torch.pi * x0) + torch.sin(3 * torch.pi * x0) + torch.sin(4 * torch.pi * x0) + torch.sin(5 * torch.pi * x0) + torch.sin(6 * torch.pi * x0) + torch.sin(7 * torch.pi * x0) + torch.sin(8 * torch.pi * x0)).unsqueeze(1) + elif case_number == 7: + return (torch.sin(2 * torch.pi * x0) + torch.sin(20 * torch.pi * x0)).unsqueeze(1) else: raise ValueError(f"Manufactured solution case {case_number} not defined for 1D.") elif domain_dim == 2: diff --git a/tony/pde/reaction_diffusion.py b/tony/pde/reaction_diffusion.py new file mode 100644 index 0000000..8c12a4c --- /dev/null +++ b/tony/pde/reaction_diffusion.py @@ -0,0 +1,83 @@ +import torch +import numpy as np + +def get_manufactured_solution(case_number, domain_dim): + """ + Generates the manufactured solution u_exact, the diffusion coefficient a_exact, + the reaction coefficient c_exact, and the source term f_exact for + the Reaction-Diffusion equation: -div(a(x) grad u) + c(x)u = f(x). + """ + + def u_exact(x): # x : (batch_size, domain_dim) + if domain_dim == 1: + x0 = x[:, 0] + if case_number == 1: + return (torch.sin(torch.pi * x0) + torch.sin(2 * torch.pi * x0) + torch.sin(3 * torch.pi * x0) + torch.sin(4 * torch.pi * x0) + torch.sin(5 * torch.pi * x0) + torch.sin(6 * torch.pi * x0) + torch.sin(7 * torch.pi * x0) + torch.sin(8 * torch.pi * x0)).unsqueeze(1) + elif case_number == 2: + return (torch.sin(2 * torch.pi * x0) + torch.sin(20 * torch.pi * x0)).unsqueeze(1) + elif case_number == 3: + return (torch.sin(torch.pi * x0) + 0.1 * torch.sin(40 * torch.pi * x0)).unsqueeze(1) + else: + raise ValueError(f"Manufactured solution case {case_number} not defined for 1D.") + else: + raise ValueError(f"Manufactured solution case {case_number} not defined for dimension {domain_dim}") + + def a_exact(x): # diffusion coefficient + if domain_dim == 1: + x0 = x[:, 0] + if case_number == 1: + return (1.0 + 0.5 * torch.sin(2 * torch.pi * x0)).unsqueeze(1) + elif case_number == 2: + return (1.0 + 0.5 * torch.sin(2 * torch.pi * x0)).unsqueeze(1) + elif case_number == 3: + return (1.0 + 0.05 * torch.sin(100 * torch.pi * x0)).unsqueeze(1) + else: + raise ValueError(f"Diffusion coefficient case {case_number} not defined for 1D.") + else: + raise ValueError(f"Diffusion coefficient case {case_number} not defined for dimension {domain_dim}") + + def c_exact(x): # reaction coefficient + if domain_dim == 1: + x0 = x[:, 0] + if case_number == 1: + return (1.0 + 0.5 * torch.cos(torch.pi * x0)).unsqueeze(1) + elif case_number == 2: + return (1.0 + 0.5 * torch.cos(torch.pi * x0)).unsqueeze(1) + elif case_number == 3: + return 0.0 * torch.zeros_like(x0).unsqueeze(1) + else: + raise ValueError(f"Reaction coefficient case {case_number} not defined for 1D.") + else: + raise ValueError(f"Reaction coefficient case {case_number} not defined for dimension {domain_dim}") + + def f_exact(x): + x_clone = x.clone().detach().requires_grad_(True) + + u_val = u_exact(x_clone) # (batch_size, 1) + a_val = a_exact(x_clone) # (batch_size, 1) - assuming scalar a(x) + c_val = c_exact(x_clone) # (batch_size, 1) + + # Compute grad u + grad_u = torch.autograd.grad(u_val, x_clone, grad_outputs=torch.ones_like(u_val), create_graph=True)[0] # (batch_size, domain_dim) + + # Compute grad(a * grad u) + a_grad_u = a_val * grad_u # Element-wise product: (batch_size, domain_dim) + + # Compute divergence of (a * grad u) + div_a_grad_u = torch.zeros_like(u_val) # (batch_size, 1) + + for i in range(domain_dim): + # Compute d/dxi (a * du/dxi) + d_a_grad_u_dxi = torch.autograd.grad(a_grad_u[:, i], x_clone, + grad_outputs=torch.ones_like(a_grad_u[:, i]), + create_graph=True)[0][:, i] + div_a_grad_u += d_a_grad_u_dxi.unsqueeze(1) + + # f = -div(a grad u) + c*u + f_val = -div_a_grad_u + c_val * u_val + + x_clone.requires_grad_(False) + + return f_val + + return u_exact, f_exact, a_exact, c_exact diff --git a/tony/run.tuo b/tony/run.tuo index 769c416..27f5d2f 100644 --- a/tony/run.tuo +++ b/tony/run.tuo @@ -1,8 +1,5 @@ #!/bin/sh -#flux: --output='Poisson.{{id}}.out' -#flux: --error='Poisson.{{id}}.err' #flux: -N 1 -#flux: -n 1 #flux: -t 60 #flux: -q pdebug #flux: --exclusive @@ -12,7 +9,13 @@ job_ts=`date +"%Y%m%d_%H%M%S"` source NN_PDE_venv/bin/activate #flux run -N1 -n 8 --verbose --exclusive python3 test_poisson.py --base_log_dir ${job_ts}_logs --base_model_save_dir ${job_ts}_models --case 2 --drm_weight 1.0 --pinn_weight 0.0 --num_uniform_partition 2048 --epochs 20000 --logging_freq 50 &> ${job_ts}_output.txt -#flux run -N1 -n 8 --verbose --exclusive python3 test_poisson.py --base_log_dir ${job_ts}_logs --base_model_save_dir ${job_ts}_models --case 5 --drm_steps_per_cycle 10000 --pinn_steps_per_cycle 10000 --num_uniform_partition 2048 --epochs 20000 --logging_freq 50 &> ${job_ts}_output.txt -flux run -N1 -n 8 --verbose --exclusive python3 test_poisson.py --base_log_dir ${job_ts}_logs --base_model_save_dir ${job_ts}_models --case 5 --drm_steps_per_cycle 10000 --pinn_steps_per_cycle 10000 --num_uniform_partition 2048 --epochs 20000 --logging_freq 50 --chebyshev_freq_min 1 --chebyshev_freq_max 10 --distance sin_half_period --bc_extension multilinear &> ${job_ts}_output.txt + +#python3 test_poisson.py --base_log_dir ${job_ts}_logs --base_model_save_dir ${job_ts}_models --case 6 --drm_steps_per_cycle 100 --pinn_steps_per_cycle 900 --num_uniform_partition 2048 --epochs 20000 --logging_freq 50 --chebyshev_freq_min 1 --chebyshev_freq_max 4 --distance sin_half_period --bc_extension multilinear &> ${job_ts}_output.txt +#python3 test_poisson.py --base_log_dir ${job_ts}_logs --base_model_save_dir ${job_ts}_models --case 6 --drm_steps_per_cycle 10000 --pinn_steps_per_cycle 10000 --num_uniform_partition 2048 --epochs 20000 --logging_freq 50 --chebyshev_freq_min 1 --chebyshev_freq_max 4 --distance sin_half_period --bc_extension multilinear &> ${job_ts}_output.txt + +#python3 test_reaction_diffusion.py --base_log_dir ${job_ts}_logs --base_model_save_dir ${job_ts}_models --case 2 --drm_steps_per_cycle 100 --pinn_steps_per_cycle 900 --num_uniform_partition 2048 --epochs 20000 --logging_freq 50 --chebyshev_freq_min 1 --chebyshev_freq_max 20 --distance sin_half_period --bc_extension multilinear &> ${job_ts}_output.txt +#python3 test_reaction_diffusion.py --base_log_dir ${job_ts}_logs --base_model_save_dir ${job_ts}_models --case 1 --drm_steps_per_cycle 2000 --pinn_steps_per_cycle 18000 --num_uniform_partition 2048 --epochs 20000 --logging_freq 50 --chebyshev_freq_min 1 --chebyshev_freq_max 20 --distance sin_half_period --bc_extension multilinear --activation sin &> ${job_ts}_output.txt +python3 test_reaction_diffusion_new.py --base_log_dir ${job_ts}_logs --base_model_save_dir ${job_ts}_models --case 1 --drm_steps_per_cycle 2000 --pinn_steps_per_cycle 18000 --num_uniform_partition 2048 --epochs 20000 --logging_freq 50 --chebyshev_freq_min 1 --chebyshev_freq_max 20 --distance sin_half_period --bc_extension multilinear --activation sin &> ${job_ts}_output.txt rsync -av --exclude 'config.json' --exclude 'plots_run*/' ${job_ts}_logs /p/lustre5/cheung26/scp_local +cp ${job_ts}_output.txt /p/lustre5/cheung26/scp_local/${job_ts}_logs diff --git a/tony/test_poisson.py b/tony/test_poisson.py index 95ed383..6def3a9 100644 --- a/tony/test_poisson.py +++ b/tony/test_poisson.py @@ -9,9 +9,9 @@ from mpi4py import MPI from torch.optim.lr_scheduler import ReduceLROnPlateau -from config.poisson import PoissonSolverConfig +from config.elliptic import EllipticSolverConfig from pde.poisson import get_manufactured_solution -from model.nn import NNModel +from model.mlp import NNModel from model.bc import get_g0_func, get_d_func from loss.poisson import calculate_drm_loss, calculate_pinn_loss from data.samplers import generate_uniform_grid_points @@ -164,7 +164,7 @@ def evaluate_and_log(epoch, u_nn_model, history, config, f_exact_func, u_exact_f return eval_points_for_plot_np, u_exact_plot_data_flat # --- Main Training Loop for a Single Experiment Run --- -def run_experiment(config: PoissonSolverConfig, rank_val): +def run_experiment(config: EllipticSolverConfig, rank_val): print(f"Rank {rank_val}: Initializing experiment") # Each rank has its own random seed, config file, and model sub-directory @@ -205,6 +205,9 @@ def run_experiment(config: PoissonSolverConfig, rank_val): activation=config.activation, g0_func=g0_func, d_func=d_func, + use_positional_encoding = config.use_positional_encoding, + pe_freq_min = config.pe_freq_min, + pe_freq_max = config.pe_freq_max, use_chebyshev_basis=config.use_chebyshev_basis, chebyshev_freq_min=config.chebyshev_freq_min, chebyshev_freq_max=config.chebyshev_freq_max @@ -227,7 +230,8 @@ def run_experiment(config: PoissonSolverConfig, rank_val): mode=config.lr_patience_mode, factor=config.lr_factor, patience=config.lr_patience, - min_lr=config.lr_min + min_lr=config.lr_min, + verbose=True ) else: lr_scheduler = None @@ -285,7 +289,7 @@ def run_experiment(config: PoissonSolverConfig, rank_val): current_pinn_weight = 1.0 elif epoch > config.drm_steps_per_cycle / 2: current_angle = 2 * np.pi / config.steps_per_cycle * (epoch - config.drm_steps_per_cycle) - current_drm_weight = 1 / (1 + np.exp(0.01 * config.steps_per_cycle * np.sin(current_angle))) + current_drm_weight = 1 / (1 + np.exp(config.steps_per_cycle * np.sin(current_angle))) current_pinn_weight = 1.0 - current_drm_weight # Log the current weights @@ -337,7 +341,6 @@ def run_experiment(config: PoissonSolverConfig, rank_val): metric = history['total_loss'][-1] # Use the last logged total loss as the metric lr_scheduler.step(metric) - if epoch % config.checkpoint_freq == 0: # Check for checkpoint frequency torch.save(u_nn_model.state_dict(), os.path.join(rank_model_path, f'model_epoch_{epoch}.pth')) logger.log_message(f"Checkpoint saved at epoch {epoch}") @@ -357,16 +360,16 @@ def run_experiment(config: PoissonSolverConfig, rank_val): # --- Main Execution Block with Argument Parsing --- if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Run Poisson PINN experiments with configurable parameters.') + parser = argparse.ArgumentParser(description='Run Poisson experiments with configurable parameters.') # Core experiment parameters parser.add_argument('--dim', type=int, default=1, choices=[1, 2, 3], - help='Dimension of the Poisson problem (1, 2, or 3).') + help='Dimension of the problem (1, 2, or 3).') parser.add_argument('--epochs', type=int, default=1000, help='Number of training epochs.') parser.add_argument('--seed', type=int, default=42, help='Base random seed for reproducibility; actual seed will be seed + rank.') - parser.add_argument('--case', type=int, default=1, + parser.add_argument('--case', type=int, default=4, help='Manufactured solution case number.') # Neural Network Architecture @@ -380,6 +383,10 @@ def run_experiment(config: PoissonSolverConfig, rank_val): parser.add_argument('--distance', type=str, default='sin_half_period', choices=['quadratic_bubble', 'inf_smooth_bump', 'abs_dist_complement', 'ratio_bubble_dist', 'sin_half_period'], help='Distance function.') + parser.add_argument('--pe_freq_min', type=int, default=-1, + help='Minimum frequency power for positional encoding.') + parser.add_argument('--pe_freq_max', type=int, default=-1, + help='Maximum frequency power for positional encoding.') parser.add_argument('--chebyshev_freq_min', type=int, default=-1, help='Minimum frequency for Chebyshev polynomials.') parser.add_argument('--chebyshev_freq_max', type=int, default=-1, @@ -431,9 +438,9 @@ def run_experiment(config: PoissonSolverConfig, rank_val): help='Resolution for plotting grids') parser.add_argument('--log_fourier_coeffs', type=bool, default=True, help='Whether to log Fourier coefficients in the real basis.') - parser.add_argument('--use_sine_series', type=bool, default=True, # Renamed in analysis/fourier.py + parser.add_argument('--use_sine_series', type=bool, default=True, help='Whether to use sine series expansion instead of full Fourier series expansion.') - parser.add_argument('--fourier_freq', type=int, nargs='+', default=[1, 4, 9], + parser.add_argument('--fourier_freq', type=int, nargs='+', default=[1, 3, 7], help='Frequency of Fourier coefficients to log.') parser.add_argument('--base_log_dir', type=str, default='logs', help='Base directory for saving experiment logs.') @@ -459,6 +466,7 @@ def run_experiment(config: PoissonSolverConfig, rank_val): # Define common configuration arguments, including device base_config_kwargs = { 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"), + 'problem': 'Poisson', 'domain_dim': args.dim, 'case_number': args.case, 'random_seed': args.seed, @@ -466,6 +474,8 @@ def run_experiment(config: PoissonSolverConfig, rank_val): 'activation': parsed_activation, 'bc_extension': args.bc_extension, 'distance': args.distance, + 'pe_freq_min': args.pe_freq_min, + 'pe_freq_max': args.pe_freq_max, 'chebyshev_freq_min': args.chebyshev_freq_min, 'chebyshev_freq_max': args.chebyshev_freq_max, 'drm_weight': args.drm_weight, @@ -506,7 +516,7 @@ def run_experiment(config: PoissonSolverConfig, rank_val): base_config_kwargs['domain_bounds'] = ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)) # Create the config object for THIS rank - rank_config = PoissonSolverConfig(**base_config_kwargs) + rank_config = EllipticSolverConfig(**base_config_kwargs) print(f"Rank {rank}: Starting its experiment run{rank+1}.") model_state_dict, history = run_experiment(rank_config, rank) diff --git a/tony/test_reaction_diffusion.py b/tony/test_reaction_diffusion.py new file mode 100644 index 0000000..4dadc48 --- /dev/null +++ b/tony/test_reaction_diffusion.py @@ -0,0 +1,558 @@ +import argparse +import datetime +import os +import json +import time +import torch +import numpy as np +import torch.optim as optim +from mpi4py import MPI +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from config.elliptic import EllipticSolverConfig +from pde.reaction_diffusion import get_manufactured_solution +from model.mlp import NNModel +from model.bc import get_g0_func, get_d_func +from loss.reaction_diffusion import calculate_drm_loss, calculate_pinn_loss +from data.samplers import generate_uniform_grid_points +from analysis.fourier import calculate_fourier_coefficients +from utils.log import ExperimentLogger +from utils.visualize import plot_ensemble_norm_errors, plot_solution_video, plot_fourier_coefficients, plot_fourier_coefficient_errors, plot_norm_errors + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +def evaluate_and_log(epoch, u_nn_model, history, config, f_exact_func, u_exact_func, a_exact_func, c_exact_func, + logger, rank_val, eval_points_for_plot_np, u_exact_plot_data_flat, + full_uniform_grid_points, current_drm_weight, current_pinn_weight): + u_nn_model.eval() + + eval_points_for_errors = full_uniform_grid_points.requires_grad_(True).to(config.device) + + drm_loss = calculate_drm_loss(u_nn_model, f_exact_func, a_exact_func, c_exact_func, eval_points_for_errors) + pinn_loss = calculate_pinn_loss(u_nn_model, f_exact_func, a_exact_func, c_exact_func, eval_points_for_errors, config.domain_dim) + + total_loss = current_drm_weight * drm_loss + current_pinn_weight * pinn_loss + + if epoch == 0: + history['total_loss'].append(total_loss.item()) + history['drm_loss'].append(drm_loss.item()) + history['pinn_loss'].append(pinn_loss.item()) + logger.log_scalar('Weights/DRM_Weight', current_drm_weight, step=epoch) + logger.log_scalar('Weights/PINN_Weight', current_pinn_weight, step=epoch) + + + logger.log_scalar('Loss/Total_Loss', total_loss.item(), step=epoch) + logger.log_scalar('Loss/DRM_Loss', drm_loss.item(), step=epoch) + logger.log_scalar('Loss/PINN_Loss', pinn_loss.item(), step=epoch) + + # L2 norm error + u_pred_eval_l2 = u_nn_model(eval_points_for_errors) + u_exact_eval_l2 = u_exact_func(eval_points_for_errors).detach() + l2_error_u = torch.mean((u_pred_eval_l2 - u_exact_eval_l2)**2) + + # H1 seminorm error + eval_points_for_derivs = eval_points_for_errors.clone().detach().requires_grad_(True) + u_pred_for_derivs = u_nn_model(eval_points_for_derivs) + u_exact_for_derivs = u_exact_func(eval_points_for_derivs) + grad_u_pred_eval = torch.autograd.grad(u_pred_for_derivs, eval_points_for_derivs, + grad_outputs=torch.ones_like(u_pred_for_derivs), + create_graph=True, allow_unused=True)[0] + grad_u_exact_eval = torch.autograd.grad(u_exact_for_derivs, eval_points_for_derivs, + grad_outputs=torch.ones_like(u_exact_for_derivs), + create_graph=True, allow_unused=True)[0] + if grad_u_pred_eval is None: grad_u_pred_eval = torch.zeros_like(eval_points_for_derivs) + if grad_u_exact_eval is None: grad_u_exact_eval = torch.zeros_like(eval_points_for_derivs) + + h1_seminorm_error_u = torch.mean(torch.sum((grad_u_pred_eval - grad_u_exact_eval)**2, dim=1)) + + # H2 seminorm error + laplacian_u_pred_eval = torch.zeros_like(u_pred_for_derivs, device=config.device) + for i in range(config.domain_dim): + d2u_dxi2 = torch.autograd.grad(grad_u_pred_eval[:, i], eval_points_for_derivs, + grad_outputs=torch.ones_like(grad_u_pred_eval[:, i]), + create_graph=True, allow_unused=True)[0][:, i] + laplacian_u_pred_eval += d2u_dxi2.unsqueeze(1) + + laplacian_u_exact_eval = torch.zeros_like(u_exact_for_derivs, device=config.device) + for i in range(config.domain_dim): + d2u_dxi2_star = torch.autograd.grad(grad_u_exact_eval[:, i], eval_points_for_derivs, + grad_outputs=torch.ones_like(grad_u_exact_eval[:, i]), + create_graph=False, allow_unused=True)[0][:, i] + laplacian_u_exact_eval += d2u_dxi2_star.unsqueeze(1) + + h2_seminorm_error_u = torch.mean((laplacian_u_pred_eval - laplacian_u_exact_eval)**2) + + eval_points_for_errors.requires_grad_(False) + eval_points_for_derivs.requires_grad_(False) + + history['epochs_logged'].append(epoch) + history['l2_error_u'].append(l2_error_u.item()) + history['h1_seminorm_error_u'].append(h1_seminorm_error_u.item()) + history['h2_seminorm_error_u'].append(h2_seminorm_error_u.item()) + + if config.domain_dim != 1: + print(f"Fourier coefficient calculation for real basis is only supported for 1D. " + f"Current domain dimension is {config.domain_dim}. Skipping calculation.") + config.log_fourier_coefficients = False + + if config.log_fourier_coefficients: + fourier_data, u_exact_plot_data_flat_current, u_pred_plot_data_flat_current, eval_points_for_plot_np_current = \ + calculate_fourier_coefficients(config, u_nn_model, u_exact_func) + + if history['fourier_frequencies_logged'] is None: + history['fourier_frequencies_logged'] = fourier_data['frequencies'] + for freq_idx_tuple in history['fourier_frequencies_logged']: + freq_key = str(freq_idx_tuple) + history['fourier_coeffs_nn_magnitudes'][freq_key] = [] + history['fourier_coeffs_true_magnitudes'][freq_key] = [] + history['fourier_coeffs_error_magnitudes'][freq_key] = [] + + for i, freq_idx_tuple in enumerate(history['fourier_frequencies_logged']): + freq_key = str(freq_idx_tuple) + history['fourier_coeffs_nn_magnitudes'][freq_key].append(fourier_data['nn_coeffs'][i]) + history['fourier_coeffs_true_magnitudes'][freq_key].append(fourier_data['true_coeffs'][i]) + history['fourier_coeffs_error_magnitudes'][freq_key].append(fourier_data['error_coeffs'][i]) + + logger.log_scalar(f'Fourier/NN_Coeff_Magnitude_Freq_{freq_key}', fourier_data['nn_coeffs'][i], step=epoch) + logger.log_scalar(f'Fourier/True_Coeff_Magnitude_Freq_{freq_key}', fourier_data['true_coeffs'][i], step=epoch) + logger.log_scalar(f'Fourier/Error_Coeff_Magnitude_Freq_{freq_key}', fourier_data['error_coeffs'][i], step=epoch) + + if eval_points_for_plot_np is None: + eval_points_for_plot_np = eval_points_for_plot_np_current + u_exact_plot_data_flat = u_exact_plot_data_flat_current + logger.save_plot_data({'eval_points': eval_points_for_plot_np, 'u_exact_data': u_exact_plot_data_flat}, + 'true_solution_data', epoch) + logger.log_message(f"Saved true solution and evaluation points for plotting.") + + if u_pred_plot_data_flat_current is not None and u_pred_plot_data_flat_current.size > 0: + logger.save_plot_data({'u_nn_data': u_pred_plot_data_flat_current}, 'nn_solution_data', epoch) + history['solution_snapshots_epochs'].append(epoch) + else: + logger.log_message(f"Warning: u_pred_plot_data_flat_current is empty/None, skipping saving nn_solution_data for epoch {epoch}.") + + current_lr = None + if epoch == 0: + current_lr = config.learning_rate + logger.log_scalar('Learning_Rate', current_lr, step=epoch) + elif hasattr(u_nn_model, 'optimizer') and u_nn_model.optimizer is not None and u_nn_model.optimizer.param_groups: + current_lr = u_nn_model.optimizer.param_groups[0]['lr'] + logger.log_scalar('Learning_Rate', current_lr, step=epoch) + + print_str = f"Rank {rank_val}, Epoch {epoch}/{config.num_epochs}, " \ + f"L2 Error: {l2_error_u.item():.6e}, " \ + f"H1 Error: {h1_seminorm_error_u.item():.6e}, " \ + f"H2 Error: {h2_seminorm_error_u.item():.6e}, " \ + f"Total Loss: {total_loss.item():.6e}, " \ + f"DRM Loss: {drm_loss.item():.6e}, " \ + f"PINN Loss: {pinn_loss.item():.6e}" + if current_lr is not None: + print_str += f", LR: {current_lr:.2e}" + print(print_str) + + return eval_points_for_plot_np, u_exact_plot_data_flat + +def run_experiment(config: EllipticSolverConfig, rank_val): + print(f"Rank {rank_val}: Initializing experiment") + + config.random_seed += rank_val + torch.manual_seed(config.random_seed) + np.random.seed(config.random_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(config.random_seed) + + config_dict = config.to_dict() + logger = ExperimentLogger(config.base_log_dir, rank_val) + logger.log_message(f"\n--- Experiment Configuration (Rank {rank_val}/{size-1}) ---") + for key, value in config_dict.items(): + logger.log_message(f" {key}: {value}") + logger.log_message("------------------------------\n") + logger.save_config(config_dict) + + rank_model_path = os.path.join(config.base_model_save_dir, f"run{rank_val+1}") + os.makedirs(rank_model_path, exist_ok=True) + + u_exact_func, f_exact_func, a_exact_func, c_exact_func = get_manufactured_solution(config.case_number, config.domain_dim) + g0_func = get_g0_func(u_exact_func, config.domain_dim, config.domain_bounds, config.bc_extension) + d_func = get_d_func(config.domain_dim, config.domain_bounds, config.distance) + + full_uniform_grid_points = generate_uniform_grid_points( + config.domain_bounds, config.num_uniform_partition + ).to(config.device) + total_domain_points_in_grid = full_uniform_grid_points.shape[0] + logger.log_message(f"Total domain points in grid: {total_domain_points_in_grid}") + + u_nn_model = NNModel( + input_dim=config.domain_dim, + output_dim=1, + hidden_neurons=config.hidden_neurons, + activation=config.activation, + g0_func=g0_func, + d_func=d_func, + use_positional_encoding = config.use_positional_encoding, + pe_freq_min = config.pe_freq_min, + pe_freq_max = config.pe_freq_max, + use_chebyshev_basis=config.use_chebyshev_basis, + chebyshev_freq_min=config.chebyshev_freq_min, + chebyshev_freq_max=config.chebyshev_freq_max + ).to(config.device) + + optimizer_class = getattr(optim, config.optimizer_type) + optimizer = optimizer_class(u_nn_model.parameters(), lr=config.learning_rate) + + u_nn_model.optimizer = optimizer + + if config.lr_scheduler_type == 'ExponentialLR': + lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config.lr_decay_gamma) + elif config.lr_scheduler_type == 'MultiStepLR': + lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.lr_step_milestones, gamma=config.lr_step_gamma) + elif config.lr_scheduler_type == 'ReduceLROnPlateau': + lr_scheduler = ReduceLROnPlateau( + optimizer, + mode=config.lr_patience_mode, + factor=config.lr_factor, + patience=config.lr_patience, + min_lr=config.lr_min, + verbose=True + ) + else: + lr_scheduler = None + + d_min = 1 + d_max = max(1, total_domain_points_in_grid // config.batch_size) + + logger.log_message(f"Starting training for {config.num_epochs} epochs...") + + history = { + 'epochs_logged': [], 'total_loss': [], 'drm_loss': [], 'pinn_loss': [], + 'l2_error_u': [], 'h1_seminorm_error_u': [], 'h2_seminorm_error_u': [], + 'fourier_coeffs_nn_magnitudes': {}, 'fourier_coeffs_true_magnitudes': {}, + 'fourier_coeffs_error_magnitudes': {}, 'fourier_frequencies_logged': None, + 'solution_snapshots_epochs': [], + } + eval_points_for_plot_np = None + u_exact_plot_data_flat = None + + start_total_time = time.time() + + if config.steps_per_cycle > 0: + current_drm_weight = 1.0 + current_pinn_weight = 0.0 + else: + current_drm_weight = config.drm_weight + current_pinn_weight = config.pinn_weight + + logger.log_message(f"--- Epoch 0 ---") + eval_points_for_plot_np, u_exact_plot_data_flat = evaluate_and_log( + epoch=0, + u_nn_model=u_nn_model, + history=history, + config=config, + f_exact_func=f_exact_func, + u_exact_func=u_exact_func, + a_exact_func=a_exact_func, + c_exact_func=c_exact_func, + logger=logger, + rank_val=rank_val, + eval_points_for_plot_np=eval_points_for_plot_np, + u_exact_plot_data_flat=u_exact_plot_data_flat, + full_uniform_grid_points=full_uniform_grid_points, + current_drm_weight=current_drm_weight, + current_pinn_weight=current_pinn_weight + ) + torch.save(u_nn_model.state_dict(), os.path.join(rank_model_path, f'model_epoch_0.pth')) + logger.log_message(f"Checkpoint saved at epoch 0") + + for epoch in range(1, config.num_epochs + 1): + u_nn_model.train() + + if config.steps_per_cycle > 0: + if epoch > config.num_epochs - config.drm_steps_per_cycle / 2: + current_drm_weight = 0.0 + current_pinn_weight = 1.0 + elif epoch > config.drm_steps_per_cycle / 2: + current_angle = 2 * np.pi / config.steps_per_cycle * (epoch - config.drm_steps_per_cycle) + current_drm_weight = 1 / (1 + np.exp(config.steps_per_cycle * np.sin(current_angle))) + current_pinn_weight = 1.0 - current_drm_weight + else: + current_drm_weight = 1.0 + current_pinn_weight = 0.0 + + current_drm_weight = max(0.0, min(1.0, current_drm_weight)) + current_pinn_weight = max(0.0, min(1.0, current_pinn_weight)) + + logger.log_scalar('Weights/DRM_Weight', current_drm_weight, step=epoch) + logger.log_scalar('Weights/PINN_Weight', current_pinn_weight, step=epoch) + + clamped_drm_weight = max(0.0, min(1.0, current_drm_weight)) + dynamic_d = int(d_min + (d_max - d_min) * (1 - clamped_drm_weight)) + dynamic_d = max(1, dynamic_d) + + all_strided_indices = [] + for p in range(dynamic_d): + all_strided_indices.extend(range(p, total_domain_points_in_grid, dynamic_d)) + + all_strided_indices_tensor = torch.tensor(all_strided_indices, dtype=torch.long) + all_strided_indices_tensor = all_strided_indices_tensor[torch.randperm(all_strided_indices_tensor.shape[0])] + + num_samples_for_epoch = all_strided_indices_tensor.shape[0] + num_batches_per_epoch = (num_samples_for_epoch + config.batch_size - 1) // config.batch_size + + epoch_total_loss_sum = 0.0 + epoch_drm_loss_sum = 0.0 + epoch_pinn_loss_sum = 0.0 + num_batches_actual = 0 + + for batch_idx_in_epoch in range(num_batches_per_epoch): + optimizer.zero_grad() + + batch_start = batch_idx_in_epoch * config.batch_size + batch_end = min((batch_idx_in_epoch + 1) * config.batch_size, num_samples_for_epoch) + + current_batch_indices = all_strided_indices_tensor[batch_start:batch_end] + + if current_batch_indices.numel() == 0: + continue + + sampled_domain_points_batch = full_uniform_grid_points[current_batch_indices].requires_grad_(True).to(config.device) + + drm_loss = calculate_drm_loss(u_nn_model, f_exact_func, a_exact_func, c_exact_func, sampled_domain_points_batch) + pinn_loss = calculate_pinn_loss(u_nn_model, f_exact_func, a_exact_func, c_exact_func, sampled_domain_points_batch, config.domain_dim) + + total_loss = current_drm_weight * drm_loss + current_pinn_weight * pinn_loss + + total_loss.backward() + optimizer.step() + + epoch_total_loss_sum += total_loss.item() + epoch_drm_loss_sum += drm_loss.item() + epoch_pinn_loss_sum += pinn_loss.item() + num_batches_actual += 1 + + if lr_scheduler and config.lr_scheduler_type != 'ReduceLROnPlateau': + lr_scheduler.step() + + if num_batches_actual > 0: + history['total_loss'].append(epoch_total_loss_sum / num_batches_actual) + history['drm_loss'].append(epoch_drm_loss_sum / num_batches_actual) + history['pinn_loss'].append(epoch_pinn_loss_sum / num_batches_actual) + else: + history['total_loss'].append(history['total_loss'][-1] if history['total_loss'] else 0.0) + history['drm_loss'].append(history['drm_loss'][-1] if history['drm_loss'] else 0.0) + history['pinn_loss'].append(history['pinn_loss'][-1] if history['pinn_loss'] else 0.0) + + if epoch % config.logging_freq == 0: + logger.log_message(f"--- Epoch {epoch} ---") + eval_points_for_plot_np, u_exact_plot_data_flat = evaluate_and_log( + epoch=epoch, + u_nn_model=u_nn_model, + history=history, + config=config, + f_exact_func=f_exact_func, + u_exact_func=u_exact_func, + a_exact_func=a_exact_func, + c_exact_func=c_exact_func, + logger=logger, + rank_val=rank_val, + eval_points_for_plot_np=eval_points_for_plot_np, + u_exact_plot_data_flat=u_exact_plot_data_flat, + full_uniform_grid_points=full_uniform_grid_points, + current_drm_weight=current_drm_weight, + current_pinn_weight=current_pinn_weight + ) + if lr_scheduler and config.lr_scheduler_type == 'ReduceLROnPlateau': + metric = history['total_loss'][-1] + lr_scheduler.step(metric) + + + if epoch % config.checkpoint_freq == 0: + torch.save(u_nn_model.state_dict(), os.path.join(rank_model_path, f'model_epoch_{epoch}.pth')) + logger.log_message(f"Checkpoint saved at epoch {epoch}") + + end_total_time = time.time() + total_duration = end_total_time - start_total_time + logger.log_message(f"Training finished in {total_duration:.2f} seconds.") + torch.save(u_nn_model.state_dict(), os.path.join(rank_model_path, 'model_final.pth')) + + rank_log_dir_for_plotting = logger.log_dir + logger.close() + + return u_nn_model.state_dict(), history + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Run reaction-diffusion experiments with configurable parameters.') + + parser.add_argument('--dim', type=int, default=1, choices=[1, 2, 3], + help='Dimension of the problem (1, 2, or 3).') + parser.add_argument('--epochs', type=int, default=1000, + help='Number of training epochs.') + parser.add_argument('--seed', type=int, default=42, + help='Base random seed for reproducibility; actual seed will be seed + rank.') + parser.add_argument('--case', type=int, default=4, + help='Manufactured solution case number.') + + parser.add_argument('--hidden_neurons', type=int, nargs='+', default=[30, 30, 30], + help='List of integers for the number of neurons in each hidden layer.') + parser.add_argument('--activation', type=str, nargs='+', default=['tanh'], + help='Activation function(s). Can be a single string (e.g., "relu") or a list (e.g., "tanh" "relu").') + parser.add_argument('--bc_extension', type=str, default='hermite_cubic_2nd_deriv', + choices=['multilinear', 'hermite_cubic_2nd_deriv'], + help='Boundary value extension function.') + parser.add_argument('--distance', type=str, default='sin_half_period', + choices=['quadratic_bubble', 'inf_smooth_bump', 'abs_dist_complement', 'ratio_bubble_dist', 'sin_half_period'], + help='Distance function.') + parser.add_argument('--pe_freq_min', type=int, default=-1, + help='Minimum frequency power for positional encoding.') + parser.add_argument('--pe_freq_max', type=int, default=-1, + help='Maximum frequency power for positional encoding.') + parser.add_argument('--chebyshev_freq_min', type=int, default=-1, + help='Minimum frequency for Chebyshev polynomials.') + parser.add_argument('--chebyshev_freq_max', type=int, default=-1, + help='Maximum frequency for Chebyshev polynomials.') + + parser.add_argument('--num_uniform_partition', type=int, default=64, + help='Number of subintervals along each dimension for uniform partitioning.') + parser.add_argument('--batch_size', type=int, default=64, + help='Number of points in each batch.') + + parser.add_argument('--drm_weight', type=float, default=0.0, + help='Weight for the DRM loss term (used if drm_steps_per_cycle is 0).') + parser.add_argument('--pinn_weight', type=float, default=1.0, + help='Weight for the PINN (PDE residual) loss term (used if pinn_steps_per_cycle is 0).') + parser.add_argument('--drm_steps_per_cycle', type=int, default=0, + help='Number of epochs to train with DRM loss active in each cycle.') + parser.add_argument('--pinn_steps_per_cycle', type=int, default=0, + help='Number of epochs to train with PINN loss active in each cycle.') + + parser.add_argument('--optimizer_type', type=str, default='Adam', choices=['Adam', 'SGD', 'LBFGS'], + help='Type of optimizer to use (e.g., Adam, SGD, LBFGS).') + parser.add_argument('--lr', type=float, default=1e-3, + help='Learning rate for the optimizer.') + parser.add_argument('--lr_scheduler_type', type=str, default='ReduceLROnPlateau', + choices=['ExponentialLR', 'MultiStepLR', 'ReduceLROnPlateau', 'None'], + help='Type of learning rate scheduler.') + parser.add_argument('--lr_decay_gamma', type=float, default=0.999, + help='Gamma for ExponentialLR (decay rate per epoch).') + parser.add_argument('--lr_step_milestones', type=int, nargs='+', default=[2000, 4000], + help='List of epochs when learning rate should drop for MultiStepLR.') + parser.add_argument('--lr_step_gamma', type=float, default=0.1, + help='Factor by which to multiply the learning rate at milestones for MultiStepLR.') + parser.add_argument('--lr_patience', type=int, default=50, + help='Number of epochs with no improvement after which learning rate will be reduced for ReduceLROnPlateau.') + parser.add_argument('--lr_factor', type=float, default=0.5, + help='Factor by which the learning rate will be reduced for ReduceLROnPlateau.') + parser.add_argument('--lr_min', type=float, default=1e-6, + help='Minimum learning rate for ReduceLROnPlateau.') + parser.add_argument('--lr_patience_mode', type=str, default='min', choices=['min', 'max'], + help='Mode for ReduceLROnPlateau (e.g., "min" for loss, "max" for accuracy).') + + parser.add_argument('--logging_freq', type=int, default=50, + help='Frequency (in epochs) to log metrics and save snapshots.') + parser.add_argument('--plot_resolution', type=int, default=256, + help='Resolution for plotting grids') + parser.add_argument('--log_fourier_coeffs', type=bool, default=True, + help='Whether to log Fourier coefficients in the real basis.') + parser.add_argument('--use_sine_series', type=bool, default=True, + help='Whether to use sine series expansion instead of full Fourier series expansion.') + parser.add_argument('--fourier_freq', type=int, nargs='+', default=[1, 2, 10, 20], + help='Frequency of Fourier coefficients to log.') + parser.add_argument('--base_log_dir', type=str, default='logs', + help='Base directory for saving experiment logs.') + parser.add_argument('--base_model_save_dir', type=str, default='models', + help='Base directory for saving model checkpoints.') + parser.add_argument('--checkpoint_freq', type=int, default=1000, + help='Frequency (in epochs) to save model checkpoints.') + + parser.add_argument('--slice_plane_dim', type=int, default=2, choices=[0, 1, 2], + help='For 3D, which dimension to slice (0 for x, 1 for y, 2 for z).') + parser.add_argument('--slice_plane_val', type=float, default=0.5, + help='For 3D, value along the sliced dimension (e.g., z=0.5).') + + args = parser.parse_args() + + if len(args.activation) == 1: + parsed_activation = args.activation[0] + else: + parsed_activation = args.activation + + base_config_kwargs = { + 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"), + 'problem': 'reaction-diffusion', + 'domain_dim': args.dim, + 'case_number': args.case, + 'random_seed': args.seed, + 'hidden_neurons': args.hidden_neurons, + 'activation': parsed_activation, + 'bc_extension': args.bc_extension, + 'distance': args.distance, + 'pe_freq_min': args.pe_freq_min, + 'pe_freq_max': args.pe_freq_max, + 'chebyshev_freq_min': args.chebyshev_freq_min, + 'chebyshev_freq_max': args.chebyshev_freq_max, + 'drm_weight': args.drm_weight, + 'pinn_weight': args.pinn_weight, + 'drm_steps_per_cycle': args.drm_steps_per_cycle, + 'pinn_steps_per_cycle': args.pinn_steps_per_cycle, + 'num_uniform_partition': args.num_uniform_partition, + 'batch_size': args.batch_size, + 'num_epochs': args.epochs, + 'learning_rate': args.lr, + 'optimizer_type': args.optimizer_type, + 'lr_scheduler_type': args.lr_scheduler_type, + 'lr_decay_gamma': args.lr_decay_gamma, + 'lr_step_milestones': args.lr_step_milestones, + 'lr_step_gamma': args.lr_step_gamma, + 'lr_patience': args.lr_patience, + 'lr_factor': args.lr_factor, + 'lr_min': args.lr_min, + 'lr_patience_mode': args.lr_patience_mode, + 'plot_resolution': args.plot_resolution, + 'log_fourier_coefficients': args.log_fourier_coeffs, + 'use_sine_series': args.use_sine_series, + 'fourier_freq': args.fourier_freq, + 'base_log_dir': args.base_log_dir, + 'base_model_save_dir': args.base_model_save_dir, + 'logging_freq': args.logging_freq, + 'checkpoint_freq': args.checkpoint_freq, + 'slice_plane_dim': args.slice_plane_dim, + 'slice_plane_val': args.slice_plane_val, + } + + if args.dim == 1: + base_config_kwargs['domain_bounds'] = (0.0, 1.0) + elif args.dim == 2: + base_config_kwargs['domain_bounds'] = ((0.0, 1.0), (0.0, 1.0)) + elif args.dim == 3: + base_config_kwargs['domain_bounds'] = ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)) + + rank_config = EllipticSolverConfig(**base_config_kwargs) + + print(f"Rank {rank}: Starting its experiment run{rank+1}.") + model_state_dict, history = run_experiment(rank_config, rank) + print(f"Rank {rank}: Finished its experiment run{rank+1}.") + + print(f"Rank {rank}: Generating individual plots for run{rank+1}.") + plot_solution_video(args.base_log_dir, rank_config, rank, + output_filename=f'solution_evolution_run{rank+1}.mp4') + plot_norm_errors(args.base_log_dir, history, rank_config, + output_filename=f'norm_errors_evolution_run{rank+1}.png') + if args.log_fourier_coeffs: + plot_fourier_coefficients(args.base_log_dir, history, rank_config, + output_filename=f'fourier_coefficients_evolution_run{rank+1}.png') + plot_fourier_coefficient_errors(args.base_log_dir, history, rank_config, + output_filename=f'fourier_coefficient_errors_evolution_run{rank+1}.png') + print(f"Rank {rank}: Finished generating individual plots for run{rank+1}.") + + if size > 1: + print(f"Rank {rank}: Gathering results for ensemble visualization") + all_histories = comm.gather(history, root=0) + + if rank == 0: + print(f"Rank 0: Generating Ensemble Visualizations for {size} runs in total)") + plot_ensemble_norm_errors(args.base_log_dir, all_histories, rank_config, + output_filename=f'ensemble_norm_errors.png') + print("Rank 0: Ensemble plots complete.") + + comm.Barrier() + print(f"Rank {rank}: Exiting.") diff --git a/tony/utils/visualize.py b/tony/utils/visualize.py index 5b44e80..ea6ca63 100644 --- a/tony/utils/visualize.py +++ b/tony/utils/visualize.py @@ -58,7 +58,7 @@ def plot_solution_video(base_log_dir, config_obj, rank_val, output_filename='sol ax.plot(x_coords, u_exact_data_flat, 'k--', label='True Solution') line, = ax.plot([], [], 'r-', label='NN Approximation') - ax.set_title(f'NN Solution Evolution (1D Poisson - Case {config_obj.case_number})') + ax.set_title(f'NN Solution Evolution (1D {config_obj.problem} - Case {config_obj.case_number})') ax.set_xlabel('x') ax.set_ylabel('u(x)') ax.legend() @@ -85,7 +85,7 @@ def animate_1d(i): surf = ax.plot_surface(x_coords, y_coords, np.zeros_like(u_exact_reshaped), cmap='plasma', alpha=0.9) - ax.set_title(f'NN Solution Evolution (2D Poisson - Case {config_obj.case_number})') + ax.set_title(f'NN Solution Evolution (2D {config_obj.problem} - Case {config_obj.case_number})') ax.set_xlabel('x') ax.set_ylabel('y') ax.set_zlabel('u(x,y)') @@ -141,7 +141,7 @@ def animate_2d(i): config.domain_bounds[plot_dims[1]][0], config_obj.domain_bounds[plot_dims[1]][1]], cmap='plasma', vmin=c_min, vmax=c_max) - ax.set_title(f'NN Solution Evolution (3D Poisson - Case {config_obj.case_number})\nSlice at {chr(ord("x")+slice_dim)}={actual_slice_val:.2f}') + ax.set_title(f'NN Solution Evolution (3D {config_obj.problem} - Case {config_obj.case_number})\nSlice at {chr(ord("x")+slice_dim)}={actual_slice_val:.2f}') ax.set_xlabel(f'{chr(ord("x")+plot_dims[0])}') ax.set_ylabel(f'{chr(ord("x")+plot_dims[1])}') plt.colorbar(img_nn, ax=ax, label='u value')