Skip to content

Commit

Permalink
create simulation function and run simulations on all 4 univariate tr…
Browse files Browse the repository at this point in the history
…ansforms in simulations.ipynb
  • Loading branch information
mbi6245 committed Jul 19, 2024
1 parent 6304183 commit 549a2c9
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 104 deletions.
320 changes: 245 additions & 75 deletions simulations.ipynb

Large diffs are not rendered by default.

51 changes: 22 additions & 29 deletions src/distrx/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,28 @@ def __init__(self, transform: str) -> None:
self.transform = transform

def __call__(
self, mu: npt.ArrayLike, sigma: npt.ArrayLike, n: npt.ArrayLike
self, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
match self.transform:
case "log":
return self.log_trans(mu, sigma, n)
return self.log_trans(mu, sigma)
case "logit":
return self.logit_trans(mu, sigma, n)
return self.logit_trans(mu, sigma)
case "exp":
return self.exp_trans(mu, sigma, n)
return self.exp_trans(mu, sigma)
case "expit":
return self.expit_trans(mu, sigma, n)
return self.expit_trans(mu, sigma)
case _:
raise ValueError(f"Invalid transform '{self.transform}'.")

def log_trans(
self, mu: npt.ArrayLike, sigma: npt.ArrayLike, n: npt.ArrayLike
self, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
"""Performs delta method on data under log transform
.. math::
\\log(mu), \\frac{\\sigma}{\\mu} \\cdot \\frac{1}{\\sqrt{n}}
\\log(mu), \\frac{\\sigma}{\\mu}
Parameters
----------
Expand All @@ -77,16 +77,16 @@ def log_trans(
Transformed mean and standard error
"""
log = c2fun_dict["log"]
return log(mu), sigma * log(mu, order=1) / np.sqrt(n)
return log(mu), sigma * log(mu, order=1)

def logit_trans(
self, mu: npt.ArrayLike, sigma: npt.ArrayLike, n: npt.ArrayLike
self, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
"""Performs delta method on data under logit transform
.. math::
\\log(\\frac{\\mu}{1 - \\mu}), \\frac{\\sigma}{\\mu \\cdot (1 - \\mu)} \\cdot \\frac{1}{\\sqrt{n}}
\\log(\\frac{\\mu}{1 - \\mu}), \\frac{\\sigma}{\\mu \\cdot (1 - \\mu)}
Parameters
----------
Expand All @@ -101,16 +101,16 @@ def logit_trans(
Transformed mean and standard error
"""
logit = c2fun_dict["logit"]
return logit(mu), sigma * logit(mu, order=1) / np.sqrt(n)
return logit(mu), sigma * logit(mu, order=1)

def exp_trans(
self, mu: npt.ArrayLike, sigma: npt.ArrayLike, n: npt.ArrayLike
self, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
"""Performs delta method on data under exponential transform
.. math::
\\exp(\\mu), \\sigma \\cdot \\exp(\\mu) \\cdot \\frac{1}{\\sqrt{n}}
\\exp(\\mu), \\sigma \\cdot \\exp(\\mu)
Parameters
----------
Expand All @@ -125,16 +125,16 @@ def exp_trans(
Transformed mean and standard error
"""
exp = c2fun_dict["exp"]
return exp(mu), sigma * exp(mu, order=1) / np.sqrt(n)
return exp(mu), sigma * exp(mu, order=1)

def expit_trans(
self, mu: npt.ArrayLike, sigma: npt.ArrayLike, n: npt.ArrayLike
self, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
"""Performs delta method on data under expit transform
.. math::
\\frac{1}{1 + \\exp(-\\mu)}, \\sigma \\cdot \\frac{\\exp(\\mu)}{(1 + \\exp(\\mu))^2} \\cdot \\frac{1}{\\sqrt{n}}
\\frac{1}{1 + \\exp(-\\mu)}, \\sigma \\cdot \\frac{\\exp(\\mu)}{(1 + \\exp(\\mu))^2}
Parameters
----------
Expand All @@ -149,7 +149,7 @@ def expit_trans(
Transformed mean and standard error
"""
expit = c2fun_dict["expit"]
return expit(mu), sigma * expit(mu, order=1) / np.sqrt(n)
return expit(mu), sigma * expit(mu, order=1)


class FirstOrderBivariate:
Expand Down Expand Up @@ -226,7 +226,6 @@ def percentage_change_trans(
def transform_univariate(
mu: npt.ArrayLike,
sigma: npt.ArrayLike,
n: npt.ArrayLike,
transform: str,
method: str = "delta",
) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -258,11 +257,11 @@ def transform_univariate(
"""

mu, sigma = np.array(mu), np.array(sigma)
_check_input(mu, sigma, n)
_check_input(mu, sigma)
match method:
case "delta":
transformer = FirstOrder(transform)
return transformer(mu, sigma, n)
return transformer(mu, sigma)
case _:
raise ValueError(f"Invalid method '{method}'.")

Expand Down Expand Up @@ -364,9 +363,7 @@ def transform_bivariate(
# return delta_hat, np.sqrt(sigma_trans)


def _check_input(
mu: npt.ArrayLike, sigma: npt.ArrayLike, n: npt.ArrayLike
) -> None:
def _check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None:
"""Run checks on input data.
Parameters
Expand All @@ -382,7 +379,7 @@ def _check_input(
"""
# _check_lengths_match(mu, sigma)
_check_sigma_n_positive(sigma, n)
_check_sigma_positive(sigma)


def _check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None:
Expand All @@ -400,7 +397,7 @@ def _check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None:
raise ValueError("Lengths of mu and sigma don't match.")


def _check_sigma_n_positive(sigma: npt.ArrayLike, n: npt.ArrayLike) -> None:
def _check_sigma_positive(sigma: npt.ArrayLike) -> None:
"""Check that `sigma` is positive.
Parameters
Expand All @@ -413,7 +410,3 @@ def _check_sigma_n_positive(sigma: npt.ArrayLike, n: npt.ArrayLike) -> None:
warnings.warn("Sigma vector contains zeros.")
if np.any(sigma < 0.0):
raise ValueError("Sigma values must be positive.")
if np.any(n == 0.0):
warnings.warn("Sigma vector contains zeros.")
if np.any(n < 0.0):
raise ValueError("Sigma values must be positive.")

0 comments on commit 549a2c9

Please sign in to comment.