Skip to content

Commit

Permalink
fixes by ruff formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Mar 6, 2024
1 parent c8a1ef7 commit 887f087
Show file tree
Hide file tree
Showing 28 changed files with 220 additions and 228 deletions.
18 changes: 9 additions & 9 deletions examples/HH_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ def calculate_summary_statistics(x):
)

# concatenation of summary statistics
sum_stats_vec = np.concatenate(
(
np.array([spike_times_stim.shape[0]]),
np.array(
[rest_pot, rest_pot_std, np.mean(x["data"][(t > t_on) & (t < t_off)])]
),
moments,
)
)
sum_stats_vec = np.concatenate((
np.array([spike_times_stim.shape[0]]),
np.array([
rest_pot,
rest_pot_std,
np.mean(x["data"][(t > t_on) & (t < t_off)]),
]),
moments,
))
sum_stats_vec = sum_stats_vec[0:n_summary]

return sum_stats_vec
38 changes: 19 additions & 19 deletions sbi/analysis/conditional_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,22 @@ def conditional_corrcoeff(
correlation_matrices = []
for cond in condition:
correlation_matrices.append(
torch.stack(
[
compute_corrcoeff(
eval_conditional_density(
density,
cond.to(device),
limits.to(device),
dim1=dim1,
dim2=dim2,
resolution=resolution,
),
limits[[dim1, dim2]].to(device),
)
for dim1 in subset_
for dim2 in subset_
if dim1 < dim2
]
)
torch.stack([
compute_corrcoeff(
eval_conditional_density(
density,
cond.to(device),
limits.to(device),
dim1=dim1,
dim2=dim2,
resolution=resolution,
),
limits[[dim1, dim2]].to(device),
)
for dim1 in subset_
for dim2 in subset_
if dim1 < dim2
])
)

average_correlations = torch.mean(torch.stack(correlation_matrices), dim=0)
Expand Down Expand Up @@ -294,7 +292,9 @@ def conditional_potential(
condition = atleast_2d_float32_tensor(condition)

conditioned_potential_fn = ConditionedPotential(
potential_fn, condition, dims_to_sample # type: ignore
potential_fn,
condition,
dims_to_sample, # type: ignore
)

restricted_prior = RestrictedPriorForConditional(prior, dims_to_sample)
Expand Down
20 changes: 9 additions & 11 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def hex2rgb(hex):
def rgb2hex(RGB):
# Components need to be integers for hex to make sense
RGB = [int(x) for x in RGB]
return "#" + "".join(
["0{0:x}".format(v) if v < 16 else "{0:x}".format(v) for v in RGB]
)
return "#" + "".join([
"0{0:x}".format(v) if v < 16 else "{0:x}".format(v) for v in RGB
])


def _update(d, u):
Expand Down Expand Up @@ -870,12 +870,10 @@ def _arrange_plots(
else:
_format_axis(ax, xhide=True, yhide=True)
if opts["tick_labels"] is not None:
ax.set_xticklabels(
(
str(opts["tick_labels"][col][0]),
str(opts["tick_labels"][col][1]),
)
)
ax.set_xticklabels((
str(opts["tick_labels"][col][0]),
str(opts["tick_labels"][col][1]),
))

# Diagonals
if current == "diag":
Expand Down Expand Up @@ -1135,9 +1133,9 @@ def _sbc_rank_plot(
figsize = (num_parameters * 4, num_rows * 5) if params_in_subplots else (8, 5)

if parameter_labels is None:
parameter_labels = [f"dim {i+1}" for i in range(num_parameters)]
parameter_labels = [f"dim {i + 1}" for i in range(num_parameters)]
if ranks_labels is None:
ranks_labels = [f"rank set {i+1}" for i in range(num_ranks)]
ranks_labels = [f"rank set {i + 1}" for i in range(num_ranks)]
if num_bins is None:
# Recommendation from Talts et al.
num_bins = num_sbc_runs // 20
Expand Down
32 changes: 14 additions & 18 deletions sbi/analysis/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,10 @@ def check_prior_vs_dap(prior_samples: Tensor, dap_samples: Tensor) -> Tensor:

assert prior_samples.shape == dap_samples.shape

return torch.tensor(
[
c2st(s1.unsqueeze(1), s2.unsqueeze(1))
for s1, s2 in zip(prior_samples.T, dap_samples.T)
]
)
return torch.tensor([
c2st(s1.unsqueeze(1), s2.unsqueeze(1))
for s1, s2 in zip(prior_samples.T, dap_samples.T)
])


def check_uniformity_frequentist(ranks, num_posterior_samples) -> Tensor:
Expand Down Expand Up @@ -344,20 +342,18 @@ def check_uniformity_c2st(
one for each dim_parameters.
"""

c2st_scores = torch.tensor(
c2st_scores = torch.tensor([
[
[
c2st(
rks.unsqueeze(1),
Uniform(zeros(1), num_posterior_samples * ones(1)).sample(
torch.Size((ranks.shape[0],))
),
)
for rks in ranks.T
]
for _ in range(num_repetitions)
c2st(
rks.unsqueeze(1),
Uniform(zeros(1), num_posterior_samples * ones(1)).sample(
torch.Size((ranks.shape[0],))
),
)
for rks in ranks.T
]
)
for _ in range(num_repetitions)
])

# Use variance over repetitions to estimate robustness of c2st.
if (c2st_scores.std(0) > 0.05).any():
Expand Down
1 change: 1 addition & 0 deletions sbi/analysis/tensorboard_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.
"""Utils for processing tensorboard event data."""

import inspect
import logging
from copy import deepcopy
Expand Down
42 changes: 18 additions & 24 deletions sbi/inference/abc/smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,13 @@ def sass_simulator(theta):
)
log_weights = torch.log(1 / num_particles * torch.ones(num_particles))

self.logger.info(
(
"population=%s, eps=%s, ess=%s, num_sims=%s",
pop_idx,
epsilon,
1.0,
num_initial_pop,
)
)
self.logger.info((
"population=%s, eps=%s, ess=%s, num_sims=%s",
pop_idx,
epsilon,
1.0,
num_initial_pop,
))

all_particles = [particles]
all_log_weights = [log_weights]
Expand Down Expand Up @@ -246,14 +244,12 @@ def sass_simulator(theta):
particles, log_weights, ess_min, pop_idx
)

self.logger.info(
(
"population=%s done: eps={epsilon:.6f}, num_sims=%s.",
pop_idx,
epsilon,
self.simulation_counter,
)
)
self.logger.info((
"population=%s done: eps={epsilon:.6f}, num_sims=%s.",
pop_idx,
epsilon,
self.simulation_counter,
))

# collect results
all_particles.append(particles)
Expand Down Expand Up @@ -479,14 +475,12 @@ def _get_next_epsilon(self, distances: Tensor, quantile: float) -> float:
try:
qidx = torch.where(distances_cdf >= quantile)[0][0]
except IndexError:
self.logger.warning(
(
"""Accepted unique distances=%s don't match quantile=%s. Selecting
self.logger.warning((
"""Accepted unique distances=%s don't match quantile=%s. Selecting
last distance.""",
distances,
quantile,
)
)
distances,
quantile,
))
qidx = -1

# The new epsilon is given by that distance.
Expand Down
5 changes: 3 additions & 2 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ def map(
return self._map

def __repr__(self):
desc = f"""{self.__class__.__name__} sampler for potential_fn=<{self.
potential_fn.__class__.__name__}>"""
desc = f"""{self.__class__.__name__} sampler for potential_fn=<{
self.potential_fn.__class__.__name__
}>"""
return desc

def __str__(self):
Expand Down
9 changes: 5 additions & 4 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ def _pyro_mcmc(
)
sampler.run()
samples = next(iter(sampler.get_samples().values())).reshape(
-1, initial_params.shape[1] # .shape[1] = dim of theta
-1,
initial_params.shape[1], # .shape[1] = dim of theta
)

# Save posterior sampler.
Expand Down Expand Up @@ -695,9 +696,9 @@ def get_arviz_inference_data(self) -> InferenceData:
*samples_shape
)

inference_data = az.convert_to_inference_data(
{f"{self.param_name}": samples}
)
inference_data = az.convert_to_inference_data({
f"{self.param_name}": samples
})

return inference_data

Expand Down
5 changes: 3 additions & 2 deletions sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def build_posterior(

assert isinstance(
likelihood_estimator, MixedDensityEstimator
), f"""net must be of type MixedDensityEstimator but is {type
(likelihood_estimator)}."""
), f"""net must be of type MixedDensityEstimator but is {
type(likelihood_estimator)
}."""

(
potential_fn,
Expand Down
10 changes: 5 additions & 5 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,7 @@ def __init__(
logits_pp,
m_pp,
prec_pp,
) = proposal.posterior_estimator._posthoc_correction(
proposal.default_x
) # type: ignore
) = proposal.posterior_estimator._posthoc_correction(proposal.default_x) # type: ignore
self._logits_pp, self._m_pp, self._prec_pp = (
logits_pp.detach(),
m_pp.detach(),
Expand Down Expand Up @@ -853,10 +851,12 @@ def _logits_posterior(

# Compute for proposal, density estimator, and proposal posterior:
exponent_pp = utils.batched_mixture_vmv(
precisions_pp, means_pp # m_0 in eq (26) in Appendix C of [1]
precisions_pp,
means_pp, # m_0 in eq (26) in Appendix C of [1]
)
exponent_d = utils.batched_mixture_vmv(
precisions_d, means_d # m_k in eq (26) in Appendix C of [1]
precisions_d,
means_d, # m_k in eq (26) in Appendix C of [1]
)
exponent_post = utils.batched_mixture_vmv(
precisions_post,
Expand Down
15 changes: 9 additions & 6 deletions sbi/neural_nets/density_estimators/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,19 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
if len(condition.shape) == condition_dims:
# nflows.sample() expects conditions to be batched.
condition = condition.unsqueeze(0)
samples = self.net.sample(num_samples, context=condition).reshape(
(*sample_shape, -1)
)
samples = self.net.sample(num_samples, context=condition).reshape((
*sample_shape,
-1,
))
else:
# For batched conditions, we need to reshape the conditions and the samples
batch_shape = condition.shape[:-condition_dims]
condition = condition.reshape(-1, *self._condition_shape)
samples = self.net.sample(num_samples, context=condition).reshape(
(*batch_shape, *sample_shape, -1)
)
samples = self.net.sample(num_samples, context=condition).reshape((
*batch_shape,
*sample_shape,
-1,
))

return samples

Expand Down
24 changes: 10 additions & 14 deletions sbi/samplers/mcmc/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,16 @@ def _sample_from_conditional(self, params, dim):
def _log_prob_d(x):
assert self.potential_fn is not None, "Chain not initialized."

return -self.potential_fn(
{
self._site_name: torch.cat(
(
params[self._site_name].view(-1)[:dim],
x.reshape(1),
params[self._site_name].view(-1)[dim + 1 :],
)
).unsqueeze(
0
) # TODO: The unsqueeze seems to give a speed up, figure out when
# this is the case exactly
}
)
return -self.potential_fn({
self._site_name: torch.cat((
params[self._site_name].view(-1)[:dim],
x.reshape(1),
params[self._site_name].view(-1)[dim + 1 :],
)).unsqueeze(
0
) # TODO: The unsqueeze seems to give a speed up, figure out when
# this is the case exactly
})

assert (
self._site_name is not None and self._width is not None
Expand Down
Loading

0 comments on commit 887f087

Please sign in to comment.