-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
additional features for NPSE #1370
additional features for NPSE #1370
Conversation
53bd4f9
to
0b49bf3
Compare
9f24294
to
bcea468
Compare
specified |
I've requested review now. While batched sampling for score-based posteriors is now possible and tested for, IID sampling is still not possible, but talking to @manuelgloeckler about this, maybe this can be done in a new PR. Other than that, I've also noticed while testing that sampling from the posterior with ode can be much less accurate than via diffusion. So the test |
e59d6d0
to
0d29c8a
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1370 +/- ##
===========================================
- Coverage 89.31% 78.18% -11.13%
===========================================
Files 119 119
Lines 8779 8844 +65
===========================================
- Hits 7841 6915 -926
- Misses 938 1929 +991
Flags with carried forward coverage won't be shown. Click here to find out more.
|
I started integrating the IID stuff into the current version of this branch and created a new PR for it (#1381). So, lets first get this merged the IID PR still requires some work from my side. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks a lot for addressing all these issues with the current NPSE. 🚀
I left a couple of suggestions and questions for my understanding. Happy to discuss in person if needed.
Thanks for the comments @janfb! I've tried to answer some of your questions, and will make the appropriate changes soon! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my questions and comments!
I have some final minor things and a renaming suggestion, plus an idea for better type checking.
Also, I would suggest that we add more info about how to use the score-based methods in the tutorials as well, e.g., the convergence details about the validation times, about sampling via sde vs ode, how to use MAP, how to do i.i.d. inference. (could also be a separate PR). I think this will be important for users to actually be able to use all these new features.
d6e4133
to
205d59e
Compare
Thanks @janfb! I agree that adding a tutorial on the new features would be important - as we plan to rework the tutorials in the hackathon, I think this would belong in a later PR 😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good now. great work @gmoss13 ! 🎉 🚀
What does this implement/fix? Explain your changes
This introduces some additional features for score estimation named in #1226, namely:
enable_transform = True
for score-based potentialsconverged()
method for NPSEDoes this close any currently open issues?
#1226
Any relevant code examples, logs, error output, etc?
Any other comments?
score_based_posterior.map()
is still quite slow. We get the gradient of the log probs with respect totheta
by using the score estimator, but still computing the log-probs explicitly ingradient_ascent
, which is more expensive. To get around this, we save a low-accuracy ode_flow to calculate the log-probs more quickly. Ideally, we might want to write a customgradient_ascent
function for calculating the MAP for score estimators to avoid doing this altogether.linearGaussian_npse_test.py::test_npse_map
- as far as I can tell, the reason this failed with the lower tolerance is not because of MAP calculation, but because score-based posteriors are currently slightly less accurate (at least for our test tasks).