Skip to content

Commit

Permalink
PR Reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilly-May committed Sep 10, 2024
1 parent 5a2e14d commit 407df8c
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 35 deletions.
13 changes: 7 additions & 6 deletions pertpy/preprocessing/_guide_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

if TYPE_CHECKING:
from anndata import AnnData
from matplotlib.axes import Axes
from matplotlib.pyplot import Figure


class GuideAssignment:
Expand Down Expand Up @@ -113,13 +113,14 @@ def assign_to_max_guide(
def plot_heatmap(
self,
adata: AnnData,
*,
layer: str | None = None,
order_by: np.ndarray | str | None = None,
key_to_save_order: str = None,
show: bool = True,
return_fig: bool = False,
**kwargs,
) -> list[Axes]:
) -> Figure | None:
"""Heatmap plotting of guide RNA expression matrix.
Assuming guides have sparse expression, this function reorders cells
Expand All @@ -141,8 +142,8 @@ def plot_heatmap(
kwargs: Are passed to sc.pl.heatmap.
Returns:
If return_fig is True, returns a list of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
If `return_fig` is `True`, returns the figure, otherwise `None`.
Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided.
Examples:
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
Expand Down Expand Up @@ -179,7 +180,7 @@ def plot_heatmap(
adata.obs[key_to_save_order] = pd.Categorical(order)

try:
axis_group = sc.pl.heatmap(
fig = sc.pl.heatmap(
adata[order, :],
var_names=adata.var.index.tolist(),
groupby=temp_col_name,
Expand All @@ -196,5 +197,5 @@ def plot_heatmap(
if show:
plt.show()
if return_fig:
return axis_group
return fig
return None
6 changes: 5 additions & 1 deletion pertpy/tools/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ def predict_differential_prioritization(
def plot_dp_scatter(
self,
results: pd.DataFrame,
*,
top_n: int = None,
ax: Axes = None,
show: bool = True,
Expand Down Expand Up @@ -1050,6 +1051,7 @@ def plot_dp_scatter(
def plot_important_features(
self,
data: dict[str, Any],
*,
key: str = "augurpy_results",
top_n: int = 10,
ax: Axes = None,
Expand Down Expand Up @@ -1117,11 +1119,12 @@ def plot_important_features(
def plot_lollipop(
self,
data: dict[str, Any] | AnnData,
*,
key: str = "augurpy_results",
ax: Axes = None,
show: bool = True,
return_fig: bool = False,
) -> Axes | Figure | None:
) -> Figure | None:
"""Plot a lollipop plot of the mean augur values.
Args:
Expand Down Expand Up @@ -1180,6 +1183,7 @@ def plot_scatterplot(
self,
results1: dict[str, Any],
results2: dict[str, Any],
*,
top_n: int = None,
show: bool = True,
return_fig: bool = False,
Expand Down
9 changes: 5 additions & 4 deletions pertpy/tools/_cinemaot.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ def plot_vis_matching(
control: str,
de_label: str,
source_label: str,
*,
matching_rep: str = "ot",
resolution: float = 0.5,
normalize: str = "col",
Expand All @@ -677,6 +678,9 @@ def plot_vis_matching(
{common_plot_args}
**kwargs: Other parameters to input for seaborn.heatmap.
Returns:
If `return_fig` is `True`, returns the figure, otherwise `None`.
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.cinemaot_example()
Expand Down Expand Up @@ -716,10 +720,7 @@ def plot_vis_matching(
if show:
plt.show()
if return_fig:
if ax is not None:
return ax
else:
return g
return g
return None


Expand Down
27 changes: 16 additions & 11 deletions pertpy/tools/_coda/_base_coda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,7 @@ def plot_stacked_barplot( # pragma: no cover
self,
data: AnnData | MuData,
feature_name: str,
*,
modality_key: str = "coda",
palette: ListedColormap | None = cm.tab20,
show_legend: bool | None = True,
Expand All @@ -1200,7 +1201,7 @@ def plot_stacked_barplot( # pragma: no cover
dpi: int | None = 100,
show: bool = True,
return_fig: bool = False,
) -> plt.Axes | Figure | None:
) -> Figure | None:
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
Args:
Expand Down Expand Up @@ -1287,6 +1288,7 @@ def plot_stacked_barplot( # pragma: no cover
def plot_effects_barplot( # pragma: no cover
self,
data: AnnData | MuData,
*,
modality_key: str = "coda",
covariates: str | list | None = None,
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
Expand All @@ -1300,7 +1302,7 @@ def plot_effects_barplot( # pragma: no cover
dpi: int | None = 100,
show: bool = True,
return_fig: bool = False,
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
) -> Figure | None:
"""Barplot visualization for effects.
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
Expand All @@ -1322,8 +1324,7 @@ def plot_effects_barplot( # pragma: no cover
{common_plot_args}
Returns:
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
If `return_fig` is `True`, returns the figure, otherwise `None`.
Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -1476,6 +1477,7 @@ def plot_boxplots( # pragma: no cover
self,
data: AnnData | MuData,
feature_name: str,
*,
modality_key: str = "coda",
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
plot_facets: bool = False,
Expand All @@ -1490,7 +1492,7 @@ def plot_boxplots( # pragma: no cover
dpi: int | None = 100,
show: bool = True,
return_fig: bool = False,
) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None:
) -> Figure | None:
"""Grouped boxplot visualization.
The cell counts for each cell type are shown as a group of boxplots
Expand All @@ -1515,8 +1517,7 @@ def plot_boxplots( # pragma: no cover
{common_plot_args}
Returns:
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
If `return_fig` is `True`, returns the figure, otherwise `None`.
Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -1707,6 +1708,7 @@ def plot_boxplots( # pragma: no cover
def plot_rel_abundance_dispersion_plot( # pragma: no cover
self,
data: AnnData | MuData,
*,
modality_key: str = "coda",
abundant_threshold: float | None = 0.9,
default_color: str | None = "Grey",
Expand All @@ -1717,7 +1719,7 @@ def plot_rel_abundance_dispersion_plot( # pragma: no cover
ax: plt.Axes | None = None,
show: bool = True,
return_fig: bool = False,
) -> plt.Axes | plt.Figure | None:
) -> Figure | None:
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
Expand All @@ -1735,7 +1737,7 @@ def plot_rel_abundance_dispersion_plot( # pragma: no cover
{common_plot_args}
Returns:
A :class:`~matplotlib.axes.Axes` object
If `return_fig` is `True`, returns the figure, otherwise `None`.
Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -1829,6 +1831,7 @@ def label_point(x, y, val, ax):
def plot_draw_tree( # pragma: no cover
self,
data: AnnData | MuData,
*,
modality_key: str = "coda",
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
tight_text: bool | None = False,
Expand Down Expand Up @@ -1912,6 +1915,7 @@ def plot_draw_effects( # pragma: no cover
self,
data: AnnData | MuData,
covariate: str,
*,
modality_key: str = "coda",
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
show_legend: bool | None = None,
Expand Down Expand Up @@ -2106,6 +2110,7 @@ def plot_effects_umap( # pragma: no cover
mdata: MuData,
effect_name: str | list | None,
cluster_key: str,
*,
modality_key_1: str = "rna",
modality_key_2: str = "coda",
color_map: Colormap | str | None = None,
Expand All @@ -2114,7 +2119,7 @@ def plot_effects_umap( # pragma: no cover
show: bool = True,
return_fig: bool = False,
**kwargs,
) -> plt.Axes | plt.Figure | None:
) -> Figure | None:
"""Plot a UMAP visualization colored by effect strength.
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
Expand All @@ -2134,7 +2139,7 @@ def plot_effects_umap( # pragma: no cover
**kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
Returns:
If `return_fig==True` a :class:`~matplotlib.axes.Axes` or a list of it.
If `return_fig` is `True`, returns the figure, otherwise `None`.
Examples:
>>> import pertpy as pt
Expand Down
6 changes: 4 additions & 2 deletions pertpy/tools/_dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,11 +1067,12 @@ def plot_split_violins(
adata: AnnData,
split_key: str,
celltype_key: str,
*,
split_which: tuple[str, str] = None,
mcp: str = "mcp_0",
show: bool = True,
return_fig: bool = False,
) -> Axes | Figure | None:
) -> Figure | None:
"""Plots split violin plots for a given MCP and split variable.
Any cells with a value for split_key not in split_which are removed from the plot.
Expand Down Expand Up @@ -1122,10 +1123,11 @@ def plot_pairplot(
celltype_key: str,
color: str,
sample_id: str,
*,
mcp: str = "mcp_0",
show: bool = True,
return_fig: bool = False,
) -> PairGrid | Figure | None:
) -> Figure | None:
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
Expand Down
4 changes: 2 additions & 2 deletions pertpy/tools/_differential_gene_expression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def plot_paired(
show_legend: bool = True,
size: int = 10,
y_label: str = "expression",
pvalue_template=lambda x: f"unadj. p={x:.2e}, t-test",
pvalue_template=lambda x: f"p={x:.2e}",
boxplot_properties=None,
palette=None,
show: bool = True,
Expand Down Expand Up @@ -594,7 +594,7 @@ def plot_paired(
raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing")

if var_names is None:
var_names = results_df.sort_values(pvalue_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
var_names = results_df.head(n_top_vars)[symbol_col].tolist()

adata = adata[:, var_names]

Expand Down
2 changes: 2 additions & 0 deletions pertpy/tools/_enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def gsea(
def plot_dotplot(
self,
adata: AnnData,
*,
targets: dict[str, dict[str, list[str]]] = None,
source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
category_name: str = "interaction_type",
Expand Down Expand Up @@ -426,6 +427,7 @@ def plot_gsea(
self,
adata: AnnData,
enrichment: dict[str, pd.DataFrame],
*,
n: int = 10,
key: str = "pertpy_enrichment_gsea",
interactive_plot: bool = False,
Expand Down
14 changes: 12 additions & 2 deletions pertpy/tools/_milo.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ def _graph_spatial_fdr(
def plot_nhood_graph(
self,
mdata: MuData,
*,
alpha: float = 0.1,
min_logFC: float = 0,
min_size: int = 10,
Expand Down Expand Up @@ -813,6 +814,7 @@ def plot_nhood(
self,
mdata: MuData,
ix: int,
*,
feature_key: str | None = "rna",
basis: str = "X_umap",
color_map: Colormap | str | None = None,
Expand Down Expand Up @@ -874,14 +876,15 @@ def plot_nhood(
def plot_da_beeswarm(
self,
mdata: MuData,
*,
feature_key: str | None = "rna",
anno_col: str = "nhood_annotation",
alpha: float = 0.1,
subset_nhoods: list[str] = None,
palette: str | Sequence[str] | dict[str, str] | None = None,
show: bool = True,
return_fig: bool = False,
) -> Figure | Axes | None:
) -> Figure | None:
"""Plot beeswarm plot of logFC against nhood labels
Args:
Expand All @@ -894,6 +897,9 @@ def plot_da_beeswarm(
Defaults to pre-defined category colors for violinplots.
{common_plot_args}
Returns:
If `return_fig` is `True`, returns the figure, otherwise `None`.
Examples:
>>> import pertpy as pt
>>> import scanpy as sc
Expand Down Expand Up @@ -999,11 +1005,12 @@ def plot_nhood_counts_by_cond(
self,
mdata: MuData,
test_var: str,
*,
subset_nhoods: list[str] = None,
log_counts: bool = False,
show: bool = True,
return_fig: bool = False,
) -> Figure | Axes | None:
) -> Figure | None:
"""Plot boxplot of cell numbers vs condition of interest.
Args:
Expand All @@ -1012,6 +1019,9 @@ def plot_nhood_counts_by_cond(
subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods.
log_counts: Whether to plot log1p of cell counts.
{common_plot_args}
Returns:
If `return_fig` is `True`, returns the figure, otherwise `None`.
"""
try:
nhood_adata = mdata["milo"].T.copy()
Expand Down
Loading

0 comments on commit 407df8c

Please sign in to comment.