Skip to content

Commit

Permalink
ENH: implement time marginalization into multiband likelihood (#842)
Browse files Browse the repository at this point in the history
- bilby/gw/likelihood/multiband.py: implement time marginalization and resampling of time
- test/gw/likelihood/marginalization_test.py: add marginalization tests for multiband likelihood
  • Loading branch information
SMorisaki authored Nov 1, 2024
1 parent 8b2f6cf commit cf4c6c5
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 16 deletions.
102 changes: 88 additions & 14 deletions bilby/gw/likelihood/multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
recursively_save_dict_contents_to_group
)
from ..prior import CBCPriorDict
from ..utils import ln_i0


class MBGravitationalWaveTransient(GravitationalWaveTransient):
Expand Down Expand Up @@ -56,6 +57,16 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
prior is set to be a delta function at phase=0.
priors: dict, bilby.prior.PriorDict
A dictionary of priors containing at least the geocent_time prior
time_marginalization: bool, optional
If true, marginalize over time in the likelihood.
If using time marginalisation and jitter_time is True a "jitter"
parameter is added to the prior which modifies the position of the
grid of times.
jitter_time: bool, optional
Whether to introduce a `time_jitter` parameter. This avoids either
missing the likelihood peak, or introducing biases in the
reconstructed time posterior due to an insufficient sampling frequency.
Default is True.
distance_marginalization_lookup_table: (dict, str), optional
If a dict, dictionary containing the lookup_table, distance_array, (distance) prior_array, and
reference_distance used to construct the table. If a string the name of a file containing these quantities. The
Expand Down Expand Up @@ -83,13 +94,15 @@ def __init__(
linear_interpolation=True, accuracy_factor=5, time_offset=None, delta_f_end=None,
maximum_banding_frequency=None, minimum_banding_duration=0., weights=None,
distance_marginalization=False, phase_marginalization=False, priors=None,
distance_marginalization_lookup_table=None, reference_frame="sky", time_reference="geocenter"
time_marginalization=False, jitter_time=True, distance_marginalization_lookup_table=None,
reference_frame="sky", time_reference="geocenter"
):
super(MBGravitationalWaveTransient, self).__init__(
interferometers=interferometers, waveform_generator=waveform_generator, priors=priors,
distance_marginalization=distance_marginalization, phase_marginalization=phase_marginalization,
time_marginalization=False, distance_marginalization_lookup_table=distance_marginalization_lookup_table,
jitter_time=False, reference_frame=reference_frame, time_reference=time_reference
time_marginalization=time_marginalization,
distance_marginalization_lookup_table=distance_marginalization_lookup_table,
jitter_time=jitter_time, reference_frame=reference_frame, time_reference=time_reference
)
if weights is None:
self.reference_chirp_mass = reference_chirp_mass
Expand All @@ -108,6 +121,8 @@ def __init__(
with h5py.File(weights, 'r') as f:
weights = recursively_load_dict_contents_from_group(f, '/')
self.setup_multibanding_from_weights(weights)
if self.time_marginalization:
self._setup_time_marginalization_multiband()

@property
def reference_chirp_mass(self):
Expand Down Expand Up @@ -693,7 +708,23 @@ def setup_multibanding_from_weights(self, weights):
else:
setattr(self, key, value)

def calculate_snrs(self, waveform_polarizations, interferometer, return_array=False):
def _setup_time_marginalization_multiband(self):
"""This overwrites attributes set by _setup_time_marginalization of the base likelihood class"""
N = self.Nbs[-1] // 2
self._delta_tc = self.durations[0] / N
self._times = \
self.interferometers.start_time + np.arange(N) * self._delta_tc
self.time_prior_array = \
self.priors['geocent_time'].prob(self._times) * self._delta_tc
# allocate array which is FFTed at each likelihood evaluation
self._full_d_h = np.zeros(N, dtype=complex)
# idxs to convert full frequency points to banded frequency points, used for filling _full_d_h.
self._full_to_multiband = [int(f * self.durations[0]) for f in self.banded_frequency_points]
self._beam_pattern_reference_time = (
self.priors['geocent_time'].minimum + self.priors['geocent_time'].maximum
) / 2

def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True):
"""
Compute the snrs
Expand All @@ -706,36 +737,36 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Fa
return_array: bool
If true, calculate and return internal array objects
(d_inner_h_array and optimal_snr_squared_array), otherwise
these are returned as None. This parameter is ignored for the multiband
model as these arrays are never calculated.
these are returned as None.
Returns
-------
calculated_snrs: _CalculatedSNRs
An object containing the SNR quantities.
"""
if self.time_marginalization:
time_ref = self._beam_pattern_reference_time
else:
time_ref = self.parameters['geocent_time']

strain = np.zeros(len(self.banded_frequency_points), dtype=complex)
for mode in waveform_polarizations:
response = interferometer.antenna_response(
self.parameters['ra'], self.parameters['dec'],
self.parameters['geocent_time'], self.parameters['psi'],
mode
time_ref, self.parameters['psi'], mode
)
strain += waveform_polarizations[mode][self.unique_to_original_frequencies] * response

dt = interferometer.time_delay_from_geocenter(
self.parameters['ra'], self.parameters['dec'],
self.parameters['geocent_time'])
self.parameters['ra'], self.parameters['dec'], time_ref)
dt_geocent = self.parameters['geocent_time'] - interferometer.strain_data.start_time
ifo_time = dt_geocent + dt
strain *= np.exp(-1j * 2. * np.pi * self.banded_frequency_points * ifo_time)

calib_factor = interferometer.calibration_model.get_calibration_factor(
strain *= interferometer.calibration_model.get_calibration_factor(
self.banded_frequency_points, prefix='recalib_{}_'.format(interferometer.name), **self.parameters)

strain *= np.exp(-1j * 2. * np.pi * self.banded_frequency_points * ifo_time)
strain *= calib_factor

d_inner_h = np.conj(np.dot(strain, self.linear_coeffs[interferometer.name]))

if self.linear_interpolation:
Expand Down Expand Up @@ -765,12 +796,55 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Fa

complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5)

if return_array and self.time_marginalization:
self._full_d_h[self._full_to_multiband] *= 0
for b in range(self.number_of_bands):
start_idx, end_idx = self.start_end_idxs[b]
self._full_d_h[self._full_to_multiband[start_idx:end_idx + 1]] += \
strain[start_idx:end_idx + 1] * self.linear_coeffs[interferometer.name][start_idx:end_idx + 1]
d_inner_h_array = np.fft.fft(self._full_d_h)
else:
d_inner_h_array = None

return self._CalculatedSNRs(
d_inner_h=d_inner_h,
optimal_snr_squared=optimal_snr_squared.real,
complex_matched_filter_snr=complex_matched_filter_snr,
d_inner_h_array=d_inner_h_array,
)

def _rescale_signal(self, signal, new_distance):
for mode in signal:
signal[mode] *= self._ref_dist / new_distance

def generate_time_sample_from_marginalized_likelihood(self, signal_polarizations=None):
self.parameters.update(self.get_sky_frame_parameters())
if signal_polarizations is None:
signal_polarizations = \
self.waveform_generator.frequency_domain_strain(self.parameters)

snrs = self._CalculatedSNRs()

for interferometer in self.interferometers:
snrs += self.calculate_snrs(
waveform_polarizations=signal_polarizations,
interferometer=interferometer
)
d_inner_h = snrs.d_inner_h_array
h_inner_h = snrs.optimal_snr_squared

if self.distance_marginalization:
time_log_like = self.distance_marginalized_likelihood(
d_inner_h, h_inner_h)
elif self.phase_marginalization:
time_log_like = ln_i0(abs(d_inner_h)) - h_inner_h.real / 2
else:
time_log_like = (d_inner_h.real - h_inner_h.real / 2)

times = self._times
if self.jitter_time:
times = times + self.parameters["time_jitter"]
time_prior_array = self.priors['geocent_time'].prob(times)
time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array
time_post /= np.sum(time_post)
return np.random.choice(times, p=time_post)
23 changes: 21 additions & 2 deletions test/gw/likelihood/marginalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class TestMarginalizations(unittest.TestCase):
The `time_jitter` parameter makes this a weaker dependence during sampling.
"""
_parameters = product(
["regular", "roq", "relbin"],
["regular", "roq", "relbin", "multiband"],
["luminosity_distance", "geocent_time", "phase"],
[True, False],
[True, False],
Expand Down Expand Up @@ -268,6 +268,17 @@ def setUp(self):
)
)

self.multiband_waveform_generator = bilby.gw.WaveformGenerator(
duration=self.duration,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
start_time=1126259640,
waveform_arguments=dict(
reference_frequency=20.0,
waveform_approximant="IMRPhenomPv2",
)
)

def tearDown(self):
del self.duration
del self.sampling_frequency
Expand Down Expand Up @@ -313,6 +324,12 @@ def likelihood_kwargs(self, kind, time_marginalization, phase_marginalization, d
elif kind == "relbin":
kwargs["fiducial_parameters"] = deepcopy(self.parameters)
kwargs["waveform_generator"] = self.relbin_waveform_generator
elif kind == "multiband":
kwargs["waveform_generator"] = self.multiband_waveform_generator
kwargs["reference_chirp_mass"] = (
(self.parameters["mass_1"] * self.parameters["mass_2"])**0.6 /
(self.parameters["mass_1"] + self.parameters["mass_2"])**0.2
)
return kwargs

def get_likelihood(
Expand All @@ -334,6 +351,8 @@ def get_likelihood(
cls_ = bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient
kwargs["epsilon"] = 0.1
self.parameters["fiducial"] = 0
elif kind == "multiband":
cls_ = bilby.gw.likelihood.MBGravitationalWaveTransient
else:
raise ValueError(f"kind {kind} not understood")
like = cls_(**kwargs)
Expand Down Expand Up @@ -414,7 +433,7 @@ def test_time_marginalisation_full_segment(self, kind):
)

@parameterized.expand(
itertools.product(["regular", "roq", "relbin"], *itertools.repeat([True, False], 3)),
itertools.product(["regular", "roq", "relbin", "multiband"], *itertools.repeat([True, False], 3)),
name_func=lambda func, num, param: (
f"{func.__name__}_{num}__{param.args[0]}_" + "_".join([
["D", "P", "T"][ii] for ii, val
Expand Down

0 comments on commit cf4c6c5

Please sign in to comment.