Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[report]
show_missing = true

[run]
omit = gpucbc/_version.py,gpucbc/test/*
3 changes: 2 additions & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ jobs:
run: |
conda update pip setuptools
conda install numpy astropy bilby python-lal python-lalsimulation
conda install pytest-cov
conda install jax
conda install pytest-cov parameterized
pip install .
- name: Test with pytest
run: |
Expand Down
34 changes: 14 additions & 20 deletions gpucbc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
from . import likelihood, waveforms
from . import backend, likelihood, pn, waveforms


def disable_cupy():
import numpy
from scipy.special import i0e
likelihood.xp = numpy
likelihood.i0e = i0e
waveforms.xp = numpy
from ._version import __version__


def enable_cupy():
try:
import cupy
from .cupy_utils import i0e
likelihood.xp = cupy
likelihood.i0e = i0e
waveforms.xp = cupy
except ImportError:
print("Cannot import cupy")
disable_cupy()


enable_cupy()
def set_backend(numpy):

scipy = dict(
numpy="scipy",
cupy="cupyx.scipy",
jax="jax.scipy",
).get(numpy, None)
numpy = dict(jax="jax.numpy").get(numpy, numpy)
BACKEND = backend.Backend(numpy=numpy, scipy=scipy)
backend.BACKEND = BACKEND
likelihood.B = BACKEND
pn.B = BACKEND
waveforms.B = BACKEND
29 changes: 29 additions & 0 deletions gpucbc/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from importlib import import_module


class Backend:

def __init__(self, numpy, scipy=None):
self.module = numpy
try:
self.np = import_module(numpy)
if scipy is not None:
self.special = import_module(f"{scipy}.special")
else:
self.special = None
except ImportError as e:
raise ImportError(f"Cannot initialize backend for {numpy} {scipy}.\n{e}")

def to_numpy(self, array):
if self.module == "numpy":
return array
elif self.module == "jax.numpy":
import numpy as np
return np.asarray(array)
elif self.module == "cupy":
return self.np.asnumpy(array)
else:
return array


BACKEND = Backend("numpy", "scipy")
72 changes: 28 additions & 44 deletions gpucbc/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import numpy as np
from bilby.core.likelihood import Likelihood

try:
import cupy as xp
from cupyx.special import i0e, logsumexp
except ImportError:
xp = np
from scipy.special import i0e, logsumexp
from .backend import BACKEND as B


class CUPYGravitationalWaveTransient(Likelihood):
Expand Down Expand Up @@ -61,25 +56,17 @@ def __init__(

def _data_to_gpu(self):
for ifo in self.interferometers:
self.psds[ifo.name] = xp.asarray(
self.psds[ifo.name] = B.np.asarray(
ifo.power_spectral_density_array[ifo.frequency_mask]
)
self.strain[ifo.name] = xp.asarray(
self.strain[ifo.name] = B.np.asarray(
ifo.frequency_domain_strain[ifo.frequency_mask]
)
self.frequency_array[ifo.name] = xp.asarray(
self.frequency_array[ifo.name] = B.np.asarray(
ifo.frequency_array[ifo.frequency_mask]
)
self.duration = ifo.strain_data.duration

def __repr__(self):
return (
self.__class__.__name__
+ "(interferometers={},\n\twaveform_generator={})".format(
self.interferometers, self.waveform_generator
)
)

def noise_log_likelihood(self):
"""Calculates the real part of noise log-likelihood

Expand All @@ -95,7 +82,7 @@ def noise_log_likelihood(self):
log_l -= (
2.0
/ self.duration
* xp.sum(xp.abs(self.strain[name]) ** 2 / self.psds[name])
* (B.np.abs(self.strain[name]) ** 2 / self.psds[name]).sum()
)
self._noise_log_l = float(log_l)
return self._noise_log_l
Expand Down Expand Up @@ -134,29 +121,26 @@ def log_likelihood_ratio(self):
d_inner_h=d_inner_h, h_inner_h=h_inner_h
)
else:
log_l = - h_inner_h / 2 + xp.real(d_inner_h)
log_l = - h_inner_h / 2 + d_inner_h
return float(log_l.real)

def calculate_snrs(self, interferometer, waveform_polarizations):
name = interferometer.name
signal_ifo = xp.sum(
xp.vstack(
[
waveform_polarizations[mode]
* float(
interferometer.antenna_response(
self.parameters["ra"],
self.parameters["dec"],
self.parameters["geocent_time"],
self.parameters["psi"],
mode,
)
signal_ifo = B.np.vstack(
[
waveform_polarizations[mode]
* float(
interferometer.antenna_response(
self.parameters["ra"],
self.parameters["dec"],
self.parameters["geocent_time"],
self.parameters["psi"],
mode,
)
for mode in waveform_polarizations
]
),
axis=0,
)[interferometer.frequency_mask]
)
for mode in waveform_polarizations
]
).sum(axis=0)[interferometer.frequency_mask]

time_delay = (
self.parameters["geocent_time"]
Expand All @@ -168,10 +152,10 @@ def calculate_snrs(self, interferometer, waveform_polarizations):
)
)

signal_ifo *= xp.exp(-2j * np.pi * time_delay * self.frequency_array[name])
signal_ifo *= B.np.exp(-2j * np.pi * time_delay * self.frequency_array[name])

d_inner_h = xp.sum(xp.conj(signal_ifo) * self.strain[name] / self.psds[name])
h_inner_h = xp.sum(xp.abs(signal_ifo) ** 2 / self.psds[name])
d_inner_h = (signal_ifo.conj() * self.strain[name] / self.psds[name]).sum()
h_inner_h = (B.np.abs(signal_ifo) ** 2 / self.psds[name]).sum()
d_inner_h *= 4 / self.duration
h_inner_h *= 4 / self.duration
return d_inner_h, h_inner_h
Expand All @@ -192,13 +176,13 @@ def distance_marglinalized_likelihood(self, d_inner_h, h_inner_h):
d_inner_h=d_inner_h_array, h_inner_h=h_inner_h_array
)
else:
log_l_array = - h_inner_h_array / 2 + xp.real(d_inner_h_array)
log_l_array = - h_inner_h_array / 2 + d_inner_h_array.real
log_l = logsumexp(log_l_array, b=self.distance_prior_array)
return log_l

def phase_marginalized_likelihood(self, d_inner_h, h_inner_h):
d_inner_h = xp.abs(d_inner_h)
d_inner_h = xp.log(i0e(d_inner_h)) + d_inner_h
d_inner_h = B.np.abs(d_inner_h)
d_inner_h = B.np.log(B.special.i0e(d_inner_h)) + d_inner_h
log_l = - h_inner_h / 2 + d_inner_h
return log_l

Expand All @@ -208,10 +192,10 @@ def _setup_distance_marginalization(self):
self.priors["luminosity_distance"].maximum,
10000,
)
self.distance_prior_array = xp.asarray(
self.distance_prior_array = B.np.asarray(
self.priors["luminosity_distance"].prob(self.distance_array)
) * (self.distance_array[1] - self.distance_array[0])
self.distance_array = xp.asarray(self.distance_array)
self.distance_array = B.np.asarray(self.distance_array)

def generate_posterior_sample_from_marginalized_likelihood(self):
return self.parameters.copy()
4 changes: 3 additions & 1 deletion gpucbc/pn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from .backend import BACKEND as B

PI = np.pi


Expand Down Expand Up @@ -135,7 +137,7 @@ def taylor_f2_phase_6(args):
+ args.eta * (-15737765635 / 3048192 + 2255 / 12 * PI ** 2)
+ args.eta ** 2 * 76055 / 1728
- args.eta ** 3 * 127825 / 1296
+ taylor_f2_phase_6l(args) * np.log(4)
+ taylor_f2_phase_6l(args) * B.np.log(4)
)
phase += (32675 / 112 + 5575 / 18 * args.eta) * args.eta * args.chi_1 * args.chi_2
for m_on_m, chi, qm_def in zip(
Expand Down
Empty file added gpucbc/test/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions gpucbc/test/test_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest

import bilby
import gpucbc
import jax.numpy as jnp
from jax.scipy import special as jsp


class TestBackend(unittest.TestCase):

def test_setting_jax(self):
gpucbc.set_backend("jax")
self.assertEqual(gpucbc.pn.B.np, jnp)
self.assertEqual(gpucbc.pn.B.special, jsp)

def test_unknown_backend_raises_error(self):
with self.assertRaises(ImportError):
gpucbc.set_backend("unknown")

def test_no_scipy_backend(self):
gpucbc.set_backend("bilby")
self.assertEqual(gpucbc.pn.B.np, bilby)
self.assertEqual(gpucbc.pn.B.special, None)
38 changes: 19 additions & 19 deletions test/test_waveform.py → gpucbc/test/test_waveform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python3

import unittest
import pytest
from parameterized import parameterized

import numpy as np
import pandas as pd
Expand All @@ -15,6 +17,8 @@
from bilby.gw.utils import noise_weighted_inner_product
from bilby.gw.waveform_generator import WaveformGenerator

import gpucbc
from gpucbc import set_backend
from gpucbc.waveforms import TF2WFG, TF2


Expand Down Expand Up @@ -51,7 +55,10 @@ def setUp(self) -> None:
parameter_conversion=convert_to_lal_binary_neutron_star_parameters,
)

def test_native_phasing(self):
@parameterized.expand(["numpy", "jax", "cupy"])
def test_native_phasing(self, backend):
pytest.importorskip(backend)
set_backend(backend)
priors = PriorDict()
priors["mass_1"] = Uniform(1, 100)
priors["mass_2"] = Uniform(1, 100)
Expand All @@ -66,21 +73,21 @@ def test_native_phasing(self):
priors["luminosity_distance"] = Uniform(10, 200)

wf = TF2(**priors.sample())
TF2.pn_tidal_order = 15
lal_phasing = wf._lal_phasing_coefficients()
my_phasing = wf.phasing_coefficients()
self.assertLess(max(abs(lal_phasing.v - my_phasing.v)), 1e-8)
self.assertLess(max(abs(lal_phasing.vlogv - my_phasing.vlogv)), 1e-8)
self.assertLess(max(abs(lal_phasing.vlogvsq - my_phasing.vlogvsq)), 1e-8)

def test_absolute_overlap(self):
self.assertLess(max(abs(lal_phasing.v - my_phasing.v)), 1e-5)
self.assertLess(max(abs(lal_phasing.vlogv - my_phasing.vlogv)), 1e-5)
self.assertLess(max(abs(lal_phasing.vlogvsq - my_phasing.vlogvsq)), 1e-5)

@parameterized.expand(["numpy", "jax", "cupy"])
def test_absolute_overlap(self, backend):
pytest.importorskip(backend)
set_backend(backend)
np.random.seed(42)
priors = BNSPriorDict(aligned_spin=True)
priors["mass_1"] = Uniform(1, 100)
priors["mass_2"] = Uniform(1, 100)
# priors["total_mass"] = Uniform(2, 100)
# priors["mass_ratio"] = Uniform(name="mass_ratio", minimum=0.5, maximum=1)
# priors["mass_1"] = Constraint(name="mass_1", minimum=1, maximum=50)
# priors["mass_2"] = Constraint(name="mass_2", minimum=1, maximum=50)
del priors["mass_ratio"], priors["chirp_mass"]
priors["luminosity_distance"] = UniformSourceFrame(
name="luminosity_distance", minimum=1e2, maximum=5e3
Expand All @@ -104,8 +111,7 @@ def test_absolute_overlap(self):
)
priors["lambda_1"] = Uniform(name="lambda_1", minimum=0, maximum=5000)
priors["lambda_2"] = Uniform(name="lambda_2", minimum=0, maximum=5000)
priors["geocent_time"] = 0
# priors["geocent_time"] = Uniform(-10, 10)
priors["geocent_time"] = Uniform(-10, 10)

n_samples = 100
all_parameters = pd.DataFrame(priors.sample(n_samples))
Expand Down Expand Up @@ -144,6 +150,7 @@ def test_absolute_overlap(self):
)
bilby_pols = dict(plus=lal_strain[0].data.data, cross=lal_strain[1].data.data)
gpu_pols = self.gpu_wfg.frequency_domain_strain(parameters)
gpu_pols = {key: gpucbc.backend.BACKEND.to_numpy(value) for key, value in gpu_pols.items()}

bilby_strain = self.ifo.get_detector_response(
waveform_polarizations=bilby_pols, parameters=parameters
Expand All @@ -164,12 +171,5 @@ def test_absolute_overlap(self):
/ self.ifo.optimal_snr_squared(signal=bilby_strain) ** 0.5
/ self.ifo.optimal_snr_squared(signal=gpu_strain) ** 0.5
)
# print(self.ifo.optimal_snr_squared(signal=bilby_strain) ** 0.5, self.ifo.optimal_snr_squared(signal=gpu_strain) ** 0.5)
# print(np.max(abs(
# np.fft.fft(bilby_strain[:-1] * gpu_strain.conj().real[:-1]) / self.ifo.power_spectral_density_array[:-1]
# / self.ifo.optimal_snr_squared(signal=bilby_strain) ** 0.5
# / self.ifo.optimal_snr_squared(signal=gpu_strain) ** 0.5
# )))
overlaps.append(overlap)
print(overlaps)
self.assertTrue(min(np.abs(overlaps)) > 0.99)
Loading