Skip to content

Commit

Permalink
Open issue for incompatibility with torch.Transform interface: #1057
Browse files Browse the repository at this point in the history
  • Loading branch information
Baschdl committed Mar 20, 2024
1 parent f2b016e commit 1fa20a7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions sbi/utils/conditional_density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def log_prob(self, *args, **kwargs):
return self.full_prior.log_prob(*args, **kwargs)


# This class doesn't follow the interface of torch.Transform. This causes pyright's `IncompatibleMethodOverride`.
class RestrictedTransformForConditional(torch_tf.Transform):
"""
Class to restrict the transform to fewer dimensions for conditional sampling.
Expand Down Expand Up @@ -421,7 +422,7 @@ def __call__(self, theta: Tensor) -> Tensor:
tf_full_theta = self.transform(full_theta)
return tf_full_theta[:, self.dims_to_sample] # type: ignore

def inv(self, theta: Tensor) -> Tensor:
def inv(self, theta: Tensor) -> Tensor: # pyright: ignore[reportIncompatibleMethodOverride]
r"""
Inverse transform restricted $\theta$.
"""
Expand All @@ -430,7 +431,7 @@ def inv(self, theta: Tensor) -> Tensor:
tf_full_theta = self.transform.inv(full_theta)
return tf_full_theta[:, self.dims_to_sample] # type: ignore

def log_abs_det_jacobian(self, theta1: Tensor, theta2: Tensor) -> Tensor:
def log_abs_det_jacobian(self, theta1: Tensor, theta2: Tensor) -> Tensor: # pyright: ignore[reportIncompatibleMethodOverride]
"""
Return the `log_abs_det_jacobian` of |dtheta1 / dtheta2|.
Expand Down

0 comments on commit 1fa20a7

Please sign in to comment.