Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New integrator, and add some metadata to integrators.py #681

Merged
merged 56 commits into from
May 27, 2024

Conversation

reubenharry
Copy link
Contributor

@reubenharry reubenharry commented May 18, 2024

Addresses #679

  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;

@reubenharry reubenharry marked this pull request as ready for review May 19, 2024 17:00
@reubenharry reubenharry mentioned this pull request May 20, 2024
10 tasks
Copy link

codecov bot commented May 25, 2024

Codecov Report

Attention: Patch coverage is 46.00000% with 27 lines in your changes are missing coverage. Please review.

Project coverage is 97.79%. Comparing base (7cf4f9d) to head (06dd04d).
Report is 8 commits behind head on main.

Current head 06dd04d differs from pull request most recent head abe707c

Please upload reports for the commit abe707c to get more accurate results.

Files Patch % Lines
blackjax/mcmc/integrators.py 46.00% 27 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #681      +/-   ##
==========================================
- Coverage   98.87%   97.79%   -1.09%     
==========================================
  Files          59       59              
  Lines        2745     2806      +61     
==========================================
+ Hits         2714     2744      +30     
- Misses         31       62      +31     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@reubenharry
Copy link
Contributor Author

I see that code coverage is complaining about integrator_order, name_integrator and calls_per_integrator_step. These seem useful, but depending on your view, I could remove them

@junpenglao
Copy link
Member

I see that code coverage is complaining about integrator_order, name_integrator and calls_per_integrator_step. These seem useful, but depending on your view, I could remove them

Yeah please remove them, depending on what you want to check there should be better ways to return these properties.

Comment on lines 26 to 29
"velocity_verlet_coefficients",
"mclachlan_coefficients",
"yoshida_coefficients",
"omelyan_coefficients",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove? Are you planning to use it outside of integrators?

Copy link
Contributor Author

@reubenharry reubenharry May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have found it convenient to use them outside integrators. For example, if I want to run scripts that try different integrators, and I want access to their number of gradient calls. I suppose the other option would be to have a dictionary like {"velocity_verlet": {"num_grads": ..., "name": ..., "order": ...} }. Is that what you'd recommend?

As another example, I often want to do benchmarks against different integrators, but want to just iterate over X_coefficients, and then use generate_isokinetic_integrator for MCLMC and generate_euclidean_integrator for HMC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it seems like code duplication to have isokinetic_velocity_verlet, velocity_verlet, etc, when we could just use the coefficients with generate_isokinetic_integrator and generate_euclidean_integrator

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it seems like code duplication to have isokinetic_velocity_verlet, velocity_verlet, etc, when we could just use the coefficients with generate_isokinetic_integrator and generate_euclidean_integrator

The design choice is motivated by how we are envisioning the usage of integrators in the library. Currently the design is to have integrator being a static set.

As another example, I often want to do benchmarks against different integrators, but want to just iterate over X_coefficients, and then use generate_isokinetic_integrator for MCLMC and generate_euclidean_integrator for HMC

You should iterate through the integrator objects instead

For example, if I want to run scripts that try different integrators, and I want access to their number of gradient calls. I suppose the other option would be to have a dictionary like {"velocity_verlet": {"num_grads": ..., "name": ..., "order": ...} }. Is that what you'd recommend?

Yes given that these are static, you should put them in your script as static parameters. I dont yet see those are useful in the library outside of benchmarking.

Copy link
Contributor Author

@reubenharry reubenharry May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that makes sense. I think my one point of disagreement is that if I don't expose the coefficients, nothing in the code "knows" that velocity_verlet and isokinetic_velocity_verlet are related. So I will have to have a dictionary of {"euclidean": velocity_verlet, "isokinetic": isokinetic_velocity_verlet, ...} when I want to compare each integrator on hmc vs mclmc, which I'm currently doing. This is a little painful, but not the end of the world

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, as per request

@junpenglao junpenglao merged commit 20666de into blackjax-devs:main May 27, 2024
5 checks passed
AdrienCorenflos added a commit to AdrienCorenflos/blackjax that referenced this pull request Aug 14, 2024
* Update README.md (blackjax-devs#638)

* Update README.md

Update citation.

* Update README.md

* Indexing the notebook showing how to reproduce the GIF. (blackjax-devs#640)

Co-authored-by: Junpeng Lao <[email protected]>

* Bump python version (blackjax-devs#645)

* Bump python version

* update bool inverse

* SMC: allow each mutation kernel to have different parameters. (blackjax-devs#649)

* vmaping over parameters in base

* switch from mcmc_factory to just passing in parameters

* pre-commit and typing

* CRU and docs improvement

* pre-commit

* code review updates

* pre-commit

* rename test

* Migrate from deprecated `host_callback` to `io_callback` (blackjax-devs#651)

* Migrate from deprecated `host_callback` to `io_callback`

Co-Authored-By:
George Necula <[email protected]>

* Format file

* Fix bug

* Fix MALA transition energy (blackjax-devs#653)

* Fix MALA transition energy

* Use a different logic.

* Change variable names (blackjax-devs#654)

* Replace iterative RNG split and carry with `jax.random.fold_in` (blackjax-devs#656)

* Replace iterative RNG split and carry with `jax.random.fold_in`

* revert unintended change

* file formatting

* change `jax.tree_map` to `jax.tree.map`

* revert unintended file

* fiddle with rng_key

* seed again

* Removal of Algorithm classes. (blackjax-devs#657)

* more

* removing export

* removal of classes, tests passing

* linter

* fix on test

* linter

* removing parametrization on test

* code review updates

* exporting as_top_level_api in dynamic_hmc

* linter

* code review update: replace imports

* Fix deprecated call to jnp.clip (blackjax-devs#664)

* Update jax version requirements (blackjax-devs#666)

Fix blackjax-devs#665

* Make tests pass on `aarch64-linux` (blackjax-devs#671)

* Enable fitlering of AdaptationInfo (blackjax-devs#674)

* enable AdaptationInfo filtering

* revert progress_bar

* fix pre-commit

* fix empty sets

* enable adapt info filtering for all adaptation algorithms

* fix precommit /progressbar=True

* change filter tuple to use tree_map

* Update `run_inference_algorithm` to split `initial_position` and `initial_state` (blackjax-devs#672)

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* ADD INITIAL_POSITION

* FIX TEST

* RENAME O

* FIX DOCSTRING

* PUT EXPECTATION AFTER TRANSFORM

* Preconditioned mclmc (blackjax-devs#673)

* TESTS

* TESTS

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* ADD PRECONDITIONING TO MCLMC

* ADD PRECONDITIONING TO TUNING FOR MCLMC

* UPDATE GITIGNORE

* UPDATE GITIGNORE

* UPDATE TESTS

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* STREAMING AVERAGE

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* GITIGNORE

* PRECOMMIT CLEAN UP

* ADD INITIAL_POSITION

* FIX TEST

* ADD TEST

* REMOVE BENCHMARKS

* BUG FIX

* CHANGE PRECISION

* CHANGE PRECISION

* RENAME O

* UPDATE STREAMING AVG

* UPDATE PR

* RENAME STD_MAT

* New integrator, and add some metadata to integrators.py (blackjax-devs#681)

* TESTS

* TESTS

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* ADD PRECONDITIONING TO MCLMC

* ADD PRECONDITIONING TO TUNING FOR MCLMC

* UPDATE GITIGNORE

* UPDATE GITIGNORE

* UPDATE TESTS

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* STREAMING AVERAGE

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* GITIGNORE

* PRECOMMIT CLEAN UP

* FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS

* TEMPORARILY ADD BENCHMARKS

* ADD INITIAL_POSITION

* FIX TEST

* CLEAN UP

* REMOVE BENCHMARKS

* ADD TEST

* REMOVE BENCHMARKS

* BUG FIX

* CHANGE PRECISION

* CHANGE PRECISION

* ADD OMELYAN TEST

* RENAME O

* UPDATE STREAMING AVG

* UPDATE PR

* RENAME STD_MAT

* MERGE MAIN

* REMOVE COEFFICIENT EXPORTS

* Minor formatting (blackjax-devs#685)

* Minor formatting

* formatting

* fix test

* formatting

* MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (blackjax-devs#687)

* FIX KWARG BUG (blackjax-devs#686)

* FIX KWARG BUG

* FIX KWARG BUG

* Change isokinetic_integrator generation API (blackjax-devs#689)

* Apply function on pytree directly. (blackjax-devs#692)

* Apply function on pytree directly.

Avoiding unnecssary unpacking

* Fix kwarg

* Fix sampling test. (blackjax-devs#693)

* Enable shared mcmc parameters with tempered smc (blackjax-devs#694)

* add parameter filtering

* fix parameter split + docstring

* change extend_paramss

* convert to bit twiddling (blackjax-devs#696)

* Remove nightly release (blackjax-devs#699)

* Fix doc mistakes (blackjax-devs#701)

* Fix equation formatting

* Clarify JAX gradient error

* Fix punctuation + capitalization

* Fix grammar

Should not begin sentence with "i.e." in English.

* Fix math formatting error

* Fix typo

Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation.

* Add SVGD citation to appear in doc

Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation.

To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring.

* Fix grammar + clarify doc

* Fix typo

---------

Co-authored-by: Junpeng Lao <[email protected]>

* Update index.md (blackjax-devs#711)

The jitted step remained unused, leading to the example running with an uncompiled nuts.step. 

Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed.

* Enable progress bar under pmap (blackjax-devs#712)

* enable pmap progbar

* fix bar creation

* add locking

* fix formatting

* switch to using chain state

* remove labels (blackjax-devs#716)

* Simplify `run_inference_algorithm` (blackjax-devs#714)

* fix minor type errors

* storing only expectation values

* fixed memory efficient sampling

* clean up

* renaming vars

* precommit fixes

* fixing tests

* fixing tests

* fixing tests

* fixing tests

* fixing tests

* merge main

* burn in and fix tests

* burn in and fix tests

* minor fixes

* minor fixes

* minor fixes

---------

Co-authored-by: [email protected] <[email protected]>

* Harmonize Quickstart example (blackjax-devs#717)

* Update README.md (blackjax-devs#719)

---------

Co-authored-by: Junpeng Lao <[email protected]>
Co-authored-by: Carlos Iguaran <[email protected]>
Co-authored-by: ksnxr <[email protected]>
Co-authored-by: Gaétan Lepage <[email protected]>
Co-authored-by: Alberto Cabezas <[email protected]>
Co-authored-by: andrewdipper <[email protected]>
Co-authored-by: Reuben <[email protected]>
Co-authored-by: Gilad Turok <[email protected]>
Co-authored-by: johannahaffner <[email protected]>
Co-authored-by: [email protected] <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants