Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation examples #391

Merged
merged 12 commits into from
Oct 16, 2023
42 changes: 42 additions & 0 deletions pertpy/plot/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure

Returns:
Axes of the plot.

Examples:
>>> import pertpy as pt
>>> adata = pt.dt.bhattacherjee()
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")

>>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine")
>>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4)
>>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)

>>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine")
>>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4)
>>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)

>>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
permuted_results1=results_15_permute, permuted_results2=results_48_permute)
>>> pt.pl.ag.dp_scatter(pvals)
"""
x = results["mean_augur_score1"]
y = results["mean_augur_score2"]
Expand Down Expand Up @@ -69,6 +86,14 @@ def important_features(

Returns:
Axes of the plot.

Examples:
>>> import pertpy as pt
>>> adata = pt.dt.sc_sim_augur()
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
>>> loaded_data = ag_rfc.load(adata)
>>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
>>> pt.pl.ag.important_features(v_results)
"""
if isinstance(data, AnnData):
results = data.uns[key]
Expand Down Expand Up @@ -115,6 +140,14 @@ def lollipop(

Returns:
Axes of the plot.

Examples:
>>> import pertpy as pt
>>> adata = pt.dt.sc_sim_augur()
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
>>> loaded_data = ag_rfc.load(adata)
>>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
>>> pt.pl.ag.lollipop(v_results)
"""
if isinstance(data, AnnData):
results = data.uns[key]
Expand Down Expand Up @@ -157,6 +190,15 @@ def scatterplot(

Returns:
Axes of the plot.

Examples:
>>> import pertpy as pt
>>> adata = pt.dt.sc_sim_augur()
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
>>> loaded_data = ag_rfc.load(adata)
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
>>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
>>> pt.pl.ag.scatterplot(v_results, h_results)
"""
cell_types = results1["summary_metrics"].columns

Expand Down
85 changes: 85 additions & 0 deletions pertpy/plot/_coda.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def stacked_barplot( # pragma: no cover

Returns:
A :class:`~matplotlib.axes.Axes` object

Examples:
Example with scCODA:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
sample_identifier="batch", covariate_obs=["condition"])
>>> pt.pl.coda.stacked_barplot(mdata, feature_name="samples")
"""
if isinstance(data, MuData):
data = data[modality_key]
Expand Down Expand Up @@ -196,6 +205,17 @@ def effects_barplot( # pragma: no cover
Returns:
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object

Examples:
Example with scCODA:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
sample_identifier="batch", covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
>>> pt.pl.coda.effects_barplot(mdata)
"""
if args_barplot is None:
args_barplot = {}
Expand Down Expand Up @@ -366,6 +386,15 @@ def boxplots( # pragma: no cover
Returns:
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object

Examples:
Example with scCODA:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
sample_identifier="batch", covariate_obs=["condition"])
>>> pt.pl.coda.boxplots(mdata, feature_name="condition", add_dots=True)
"""
if args_boxplot is None:
args_boxplot = {}
Expand Down Expand Up @@ -570,6 +599,17 @@ def rel_abundance_dispersion_plot( # pragma: no cover

Returns:
A :class:`~matplotlib.axes.Axes` object

Examples:
Example with scCODA:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
sample_identifier="batch", covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
>>> pt.pl.coda.rel_abundance_dispersion_plot(mdata)
"""
if isinstance(data, MuData):
data = data[modality_key]
Expand Down Expand Up @@ -677,6 +717,22 @@ def draw_tree( # pragma: no cover

Returns:
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)

Examples:
Example with tascCODA:
>>> import pertpy as pt
>>> adata = pt.dt.smillie()
>>> tasccoda = pt.tl.Tasccoda()
>>> mdata = tasccoda.load(
>>> adata, type="sample_level",
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
>>> key_added="lineage", add_level_name=True
>>> )
>>> mdata = tasccoda.prepare(
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
>>> )
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
>>> pt.pl.coda.draw_tree(mdata, tree="lineage")
"""
if isinstance(data, MuData):
data = data[modality_key]
Expand Down Expand Up @@ -741,6 +797,22 @@ def draw_effects( # pragma: no cover
Returns:
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
or plot the tree inline (`show = False`)

Examples:
Example with tascCODA:
>>> import pertpy as pt
>>> adata = pt.dt.smillie()
>>> tasccoda = pt.tl.Tasccoda()
>>> mdata = tasccoda.load(
>>> adata, type="sample_level",
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
>>> key_added="lineage", add_level_name=True
>>> )
>>> mdata = tasccoda.prepare(
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
>>> )
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
>>> pt.pl.coda.draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
"""
if isinstance(data, MuData):
data = data[modality_key]
Expand Down Expand Up @@ -895,6 +967,19 @@ def effects_umap( # pragma: no cover

Returns:
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.

Examples:
Example with scCODA:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
sample_identifier="batch", covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)

>>> pt.pl.coda.effects_umap(mdata, effect_name="", cluster_key="")
#TODO: Add effect_name parameter and cluster_key and test the example
"""
data_rna = data[modality_key_1]
data_coda = data[modality_key_2]
Expand Down
23 changes: 23 additions & 0 deletions pertpy/plot/_dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def split_violins(

Returns:
A :class:`~matplotlib.axes.Axes` object

Examples:
>>> import pertpy as pt
>>> import scanpy as sc
>>> adata = pt.dt.dialogue_example()
>>> sc.pp.pca(adata)
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
n_counts_key = "nCount_RNA", n_mpcs = 3)
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
>>> pt.pl.dl.split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
"""
df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
if split_which is None:
Expand Down Expand Up @@ -56,6 +66,19 @@ def pairplot(self, adata: AnnData, celltype_key: str, color: str, sample_id: str

Returns:
Seaborn Pairgrid object.

Examples:
>>> import pertpy as pt
>>> import scanpy as sc
>>> adata = pt.dt.dialogue_example()
>>> sc.pp.pca(adata)
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
n_counts_key = "nCount_RNA", n_mpcs = 3)
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
#>>> dl_pl=pt.pl.dl()
#>>> dl_pl.pairplot(adata=adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
>>> pt.pl.dl.pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
#TODO: Is self parameter there on purpose -> create DialoguePlot object first?
"""
mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
mean_mcps = mean_mcps.reset_index()
Expand Down
11 changes: 11 additions & 0 deletions pertpy/plot/_guide_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def heatmap(
Returns:
List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
Order of cells in the y axis will be saved on adata.obs[key_to_save_order] if provided.

Examples:
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
visualized using a heatmap.

>>> import pertpy as pt
>>> mdata = pt.data.papalexi_2021()
>>> gdo = mdata.mod['gdo']
>>> ga = pt.pp.GuideAssignment()
>>> ga.assign_by_threshold(gdo, assignment_threshold=5)
>>> pt.pl.guide.heatmap(gdo)
"""
data = adata.X if layer is None else adata.layers[layer]

Expand Down
38 changes: 38 additions & 0 deletions pertpy/plot/_milopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ def nhood_graph(
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
**kwargs: Additional arguments to `scanpy.pl.embedding`.

Examples:
>>> import pertpy as pt
>>> adata = pt.dt.bhattacherjee()
>>> milo = pt.tl.Milo()
>>> mdata = milo.load(adata)
>>> sc.pp.neighbors(mdata["rna"])
>>> sc.tl.umap(mdata["rna"])
>>> milo.make_nhoods(mdata["rna"])
>>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
>>> milo.da_nhoods(mdata, design="~label")
>>> milo.build_nhood_graph(mdata)
>>> pt.pl.milo.nhood_graph(mdata)
# TODO: If necessary adjust after fixing StopIteration error, which is currently thrown
"""
nhood_adata = mdata["milo"].T.copy()

Expand Down Expand Up @@ -101,6 +115,17 @@ def nhood(
show: Show the plot, do not return axis.
save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
**kwargs: Additional arguments to `scanpy.pl.embedding`.

Examples:
>>> import pertpy as pt
>>> import scanpy as sc
>>> adata = pt.dt.bhattacherjee()
>>> milo = pt.tl.Milo()
>>> mdata = milo.load(adata)
>>> sc.pp.neighbors(mdata["rna"])
>>> sc.tl.umap(mdata["rna"])
>>> milo.make_nhoods(mdata["rna"])
>>> pt.pl.milo.nhood(mdata, ix=0)
"""

mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel()
Expand All @@ -126,6 +151,19 @@ def da_beeswarm(
subset_nhoods: List of nhoods to plot. If None, plot all nhoods. (default: None)
palette: Name of Seaborn color palette for violinplots.
Defaults to pre-defined category colors for violinplots.

Examples:
>>> import pertpy as pt
>>> import scanpy as sc
>>> adata = pt.dt.bhattacherjee()
>>> milo = pt.tl.Milo()
>>> mdata = milo.load(adata)
>>> sc.pp.neighbors(mdata["rna"])
>>> milo.make_nhoods(mdata["rna"])
>>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident")
>>> milo.da_nhoods(mdata, design="~label")
>>> milo.annotate_nhoods(mdata, anno_col='cell_type')
>>> pt.pl.milo.da_beeswarm(mdata)
"""
try:
nhood_adata = mdata["milo"].T.copy()
Expand Down
43 changes: 43 additions & 0 deletions pertpy/plot/_mixscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def barplot( # pragma: no cover

Returns:
If show is False, return ggplot object used to draw the plot.

Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT')
"""
if mixscape_class_global not in adata.obs:
raise ValueError("Please run `pt.tl.mixscape` first.")
Expand Down Expand Up @@ -148,6 +156,14 @@ def heatmap( # pragma: no cover
save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
ax: A matplotlib axes object. Only works if plotting a single component.
**kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.

Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> pt.pl.ms.heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
"""
if "mixscape_class" not in adata.obs:
raise ValueError("Please run `pt.tl.mixscape` first.")
Expand Down Expand Up @@ -195,6 +211,16 @@ def perturbscore( # pragma: no cover

Returns:
The ggplot object used for drawn.

Examples:
Visualizing the perturbation scores for the cells in a dataset:

>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
"""
if "mixscape" not in adata.uns:
raise ValueError("Please run `pt.tl.mixscape` first.")
Expand Down Expand Up @@ -361,6 +387,14 @@ def violin( # pragma: no cover

Returns:
A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.

Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> pt.pl.ms.violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
"""
if isinstance(target_gene_idents, str):
mixscape_class_mask = adata.obs[groupby] == target_gene_idents
Expand Down Expand Up @@ -532,6 +566,15 @@ def lda( # pragma: no cover
show: Show the plot, do not return axis.
save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
**kwds: Additional arguments to `scanpy.pl.umap`.

Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
>>> pt.pl.ms.lda(adata=mdata['rna'], control='NT')
"""
if mixscape_class not in adata.obs:
raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.')
Expand Down
Loading