Skip to content

Commit

Permalink
scope imports of ete3 in tasccoda (#422)
Browse files Browse the repository at this point in the history
* Simplify ete3

* Simplify ete3

* Simplify ete3

* Simplify ete3
  • Loading branch information
Zethson authored Nov 7, 2023
1 parent 0f1e60f commit 38a5bac
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 36 deletions.
9 changes: 2 additions & 7 deletions pertpy/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from pertpy.plot._augur import AugurpyPlot as ag
from pertpy.plot._dialogue import DialoguePlot as dl

try:
from pertpy.plot._coda import CodaPlot as coda
except ImportError:
pass

from pertpy.plot._cinemaot import CinemaotPlot as cot
from pertpy.plot._coda import CodaPlot as coda
from pertpy.plot._dialogue import DialoguePlot as dl
from pertpy.plot._guide_rna import GuideRnaPlot as guide
from pertpy.plot._milopy import MilopyPlot as milo
from pertpy.plot._mixscape import MixscapePlot as ms
Expand Down
21 changes: 15 additions & 6 deletions pertpy/plot/_coda.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Literal, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Union

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
Expand All @@ -9,7 +9,6 @@
import seaborn as sns
from adjustText import adjust_text
from anndata import AnnData
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
from matplotlib import cm, rcParams
from matplotlib.axes import Axes
from matplotlib.colors import ListedColormap
Expand Down Expand Up @@ -687,7 +686,7 @@ def label_point(x, y, val, ax):
def draw_tree( # pragma: no cover
data: Union[AnnData, MuData],
modality_key: str = "coda",
tree: Union[Tree, str] = "tree",
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
tight_text: Optional[bool] = False,
show_scale: Optional[bool] = False,
show: Optional[bool] = True,
Expand Down Expand Up @@ -734,6 +733,11 @@ def draw_tree( # pragma: no cover
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
>>> pt.pl.coda.draw_tree(mdata, tree="lineage")
"""
try:
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
except ImportError:
raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None

if isinstance(data, MuData):
data = data[modality_key]
if isinstance(data, AnnData):
Expand All @@ -750,9 +754,9 @@ def my_layout(node):
tree_style.layout_fn = my_layout
tree_style.show_scale = show_scale
if file_name is not None:
tree.render(file_name, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
tree.render(file_name, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) # type: ignore
if show:
return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) # type: ignore
else:
return tree, tree_style

Expand All @@ -761,7 +765,7 @@ def draw_effects( # pragma: no cover
data: Union[AnnData, MuData],
covariate: str,
modality_key: str = "coda",
tree: Union[Tree, str] = "tree",
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
show_legend: Optional[bool] = None,
show_leaf_effects: Optional[bool] = False,
tight_text: Optional[bool] = False,
Expand Down Expand Up @@ -814,6 +818,11 @@ def draw_effects( # pragma: no cover
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
>>> pt.pl.coda.draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
"""
try:
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
except ImportError:
raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None

if isinstance(data, MuData):
data = data[modality_key]
if isinstance(data, AnnData):
Expand Down
13 changes: 2 additions & 11 deletions pertpy/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from rich import print

from pertpy.tools._augur import Augur
from pertpy.tools._cinemaot import Cinemaot
from pertpy.tools._coda._sccoda import Sccoda
from pertpy.tools._coda._tasccoda import Tasccoda
from pertpy.tools._dialogue import Dialogue
from pertpy.tools._differential_gene_expression import DifferentialGeneExpression
from pertpy.tools._distances._distance_tests import DistanceTest
Expand All @@ -13,12 +13,3 @@
from pertpy.tools._perturbation_space._discriminator_classifier import DiscriminatorClassifierSpace
from pertpy.tools._perturbation_space._simple import CentroidSpace, DBSCANSpace, KMeansSpace, PseudobulkSpace
from pertpy.tools._scgen import SCGEN

try:
from pertpy.tools._coda._sccoda import Sccoda
from pertpy.tools._coda._tasccoda import Tasccoda
except ImportError as e:
if "ete3" in str(e):
print("[bold yellow]To use sccoda or tasccoda please install ete3 with [green]pip install ete3")
else:
raise e
22 changes: 17 additions & 5 deletions pertpy/tools/_coda/_base_coda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING

import arviz as az
import ete3 as ete
import jax.numpy as jnp
import numpy as np
import pandas as pd
Expand All @@ -22,6 +21,7 @@
if TYPE_CHECKING:
import numpyro as npy
import toytree as tt
from ete3 import Tree
from jax._src.prng import PRNGKeyArray
from jax._src.typing import Array

Expand Down Expand Up @@ -1242,7 +1242,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) ->


def get_a_2(
tree: ete.Tree,
tree: Tree,
leaf_order: list[str] = None,
node_order: list[str] = None,
) -> tuple[np.ndarray, int]:
Expand All @@ -1263,6 +1263,11 @@ def get_a_2(
T
number of nodes in the tree, excluding the root node
"""
try:
import ete3 as ete
except ImportError:
raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None

n_tips = len(tree.get_leaves())
n_nodes = len(tree.get_descendants())

Expand Down Expand Up @@ -1292,7 +1297,7 @@ def get_a_2(
return A_, n_nodes


def collapse_singularities_2(tree: ete.Tree) -> ete.Tree:
def collapse_singularities_2(tree: Tree) -> Tree:
"""Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child).
Args:
Expand Down Expand Up @@ -1368,8 +1373,10 @@ def import_tree(
dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object. Defaults to None.
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}. Defaults to True.
key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
Defaults to True.
key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`.
If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
copy: Return a copy instead of writing to `data`. Defaults to False.
Returns:
Expand All @@ -1379,6 +1386,11 @@ def import_tree(
tree: A ete3 tree object.
"""
try:
import ete3 as ete
except ImportError:
raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None

if isinstance(data, MuData):
try:
data_1 = data[modality_1]
Expand Down
21 changes: 14 additions & 7 deletions pertpy/tools/_coda/_tasccoda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Literal

import arviz as az
import ete3 as ete
import jax.numpy as jnp
import numpy as np
import numpyro as npy
Expand Down Expand Up @@ -148,17 +147,19 @@ def prepare(
pen_args: dict = None,
modality_key: str = "coda",
) -> AnnData | MuData:
"""Handles data preprocessing, covariate matrix creation, reference selection, and zero count replacement for tascCODA. Also sets model parameters, model type (tree_agg), effect selection type (sslaso) and performs tree processing.
"""Handles data preprocessing, covariate matrix creation, reference selection, and zero count replacement for tascCODA.
Args:
data: Anndata object with cell counts as .X and covariates saved in .obs or a MuData object.
formula: R-style formula for building the covariate matrix.
Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
reference_cell_type: Column name that sets the reference cell type.
Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen. Defaults to "automatic".
automatic_reference_absence_threshold: If using reference_cell_type = "automatic", determine the maximum fraction of zero entries for a cell type
to be considered as a possible reference cell type. Defaults to 0.05.
If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen.
Defaults to "automatic".
automatic_reference_absence_threshold: If using reference_cell_type = "automatic",
determine the maximum fraction of zero entries for a cell type
to be considered as a possible reference cell type. Defaults to 0.05.
tree_key: Key in `adata.uns` that contains the tree structure
pen_args: Dictionary with penalty arguments. With `reg="scaled_3"`, the parameters phi (aggregation bias), lambda_1, lambda_0 can be set here.
See the tascCODA paper for an explanation of these parameters. Default: lambda_0 = 50, lambda_1 = 5, phi = 0.
Expand Down Expand Up @@ -199,6 +200,12 @@ def prepare(
if tree_key is None:
raise ValueError("Please specify the key in .uns that contains the tree structure!")

# Scoped import due to installation issues
try:
import ete3 as ete
except ImportError:
raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None

# toytree tree - only for legacy reasons, can be removed in the final version
if isinstance(adata.uns[tree_key], tt.tree):
# Collapse singularities in the tree
Expand Down

0 comments on commit 38a5bac

Please sign in to comment.