Skip to content

Commit

Permalink
Merge branch 'github-main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
ColmTalbot authored Feb 26, 2024
2 parents 34b2aef + c5556bb commit c521917
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 26 deletions.
66 changes: 66 additions & 0 deletions test/core/result_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,5 +751,71 @@ def test_pp_plot_raises_error_with_wrong_number_of_confidence_intervals(self):
)


class SimpleGaussianLikelihood(bilby.core.likelihood.Likelihood):
def __init__(self, mean=0, sigma=1):
"""
A very simple Gaussian likelihood for testing
"""
from scipy.stats import norm
super().__init__(parameters=dict())
self.mean = mean
self.sigma = sigma
self.dist = norm(loc=mean, scale=sigma)

def log_likelihood(self):
return self.dist.logpdf(self.parameters["mu"])


class TestReweight(unittest.TestCase):

def setUp(self):
self.priors = bilby.core.prior.PriorDict(dict(
mu=bilby.core.prior.TruncatedNormal(0, 1, minimum=-5, maximum=5),
))
self.result = bilby.core.result.Result(
search_parameter_keys=list(self.priors.keys()),
priors=self.priors,
posterior=pd.DataFrame(self.priors.sample(1000)),
log_evidence=-np.log(10),
)

def _run_reweighting(self, sigma):
likelihood_1 = SimpleGaussianLikelihood()
likelihood_2 = SimpleGaussianLikelihood(sigma=sigma)
original_ln_likelihoods = list()
for ii in range(len(self.result.posterior)):
likelihood_1.parameters = self.result.posterior.iloc[ii]
original_ln_likelihoods.append(likelihood_1.log_likelihood())
self.result.posterior["log_prior"] = self.priors.ln_prob(self.result.posterior)
self.result.posterior["log_likelihood"] = original_ln_likelihoods
self.original_ln_likelihoods = original_ln_likelihoods
return bilby.core.result.reweight(
self.result, likelihood_1, likelihood_2, verbose_output=True
)

def test_reweight_same_likelihood_weights_1(self):
"""
When the likelihoods are the same, the weights should be 1.
"""
_, weights, _, _, _, _ = self._run_reweighting(sigma=1)
self.assertLess(min(abs(weights - 1)), 1e-10)

def test_reweight_different_likelihood_weights_correct(self):
"""
Test the known case where the target likelihood is a Gaussian with
sigma=0.5. The weights can be calculated analytically and the evidence
should be close to the original evidence within statistical error.
"""
from scipy.stats import norm
new, weights, _, _, _, _ = self._run_reweighting(sigma=0.5)
expected_weights = (
norm(0, 0.5).pdf(self.result.posterior["mu"])
/ norm(0, 1).pdf(self.result.posterior["mu"])
)
self.assertLess(min(abs(weights - expected_weights)), 1e-10)
self.assertLess(abs(new.log_evidence - self.result.log_evidence), 0.05)
self.assertNotEqual(new.log_evidence, self.result.log_evidence)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions test/gw/likelihood/relative_binning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ def setUp(self):
duration=duration, sampling_frequency=sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=fmin, minimum_frequency=fmin, approximant=approximant)
reference_frequency=fmin, minimum_frequency=fmin, waveform_approximant=approximant)
)
bin_wfg = bilby.gw.waveform_generator.WaveformGenerator(
duration=duration, sampling_frequency=sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning,
waveform_arguments=dict(
reference_frequency=fmin, approximant=approximant, minimum_frequency=fmin)
reference_frequency=fmin, waveform_approximant=approximant, minimum_frequency=fmin)
)
ifos.inject_signal(
parameters=self.test_parameters,
Expand Down
46 changes: 23 additions & 23 deletions test/gw/likelihood_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def setUp(self):
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
waveform_approximant="IMRPhenomPv2",
),
)

Expand All @@ -360,7 +360,7 @@ def setUp(self):
frequency_nodes_quadratic=fnodes_quadratic,
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
waveform_approximant="IMRPhenomPv2",
),
)

Expand Down Expand Up @@ -597,7 +597,7 @@ def test_rescaling(self):
frequency_nodes_quadratic=fnodes_quadratic,
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
waveform_approximant="IMRPhenomPv2",
),
)

Expand Down Expand Up @@ -1240,7 +1240,7 @@ def tearDown(self):
("IMRPhenomHM", False, 4, True, 1e-3)
])
def test_matches_original_likelihood(
self, approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
self, waveform_approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
):
"""
Check if multi-band likelihood values match original likelihood values
Expand All @@ -1249,7 +1249,7 @@ def test_matches_original_likelihood(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
Expand All @@ -1258,7 +1258,7 @@ def test_matches_original_likelihood(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
Expand All @@ -1284,12 +1284,12 @@ def test_large_accuracy_factor(self):
"""
Check if larger accuracy factor increases the accuracy.
"""
approximant = "IMRPhenomD"
waveform_approximant = "IMRPhenomD"
wfg = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
Expand All @@ -1298,7 +1298,7 @@ def test_large_accuracy_factor(self):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
Expand Down Expand Up @@ -1330,7 +1330,7 @@ def test_reference_chirp_mass_from_prior(self):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant="IMRPhenomD"
reference_frequency=self.fmin, waveform_approximant="IMRPhenomD"
)
)
likelihood1 = bilby.gw.likelihood.MBGravitationalWaveTransient(
Expand All @@ -1352,7 +1352,7 @@ def test_no_reference_chirp_mass(self):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant="IMRPhenomD"
reference_frequency=self.fmin, waveform_approximant="IMRPhenomD"
)
)
with self.assertRaises(TypeError):
Expand All @@ -1368,7 +1368,7 @@ def test_cannot_determine_reference_chirp_mass(self):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant="IMRPhenomD"
reference_frequency=self.fmin, waveform_approximant="IMRPhenomD"
)
)
for key in ["chirp_mass", "mass_1", "mass_2"]:
Expand All @@ -1385,12 +1385,12 @@ def test_inout_weights(self, linear_interpolation):
Check if multiband weights can be saved as a file, and a likelihood object constructed from the weights file
produces the same likelihood value.
"""
approximant = "IMRPhenomD"
waveform_approximant = "IMRPhenomD"
wfg = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(
Expand All @@ -1401,7 +1401,7 @@ def test_inout_weights(self, linear_interpolation):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
Expand All @@ -1424,7 +1424,7 @@ def test_inout_weights(self, linear_interpolation):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient(
Expand All @@ -1441,12 +1441,12 @@ def test_from_dict_weights(self, linear_interpolation):
"""
Check if a likelihood object constructed from dictionary-like weights produce the same likelihood value
"""
approximant = "IMRPhenomD"
waveform_approximant = "IMRPhenomD"
wfg = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(
Expand All @@ -1457,7 +1457,7 @@ def test_from_dict_weights(self, linear_interpolation):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
Expand All @@ -1474,7 +1474,7 @@ def test_from_dict_weights(self, linear_interpolation):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
weights = likelihood_mb.weights
Expand All @@ -1492,7 +1492,7 @@ def test_from_dict_weights(self, linear_interpolation):
("IMRPhenomHM", False, 4, False, 5e-3),
])
def test_matches_original_likelihood_low_maximum_frequency(
self, approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
self, waveform_approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
):
"""
Test for maximum frequency < sampling frequency / 2
Expand All @@ -1504,7 +1504,7 @@ def test_matches_original_likelihood_low_maximum_frequency(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
Expand All @@ -1513,7 +1513,7 @@ def test_matches_original_likelihood_low_maximum_frequency(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
Expand Down
2 changes: 1 addition & 1 deletion test/gw/source_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def setUp(self):
frequency_nodes_quadratic=fnodes_quadratic,
reference_frequency=50.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
waveform_approximant="IMRPhenomPv2",
)
self.frequency_array = bilby.core.utils.create_frequency_series(2048, 4)

Expand Down

0 comments on commit c521917

Please sign in to comment.