-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sine Gaussian Waveform #23
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
42897ae
add sine gaussian waveform
ravioli1369 2121c0a
refactored to match with the rest of the codebase
ravioli1369 793b5f0
moved comparision notebook to correct folder
ravioli1369 fada52b
add mean and mean squared error comparisions
ravioli1369 ec27bd7
remove unnecessary imports
ravioli1369 975731e
rename to gen_SineGaussian_hphc
ravioli1369 2ec3a9c
remove 64bit floats
ravioli1369 9492d13
update notebook accordingly
ravioli1369 563fe7a
add comparision between 64bit and 32bit results
ravioli1369 0797388
remove warning
ravioli1369 a403216
remove unnecessary line
ravioli1369 08ac42d
fix typo
ravioli1369 fb7a1d2
refactor: Update gen_SineGaussian_hphc to use time_grid to make it ji…
ravioli1369 be431f3
add scaling
ravioli1369 249f0e8
refactor to make input args as t, theta
ravioli1369 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import jax.numpy as jnp | ||
from jax.lax import complex | ||
|
||
from ..constants import PI | ||
from ..typing import Array | ||
|
||
|
||
def semi_major_minor_from_e(e: Array) -> tuple[Array, Array]: | ||
""" | ||
Calculate the semi-major and semi-minor axes of an ellipse given the | ||
eccentricity of the ellipse. | ||
|
||
Args: | ||
e: Eccentricity of the ellipse | ||
Returns: | ||
Semi-major (a) and semi-minor (b) axes of the ellipse | ||
""" | ||
a = 1.0 / jnp.sqrt(2.0 - (e * e)) | ||
b = a * jnp.sqrt(1.0 - (e * e)) | ||
return a, b | ||
|
||
|
||
def gen_SineGaussian_hphc( | ||
t: Array, | ||
theta: Array, | ||
) -> tuple[Array, Array]: | ||
""" | ||
Generate lalinference implementation of a sine-Gaussian waveform in Jax. | ||
See | ||
git.ligo.org/lscsoft/lalsuite/-/blob/master/lalinference/lib/LALInferenceBurstRoutines.c#L381 | ||
for details on parameter definitions. | ||
|
||
Args: | ||
-------- | ||
t: | ||
Time grid (centered at t=0) on which to evaluate the waveform. | ||
Create it using `jax.numpy.arange(-duration/2, duration/2, 1/fs)` | ||
where `duration` is the duration of the waveform (in seconds) and `fs` | ||
is the sample rate at which the waveform is evaluated. | ||
theta: | ||
Array of waveform parameters [quality, frequency, hrss, phase, eccentricity] | ||
quality: | ||
Quality factor of the sine-Gaussian waveform | ||
frequency: | ||
Central frequency of the sine-Gaussian waveform | ||
hrss: | ||
Hrss of the sine-Gaussian waveform | ||
phase: | ||
Phase of the sine-Gaussian waveform | ||
eccentricity: | ||
Eccentricity of the sine-Gaussian waveform. | ||
Controls the relative amplitudes of the | ||
hplus and hcross polarizations. | ||
|
||
Returns: | ||
-------- | ||
Jax Arrays of plus and cross polarizations (in that order) | ||
""" | ||
|
||
quality, frequency, hrss, phase, eccentricity = theta | ||
|
||
# add dimension for calculating waveforms in batch | ||
quality = quality.reshape(-1, 1) | ||
frequency = frequency.reshape(-1, 1) | ||
hrss = hrss.reshape(-1, 1) | ||
phase = phase.reshape(-1, 1) | ||
eccentricity = eccentricity.reshape(-1, 1) | ||
|
||
pi = jnp.array([PI]) | ||
|
||
# calculate relative hplus / hcross amplitudes based on eccentricity | ||
# as well as normalization factors | ||
a, b = semi_major_minor_from_e(eccentricity) | ||
norm_prefactor = quality / (4.0 * frequency * jnp.sqrt(pi)) | ||
cosine_norm = norm_prefactor * (1.0 + jnp.exp(-quality * quality)) | ||
sine_norm = norm_prefactor * (1.0 - jnp.exp(-quality * quality)) | ||
|
||
cos_phase, sin_phase = jnp.cos(phase), jnp.sin(phase) | ||
|
||
h0_plus = ( | ||
hrss * a / jnp.sqrt(cosine_norm * (cos_phase**2) + sine_norm * (sin_phase**2)) | ||
) | ||
h0_cross = ( | ||
hrss * b / jnp.sqrt(cosine_norm * (sin_phase**2) + sine_norm * (cos_phase**2)) | ||
) | ||
|
||
# cast the phase to a complex number | ||
phi = 2 * pi * frequency * t | ||
complex_phase = complex(jnp.zeros_like(phi), (phi - phase)) | ||
|
||
# calculate the waveform and apply a tukey | ||
# window to taper the waveform | ||
fac = jnp.exp(phi**2 / (-2.0 * quality**2) + complex_phase) | ||
|
||
cross = fac.imag * h0_cross | ||
plus = fac.real * h0_plus | ||
|
||
return plus, cross |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure whether this is needed for jax?