-
Notifications
You must be signed in to change notification settings - Fork 3
[Feature] Adding blackjax sampling functionality #367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
40e4f8f to
d370250
Compare
4f963d1 to
2786469
Compare
…rior is now returning a dictionary
cf7e3b4 to
10fe4ee
Compare
| nnpdf = { git = "https://github.com/NNPDF/nnpdf" } | ||
| anesthetic = "^2.10.2" | ||
| tfp-nightly = { extras = ["jax"], version = "*" } | ||
| blackjax = {git = "https://github.com/handley-lab/blackjax", rev = "nested_sampling"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we are preparing this I would suggest revisiting if we can switch to the tagged version:
https://github.com/handley-lab/blackjax/releases/tag/nested-sampling-beta.1
There are some changes incoming to this main PR branch as we finalise it that may be breaking, so it would be better if possible to point to a stable release so we can manage the (hopefully) finished merged PR integration with colibri at a later date. I think in theory it should be a like for like swap but I think there were some conda env problems when we last looked at this
Adds a JAX-native nested sampling fitter using the
blackjaxlibrary, enabling end-to-end JAX-based Bayesian inference.Summary of current changes:
blackjax_fit.py, which implements the nested sampling loop usingblackjax.nss.bayesian_priorfunction is refactored. It now returns a dictionary containingprior_transform(for UltraNest), andlog_probandsamplefunctions (for BlackJAX).blackjax,tensorflow-probability(for prior distributions), andanesthetic(for results handling). The last two of these are optional but are included into the pyproject.toml for ease of useblackjax_settingsin runcards.lh_fit_closure_test_blackjax.yaml, is included to demonstrate usage.Todo