-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added plotting utilities and refactored hyperparams #143
Conversation
@ThibeauWouters The checks seem to have failed due to some missing hyperparameters, would you mind have a look of that? |
@kazewong I have pushed some updates for the failed checks, and have also added some additional postprocessing analysis code to check the Gelman-Rubin R statistic. |
My tests show something is still off, I need to double check things |
It should be fixed now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not include r_hat related functionality in this PR. I am more inclined to have that as a notebook in example, since statistical metric like r_hat or correlation length needs to be used with context of user understanding the meaning of what they are measuring or monitoring.
src/flowMC/utils/postprocessing.py
Outdated
import jax.numpy as jnp | ||
from flowMC.sampler.Sampler import Sampler | ||
|
||
def plot_summary(sampler: Sampler, which: str = "training", **plotkwargs) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of taking the entire Sampler as input, these post-processing function should only interact with serialized output from the sampler to avoid run time complication.
This might need a bit more work in terms of tidying up what kind of output we are serialize from the sampler. I would suggest holding these modifications off or open a separate issue. I will get to it soonish
src/flowMC/sampler/Sampler.py
Outdated
@@ -113,11 +128,16 @@ def __init__( | |||
production["log_prob"] = jnp.empty((self.n_chains, 0)) | |||
production["local_accs"] = jnp.empty((self.n_chains, 0)) | |||
production["global_accs"] = jnp.empty((self.n_chains, 0)) | |||
|
|||
if self.track_gelman_rubin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since R_hat can be compute in post-processing, I don't think we should have it in the main sampler.
src/flowMC/sampler/Sampler.py
Outdated
@@ -264,6 +284,15 @@ def sampling_loop( | |||
global_acceptance[:, 1::self.output_thinning], | |||
axis=1, | |||
) | |||
|
|||
if self.track_gelman_rubin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above. Please remove this from the main sampler
src/flowMC/utils/postprocessing.py
Outdated
plt.ylim(0-eps, 1+eps) | ||
plt.savefig(f"{outdir}{name}_{which}.png", bbox_inches='tight') | ||
|
||
def gelman_rubin(chains: Float[Array, "n_chains n_steps n_dim"], discard_fraction: float = 0.1) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arviz actually support r_hat and many more different statistical metric. I would suggest to use that instead of writing our own function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took away the gelman rubin part of the previous PR. Will merge it if the test does not fail
This PR adds a few basic functionalities to simplify the treatment of the hyperparameters (now moved to a separate file) and by adding a few functionalities that allow to easily plot a few key quantities such as acceptance rates by a simple function call from any flowMC or jim script.