Skip to content

Commit

Permalink
allow gpu when using an empirical prior for SNPE
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Aug 24, 2022
1 parent 6fed0f4 commit 893ac43
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ def append_simulations(
"run single-round inference with "
"`append_simulations(..., proposal=None)`."
)
theta_prior = self.get_simulations()[0]
self._prior = ImproperEmpirical(theta_prior, ones(theta_prior.shape[0]))
theta_prior = self.get_simulations()[0].to(self._device)
self._prior = ImproperEmpirical(
theta_prior, ones(theta_prior.shape[0], device=self._device)
)

return self

Expand Down

0 comments on commit 893ac43

Please sign in to comment.