Skip to content

Commit

Permalink
Use p-values from results_df in plot_paired
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilly-May committed Sep 9, 2024
1 parent 0ddfe79 commit 5a2e14d
Showing 1 changed file with 44 additions and 86 deletions.
130 changes: 44 additions & 86 deletions pertpy/tools/_differential_gene_expression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,20 +515,21 @@ def _map_genes_categories_highlight(
def plot_paired(
self,
adata: ad.AnnData,
var_names: Sequence[str],
results_df: pd.DataFrame,
groupby: str,
pairedby: str,
*,
pairedby: str = None,
hue: str = None,
var_names: Sequence[str] = None,
n_top_vars: int = 15,
layer: str = None,
pvalue_col: str = "adj_p_value",
symbol_col: str = "variable",
n_cols: int = 4,
panel_size: tuple[int, int] = (5, 5),
show_legend: bool = True,
size: int = 10,
y_label: str = "expression",
pvalues: Sequence[float] = None,
pvalue_template=lambda x: f"unadj. p={x:.2f}, t-test",
adjust_fdr: bool = False,
pvalue_template=lambda x: f"unadj. p={x:.2e}, t-test",
boxplot_properties=None,
palette=None,
show: bool = True,
Expand All @@ -540,19 +541,20 @@ def plot_paired(
Args:
adata: AnnData object, can be pseudobulked.
results_df: DataFrame with results from a differential expression test.
groupby: .obs column containing the grouping. Must contain exactly two different values.
pairedby: .obs column containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed.
var_names: Variables to plot.
groupby: Column in adata.obs containing the grouping. Must contain exactly two different values.
pairedby: Column in adata.obs containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed.
hue: Column in adata.obs to color by.
n_top_vars: Number of top variables to plot.
layer: Layer to use for plotting.
pvalue_col: Column name of the p values.
symbol_col: Column name of gene IDs.
n_cols: Number of columns in the plot.
panel_size: Size of each panel.
show_legend: Whether to show the legend.
size: Size of the points.
y_label: Label for the y-axis.
pvalues: P-values for each variable. If None, they are calculated.
pvalue_template: Template for the p-value string displayed in the title of each panel.
adjust_fdr: Whether to correct p-values for false discovery rate.
boxplot_properties: Additional properties for the boxplot, passed to seaborn.boxplot.
palette: Color palette for the line- and stripplot.
{common_plot_args}
Expand Down Expand Up @@ -580,7 +582,7 @@ def plot_paired(
>>> res_df = edgr.test_contrasts(
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
... )
>>> edgr.plot_paired(pdata, var_names=res_df["variable"][:8], groupby="Treatment", pairedby="Major celltype")
>>> edgr.plot_paired(pdata, results_df=res_df, n_top_vars=8, groupby="Treatment", pairedby="Efficacy")
Preview:
.. image:: /_static/docstring_previews/de_paired_expression.png
Expand All @@ -590,12 +592,13 @@ def plot_paired(
groups = adata.obs[groupby].unique()
if len(groups) != 2:
raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing")
if pairedby is not None and hue is not None and (adata.obs.groupby(pairedby)[hue].nunique().max() > 1):
raise ValueError("Each paired sample must have an unambiguous hue")

if var_names is None:
var_names = results_df.sort_values(pvalue_col, ascending=True).head(n_top_vars)[symbol_col].tolist()

adata = adata[:, var_names]

if pairedby is not None and any(adata.obs[[groupby, pairedby]].value_counts() > 1):
if any(adata.obs[[groupby, pairedby]].value_counts() > 1):
logger.info("Performing pseudobulk for paired samples")
ps = PseudobulkSpace()
adata = ps.compute(
Expand All @@ -611,52 +614,18 @@ def plot_paired(
except AttributeError:
pass

groupby_cols = [groupby]
if pairedby is not None:
groupby_cols.insert(0, pairedby)
if hue is not None:
groupby_cols.insert(0, hue)

groupby_cols = [pairedby, groupby]
df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))

if pairedby is not None:
# remove unpaired samples
paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby])
df = df[df[pairedby].isin(paired_samples)]
removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique())
if removed_samples > 0:
logger.warning(f"{removed_samples} unpaired samples removed")

# perform statistics (paired ttest)
if pvalues is None:
_, pvalues = scipy.stats.ttest_rel(
df.loc[
df[groupby] == groups[0],
var_names,
],
df.loc[
df[groupby] == groups[1],
var_names,
],
)

df.reset_index(drop=False, inplace=True)
# remove unpaired samples
paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby])
df = df[df[pairedby].isin(paired_samples)]
removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique())
if removed_samples > 0:
logger.warning(f"{removed_samples} unpaired samples removed")

else:
if pvalues is None:
_, pvalues = scipy.stats.ttest_ind(
df.loc[
df[groupby] == groups[0],
var_names,
],
df.loc[
df[groupby] == groups[1],
var_names,
],
)

if adjust_fdr:
pvalues = statsmodels.stats.multitest.fdrcorrection(pvalues)[1]
pvalues = results_df.set_index(symbol_col).loc[var_names, pvalue_col].values
df.reset_index(drop=False, inplace=True)

# transform data for seaborn
df_melt = df.melt(
Expand All @@ -665,7 +634,6 @@ def plot_paired(
value_name="val",
)

# start plotting
n_panels = len(var_names)
nrows = math.ceil(n_panels / n_cols)
ncols = min(n_cols, n_panels)
Expand All @@ -678,8 +646,6 @@ def plot_paired(
squeeze=False,
)
axes = axes.flatten()
if hue is None:
hue = pairedby
for i, (var, ax) in enumerate(zip_longest(var_names, axes)):
if var is not None:
sns.boxplot(
Expand All @@ -695,9 +661,9 @@ def plot_paired(
sns.lineplot(
x=groupby,
data=df_melt.loc[df_melt["var"] == var],
hue=hue,
y="val",
ax=ax,
hue=pairedby,
legend=False,
errorbar=None,
palette=palette,
Expand All @@ -708,7 +674,7 @@ def plot_paired(
data=df_melt.loc[df_melt["var"] == var],
y="val",
ax=ax,
hue=hue,
hue=pairedby,
jitter=jitter,
size=size,
linewidth=1,
Expand All @@ -729,8 +695,11 @@ def plot_paired(

if show_legend is True:
axes[n_panels - 1].legend().set_visible(True)
axes[n_panels - 1].legend(bbox_to_anchor=(1.1, 1.05))
axes[n_panels - 1].legend(
bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=adata.obs[pairedby].nunique()
)

plt.tight_layout()
if show:
plt.show()
if return_fig:
Expand All @@ -744,6 +713,8 @@ def plot_fold_change(
*,
var_names: Sequence[str] = None,
n_top_vars: int = 15,
log2fc_col: str = "log_fc",
symbol_col: str = "variable",
y_label: str = "Log2 fold change",
figsize: tuple[int, int] = (10, 5),
show: bool = True,
Expand All @@ -754,9 +725,10 @@ def plot_fold_change(
Args:
results_df: DataFrame with results from DE analysis.
pairedby: Column in results_df containing information about paired samples.
var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted.
n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively.
log2fc_col: Column name of log2 Fold-Change values.
symbol_col: Column name of gene IDs.
y_label: Label for the y-axis.
figsize: Size of the figure.
{common_plot_args}
Expand Down Expand Up @@ -791,33 +763,19 @@ def plot_fold_change(
.. image:: /_static/docstring_previews/de_fold_change.png
"""
if var_names is None:
var_names = results_df.sort_values("log_fc", ascending=False).head(n_top_vars)["variable"].tolist()
var_names += results_df.sort_values("log_fc", ascending=True).head(n_top_vars)["variable"].tolist()
var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist()
var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
assert len(var_names) == 2 * n_top_vars

df = results_df[results_df["variable"].isin(var_names)]
df.sort_values("log_fc", ascending=False, inplace=True)

max_fc = max(df["log_fc"])
min_fc = min(df["log_fc"])

def value_to_color(val):
if val > 0:
return plt.cm.Reds(val / max_fc)
elif val < 0:
return plt.cm.Blues(val / min_fc)
else:
return "grey"

df["color"] = df["log_fc"].apply(value_to_color)
df = results_df[results_df[symbol_col].isin(var_names)]
df.sort_values(log2fc_col, ascending=False, inplace=True)

plt.figure(figsize=figsize)
sns.barplot(
x="variable",
y="log_fc",
x=symbol_col,
y=log2fc_col,
data=df,
hue="variable",
palette=df["color"].tolist(),
palette="RdBu",
legend=False,
**barplot_kwargs,
)
Expand Down

0 comments on commit 5a2e14d

Please sign in to comment.