Skip to content

Commit

Permalink
renaming and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Feb 18, 2025
1 parent e9d9af4 commit 205d59e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 6 deletions.
6 changes: 3 additions & 3 deletions sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def sample(

if self.sample_with == "ode":
samples = rejection.accept_reject_sample(
proposal=self.sample_via_zuko,
proposal=self.sample_via_ode,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -241,7 +241,7 @@ def _sample_via_diffusion(

return samples

def sample_via_zuko(
def sample_via_ode(
self,
sample_shape: Shape = torch.Size(),
) -> Tensor:
Expand Down Expand Up @@ -333,7 +333,7 @@ def sample_batched(

if self.sample_with == "ode":
samples = rejection.accept_reject_sample(

Check warning on line 335 in sbi/inference/posteriors/score_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/score_posterior.py#L334-L335

Added lines #L334 - L335 were not covered by tests
proposal=self.sample_via_zuko,
proposal=self.sample_via_ode,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
num_xos=batch_size,
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from functools import partial
from typing import Optional, Tuple

Expand Down
3 changes: 0 additions & 3 deletions sbi/samplers/score/diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ def initialize(self, num_samples: int) -> Tensor:
# iid_bridge is not accurate.
num_batch = self.batch_shape.numel()
init_shape = (num_samples, num_batch) + self.input_shape
# init_shape = (
# num_samples,
# ) + self.input_shape # just use num_samples, not num_batch
# NOTE: for the IID setting we might need to scale the noise with iid batch
# size, as in equation (7) in the paper.
eps = torch.randn(init_shape, device=self.device)
Expand Down

0 comments on commit 205d59e

Please sign in to comment.