diff --git a/pertpy/plot/_augur.py b/pertpy/plot/_augur.py index 3353c795..ebfe874c 100644 --- a/pertpy/plot/_augur.py +++ b/pertpy/plot/_augur.py @@ -26,6 +26,23 @@ def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure Returns: Axes of the plot. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.bhattacherjee() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + + >>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine") + >>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4) + >>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4) + + >>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine") + >>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4) + >>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4) + + >>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \ + permuted_results1=results_15_permute, permuted_results2=results_48_permute) + >>> pt.pl.ag.dp_scatter(pvals) """ x = results["mean_augur_score1"] y = results["mean_augur_score2"] @@ -69,6 +86,14 @@ def important_features( Returns: Axes of the plot. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) + >>> pt.pl.ag.important_features(v_results) """ if isinstance(data, AnnData): results = data.uns[key] @@ -115,6 +140,14 @@ def lollipop( Returns: Axes of the plot. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) + >>> pt.pl.ag.lollipop(v_results) """ if isinstance(data, AnnData): results = data.uns[key] @@ -157,6 +190,15 @@ def scatterplot( Returns: Axes of the plot. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4) + >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) + >>> pt.pl.ag.scatterplot(v_results, h_results) """ cell_types = results1["summary_metrics"].columns diff --git a/pertpy/plot/_coda.py b/pertpy/plot/_coda.py index 4b80682f..cda77f54 100644 --- a/pertpy/plot/_coda.py +++ b/pertpy/plot/_coda.py @@ -108,6 +108,15 @@ def stacked_barplot( # pragma: no cover Returns: A :class:`~matplotlib.axes.Axes` object + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> pt.pl.coda.stacked_barplot(mdata, feature_name="samples") """ if isinstance(data, MuData): data = data[modality_key] @@ -196,6 +205,17 @@ def effects_barplot( # pragma: no cover Returns: Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`) or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> pt.pl.coda.effects_barplot(mdata) """ if args_barplot is None: args_barplot = {} @@ -366,6 +386,15 @@ def boxplots( # pragma: no cover Returns: Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`) or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> pt.pl.coda.boxplots(mdata, feature_name="condition", add_dots=True) """ if args_boxplot is None: args_boxplot = {} @@ -570,6 +599,17 @@ def rel_abundance_dispersion_plot( # pragma: no cover Returns: A :class:`~matplotlib.axes.Axes` object + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> pt.pl.coda.rel_abundance_dispersion_plot(mdata) """ if isinstance(data, MuData): data = data[modality_key] @@ -677,6 +717,22 @@ def draw_tree( # pragma: no cover Returns: Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`) + + Examples: + Example with tascCODA: + >>> import pertpy as pt + >>> adata = pt.dt.smillie() + >>> tasccoda = pt.tl.Tasccoda() + >>> mdata = tasccoda.load( + >>> adata, type="sample_level", + >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"], + >>> key_added="lineage", add_level_name=True + >>> ) + >>> mdata = tasccoda.prepare( + >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} + >>> ) + >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) + >>> pt.pl.coda.draw_tree(mdata, tree="lineage") """ if isinstance(data, MuData): data = data[modality_key] @@ -741,6 +797,22 @@ def draw_effects( # pragma: no cover Returns: Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`) + + Examples: + Example with tascCODA: + >>> import pertpy as pt + >>> adata = pt.dt.smillie() + >>> tasccoda = pt.tl.Tasccoda() + >>> mdata = tasccoda.load( + >>> adata, type="sample_level", + >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"], + >>> key_added="lineage", add_level_name=True + >>> ) + >>> mdata = tasccoda.prepare( + >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} + >>> ) + >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) + >>> pt.pl.coda.draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage") """ if isinstance(data, MuData): data = data[modality_key] @@ -895,6 +967,19 @@ def effects_umap( # pragma: no cover Returns: If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + + >>> pt.pl.coda.effects_umap(mdata, effect_name="", cluster_key="") + #TODO: Add effect_name parameter and cluster_key and test the example """ data_rna = data[modality_key_1] data_coda = data[modality_key_2] diff --git a/pertpy/plot/_dialogue.py b/pertpy/plot/_dialogue.py index 51b337a8..f5106eeb 100644 --- a/pertpy/plot/_dialogue.py +++ b/pertpy/plot/_dialogue.py @@ -28,6 +28,16 @@ def split_violins( Returns: A :class:`~matplotlib.axes.Axes` object + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.dialogue_example() + >>> sc.pp.pca(adata) + >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \ + n_counts_key = "nCount_RNA", n_mpcs = 3) + >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True) + >>> pt.pl.dl.split_violins(adata, split_key='gender', celltype_key='cell.subtypes') """ df = sc.get.obs_df(adata, [celltype_key, mcp, split_key]) if split_which is None: @@ -56,6 +66,19 @@ def pairplot(self, adata: AnnData, celltype_key: str, color: str, sample_id: str Returns: Seaborn Pairgrid object. + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.dialogue_example() + >>> sc.pp.pca(adata) + >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \ + n_counts_key = "nCount_RNA", n_mpcs = 3) + >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True) + #>>> dl_pl=pt.pl.dl() + #>>> dl_pl.pairplot(adata=adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status") + >>> pt.pl.dl.pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status") + #TODO: Is self parameter there on purpose -> create DialoguePlot object first? """ mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean() mean_mcps = mean_mcps.reset_index() diff --git a/pertpy/plot/_guide_rna.py b/pertpy/plot/_guide_rna.py index 2d46c95e..de030dd5 100644 --- a/pertpy/plot/_guide_rna.py +++ b/pertpy/plot/_guide_rna.py @@ -41,6 +41,17 @@ def heatmap( Returns: List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap. Order of cells in the y axis will be saved on adata.obs[key_to_save_order] if provided. + + Examples: + Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then + visualized using a heatmap. + + >>> import pertpy as pt + >>> mdata = pt.data.papalexi_2021() + >>> gdo = mdata.mod['gdo'] + >>> ga = pt.pp.GuideAssignment() + >>> ga.assign_by_threshold(gdo, assignment_threshold=5) + >>> pt.pl.guide.heatmap(gdo) """ data = adata.X if layer is None else adata.layers[layer] diff --git a/pertpy/plot/_milopy.py b/pertpy/plot/_milopy.py index a7cedee3..b5198307 100644 --- a/pertpy/plot/_milopy.py +++ b/pertpy/plot/_milopy.py @@ -42,6 +42,20 @@ def nhood_graph( save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. **kwargs: Additional arguments to `scanpy.pl.embedding`. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> sc.tl.umap(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.da_nhoods(mdata, design="~label") + >>> milo.build_nhood_graph(mdata) + >>> pt.pl.milo.nhood_graph(mdata) + # TODO: If necessary adjust after fixing StopIteration error, which is currently thrown """ nhood_adata = mdata["milo"].T.copy() @@ -101,6 +115,17 @@ def nhood( show: Show the plot, do not return axis. save: If True or a str, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}. **kwargs: Additional arguments to `scanpy.pl.embedding`. + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> sc.tl.umap(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> pt.pl.milo.nhood(mdata, ix=0) """ mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel() @@ -126,6 +151,19 @@ def da_beeswarm( subset_nhoods: List of nhoods to plot. If None, plot all nhoods. (default: None) palette: Name of Seaborn color palette for violinplots. Defaults to pre-defined category colors for violinplots. + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.da_nhoods(mdata, design="~label") + >>> milo.annotate_nhoods(mdata, anno_col='cell_type') + >>> pt.pl.milo.da_beeswarm(mdata) """ try: nhood_adata = mdata["milo"].T.copy() diff --git a/pertpy/plot/_mixscape.py b/pertpy/plot/_mixscape.py index bcae4adb..4663ff62 100644 --- a/pertpy/plot/_mixscape.py +++ b/pertpy/plot/_mixscape.py @@ -69,6 +69,14 @@ def barplot( # pragma: no cover Returns: If show is False, return ggplot object used to draw the plot. + + Examples: + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> mixscape_identifier = pt.tl.Mixscape() + >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT') """ if mixscape_class_global not in adata.obs: raise ValueError("Please run `pt.tl.mixscape` first.") @@ -148,6 +156,14 @@ def heatmap( # pragma: no cover save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. ax: A matplotlib axes object. Only works if plotting a single component. **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`. + + Examples: + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> mixscape_identifier = pt.tl.Mixscape() + >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> pt.pl.ms.heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT') """ if "mixscape_class" not in adata.obs: raise ValueError("Please run `pt.tl.mixscape` first.") @@ -195,6 +211,16 @@ def perturbscore( # pragma: no cover Returns: The ggplot object used for drawn. + + Examples: + Visualizing the perturbation scores for the cells in a dataset: + + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> mixscape_identifier = pt.tl.Mixscape() + >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange') """ if "mixscape" not in adata.uns: raise ValueError("Please run `pt.tl.mixscape` first.") @@ -361,6 +387,14 @@ def violin( # pragma: no cover Returns: A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`. + + Examples: + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> mixscape_identifier = pt.tl.Mixscape() + >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> pt.pl.ms.violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class') """ if isinstance(target_gene_idents, str): mixscape_class_mask = adata.obs[groupby] == target_gene_idents @@ -532,6 +566,15 @@ def lda( # pragma: no cover show: Show the plot, do not return axis. save: If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}. **kwds: Additional arguments to `scanpy.pl.umap`. + + Examples: + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> mixscape_identifier = pt.tl.Mixscape() + >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert') + >>> pt.pl.ms.lda(adata=mdata['rna'], control='NT') """ if mixscape_class not in adata.obs: raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.') diff --git a/pertpy/plot/_scgen.py b/pertpy/plot/_scgen.py index 013d99f6..18f7a8c2 100644 --- a/pertpy/plot/_scgen.py +++ b/pertpy/plot/_scgen.py @@ -44,6 +44,18 @@ def reg_mean_plot( gene_list: list of gene names to be plotted. show: if `True`: will show to the plot after saving it. **kwargs: + + Examples: + >>> import pertpy at pt + >>> data = pt.dt.kang_2018() + >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") + >>> model = pt.tl.SCGEN(data) + >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) + >>> pred, delta = model.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells') + >>> pred.obs['label'] = 'pred' + >>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred) + >>> r2_value = pt.pl.scg.reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \ + labels={"x": "predicted", "y": "ground truth"}, save=False, show=True) """ import seaborn as sns diff --git a/pertpy/preprocessing/_guide_rna.py b/pertpy/preprocessing/_guide_rna.py index 0f092ab6..07c94ee9 100644 --- a/pertpy/preprocessing/_guide_rna.py +++ b/pertpy/preprocessing/_guide_rna.py @@ -33,6 +33,15 @@ def assign_by_threshold( output_layer: Assigned guide will be saved on adata.layers[output_key]. Defaults to `assigned_guides`. only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray. Defaults to False. + + Examples: + Each cell is assigned to gRNA that occurs at least 5 times in the respective cell. + + >>> import pertpy as pt + >>> mdata = pt.data.papalexi_2021() + >>> gdo = mdata.mod['gdo'] + >>> ga = pt.pp.GuideAssignment() + >>> ga.assign_by_threshold(gdo, assignment_threshold=5) """ counts = adata.X if layer is None else adata.layers[layer] if scipy.sparse.issparse(counts): @@ -69,6 +78,15 @@ def assign_to_max_guide( output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`. no_grna_assigned_key: The key to return if no gRNA is expressed enough. only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray. + + Examples: + Each cell is assigned to the most expressed gRNA if it has at least 5 counts. + + >>> import pertpy as pt + >>> mdata = pt.data.papalexi_2021() + >>> gdo = mdata.mod['gdo'] + >>> ga = pt.pp.GuideAssignment() + >>> ga.assign_to_max_guide(gdo, assignment_threshold=5) """ counts = adata.X if layer is None else adata.layers[layer] if scipy.sparse.issparse(counts): diff --git a/pertpy/tools/_augur.py b/pertpy/tools/_augur.py index 73ba5037..12d647fe 100644 --- a/pertpy/tools/_augur.py +++ b/pertpy/tools/_augur.py @@ -104,6 +104,12 @@ def load( Returns: Anndata object containing gene expression values (cells in rows, genes in columns) and cell type, label and y dummy variables as obs + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) """ if isinstance(input, AnnData): input.obs = input.obs.rename(columns={cell_type_col: "cell_type", label_col: "label"}) @@ -161,6 +167,11 @@ def create_estimator( Returns: Estimator object. + + Examples: + >>> import pertpy as pt + >>> augur = pt.tl.Augur("random_forest_classifier") + >>> estimator = augur.create_estimator("logistic_regression_classifier") """ if params is None: params = Params() @@ -195,6 +206,15 @@ def sample(self, adata: AnnData, categorical: bool, subsample_size: int, random_ Returns: Subsample of AnnData object of size subsample_size with given features + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> ag_rfc.select_highly_variable(loaded_data) + >>> features = loaded_data.var_names + >>> subsample = ag_rfc.sample(loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names) """ # export subsampling. random.seed(random_state) @@ -244,11 +264,22 @@ def draw_subsample( while setting augur_mode = "permute" will generate a null distribution of AUCs for each cell type by permuting the labels subsample_size: number of cells to subsample randomly per type from each experimental condition - categorical_data: `True` if target values are categorical + feature_perc: proportion of genes that are randomly selected as features for input to the classifier in each + subsample using the random gene filter + categorical: `True` if target values are categorical random_state: set numpy random seed and sampling seed Returns: Subsample of anndata object of size subsample_size + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> ag_rfc.select_highly_variable(loaded_data) + >>> subsample = ag_rfc.draw_subsample(adata, augur_mode="default", subsample_size=20, feature_perc=0.5, \ + categorical=True, random_state=42) """ random.seed(random_state) if augur_mode == "permute": @@ -296,6 +327,8 @@ def cross_validate_subsample( permuting the labels subsample_size: number of cells to subsample randomly per type from each experimental condition folds: number of folds to run cross validation on + feature_perc: proportion of genes that are randomly selected as features for input to the classifier in each + subsample using the random gene filter subsample_idx: index of the subsample random_state: set numpy random seed, sampling seed and fold seed zero_division: 0 or 1 or `warn`; Sets the value to return when there is a zero division. If @@ -303,6 +336,15 @@ def cross_validate_subsample( Returns: Results for each cross validation fold. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> ag_rfc.select_highly_variable(loaded_data) + >>> results = ag_rfc.cross_validate_subsample(loaded_data, augur_mode="default", subsample_size=20, \ + folds=3, feature_perc=0.5, subsample_idx=0, random_state=42, zero_division=0) """ subsample = self.draw_subsample( adata, @@ -359,6 +401,11 @@ def set_scorer( Returns: Dict linking name to scorer object and string name + + Examples: + >>> import pertpy as pt + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> scorer = ag_rfc.set_scorer(True, 0) """ if multiclass: return { @@ -409,6 +456,16 @@ def run_cross_validation( Returns: Dictionary containing prediction metrics and estimator for each fold. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> ag_rfc.select_highly_variable(loaded_data) + >>> subsample = ag_rfc.draw_subsample(adata, augur_mode="default", subsample_size=20, feature_perc=0.5, \ + categorical=True, random_state=42) + >>> results = ag_rfc.run_cross_validation(subsample=subsample, folds=3, subsample_idx=0, random_state=42, zero_division=0) """ x = subsample.to_df() y = subsample.obs["y_"] @@ -477,6 +534,13 @@ def select_highly_variable(self, adata: AnnData) -> AnnData: Results: Anndata object with highly variable genes added as layer + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> ag_rfc.select_highly_variable(loaded_data) """ min_features_for_selection = 1000 @@ -540,6 +604,13 @@ def select_variance(self, adata: AnnData, var_quantile: float, filter_negative_r Return: AnnData object with additional select_variance column in var. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> ag_rfc.select_variance(loaded_data, var_quantile=0.5, filter_negative_residuals=False, span=0.75) """ adata.var["highly_variable"] = False adata.var["means"] = np.ravel(adata.X.mean(axis=0)) @@ -644,6 +715,13 @@ def predict( * feature_importances: Pandas Dataframe containing feature importances of genes across all cross validation runs * full_results: Dict containing merged results of individual cross validation runs for each cell type * [cell_types]: Cross validation runs of the cell type called + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.sc_sim_augur() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + >>> loaded_data = ag_rfc.load(adata) + >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4) """ if augur_mode == "permute" and n_subsamples < 100: n_subsamples = 500 @@ -767,6 +845,22 @@ def predict_differential_prioritization( Returns: Results object containing mean augur scores. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.bhattacherjee() + >>> ag_rfc = pt.tl.Augur("random_forest_classifier") + + >>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine") + >>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4) + >>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4) + + >>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine") + >>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4) + >>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4) + + >>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \ + permuted_results1=results_15_permute, permuted_results2=results_48_permute) """ # compare available cell types cell_types = ( diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index 477635dc..efc528d7 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -333,6 +333,16 @@ def run_hmc( num_warmup: Number of burn-in (warmup) samples. Defaults to 5000. rng_key: The rng state used. If None, a random state will be selected. Defaults to None. copy: Return a copy instead of writing to adata. Defaults to False. + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_hmc(mdata, num_warmup=100, num_samples=1000) """ if isinstance(data, MuData): try: @@ -410,6 +420,17 @@ def summary_prepare( - SD: Standard deviation of MCMC samples - Delta: Decision boundary value - threshold of practical significance - Is credible: Boolean indicator whether effect is credible + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> intercept_df, effect_df = sccoda.summary_prepare(mdata["coda"]) """ # Get model and effect selection types select_type = sample_adata.uns["scCODA_params"]["select_type"] @@ -750,6 +771,17 @@ def summary(self, data: AnnData | MuData, extended: bool = False, modality_key: modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda". args: Passed to az.summary kwargs: Passed to az.summary + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> sccoda.summary(mdata) """ if isinstance(data, MuData): try: @@ -884,6 +916,17 @@ def get_intercept_df(self, data: AnnData | MuData, modality_key: str = "coda"): Returns: pd.DataFrame: Intercept data frame. + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> intercepts = sccoda.get_intercept_df(mdata) """ if isinstance(data, MuData): @@ -906,6 +949,17 @@ def get_effect_df(self, data: AnnData | MuData, modality_key: str = "coda"): Returns: pd.DataFrame: Effect data frame. + + Examples: + Example with scCODA: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> effects = sccoda.get_effect_df(mdata) """ if isinstance(data, MuData): @@ -939,6 +993,22 @@ def get_node_df(self, data: AnnData | MuData, modality_key: str = "coda"): Returns: pd.DataFrame: Node effect data frame. + + Examples: + Example with tascCODA (works only for model of type tree_agg, i.e. a tascCODA model): + >>> import pertpy as pt + >>> adata = pt.dt.smillie() + >>> tasccoda = pt.tl.Tasccoda() + >>> mdata = tasccoda.load( + >>> adata, type="sample_level", + >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"], + >>> key_added="lineage", add_level_name=True + >>> ) + >>> mdata = tasccoda.prepare( + >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} + >>> ) + >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) + >>> node_effects = tasccoda.get_node_df(mdata) """ if isinstance(data, MuData): diff --git a/pertpy/tools/_coda/_sccoda.py b/pertpy/tools/_coda/_sccoda.py index 9e15af93..ba27bc74 100644 --- a/pertpy/tools/_coda/_sccoda.py +++ b/pertpy/tools/_coda/_sccoda.py @@ -88,16 +88,10 @@ def load( Examples: >>> import pertpy as pt - >>> adata = pt.dt.haber_2017_regions() + >>> haber_cells = pt.dt.haber_2017_regions() >>> sccoda = pt.tl.Sccoda() - >>> mdata = sccoda.load( - >>> adata, - >>> type="cell_level", - >>> generate_sample_level=True, - >>> cell_type_identifier="cell_label", - >>> sample_identifier="batch", - >>> covariate_obs=["condition"], - >>> ) + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) """ if type == "cell_level": if generate_sample_level: @@ -148,17 +142,10 @@ def prepare( Examples: >>> import pertpy as pt - - >>> adata = pt.dt.haber_2017_regions() + >>> haber_cells = pt.dt.haber_2017_regions() >>> sccoda = pt.tl.Sccoda() - >>> mdata = sccoda.load( - >>> adata, - >>> type="cell_level", - >>> generate_sample_level=True, - >>> cell_type_identifier="cell_label", - >>> sample_identifier="batch", - >>> covariate_obs=["condition"], - >>> ) + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") """ if isinstance(data, MuData): @@ -201,6 +188,15 @@ def set_init_mcmc_states(self, rng_key: None, ref_index: np.ndarray, sample_adat Returns: Return AnnData object. + + Examples: + >>> import pertpy as pt + >>> haber_cells = pt.dt.haber_2017_regions() + >>> sccoda = pt.tl.Sccoda() + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> adata = sccoda.set_init_mcmc_states(rng_key=42, ref_index=0, sample_adata=mdata['coda']) """ # data dimensions N, D = sample_adata.obsm["covariate_matrix"].shape @@ -315,22 +311,13 @@ def make_arviz( # type: ignore Examples: >>> import pertpy as pt - >>> adata = pt.dt.haber_2017_regions() + >>> haber_cells = pt.dt.haber_2017_regions() >>> sccoda = pt.tl.Sccoda() - >>> adata_salm = adata[adata.obs["condition"].isin(["Control", "Salmonella"])] - >>> mdata_salm = sccoda.load( - >>> adata_salm, - >>> type="cell_level", - >>> generate_sample_level=True, - >>> cell_type_identifier="cell_label", - >>> sample_identifier="batch", - >>> covariate_obs=["condition"], - >>> ) - >>> mdata_salm = sccoda.prepare( - >>> mdata_salm, formula="condition", reference_cell_type="Goblet" - >>> ) - >>> sccoda.run_nuts(mdata_salm) - >>> sccoda.make_arviz(mdata_salm, num_prior_samples=100) + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> arviz_data = sccoda.make_arviz(mdata, num_prior_samples=100) """ if isinstance(data, MuData): try: @@ -425,18 +412,12 @@ def run_nuts( """ Examples: >>> import pertpy as pt - >>> adata = pt.dt.haber_2017_regions() + >>> haber_cells = pt.dt.haber_2017_regions() >>> sccoda = pt.tl.Sccoda() - >>> mdata = sccoda.load( - >>> adata, - >>> type="cell_level", - >>> generate_sample_level=True, - >>> cell_type_identifier="cell_label", - >>> sample_identifier="batch", - >>> covariate_obs=["condition"], - >>> ) + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") - >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000) + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) """ return super().run_nuts(data, modality_key, num_samples, num_warmup, rng_key, copy, *args, **kwargs) @@ -446,22 +427,13 @@ def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", e """ Examples: >>> import pertpy as pt - >>> adata = pt.dt.haber_2017_regions() + >>> haber_cells = pt.dt.haber_2017_regions() >>> sccoda = pt.tl.Sccoda() - >>> adata_salm = adata[adata.obs["condition"].isin(["Control", "Salmonella"])] - >>> mdata_salm = sccoda.load( - >>> adata_salm, - >>> type="cell_level", - >>> generate_sample_level=True, - >>> cell_type_identifier="cell_label", - >>> sample_identifier="batch", - >>> covariate_obs=["condition"], - >>> ) - >>> mdata_salm = sccoda.prepare( - >>> mdata_salm, formula="condition", reference_cell_type="Goblet" - >>> ) - >>> sccoda.run_nuts(mdata_salm) - >>> sccoda.credible_effects(mdata_salm) + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> credible_effects = sccoda.credible_effects(mdata) """ return super().credible_effects(data, modality_key, est_fdr) @@ -471,22 +443,13 @@ def summary(self, data: AnnData | MuData, extended: bool = False, modality_key: """ Examples: >>> import pertpy as pt - >>> adata = pt.dt.haber_2017_regions() + >>> haber_cells = pt.dt.haber_2017_regions() >>> sccoda = pt.tl.Sccoda() - >>> adata_salm = adata[adata.obs["condition"].isin(["Control", "Salmonella"])] - >>> mdata_salm = sccoda.load( - >>> adata_salm, - >>> type="cell_level", - >>> generate_sample_level=True, - >>> cell_type_identifier="cell_label", - >>> sample_identifier="batch", - >>> covariate_obs=["condition"], - >>> ) - >>> mdata_salm = sccoda.prepare( - >>> mdata_salm, formula="condition", reference_cell_type="Goblet" - >>> ) - >>> sccoda.run_nuts(mdata_salm) - >>> sccoda.summary(mdata_salm) + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> sccoda.summary(mdata) """ return super().summary(data, extended, modality_key, *args, **kwargs) @@ -496,24 +459,13 @@ def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "c """ Examples: >>> import pertpy as pt - - >>> adata = pt.dt.haber_2017_regions() + >>> haber_cells = pt.dt.haber_2017_regions() >>> sccoda = pt.tl.Sccoda() - >>> adata_salm = adata[adata.obs["condition"].isin(["Control", "Salmonella"])] - >>> mdata_salm = sccoda.load( - >>> adata_salm, - >>> type="cell_level", - >>> generate_sample_level=True, - >>> cell_type_identifier="cell_label", - >>> sample_identifier="batch", - >>> covariate_obs=["condition"], - >>> ) - >>> mdata_salm = sccoda.prepare( - >>> mdata_salm, formula="condition", reference_cell_type="Goblet" - >>> ) - >>> sccoda.run_nuts(mdata_salm) - >>> sccoda.set_fdr(mdata_salm, est_fdr=0.4) - >>> sccoda.summary(mdata_salm) + >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \ + sample_identifier="batch", covariate_obs=["condition"]) + >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine") + >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42) + >>> sccoda.set_fdr(mdata, est_fdr=0.4) """ return super().set_fdr(data, est_fdr, modality_key, *args, **kwargs) diff --git a/pertpy/tools/_coda/_tasccoda.py b/pertpy/tools/_coda/_tasccoda.py index f0fd30fa..41ef8fca 100644 --- a/pertpy/tools/_coda/_tasccoda.py +++ b/pertpy/tools/_coda/_tasccoda.py @@ -303,7 +303,7 @@ def prepare( def set_init_mcmc_states(self, rng_key: None, ref_index: np.ndarray, sample_adata: AnnData) -> AnnData: # type: ignore """ - Sets initial MCMC state values for scCODA model + Sets initial MCMC state values for tascCODA model Args: rng_key: RNG value to be set @@ -312,6 +312,20 @@ def set_init_mcmc_states(self, rng_key: None, ref_index: np.ndarray, sample_adat Returns: Return AnnData + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.smillie() + >>> tasccoda = pt.tl.Tasccoda() + >>> mdata = tasccoda.load( + >>> adata, type="sample_level", + >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"], + >>> key_added="lineage", add_level_name=True + >>> ) + >>> mdata = tasccoda.prepare( + >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} + >>> ) + >>> adata = tasccoda.set_init_mcmc_states(rng_key=42, ref_index=[0,1], sample_adata=mdata['coda']) """ N, D = sample_adata.obsm["covariate_matrix"].shape P = sample_adata.X.shape[1] @@ -476,8 +490,8 @@ def make_arviz( # type: ignore >>> mdata = tasccoda.prepare( >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} >>> ) - >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100) - >>> tasccoda.make_arviz(mdata) + >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) + >>> arviz_data = tasccoda.make_arviz(mdata) """ if isinstance(data, MuData): try: @@ -586,7 +600,7 @@ def run_nuts( >>> mdata = tasccoda.prepare( >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} >>> ) - >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100) + >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) """ return super().run_nuts(data, modality_key, num_samples, num_warmup, rng_key, copy, *args, **kwargs) @@ -606,7 +620,7 @@ def summary(self, data: AnnData | MuData, extended: bool = False, modality_key: >>> mdata = tasccoda.prepare( >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} >>> ) - >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100) + >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) >>> tasccoda.summary(mdata) """ return super().summary(data, extended, modality_key, *args, **kwargs) @@ -627,7 +641,7 @@ def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", e >>> mdata = tasccoda.prepare( >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} >>> ) - >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100) + >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) >>> tasccoda.credible_effects(mdata) """ return super().credible_effects(data, modality_key, est_fdr) @@ -648,8 +662,9 @@ def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "c >>> mdata = tasccoda.prepare( >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} >>> ) - >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100) - >>> tasccoda.set_fdr(mdata_salm, est_fdr=0.4) + >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42) + >>> tasccoda.set_fdr(mdata, est_fdr=0.4) + #TODO: Not working (throws error, because too many values to unpack -> Correct set_fdr first) """ return super().set_fdr(data, est_fdr, modality_key, *args, **kwargs) diff --git a/pertpy/tools/_dialogue.py b/pertpy/tools/_dialogue.py index ad20aef1..55c03b28 100644 --- a/pertpy/tools/_dialogue.py +++ b/pertpy/tools/_dialogue.py @@ -574,6 +574,16 @@ def load( Returns: A celltype_label:array dictionary. + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.dialogue_example() + >>> sc.pp.pca(adata) + >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \ + n_counts_key = "nCount_RNA", n_mpcs = 3) + >>> cell_types = adata.obs[dl.celltype_key].astype("category").cat.categories + >>> mcca_in, ct_subs = dl.load(adata, ct_order=cell_types) """ ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order} fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks @@ -615,6 +625,15 @@ def calculate_multifactor_PMD( Returns: MCP scores # TODO this requires more detail + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.dialogue_example() + >>> sc.pp.pca(adata) + >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \ + n_counts_key = "nCount_RNA", n_mpcs = 3) + >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True) """ # IMPORTANT NOTE: the order in which matrices are passed to multicca matters. As such, # it is important here that to obtain the same result as in R, we pass the matrices in @@ -681,6 +700,17 @@ def multilevel_modeling( - for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 - merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps TODO: Describe both returns + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.dialogue_example() + >>> sc.pp.pca(adata) + >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \ + n_counts_key = "nCount_RNA", n_mpcs = 3) + >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True) + >>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \ + confounder="gender") """ # all possible pairs of cell types with out pairing same cell type cell_types = list(ct_subs.keys()) @@ -825,10 +855,20 @@ def test_association( Args: adata: AnnData object with MCPs in obs condition_label: Column name in adata.obs with condition labels. Must be categorical. - conditions_compare: Tuple of length 2 with the two conditions to compare, must be in in adata.obs[condition_label] + conditions_compare: Tuple of length 2 with the two conditions to compare, must be in adata.obs[condition_label] Returns: Dict of data frames with pvals, tstats, and pvals_adj for each MCP + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.dialogue_example() + >>> sc.pp.pca(adata) + >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \ + n_counts_key = "nCount_RNA", n_mpcs = 3) + >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True) + >>> stats = dl.test_association(adata, condition_label="pathology") """ celltype_label = self.celltype_key sample_label = self.sample_id @@ -887,6 +927,18 @@ def get_mlm_mcp_genes( Returns: Dict with keys 'up_genes' and 'down_genes' and values of lists of genes + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.dialogue_example() + >>> sc.pp.pca(adata) + >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \ + n_counts_key = "nCount_RNA", n_mpcs = 3) + >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True) + >>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \ + confounder="gender") + >>> mcp_genes = dl.get_mlm_mcp_genes(celltype='Macrophages', results=all_results) """ # Convert "mcp_x" to "MCPx" format # REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD @@ -999,6 +1051,16 @@ def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1): Nested dictionary where keys of the first level are MCPs (of the form "mcp_0" etc) and the second level keys are cell types. The values are dataframes containing the results of the rank_genes_groups analysis. + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.dialogue_example() + >>> sc.pp.pca(adata) + >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \ + n_counts_key = "nCount_RNA", n_mpcs = 3) + >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True) + >>> extrema_mcp_genes = dl.get_extrema_MCP_genes(ct_subs) """ rank_dfs: dict[str, dict[Any, Any]] = {} _, ct_sub = next(iter(ct_subs.items())) diff --git a/pertpy/tools/_distances/_distance_tests.py b/pertpy/tools/_distances/_distance_tests.py index e85337a4..7f818f46 100644 --- a/pertpy/tools/_distances/_distance_tests.py +++ b/pertpy/tools/_distances/_distance_tests.py @@ -36,9 +36,9 @@ class DistanceTest: Examples: >>> import pertpy as pt - >>> adata = pt.dt.distance_example() - >>> etest = pt.tl.DistanceTest('edistance', n_perms=1000) - >>> tab = etest(adata, groupby='perturbation', contrast='control') + >>> adata = pt.dt.distance_example_data() + >>> distance_test = pt.tl.DistanceTest('edistance', n_perms=1000) + >>> tab = distance_test(adata, groupby='perturbation', contrast='control') """ def __init__( @@ -99,9 +99,9 @@ def __call__( Examples: >>> import pertpy as pt - >>> adata = pt.dt.distance_example() - >>> etest = pt.tl.DistanceTest('edistance', n_perms=1000) - >>> tab = etest(adata, groupby='perturbation', contrast='control') + >>> adata = pt.dt.distance_example_data() + >>> distance_test = pt.tl.DistanceTest('edistance', n_perms=1000) + >>> tab = distance_test(adata, groupby='perturbation', contrast='control') """ if self.distance.metric_fct.accepts_precomputed: # Much faster if the metric can be called on the precomputed @@ -130,6 +130,12 @@ def test_xy(self, adata: AnnData, groupby: str, contrast: str, show_progressbar: - significant: whether the group is significantly different from the contrast group - pvalue_adj: p-value after multiple testing correction - significant_adj: whether the group is significantly different from the contrast group after multiple testing correction + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example_data() + >>> distance_test = pt.tl.DistanceTest('edistance', n_perms=1000) + >>> test_results = distance_test.test_xy(adata, groupby='perturbation', contrast='control') """ groups = adata.obs[groupby].unique() if contrast not in groups: @@ -215,6 +221,12 @@ def test_precomputed(self, adata: AnnData, groupby: str, contrast: str, verbose: - significant: whether the group is significantly different from the contrast group - pvalue_adj: p-value after multiple testing correction - significant_adj: whether the group is significantly different from the contrast group after multiple testing correction + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example_data() + >>> distance_test = pt.tl.DistanceTest('edistance', n_perms=1000) + >>> test_results = distance_test.test_precomputed(adata, groupby='perturbation', contrast='control') """ if not self.distance.metric_fct.accepts_precomputed: raise ValueError(f"Metric {self.metric} does not accept precomputed distances.") diff --git a/pertpy/tools/_distances/_distances.py b/pertpy/tools/_distances/_distances.py index a3e9769a..82c35082 100644 --- a/pertpy/tools/_distances/_distances.py +++ b/pertpy/tools/_distances/_distances.py @@ -352,8 +352,13 @@ def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None: Args: adata: Annotated data matrix. - obs_key: Column name in adata.obs. n_jobs: Number of cores to use. Defaults to -1 (all). + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example() + >>> distance = pt.tools.Distance(metric="edistance") + >>> distance.precompute_distances(adata) """ if self.layer_key: cells = adata.layers[self.layer_key] diff --git a/pertpy/tools/_metadata/_cell_line.py b/pertpy/tools/_metadata/_cell_line.py index 186854f0..178f19b4 100644 --- a/pertpy/tools/_metadata/_cell_line.py +++ b/pertpy/tools/_metadata/_cell_line.py @@ -224,6 +224,13 @@ def annotate_cell_lines( Returns: Returns an AnnData object with cell line annotation. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.dialogue_example() + >>> adata.obs['cell_line_name'] = 'MCF7' + >>> pt_metadata = pt.tl.CellLineMetaData() + >>> adata_annotated = pt_metadata.annotate_cell_lines(adata=adata, reference_id='cell_line_name', query_id='cell_line_name', copy=True) """ if copy: adata = adata.copy() @@ -344,6 +351,14 @@ def annotate_bulk_rna_expression( Returns: Returns an AnnData object with bulk rna expression annotation. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.dialogue_example() + >>> adata.obs['cell_line_name'] = 'MCF7' + >>> pt_metadata = pt.tl.CellLineMetaData() + >>> adata_annotated = pt_metadata.annotate_cell_lines(adata=adata, reference_id='cell_line_name', query_id='cell_line_name', copy=True) + >>> pt_metadata.annotate_bulk_rna_expression(adata_annotated) """ if copy: adata = adata.copy() @@ -434,6 +449,14 @@ def annotate_protein_expression( Returns: Returns an AnnData object with protein expression annotation. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.dialogue_example() + >>> adata.obs['cell_line_name'] = 'MCF7' + >>> pt_metadata = pt.tl.CellLineMetaData() + >>> adata_annotated = pt_metadata.annotate_cell_lines(adata=adata, reference_id='cell_line_name', query_id='cell_line_name', copy=True) + >>> pt_metadata.annotate_protein_expression(adata_annotated) """ if copy: adata = adata.copy() @@ -514,6 +537,12 @@ def annotate_from_gdsc( Returns: Returns an AnnData object with drug response annotation. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.mcfarland_2020() + >>> pt_metadata = pt.tl.CellLineMetaData() + >>> pt_metadata.annotate_from_gdsc(adata, query_id='cell_line') """ if copy: adata = adata.copy() @@ -573,6 +602,11 @@ def lookup(self) -> LookUp: Returns: Returns a LookUp object specific for cell line annotation. + + Examples: + >>> import pertpy as pt + >>> pt_metadata = pt.tl.CellLineMetaData() + >>> lookup = pt_metadata.lookup() """ return LookUp( type="cell_line", diff --git a/pertpy/tools/_milo.py b/pertpy/tools/_milo.py index 7b060af8..a6c0e6d9 100644 --- a/pertpy/tools/_milo.py +++ b/pertpy/tools/_milo.py @@ -40,6 +40,12 @@ def load( feature_key: Key to store the cell-level AnnData object in the MuData object Returns: MuData: MuData object with original AnnData (default is `mudata[feature_key]`). + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) """ mdata = MuData({feature_key: input, "milo": AnnData()}) @@ -86,6 +92,15 @@ def make_nhoods( nhood_neighbors_key: `adata.uns["nhood_neighbors_key"]` KNN graph key, used for neighbourhood construction + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) """ if isinstance(data, MuData): adata = data[feature_key] @@ -185,6 +200,16 @@ def count_nhoods( - `mudata['milo'].var_names` are neighbourhoods - `mudata['milo'].X` is the matrix counting the number of cells from each sample in each neighbourhood + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") """ if isinstance(data, MuData): adata = data[feature_key] @@ -246,8 +271,19 @@ def da_nhoods( None, modifies `milo_mdata['milo']` in place, adding the results of the DA test to `.var`: - `logFC` stores the log fold change in cell abundance (coefficient from the GLM) - `PValue` stores the p-value for the QLF test before multiple testing correction - - `SpatialFDR` stores the the p-value adjusted for multiple testing to limit the false discovery rate, + - `SpatialFDR` stores the p-value adjusted for multiple testing to limit the false discovery rate, calculated with weighted Benjamini-Hochberg procedure + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.da_nhoods(mdata, design="~label") """ try: sample_adata = mdata["milo"] @@ -377,6 +413,17 @@ def annotate_nhoods( - `milo_mdata['milo'].var["nhood_annotation_frac"]` stores the fraciton of cells in the neighbourhood with the assigned label - `milo_mdata['milo'].varm['frac_annotation']`: stores the fraction of cells from each label in each nhood - `milo_mdata['milo'].uns["annotation_labels"]`: stores the column names for `milo_mdata['milo'].varm['frac_annotation']` + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.annotate_nhoods(mdata, anno_col='cell_type') """ try: sample_adata = mdata["milo"] @@ -417,6 +464,17 @@ def annotate_nhoods_continuous(self, mdata: MuData, anno_col: str, feature_key: Returns: None. Adds in place: - `milo_mdata['milo'].var["nhood_{anno_col}"]`: assigning a continuous value to each nhood + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.annotate_nhoods_continuous(mdata, anno_col='nUMI') """ if "milo" not in mdata.mod: raise ValueError( @@ -446,6 +504,17 @@ def add_covariate_to_nhoods_var(self, mdata: MuData, new_covariates: list[str], Returns: None, adds columns to `milo_mdata['milo']` in place + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.add_covariate_to_nhoods_var(mdata, new_covariates=["label"]) """ try: sample_adata = mdata["milo"] @@ -488,6 +557,18 @@ def build_nhood_graph(self, mdata: MuData, basis: str = "X_umap", feature_key: s Returns: - `milo_mdata['milo'].varp['nhood_connectivities']`: graph of overlap between neighbourhoods (i.e. no of shared cells) - `milo_mdata['milo'].var["Nhood_size"]`: number of cells in neighbourhoods + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> sc.tl.umap(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.build_nhood_graph(mdata) """ adata = mdata[feature_key] # # Add embedding positions @@ -513,6 +594,17 @@ def add_nhood_expression(self, mdata: MuData, layer: str | None = None, feature_ Returns: Updates adata in place to store the matrix of average expression in each neighbourhood in `milo_mdata['milo'].varm['expr']` + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.add_nhood_expression(mdata) """ try: sample_adata = mdata["milo"] diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 14303df7..b2afbab2 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -66,6 +66,14 @@ def perturbation_signature( Returns: If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`. Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`. + + Examples: + Calcutate perturbation signature for each cell in the dataset: + + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> mixscape_identifier = pt.tl.Mixscape() + >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') """ if copy: adata = adata.copy() @@ -180,6 +188,15 @@ def mixscape( mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`). Posterior probabilities used to determine if a cell is KO (default). Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP + + Examples: + Calcutate perturbation signature for each cell in the dataset: + + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> mixscape_identifier = pt.tl.Mixscape() + >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') """ if copy: adata = adata.copy() @@ -337,6 +354,16 @@ def lda( mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`). LDA result. + + Examples: + Use LDA dimensionality reduction to visualize the perturbation effects: + + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> mixscape_identifier = pt.tl.Mixscape() + >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert') """ if copy: adata = adata.copy() diff --git a/pertpy/tools/_perturbation_space/_clustering.py b/pertpy/tools/_perturbation_space/_clustering.py index fe799fc8..30fa2482 100644 --- a/pertpy/tools/_perturbation_space/_clustering.py +++ b/pertpy/tools/_perturbation_space/_clustering.py @@ -31,6 +31,15 @@ def evaluate_clustering( true_label_col: ground truth labels. cluster_col: cluster computed labels. metrics: Metrics to compute. Defaults to ['nmi', 'ari', 'asw']. + + Examples: + Example usage with KMeansSpace: + + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> kmeans = pt.tl.KMeansSpace() + >>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26) + >>> results = kmeans.evaluate_clustering(kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=['nmi']) """ if metrics is None: metrics = ["nmi", "ari", "asw"] diff --git a/pertpy/tools/_perturbation_space/_discriminator_classifier.py b/pertpy/tools/_perturbation_space/_discriminator_classifier.py index cedd70f5..6b7edaa3 100644 --- a/pertpy/tools/_perturbation_space/_discriminator_classifier.py +++ b/pertpy/tools/_perturbation_space/_discriminator_classifier.py @@ -55,6 +55,12 @@ def load( # type: ignore test_split_size: Default to 0.2. validation_split_size: Size of the validation split taking into account that is taking with respect to the resultant train split. Defaults to 0.25. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.papalexi_2021()['rna'] + >>> dcs = pt.tl.DiscriminatorClassifierSpace() + >>> dcs.load(adata, target_col="gene_target") """ if layer_key is not None and layer_key not in adata.obs.columns: raise ValueError(f"Layer key {layer_key} not found in adata. {layer_key}") @@ -121,6 +127,13 @@ def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int = max_epochs: max epochs for training. Default to 40 val_epochs_check: check in validation dataset each val_epochs_check epochs patience: patience before the early stopping flag is activated + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.papalexi_2021()['rna'] + >>> dcs = pt.tl.DiscriminatorClassifierSpace() + >>> dcs.load(adata, target_col="gene_target") + >>> dcs.train(max_epochs=5) """ self.trainer = pl.Trainer( min_epochs=1, @@ -143,6 +156,14 @@ def get_embeddings(self) -> AnnData: Returns: AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.papalexi_2021()['rna'] + >>> dcs = pt.tl.DiscriminatorClassifierSpace() + >>> dcs.load(adata, target_col="gene_target") + >>> dcs.train() + >>> embeddings = dcs.get_embeddings() """ with torch.no_grad(): self.model.eval() diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index ce10e878..930d4757 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -46,6 +46,13 @@ def compute_control_diff( # type: ignore new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control diff'. all_data: if True, do the computation in all data representations (X, all layers and all embeddings) copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object. + + Examples: + Example usage with PseudobulkSpace: + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> ps = pt.tl.PseudobulkSpace() + >>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key='NT') """ if reference_key not in adata.obs[target_col].unique(): raise ValueError( @@ -125,6 +132,14 @@ def add( perturbations: Perturbations to add. reference_key: perturbation source from which the perturbation summation starts. ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space. + + Examples: + Example usage with PseudobulkSpace: + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> ps = pt.tl.PseudobulkSpace() + >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target") + >>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key='NT') """ new_pert_name = "" for perturbation in perturbations: @@ -216,6 +231,14 @@ def subtract( perturbations: Perturbations to subtract, reference_key: Perturbation source from which the perturbation subtraction starts ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space. + + Examples: + Example usage with PseudobulkSpace: + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> ps = pt.tl.PseudobulkSpace() + >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target") + >>> new_perturbation = ps.add(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"]) """ new_pert_name = reference_key + "-" for perturbation in perturbations: diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index f8123537..68f9d2c8 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -26,6 +26,18 @@ def compute( target_col: .obs column that stores the label of the perturbation applied to each cell. layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise. + + Examples: + Compute the centroids of a UMAP embedding of the papalexi_2021 dataset: + + >>> import pertpy as pt + >>> import scanpy as sc + >>> mdata = pt.dt.papalexi_2021() + >>> sc.pp.pca(mdata["rna"]) + >>> sc.pp.neighbors(mdata['rna']) + >>> sc.tl.umap(mdata["rna"]) + >>> cs = pt.tl.CentroidSpace() + >>> cs_adata = cs.compute(mdata["rna"], target_col="gene_target") """ X = None @@ -94,6 +106,12 @@ def compute( target_col: .obs column that stores the label of the perturbation applied to each cell. layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise. + + Examples: + >>> import pertpy as pp + >>> mdata = pt.dt.papalexi_2021() + >>> ps = pt.tl.PseudobulkSpace() + >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target") """ if "groups_col" not in kwargs: kwargs["groups_col"] = "perturbations" @@ -144,6 +162,12 @@ def compute( # type: ignore copy: if True returns a new Anndata of same size with the new column; otherwise it updates the initial adata return_object: if True returns the clustering object **kwargs: Are passed to sklearn's KMeans. + + Examples: + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> kmeans = pt.tl.KMeansSpace() + >>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26) """ if copy: adata = adata.copy() @@ -200,6 +224,12 @@ def compute( # type: ignore cluster_key: name of the .obs column to store the cluster labels. Defaults to 'k-means' copy: if True returns a new Anndata of same size with the new column; otherwise it updates the initial adata return_object: if True returns the clustering object + + Examples: + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> dbscan = pt.tl.DBSCANSpace() + >>> dbscan_adata = dbscan.compute(mdata["rna"]) """ if copy: adata = adata.copy() diff --git a/pertpy/tools/_scgen/_jax_scgen.py b/pertpy/tools/_scgen/_jax_scgen.py index 55839ac7..ba114a97 100644 --- a/pertpy/tools/_scgen/_jax_scgen.py +++ b/pertpy/tools/_scgen/_jax_scgen.py @@ -74,6 +74,14 @@ def predict( `np nd-array` of predicted cells in primary space. delta: float Difference between stimulated and control cells in latent space + + Examples: + >>> import pertpy as pt + >>> data = pt.dt.kang_2018() + >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") + >>> model = pt.tl.SCGEN(data) + >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) + >>> pred, delta = model.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells') """ # use keys registered from `setup_anndata()` cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key @@ -142,7 +150,25 @@ def get_decoded_expression( indices: Sequence[int] | None = None, batch_size: int | None = None, ) -> Array: - """Get decoded expression.""" + """Get decoded expression. + + Args: + adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the + AnnData object used to initialize the model. + indices: Indices of cells in adata to use. If `None`, all cells are used. + batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + + Returns: + Decoded expression for each cell + + Examples: + >>> import pertpy as pt + >>> data = pt.dt.kang_2018() + >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") + >>> model = pt.tl.SCGEN(data) + >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) + >>> decoded_X = model.get_decoded_expression() + """ if self.is_trained_ is False: raise RuntimeError("Please train the model first.") @@ -168,6 +194,14 @@ def batch_removal(self, adata: AnnData | None = None) -> AnnData: corrected: `~anndata.AnnData` AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"]. A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute. + + Examples: + >>> import pertpy as pt + >>> data = pt.dt.kang_2018() + >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") + >>> model = pt.tl.SCGEN(data) + >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) + >>> corrected_adata = model.batch_removal() """ adata = self._validate_anndata(adata) latent_all = self.get_latent_representation(adata) @@ -264,6 +298,11 @@ def setup_anndata( %(param_batch_key)s %(param_labels_key)s + + Examples: + >>> import pertpy as pt + >>> data = pt.dt.kang_2018() + >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ @@ -300,6 +339,14 @@ def get_latent_representation( Returns: Low-dimensional representation for each cell + + Examples: + >>> import pertpy as pt + >>> data = pt.dt.kang_2018() + >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") + >>> model = pt.tl.SCGEN(data) + >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) + >>> latent_X = model.get_latent_representation() """ self._check_if_trained(warn=False)