From 38a5baccc30a5a587cc8bdf822f7e9064ae02921 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Tue, 7 Nov 2023 05:28:41 -0800 Subject: [PATCH] scope imports of ete3 in tasccoda (#422) * Simplify ete3 * Simplify ete3 * Simplify ete3 * Simplify ete3 --- pertpy/plot/__init__.py | 9 ++------- pertpy/plot/_coda.py | 21 +++++++++++++++------ pertpy/tools/__init__.py | 13 ++----------- pertpy/tools/_coda/_base_coda.py | 22 +++++++++++++++++----- pertpy/tools/_coda/_tasccoda.py | 21 ++++++++++++++------- 5 files changed, 50 insertions(+), 36 deletions(-) diff --git a/pertpy/plot/__init__.py b/pertpy/plot/__init__.py index 84406443..505578e0 100644 --- a/pertpy/plot/__init__.py +++ b/pertpy/plot/__init__.py @@ -1,12 +1,7 @@ from pertpy.plot._augur import AugurpyPlot as ag -from pertpy.plot._dialogue import DialoguePlot as dl - -try: - from pertpy.plot._coda import CodaPlot as coda -except ImportError: - pass - from pertpy.plot._cinemaot import CinemaotPlot as cot +from pertpy.plot._coda import CodaPlot as coda +from pertpy.plot._dialogue import DialoguePlot as dl from pertpy.plot._guide_rna import GuideRnaPlot as guide from pertpy.plot._milopy import MilopyPlot as milo from pertpy.plot._mixscape import MixscapePlot as ms diff --git a/pertpy/plot/_coda.py b/pertpy/plot/_coda.py index cda77f54..01773108 100644 --- a/pertpy/plot/_coda.py +++ b/pertpy/plot/_coda.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import matplotlib.image as mpimg import matplotlib.pyplot as plt @@ -9,7 +9,6 @@ import seaborn as sns from adjustText import adjust_text from anndata import AnnData -from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces from matplotlib import cm, rcParams from matplotlib.axes import Axes from matplotlib.colors import ListedColormap @@ -687,7 +686,7 @@ def label_point(x, y, val, ax): def draw_tree( # pragma: no cover data: Union[AnnData, MuData], modality_key: str = "coda", - tree: Union[Tree, str] = "tree", + tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors tight_text: Optional[bool] = False, show_scale: Optional[bool] = False, show: Optional[bool] = True, @@ -734,6 +733,11 @@ def draw_tree( # pragma: no cover >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) >>> pt.pl.coda.draw_tree(mdata, tree="lineage") """ + try: + from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces + except ImportError: + raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None + if isinstance(data, MuData): data = data[modality_key] if isinstance(data, AnnData): @@ -750,9 +754,9 @@ def my_layout(node): tree_style.layout_fn = my_layout tree_style.show_scale = show_scale if file_name is not None: - tree.render(file_name, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) + tree.render(file_name, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) # type: ignore if show: - return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) + return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) # type: ignore else: return tree, tree_style @@ -761,7 +765,7 @@ def draw_effects( # pragma: no cover data: Union[AnnData, MuData], covariate: str, modality_key: str = "coda", - tree: Union[Tree, str] = "tree", + tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors show_legend: Optional[bool] = None, show_leaf_effects: Optional[bool] = False, tight_text: Optional[bool] = False, @@ -814,6 +818,11 @@ def draw_effects( # pragma: no cover >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) >>> pt.pl.coda.draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage") """ + try: + from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces + except ImportError: + raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None + if isinstance(data, MuData): data = data[modality_key] if isinstance(data, AnnData): diff --git a/pertpy/tools/__init__.py b/pertpy/tools/__init__.py index 67a6d587..cd545da7 100644 --- a/pertpy/tools/__init__.py +++ b/pertpy/tools/__init__.py @@ -1,7 +1,7 @@ -from rich import print - from pertpy.tools._augur import Augur from pertpy.tools._cinemaot import Cinemaot +from pertpy.tools._coda._sccoda import Sccoda +from pertpy.tools._coda._tasccoda import Tasccoda from pertpy.tools._dialogue import Dialogue from pertpy.tools._differential_gene_expression import DifferentialGeneExpression from pertpy.tools._distances._distance_tests import DistanceTest @@ -13,12 +13,3 @@ from pertpy.tools._perturbation_space._discriminator_classifier import DiscriminatorClassifierSpace from pertpy.tools._perturbation_space._simple import CentroidSpace, DBSCANSpace, KMeansSpace, PseudobulkSpace from pertpy.tools._scgen import SCGEN - -try: - from pertpy.tools._coda._sccoda import Sccoda - from pertpy.tools._coda._tasccoda import Tasccoda -except ImportError as e: - if "ete3" in str(e): - print("[bold yellow]To use sccoda or tasccoda please install ete3 with [green]pip install ete3") - else: - raise e diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index b108e9f9..42fa2119 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING import arviz as az -import ete3 as ete import jax.numpy as jnp import numpy as np import pandas as pd @@ -22,6 +21,7 @@ if TYPE_CHECKING: import numpyro as npy import toytree as tt + from ete3 import Tree from jax._src.prng import PRNGKeyArray from jax._src.typing import Array @@ -1242,7 +1242,7 @@ def df2newick(df: pd.DataFrame, levels: list[str], inner_label: bool = True) -> def get_a_2( - tree: ete.Tree, + tree: Tree, leaf_order: list[str] = None, node_order: list[str] = None, ) -> tuple[np.ndarray, int]: @@ -1263,6 +1263,11 @@ def get_a_2( T number of nodes in the tree, excluding the root node """ + try: + import ete3 as ete + except ImportError: + raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None + n_tips = len(tree.get_leaves()) n_nodes = len(tree.get_descendants()) @@ -1292,7 +1297,7 @@ def get_a_2( return A_, n_nodes -def collapse_singularities_2(tree: ete.Tree) -> ete.Tree: +def collapse_singularities_2(tree: Tree) -> Tree: """Collapses (deletes) nodes in a ete3 tree that are singularities (have only one child). Args: @@ -1368,8 +1373,10 @@ def import_tree( dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object. Defaults to None. levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None. levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None. - add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}. Defaults to True. - key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2]. Defaults to "tree". + add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}. + Defaults to True. + key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`. + If `data` is MuData, save tree in data[modality_2]. Defaults to "tree". copy: Return a copy instead of writing to `data`. Defaults to False. Returns: @@ -1379,6 +1386,11 @@ def import_tree( tree: A ete3 tree object. """ + try: + import ete3 as ete + except ImportError: + raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None + if isinstance(data, MuData): try: data_1 = data[modality_1] diff --git a/pertpy/tools/_coda/_tasccoda.py b/pertpy/tools/_coda/_tasccoda.py index 575095f9..a4e2ea03 100644 --- a/pertpy/tools/_coda/_tasccoda.py +++ b/pertpy/tools/_coda/_tasccoda.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Literal import arviz as az -import ete3 as ete import jax.numpy as jnp import numpy as np import numpyro as npy @@ -148,17 +147,19 @@ def prepare( pen_args: dict = None, modality_key: str = "coda", ) -> AnnData | MuData: - """Handles data preprocessing, covariate matrix creation, reference selection, and zero count replacement for tascCODA. Also sets model parameters, model type (tree_agg), effect selection type (sslaso) and performs tree processing. + """Handles data preprocessing, covariate matrix creation, reference selection, and zero count replacement for tascCODA. Args: data: Anndata object with cell counts as .X and covariates saved in .obs or a MuData object. formula: R-style formula for building the covariate matrix. - Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category. - To set a different level as the base category for a categorical covariate, use "C(, Treatment(''))" + Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category. + To set a different level as the base category for a categorical covariate, use "C(, Treatment(''))" reference_cell_type: Column name that sets the reference cell type. - Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen. Defaults to "automatic". - automatic_reference_absence_threshold: If using reference_cell_type = "automatic", determine the maximum fraction of zero entries for a cell type - to be considered as a possible reference cell type. Defaults to 0.05. + If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen. + Defaults to "automatic". + automatic_reference_absence_threshold: If using reference_cell_type = "automatic", + determine the maximum fraction of zero entries for a cell type + to be considered as a possible reference cell type. Defaults to 0.05. tree_key: Key in `adata.uns` that contains the tree structure pen_args: Dictionary with penalty arguments. With `reg="scaled_3"`, the parameters phi (aggregation bias), lambda_1, lambda_0 can be set here. See the tascCODA paper for an explanation of these parameters. Default: lambda_0 = 50, lambda_1 = 5, phi = 0. @@ -199,6 +200,12 @@ def prepare( if tree_key is None: raise ValueError("Please specify the key in .uns that contains the tree structure!") + # Scoped import due to installation issues + try: + import ete3 as ete + except ImportError: + raise ImportError("To use tasccoda please install ete3 with pip install ete3") from None + # toytree tree - only for legacy reasons, can be removed in the final version if isinstance(adata.uns[tree_key], tt.tree): # Collapse singularities in the tree