Skip to content

Commit

Permalink
quick bugfix to beta distribution with nonstandard bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
mbi6245 committed Jan 29, 2025
1 parent d2a20d7 commit 2249899
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 60 deletions.
84 changes: 37 additions & 47 deletions plots.ipynb

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions src/ensemble/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
# not np.isinf(self.support[1]),
):
case (True, True):
# print(np.isinf(self.support[0]))
# print(np.isinf(self.support[1]))
if np.isinf(self.support[0]) or np.isinf(self.support[1]):
raise ValueError(
"You may not change an infinite bound to be finite or"
Expand Down Expand Up @@ -348,8 +350,8 @@ def __init__(
lb: float = 0,
ub: float = 1,
):
super().__init__(mean, variance, lb, ub)
self.width = np.abs(ub - lb)
super().__init__(mean, variance, lb, ub)

def _squeeze(self, x: float) -> float:
"""transform x to be within (0, 1)
Expand Down Expand Up @@ -392,8 +394,7 @@ def _create_scipy_dist(self, csd_mean) -> None:
+ "combinations. The supplied variance must be in between "
+ "(0, mean^2)"
)
beta_bounds(self.mean)
if self.lb != 0 and self.ub != 1:
if self.lb != 0 or self.ub != 1:
mean = (self.mean - self.lb) / self.width
var = self.variance / self.width
else:
Expand Down Expand Up @@ -511,6 +512,6 @@ def strict_positive_support(mean: float) -> None:
raise ValueError("This distribution is only supported on (0, np.inf)")


def beta_bounds(mean: float) -> None:
if (mean < 0) or (mean > 1):
raise ValueError("This distribution is only supposrted on [0, 1]")
# def beta_bounds(mean: float) -> None:
# if (mean < 0) or (mean > 1):
# raise ValueError("This distribution is only supposrted on [0, 1]")
55 changes: 48 additions & 7 deletions src/ensemble/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,28 @@ def to_json(self, file_path: str, appending: bool = False) -> None:
def from_objs(
cls, fitted_distributions: List[Distribution]
) -> EnsembleDistribution:
"""generates ensemble distribution from Distribution objects also from
the ensemble package. all parameters, such as mean, variance, upper
bound, and lower bound, must match
Parameters
----------
fitted_distributions : List[Distribution]
list of distribution objects from the ensemble package
Returns
-------
EnsembleDistribution
ensemble distribution object with parameters equal to individual
input distributions
Raises
------
ValueError
if parameters across individual distributions don't match
ValueError
if the weight of a distribution is not set
"""
try:
mean, variance, lb, ub = (
fitted_distributions[0].mean,
Expand Down Expand Up @@ -441,23 +463,19 @@ class EnsembleFitter:
names of distributions in ensemble
objective: str
name of objective function for use in fitting ensemble
lb: str
"""

def __init__(
self,
distributions: List[str],
objective: str,
lb: str = None,
ub: str = None,
):
self.support = _check_supports_match(distributions)
self.distributions = distributions
self.objective = objective
self.lb = lb
self.ub = ub
# self.lb = lb
# self.ub = ub

def _objective_func(self, vec: np.ndarray) -> float:
"""applies different penalties to vector of distances given by user
Expand Down Expand Up @@ -517,13 +535,18 @@ def fit(
lb: float | None = None,
ub: float | None = None,
) -> EnsembleResult:
# TODO: HOW SHOULD WE DESCRIBE UB AND LB?
"""fits weighted sum of CDFs corresponding to distributions in
EnsembleModel object to empirical CDF of given data
Parameters
----------
data : npt.ArrayLike
individual-level data (i.e. microdata)
lb: float, optional
lower allowable bound of data, by default None
ub: float, optional
upper allowable bound of data, by default None
Returns
-------
Expand Down Expand Up @@ -604,7 +627,25 @@ def fit(
### HELPER FUNCTIONS


def _check_valid_ensemble(distributions: List[str], weights: List[float]):
def _check_valid_ensemble(
distributions: List[str], weights: List[float]
) -> None:
"""checks if ensemble distribution is valid
Parameters
----------
distributions : List[str]
list of named distributions, as strings
weights : List[float]
list of weights, in order of provided distribution list
Raises
------
ValueError
if there is a mismatch between num distributions and num weights
ValueError
if weights do not sum to 1
"""
if len(distributions) != len(weights):
raise ValueError(
"there must be the same number of distributions as weights!"
Expand Down
9 changes: 9 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ def test_beta():
assert np.isclose(res[0], BETA_MEAN)
assert np.isclose(res[1], BETA_VARIANCE)

mean = 3
vari = 1
lb = 0
ub = 5
beta_alt = Beta(mean, vari, lb, ub)
res_alt = beta_alt.stats(moments="mv")
assert np.isclose(res_alt[0], mean)
assert np.isclose(res_alt[1], vari)


def test_invalid_means():
# negative means for only positive RVs
Expand Down

0 comments on commit 2249899

Please sign in to comment.