Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sine Gaussian Waveform #23

Merged
merged 15 commits into from
Jun 5, 2024
Merged

Conversation

ravioli1369
Copy link
Contributor

This PR adds the SineGaussian waveform in ripple, along with a detailed python notebook showing the mismatch between the LALInference and Jax implementations.

@mcoughlin
Copy link

@tedwards2412 I hope you don't mind us hopping on board. We envision this being helpful for PE follow-up to Burst searches. @ravioli1369 is a student working with me, @ThibeauWouters, and others.

times = jnp.arange(num, dtype=jnp.float64) / sample_rate
times -= duration / 2.0

# add dimension for calculating waveforms in batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether this is needed for jax?

@ThibeauWouters
Copy link
Collaborator

Hi all,

Thanks a lot for the contributions, and great to see the extensive checks of the code! I quickly went over this and will jot down some comments from looking at the main py file, @tedwards2412 can indicate whether he agrees or not:

  • Remove the jax_enable_x64: it has been decided before to allow users to use float16 or float32, see Capability of using float16/float32 in waveforms #10
  • Please remove the flax import if you do not plan to use it
  • Rename gen_SineGaussian to gen_SineGaussian_hphc to agree with other waveform files, and also make sure to return in that order (first plus, then cross)
  • I am not sure whether you have to reshape the input as done in L64 and following, this should be handled well by the jax.vmap functionalities without these extra lines. It would be great if someone can confirm that, either from more jax knowledge or by simply running this as an experiment.

@ravioli1369
Copy link
Contributor Author

ravioli1369 commented Jun 1, 2024

I have made most of the changes, but when I removed 64-bit precision from my notebook and ran it again, I noticed a large change in the accuracy of the implementation. Following that, I added a comparison between 64 and 32 bit precision waveforms and how they compared to the LALInference implementation (refer code cells just below the markdown heading Vary each parameter independently by fixing others for each 64 and 32 bit case, as well as the Conclusion at the very end).
Is this sort of variation expected?

As for the jax.vmap implementation, initially I was trying to port the code 1:1 from the existing torch implementation of sine gaussains in the ml4gw repo (https://github.com/ML4GW/ml4gw/blob/main/ml4gw/waveforms/sine_gaussian.py). I'll look into using vmaps and hopefully commit that soon enough.

@kazewong kazewong added the enhancement New feature or request label Jun 1, 2024
@kazewong kazewong self-requested a review June 1, 2024 20:42
@kazewong
Copy link
Collaborator

kazewong commented Jun 1, 2024

Regarding float64 vs float32, I think one has to scale the signal respectively to avoid loss of accuracy. I don't see why the sine Gaussian will need float64 accuracy, so it should be okay to refactor it into float32. @tedwards2412 can comment more on this.

A side note going forward is I think we should start incorporating tests in the code base as it grows

@ThibeauWouters ThibeauWouters mentioned this pull request Jun 1, 2024
@tedwards2412
Copy link
Owner

Sorry for the delay, this looks great so far! @mcoughlin no problem at all, I'm just happy that people are starting to find the code useful and want to contribute :)

Couple of comments

  1. I think that to keep function signatures similar across waveforms we should try to always have gen_X_hphc(f, theta, kwargs) where f is the frequency grid and theta are the parameters associated with that particular waveform. Can we change the function to have a similar structure here? I'm not super familiar with the SineGaussian waveform structure though so don't know best how to do this.
  2. You indeed shouldn't need to do any reshaping of the arrays, you should simply use the vmap if you want to evaluate waveforms in parallel and this will add the batch dimension for you.
  3. It's to be expected that dropping to float32 precision will truncate some of the accuracy. What's likely happening is that some quantities are going outside the total range supported by float32. For example in the match, when you multiply the two waveforms in the numerator, it goes outside of the supported range so clips to 0. Maybe you're finding something similar to this? It shouldn't be a problem as long as you can scale everything accordingly. We by default therefore don't want to fix float64 in case the user wants to use float32. We might want to add a warning though so that the user is aware of this?

@ravioli1369
Copy link
Contributor Author

I had a few doubts regarding this:

  1. I'm also not sure how to go about constructing the waveform in the frequency domain. The waveform was meant to be a 1:1 implementation from LALInference, and thus is in the time domain. Converting ours to a frequency implementation would make it quite different from how it is called using LAL. This could lead to confusion for those familiar with LAL?

  2. I'm a bit new to this, so to clarify: I can remove the reshaping lines and instead call the gen_SineGaussian function itself with a vmap if I need to do batched computations?

  3. I'm not sure what is meant by scaling the waveforms, but I have done something to try and reduce the mismatch.

hcross_ripple.append(hcross[i][start:stop]*scale)
hplus_ripple.append(hplus[i][start:stop]*scale)
hcross_lal.append(hcross_*scale)
hplus_lal.append(hplus_*scale)

I pushed the results back to the notebook towards the end (https://github.com/ravioli1369/ripple/blob/sine-gaussian/notebooks/check_SineGaussian.ipynb). Is this correct, or should I have done something else?

@ThibeauWouters
Copy link
Collaborator

@ravioli1369 @tedwards2412 Perhaps it might be good to start dividing up the ripple source code into frequency domain and time domain waveforms? I believe that time domain waveforms will get more supported in the future for other use cases as well. Thomas can indicate whether he agrees with this.

@tedwards2412
Copy link
Owner

  1. Ahh sorry @ravioli1369, I didn't realize it was a time domain waveform. In this case, I still think the function signature should be gen_X_hphc(t, theta, kwargs) where t is the frequency grid. Overall the goal of ripple is not meant to be a 1:1 recreation of LAL in jax, we try to follow sensible design choices that are easy to use and provide some utility. The reason we like to package up the parameters into a single theta is because typically you'd like to do autodiff with respect to all these parameters and this provides a clean way to do that.
  2. Yeah you should just remove the reshaping lines and instead do vmap(func) which will return a new function which can be given a grid of parameters of shape (n_batch, n_theta). Here is an example of how this works for a different waveform:
    func = vmap(waveform)
  3. What you did looks sensible to me. All I mean by scaling is you multiply the waveform by an arbitrary constant to make sure all the values used in any computations stay within the range allowed by float32 and then scale back at the end if necessary. For example in the match you have a quantity like h*h/PSD. If you just multiply the PSD and individual h's by a factor of 1e20 before calculating the match, the ratio is unchanged by all the numbers remain with the float32 range. Does this make more sense?

@ThibeauWouters I agree that splitting the waveforms makes sense and will probably make things easier to track as things grow.

@ravioli1369
Copy link
Contributor Author

  1. Understood, I've made the necessary change.

    t is the frequency grid

    I assume this is supposed to be time grid?

  2. I had done a speed comparison between regular jax, jit, and vmap here (https://github.com/ravioli1369/sine-gaussian/blob/main/speed-comparision.ipynb) and found that vmap is ~2 times slower than running the function without it. I'm not fully sure as to why that is but I removed the reshape lines and tried this test again and found similar results:
    Regular Jax
    10.9 ms ± 163 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
    Jax with JIT
    1.49 ms ± 47.3 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)
    Using vmap
    19.8 ms ± 766 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
    Using vmap with JIT
    1.49 ms ± 47.4 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)
    A thing to note is that the performance of vmap+jit and jitting the regular function are identical. Is it worth removing the reshaping in favor of vmaps in this case?

  3. I understand how during the calculation of mismatch the float32 could go out of range, but I don't think that's what's happening here, and the difference would just get scaled back to the original values if I scale the waveforms again.

@tedwards2412 tedwards2412 merged commit d18afa3 into tedwards2412:main Jun 5, 2024
@tedwards2412
Copy link
Owner

tedwards2412 commented Jun 5, 2024

  1. Ok great, thanks this looks much cleaner now. I've merged the current version. And yes, that was supposed to say time grid :)

  2. For the timing, I'm not exactly sure what is going on here but a few things could be going wrong. Firstly, especially when running on a GPU, you need to make sure to call block_until_ready() to ensure that the computation is actually complete. See here for more details: https://jax.readthedocs.io/en/latest/faq.html. Overall though, vmap should be the default choice for vectorizing in Jax and this shouldn't be done manually. It's much cleaner to just use vmap and I suspect that it will also be faster once the timings are sorted out.

  3. Ok let me have a look in more detail when I have time next week and I can see if it's clear what is going on. There will obviously be some loss in precision going to float32 but I would have guessed it's not an issue at current detector sensitivities. In my tests on the match I definitely found a reduction of a few orders of magnitude in accuracy but it was still well below detectable levels.

@ravioli1369
Copy link
Contributor Author

I ran the timing benchmarks again with the block_until_ready() command and found the results to be similar to before. The vmapped version is ~2 times slower. This again has no impact on the jitted versions of both (vmapped and non vmapped) functions.

Regular Jax
13.5 ms ± 204 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)

Jax with JIT
1.77 ms ± 24.9 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)

Using vmap
25.2 ms ± 273 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)

Using vmap with JIT
1.75 ms ± 41.3 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)

The notebook (https://github.com/ravioli1369/sine-gaussian/blob/main/speed-comparision.ipynb) has more details on how I ran the benchmark. I even changed the output of the sine gaussian function to give a single array so that I don't have to evaluate it in a list comprehension, but the results of that were also similar.

@tedwards2412
Copy link
Owner

I tried myself with the merged version of the SineGaussian waveform in ripple and find that vmap is instead quicker. Here is the code:

import jax
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ripplegw.waveforms import SineGaussian
import time
from jax import vmap

duration = 10.0
sampling_frequency = 4096
dt = 1 / sampling_frequency
times = jnp.arange(-duration / 2, duration / 2, dt)
print(times.shape)

n_waveforms = 1000

quality = jnp.linspace(3, 100, n_waveforms)
frequency = jnp.logspace(1, 3, n_waveforms)
hrss = jnp.logspace(-23, -6, n_waveforms)
phase = jnp.linspace(0, 2 * np.pi, n_waveforms)
eccentricity = jnp.linspace(0, 0.99, n_waveforms)

theta_ripple = np.array([quality, frequency, hrss, phase, eccentricity]).T

print(theta_ripple.shape)


@jax.jit
def waveform(theta):
    return SineGaussian.gen_SineGaussian_hphc(times, theta)


print("JIT compiling")
waveform(theta_ripple[0])[0].block_until_ready()
print("Finished JIT compiling")

start = time.time()
for t in theta_ripple:
    waveform(t)[0].block_until_ready()
end = time.time()

print("Ripple waveform call takes: %.6f ms" % ((end - start) * 1000 / n_waveforms))

func = vmap(waveform)
func(theta_ripple)[0].block_until_ready()

start = time.time()
hp_batch = func(theta_ripple)[0].block_until_ready()
end = time.time()

print(
    "Vmapped ripple waveform call takes: %.6f ms" % ((end - start) * 1000 / n_waveforms)
)

print(hp_batch.shape)

My output gives:

> (1000, 5)
> JIT compiling
> Finished JIT compiling
> Ripple waveform call takes: 0.189530 ms
> Vmapped ripple waveform call takes: 0.071345 ms
> (1000, 40960)

Note I just did this directly on my laptop on the CPU so the effect will be further magnified when using a GPU. Also, I had to remove the reshapes which now incorrectly add dimensions once you use vmap. The resulting array should be shape (batch_dimension, time_grid).

@ravioli1369
Copy link
Contributor Author

It does indeed look like vmap is faster than running through the parameters in a for loop. The way I tested it was to send all the parameters into the function and call reshape inside of it; this gave results that were faster than removing the reshape and running vmap. I'm not sure why this is happening. The jitted versions (with and without reshape) give identical results, so I think it should be fine to leave it this way, although it does warrant some investigation to see why vmap is performing worse than reshaping.

@tedwards2412
Copy link
Owner

I think this overall makes sense, once you add the jit and your manual reshaping I think this is basically manually vectorizing the function and so it should perform similarly to vmap + jit. Overall though, it's not good practice in Jax to add this kind of manual reshaping when you can instead use vmap :) vmapping doesn't do the jit for you, so this is required to make it fast!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants