Skip to content

Conversation

stevengogogo
Copy link
Collaborator

@stevengogogo stevengogogo commented Jun 18, 2025

Vecotrize the calculation of sine terms, and avoid speed bottleneck caused by for-loop in PDE problem 1. This is especially useful with many sine terms.

 

Before After
Model_Outputs_00000000_00000000_00000006 image

Verification and performance

A script to check the performance and accuracy:

168X speed up with 1000 high frequency terms

This script uses @timeit on jupyter

import torch
device =  "cpu"

class PDE:
    def __init__(self, high=None, mu=70, r=0, problem=1, device=device):
        # omega = [high]
        omega = list(range(1, high + 1, 2))
        # omega += [i + 50 for i in omega]
        # omega = list(range(2, high + 1, 2))
        # omega = [2**i for i in range(high.bit_length()) if 2**i <= high]
        coeff = [1] * len(omega)

        self.w = torch.asarray(omega, device=device)
        self.c = torch.asarray(coeff, device=device)
        self.mu = mu
        self.r = r
        if problem == 1:
            self.f = self.f_1
            self.u_ex = self.u_ex_1
        else:
            self.f = self.f_2
            self.u_ex = self.u_ex_2

    # Source term
    @staticmethod
    def sin_series(w:float, c:float, x:float, r:float)->float:
        """
        return c  (4 w^2  \pi^2 + r)  \sin(2 w  \pi  x)$
        """
        pi_w = 2*w*torch.pi
        sin_term = c * (pi_w ** 2 + r) * torch.sin(pi_w * x)
        return sin_term
    
    def f_1(self, x):
        """
        x: shape (nx, 1)
        """
        y = torch.zeros_like(x)
        #for w, c in zip(self.w, self.c):
        #    pi_w = 2 * torch.pi * w
        #    y += c * (pi_w ** 2 + self.r) * torch.sin(pi_w * x)
        sin_terms = torch.func.vmap(self.sin_series, in_dims=(0,0,None,None))(self.w, self.c, x, self.r)
        y = torch.sum(sin_terms, dim=0)
        return y

        # Source term
    def f_1_org(self, x):
        y = torch.zeros_like(x)
        for w, c in zip(self.w, self.c):
            pi_w = 2 * torch.pi * w
            y += c * (pi_w ** 2 + self.r) * torch.sin(pi_w * x)
        return y

    def f_2(self, x):
        z = x ** 2
        a = self.r + 4 * z * (self.mu ** 2 - 1) + 2
        b = self.mu * z
        c = 8 * b - 2 * self.mu
        return torch.exp(-z) * (a * torch.sin(b) + c * torch.cos(b))

    # Analytical solution
    def u_ex_1(self, x):
        y = torch.zeros_like(x)
        for w, c in zip(self.w, self.c):
            y += c * torch.sin(2 * w * torch.pi * x)
        return y

    def u_ex_2(self, x):
        return torch.exp(-x**2) * torch.sin(self.mu * x ** 2)


xs = torch.linspace(0, 1, 100, device=device).view(-1, 1)
pde = PDE(high=1000, r=3, device=device)
print("Measure New Implementation")
%timeit y = pde.f_1(xs)
print("Measure original implementation")
%timeit y2 = pde.f_1_org(xs)
print("Are two method produce same answer?")
torch.allclose(y, y2)
Measure New Implementation
69.7 µs ± 337 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Measure original implementation
11.6 ms ± 39 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Are two method produce same answer?
True

@stevengogogo stevengogogo added the enhancement New feature or request label Jun 18, 2025
@stevengogogo stevengogogo requested a review from liruipeng June 18, 2025 23:58
@stevengogogo stevengogogo marked this pull request as ready for review June 18, 2025 23:58
@stevengogogo stevengogogo added the RFR ready for review label Oct 5, 2025
@siuwuncheung siuwuncheung self-requested a review October 10, 2025 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request RFR ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant