Skip to content
Open
199 changes: 185 additions & 14 deletions pinn/pinn_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
import torch.nn as nn
import torch.optim as optim
import numpy as np
import itertools
from enum import Enum
from typing import Union, Tuple, Callable
from utils import parse_args, get_activation, print_args, save_frame, make_video_from_frames
from utils import is_notebook, cleanfiles, fourier_analysis, get_scheduler_generator, scheduler_step
from cheby import generate_chebyshev_features
Expand All @@ -57,6 +59,147 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# %%
# Helper functions from the new BC implementation
def _calculate_laplacian_1d(func: Callable[[torch.Tensor], torch.Tensor], x_val: float) -> torch.Tensor:
x_tensor = torch.tensor([[x_val]], dtype=torch.float32, requires_grad=True)
u = func(x_tensor)
grad_u = torch.autograd.grad(u, x_tensor, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0]
laplacian_u = torch.autograd.grad(grad_u, x_tensor, grad_outputs=torch.ones_like(grad_u), create_graph=False, retain_graph=False)[0]
return laplacian_u

def get_g0_func(
u_exact_func: Callable[[torch.Tensor], torch.Tensor],
domain_dim: int,
domain_bounds: Union[Tuple[float, float], Tuple[Tuple[float, float], ...]],
g0_type: str = "multilinear"
) -> Callable[[torch.Tensor], torch.Tensor]:
domain_bounds_tuple = domain_bounds
if domain_dim == 1 and not isinstance(domain_bounds[0], (tuple, list)):
domain_bounds_tuple = (domain_bounds,)
min_bounds = torch.tensor([b[0] for b in domain_bounds_tuple], dtype=torch.float32)
max_bounds = torch.tensor([b[1] for b in domain_bounds_tuple], dtype=torch.float32)

if g0_type == "hermite_cubic_2nd_deriv":
if domain_dim != 1: raise ValueError("Hermite cubic interpolation with 2nd derivatives is only supported for 1D problems.")
x0, x1 = min_bounds.item(), max_bounds.item()
h = x1 - x0
u_x0 = u_exact_func(torch.tensor([[x0]], dtype=torch.float32)).item()
u_x1 = u_exact_func(torch.tensor([[x1]], dtype=torch.float32)).item()
u_prime_prime_x0 = _calculate_laplacian_1d(u_exact_func, x0).item()
u_prime_prime_x1 = _calculate_laplacian_1d(u_exact_func, x1).item()
a3 = (u_prime_prime_x1 - u_prime_prime_x0) / (6 * h)
a2 = u_prime_prime_x0 / 2 - 3 * a3 * x0
a1 = (u_x1 - u_x0) / h - a2 * (x1 + x0) - a3 * (x1**2 + x1 * x0 + x0**2)
a0 = u_x0 - a1 * x0 - a2 * x0**2 - a3 * x0**3
coeffs = torch.tensor([a0, a1, a2, a3], dtype=torch.float32)

def g0_hermite_cubic_val(x: torch.Tensor) -> torch.Tensor:
x_flat = x[:, 0]
g0_vals = coeffs[0] + coeffs[1] * x_flat + coeffs[2] * (x_flat**2) + coeffs[3] * (x_flat**3)
return g0_vals.unsqueeze(1)
return g0_hermite_cubic_val

elif g0_type == "multilinear":
boundary_values = {}
dim_ranges = [[min_bounds[d].item(), max_bounds[d].item()] for d in range(domain_dim)]
for corner_coords in itertools.product(*dim_ranges):
corner_coords_tensor = torch.tensor(corner_coords, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
boundary_values[corner_coords] = u_exact_func(corner_coords_tensor).item()

def g0_multilinear_val(x: torch.Tensor) -> torch.Tensor:
num_points = x.shape[0]
xi = (x - min_bounds.to(x.device)) / (max_bounds.to(x.device) - min_bounds.to(x.device))
xi = torch.clamp(xi, 0.0, 1.0)
g0_vals = torch.zeros((num_points, 1), device=x.device)
for corner_label in itertools.product([0, 1], repeat=domain_dim):
current_corner_key_list = []
weight_factors = torch.ones((num_points, 1), device=x.device)
for d in range(domain_dim):
if corner_label[d] == 0:
current_corner_key_list.append(min_bounds[d].item())
weight_factors *= (1 - xi[:, d]).unsqueeze(1)
else:
current_corner_key_list.append(max_bounds[d].item())
weight_factors *= xi[:, d].unsqueeze(1)
corner_key_tuple = tuple(current_corner_key_list)
corner_value = boundary_values[corner_key_tuple]
g0_vals += corner_value * weight_factors
return g0_vals
return g0_multilinear_val

else:
raise ValueError(f"Unknown g0_type: {g0_type}. Choose 'multilinear' or 'hermite_cubic_2nd_deriv'.")

def _psi_tensor(t: torch.Tensor) -> torch.Tensor:
return torch.where(t <= 0, torch.tensor(0.0, dtype=t.dtype, device=t.device), torch.exp(-1.0 / t))

def get_d_func(domain_dim: int, domain_bounds: Union[Tuple[float, float], Tuple[Tuple[float, float], ...]],
d_type: str = "sin_half_period") -> Callable[[torch.Tensor], torch.Tensor]:
domain_bounds_tuple = domain_bounds
if domain_dim == 1 and not isinstance(domain_bounds[0], (tuple, list)):
domain_bounds_tuple = (domain_bounds,)
min_bounds = torch.tensor([b[0] for b in domain_bounds_tuple], dtype=torch.float32)
max_bounds = torch.tensor([b[1] for b in domain_bounds_tuple], dtype=torch.float32)
domain_length = (max_bounds[0] - min_bounds[0]).item() if domain_dim == 1 else None

if d_type == "quadratic_bubble":
def d_func_val(x: torch.Tensor) -> torch.Tensor:
d_vals = torch.ones_like(x[:, 0], dtype=torch.float32, device=x.device)
for i in range(domain_dim):
x_i = x[:, i]
min_val, max_val = domain_bounds_tuple[i]
d_vals *= (x_i - min_val) * (max_val - x_i)
return d_vals.unsqueeze(1)
return d_func_val

elif d_type == "inf_smooth_bump":
def d_inf_smooth_bump_val(x: torch.Tensor) -> torch.Tensor:
product_terms = torch.ones((x.shape[0],), dtype=x.dtype, device=x.device)
for i in range(domain_dim):
x_i = x[:, i]
min_val_i = min_bounds[i]
max_val_i = max_bounds[i]
x_c_i = (min_val_i + max_val_i) / 2.0
R_i = (max_val_i - min_val_i) / 2.0
R_i_squared = R_i**2
arg_for_psi = R_i_squared - (x_i - x_c_i)**2
product_terms *= _psi_tensor(arg_for_psi)
return product_terms.unsqueeze(1)
return d_inf_smooth_bump_val

elif d_type == "abs_dist_complement":
if domain_dim != 1: raise ValueError(f"d_type '{d_type}' is only supported for 1D problems.")
def d_abs_dist_complement_val(x: torch.Tensor) -> torch.Tensor:
x_val = x[:, 0]
x_norm = (x_val - min_bounds[0]) / domain_length
sqrt_term = torch.sqrt(x_norm**2 + (1.0 - x_norm)**2)
return (1.0 - sqrt_term).unsqueeze(1)
return d_abs_dist_complement_val

elif d_type == "ratio_bubble_dist":
if domain_dim != 1: raise ValueError(f"d_type '{d_type}' is only supported for 1D problems.")
def d_ratio_bubble_dist_val(x: torch.Tensor) -> torch.Tensor:
x_val = x[:, 0]
x_norm = (x_val - min_bounds[0]) / domain_length
numerator = x_norm * (1.0 - x_norm)
denominator = torch.sqrt(x_norm**2 + (1.0 - x_norm)**2)
return (numerator / denominator).unsqueeze(1)
return d_ratio_bubble_dist_val

elif d_type == "sin_half_period":
if domain_dim != 1: raise ValueError(f"d_type '{d_type}' is only supported for 1D problems.")
if domain_length is None: raise ValueError("Domain length must be defined for 'sin_half_period' d_type.")
def d_sin_half_period_val(x: torch.Tensor) -> torch.Tensor:
x_val = x[:, 0]
argument = (torch.pi / domain_length) * (x_val - min_bounds[0])
return torch.sin(argument).unsqueeze(1)
return d_sin_half_period_val

else:
raise ValueError(f"Unknown d_type: {d_type}. Choose from 'quadratic_bubble', 'inf_smooth_bump', 'abs_dist_complement', 'ratio_bubble_dist', or 'sin_half_period'.")

# %%
# Define PDE
class PDE:
Expand Down Expand Up @@ -188,8 +331,8 @@ def __init__(self, mesh: Mesh, num_levels: int, dim_inputs, dim_outputs, dim_hid
act: nn.Module = nn.ReLU(), enforce_bc: bool = False,
g0_type: str = "multilinear", d_type: str = "sin_half_period",
use_chebyshev_basis: bool = False,
chebyshev_freq_min: int = 0,
chebyshev_freq_max: int = 0) -> None:
chebyshev_freq_min: np.ndarray = None,
chebyshev_freq_max: np.ndarray = None) -> None:
super().__init__()
self.mesh = mesh
# currently the same model on each level
Expand All @@ -214,15 +357,14 @@ def __init__(self, mesh: Mesh, num_levels: int, dim_inputs, dim_outputs, dim_hid
print(f"BCs will be enforced using g0_type: {g0_type} and d_type: {d_type}")

self.use_chebyshev_basis = use_chebyshev_basis
self.chebyshev_freqs = np.round(np.linspace(chebyshev_freq_min, chebyshev_freq_max, num_levels + 1)).astype(int)
self.models = nn.ModuleList([
Level(dim_inputs=dim_inputs, dim_outputs=dim_outputs, dim_hidden=dim_hidden, act=act,
use_chebyshev_basis=use_chebyshev_basis,
chebyshev_freq_min=self.chebyshev_freqs[i],
chebyshev_freq_max=self.chebyshev_freqs[i+1])
chebyshev_freq_min=chebyshev_freq_min[i],
chebyshev_freq_max=chebyshev_freq_max[i])
for i in range(num_levels)
])

# All levels start as "off"
self.level_status = [LevelStatus.OFF] * num_levels

Expand Down Expand Up @@ -475,16 +617,37 @@ def main(args=None):
torch.manual_seed(0)
# Parse args
args = parse_args(args=args)
# Ensure chebyshev_freq_max is at least chebyshev_freq_min for range to be valid
if args.use_chebyshev_basis and args.chebyshev_freq_max < args.chebyshev_freq_min:
raise ValueError("chebyshev_freq_max must be >= chebyshev_freq_min when using Chebyshev basis.")
print_args(args=args, output_file=f"results_pinn_1d_{ts}/args.txt")
# PDE
pde = PDE(high=args.high_freq, mu=args.mu, r=args.gamma,
problem=args.problem_id)
# Loss function [supervised with analytical solution (-1) or PINN loss (0)]
loss = Loss(loss_type=args.loss_type, bc_weight=args.bc_weight)
print(f"Using loss: {loss.name}")
losses = []
losses.append(Loss(loss_type=-1, bc_weight=args.bc_weight))
losses.append(Loss(loss_type=0, bc_weight=args.bc_weight))
losses.append(Loss(loss_type=1, bc_weight=args.bc_weight))

if args.use_chebyshev_basis:
if len(args.chebyshev_freq_min) == 1:
chebyshev_freq_min = np.ones(args.levels, dtype=int) * args.chebyshev_freq_min
else:
chebyshev_freq_min = np.array(args.chebyshev_freq_min).astype(int)
print(f"Chebyshev frequencies lower bounds = {chebyshev_freq_min}")

if len(args.chebyshev_freq_max) == 1:
chebyshev_freq_max = np.ones(args.levels, dtype=int) * args.chebyshev_freq_max
else:
chebyshev_freq_max = np.array(args.chebyshev_freq_max).astype(int)
print(f"Chebyshev frequencies upper bounds = {chebyshev_freq_max}")
else:
chebyshev_freq_min = np.ones(args.levels, dtype=int) * -1
chebyshev_freq_max = np.ones(args.levels, dtype=int) * -1

if len(args.epochs) == 1:
epochs = np.ones(args.levels, dtype=int) * args.epochs
else:
epochs = np.array(args.epochs).astype(int)

# scheduler gen takes optimizer to return scheduler
scheduler_gen = get_scheduler_generator(args)
# 1-D mesh
Expand All @@ -494,6 +657,7 @@ def main(args=None):
# Input and output dimension: x -> u(x)
dim_inputs = 1
dim_outputs = 1

model = MultiLevelNN(mesh=mesh,
num_levels=args.levels,
dim_inputs=dim_inputs, dim_outputs=dim_outputs,
Expand All @@ -503,8 +667,8 @@ def main(args=None):
g0_type=args.bc_extension,
d_type=args.distance,
use_chebyshev_basis=args.use_chebyshev_basis,
chebyshev_freq_min=args.chebyshev_freq_min,
chebyshev_freq_max=args.chebyshev_freq_max)
chebyshev_freq_min=chebyshev_freq_min,
chebyshev_freq_max=chebyshev_freq_max)
print(model)
model.to(device)
# Plotting
Expand All @@ -530,7 +694,14 @@ def main(args=None):
scale = lev + 1
model.set_scale(level_idx=lev, scale=scale)
# Crank that !@#$ up
train(model=model, mesh=mesh, criterion=loss, iterations=args.epochs,
if args.loss_type < 2:
loss = losses[args.loss_type+1]
else:
if lev == 0: # DRM
loss = losses[2]
else: # PINN
loss = losses[1]
train(model=model, mesh=mesh, criterion=loss, iterations=epochs[lev],
adam_iterations=args.adam_epochs,
learning_rate=args.lr, num_check=args.num_checks, num_plots=num_plots,
sweep_idx=i, level_idx=lev, frame_dir=frame_dir, scheduler_gen=scheduler_gen)
Expand Down
22 changes: 9 additions & 13 deletions pinn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def parse_args(args=None):
help="Number of evaluation checkpoints during training.")
parser.add_argument('--num_plots', type=int, default=10,
help="Number of plotting points during training.")
parser.add_argument('--epochs', type=int, default=10000,
parser.add_argument('--epochs', type=int, nargs='+', default=10000,
help="Number of training epochs per sweep.")
parser.add_argument('--adam_epochs', type=int, default=None,
help="Number of training epochs using Adam per sweep. Defaults to --epochs if not set.")
Expand All @@ -78,29 +78,31 @@ def parse_args(args=None):
help="Learning rate for the optimizer.")
parser.add_argument('--levels', type=int, default=4,
help="Number of levels in multilevel training.")
parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0, 1],
help="Loss type: -1 for supervised (true solution), 0 for PINN loss.")
parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0, 1, 2],
help="Loss type: -1 for supervised (true solution), 0 for PINN loss, 1 for DRM loss, 2 for mixed.")
parser.add_argument('--activation', type=str, default='tanh',
choices=['tanh', 'silu', 'relu', 'gelu', 'softmax'],
help="Activation function to use.")
parser.add_argument('--enforce_bc', action='store_true',
help="If set, enforce the BC in solution.")
parser.add_argument('--bc_extension', type=str, default='hermite_cubic_2nd_deriv',
choices=['multilinear', 'hermite_cubic_2nd_deriv'],
help='Boundary value extension function.')
parser.add_argument('--distance', type=str, default='sin_half_period',
choices=['quadratic_bubble', 'inf_smooth_bump', 'abs_dist_complement', 'ratio_bubble_dist', 'sin_half_period'],
help='Distance function.')
parser.add_argument('--chebyshev_freq_min', type=int, default=-1,
parser.add_argument('--use_chebyshev_basis', action='store_true',
help="If set, use Chebyshev features.")
parser.add_argument('--chebyshev_freq_min', type=int, nargs='+',
help='Minimum frequency for Chebyshev polynomials.')
parser.add_argument('--chebyshev_freq_max', type=int, default=-1,
parser.add_argument('--chebyshev_freq_max', type=int, nargs='+',
help='Maximum frequency for Chebyshev polynomials.')
parser.add_argument('--plot', action='store_true',
help="If set, generate plots during or after training.")
parser.add_argument('--no-clear', action='store_false', dest='clear',
help="If set, do not remove plot files generated before.")
parser.add_argument('--problem_id', type=int, default=1, choices=[1, 2],
help="PDE problem to solve: 1 or 2.")
parser.add_argument('--enforce_bc', action='store_true',
help="If set, enforce the BC in solution.")
parser.add_argument('--bc_weight', type=float, default=1.0,
help="Weight for the loss of BC.")
parser.add_argument("--scheduler", type=str, default="StepLR",
Expand All @@ -118,12 +120,6 @@ def parse_args(args=None):
if args.adam_epochs is None:
args.adam_epochs = args.epochs

if (1 <= args.chebyshev_freq_min <= args.chebyshev_freq_max):
print(f"Chebyshev basis of frequency {args.chebyshev_freq_min} to {args.chebyshev_freq_max} are used")
args.use_chebyshev_basis = True
else:
args.use_chebyshev_basis = False

return args


Expand Down
Loading