Skip to content

Commit

Permalink
latent space bias functions moved into exploration file and notebook … (
Browse files Browse the repository at this point in the history
#517)

* latent space bias functions moved into exploration file and notebook to use these functions
  • Loading branch information
anamika1302 committed Jun 6, 2023
1 parent 1f665ee commit 69fb14d
Show file tree
Hide file tree
Showing 2 changed files with 1,179 additions and 0 deletions.
235 changes: 235 additions & 0 deletions ml4h/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import operator
import datetime
from scipy import stats
from functools import reduce
from itertools import combinations
from collections import defaultdict, Counter, OrderedDict
Expand All @@ -20,8 +21,10 @@
from sklearn.decomposition import PCA
from tensorflow.keras.models import Model


import matplotlib
matplotlib.use('Agg') # Need this to write images from the GSA servers. Order matters:
import matplotlib.cm as cm
import matplotlib.pyplot as plt # First import matplotlib, then use Agg, then import plt

from ml4h.models.legacy_models import legacy_multimodal_multitask_model
Expand All @@ -33,10 +36,242 @@
from ml4h.defines import JOIN_CHAR, MRI_SEGMENTED_CHANNEL_MAP, CODING_VALUES_MISSING, CODING_VALUES_LESS_THAN_ONE
from ml4h.defines import TENSOR_EXT, IMAGE_EXT, ECG_CHAR_2_IDX, ECG_IDX_2_CHAR, PARTNERS_CHAR_2_IDX, PARTNERS_IDX_2_CHAR, PARTNERS_READ_TEXT

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression, LinearRegression, ElasticNet, Ridge, Lasso


CSV_EXT = '.tsv'


def stratify_and_project_latent_space(stratify_column: str,stratify_thresh: float,stratify_std: float,latent_cols: List[str],
latent_df: pd.DataFrame,
normalize: bool = False,
train_ratio: int = 1.0):
"""
Stratify data and project it to new latent space.
Args:
stratify_column (str): Name of the column used for stratification.
stratify_thresh (float): Threshold value for stratification.
stratify_std (float): Standard deviation value for stratification.
latent_cols (List[str]): List of column names for the latent space.
latent_df (pd.DataFrame): DataFrame containing the latent space data.
normalize (bool): Flag indicating whether to normalize the data. Default is False.
train_ratio (int): Ratio of training data to be used. Default is 1.0.
Returns:
Dict[str, Tuple[float,float,float]]
"""
if train_ratio == 1.0:
train = latent_df
test = latent_df
else:
train = latent_df.sample(frac=train_ratio)
test = latent_df.drop(train.index)
hit = train.loc[train[stratify_column] >= stratify_thresh+(1*stratify_std)]
miss = train.loc[train[stratify_column] < stratify_thresh-(1*stratify_std)]
hit_np = hit[latent_cols].to_numpy()
miss_np = miss[latent_cols].to_numpy()
miss_mean_vector = np.mean(miss_np, axis=0)
hit_mean_vector = np.mean(hit_np, axis=0)
angle = angle_between(miss_mean_vector, hit_mean_vector)

hit_test = test.loc[test[stratify_column] >= stratify_thresh+(1*stratify_std)]
miss_test = test.loc[test[stratify_column] < stratify_thresh-(1*stratify_std)]

if normalize:
phenotype_vector = unit_vector(hit_mean_vector-miss_mean_vector)
hit_dots = [np.dot(phenotype_vector, unit_vector(v)) for v in hit_test[latent_cols].to_numpy()]
miss_dots = [np.dot(phenotype_vector, unit_vector(v)) for v in miss_test[latent_cols].to_numpy()]
else:
phenotype_vector = hit_mean_vector-miss_mean_vector
hit_dots = [np.dot(phenotype_vector, v) for v in hit_test[latent_cols].to_numpy()]
miss_dots = [np.dot(phenotype_vector, v) for v in miss_test[latent_cols].to_numpy()]
t2, p2 = stats.ttest_ind(hit_dots, miss_dots, equal_var = False)

return {f'{stratify_column}': (t2, p2, len(hit)) }



def plot_nested_dictionary(all_scores: DefaultDict[str, DefaultDict[str, Tuple[float,float,float]]]) -> None:
"""
Function to create a plot displaying T-statistics v/s Negative Log P-Value for each covariate.
Args:
all_scores (DefaultDict[str, DefaultDict[str, Tuple[float, float, float]]]): Nested dictionary containing the scores.
Returns:
None
"""
n = 4
eps = 1e-300
for model in all_scores:
n = max(n, len(all_scores[model]))
cols = max(2, int(math.ceil(math.sqrt(n))))
rows = max(2, int(math.ceil(n / cols)))
fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 3), sharex=True, dpi=300)
renest = defaultdict(dict)
errors = defaultdict(dict)
lens = {}
max_tstat = 0
max_pval = 0
for model in all_scores:
for metric in all_scores[model]:
renest[metric][model] = all_scores[model][metric][0]
errors[metric][model] = all_scores[model][metric][1]
lens[metric] = all_scores[model][metric][2]
max_tstat = max(abs(all_scores[model][metric][0]), max_tstat)
max_pval = max(-np.log10(all_scores[model][metric][1]+eps), max_pval)
for metric, ax in zip(renest, axes.ravel()):

models = [k for k,v in sorted(renest[metric].items(), key=lambda x: x[0].lower())]
tstats = [abs(v) for k,v in sorted(renest[metric].items(), key=lambda x: x[0].lower())]
pvalues = [-np.log10(v) if v > 1e-4800 else 500 for k,v in sorted(errors[metric].items(), key=lambda x: x[0].lower())]
y_pos = np.arange(len(models))
x = np.linspace(0, 1, int(max_pval))
plt.imshow(x[:, np.newaxis], cmap=cm.jet)
cb = plt.colorbar(ax=ax, ticks=[0, 1.0])
cb.set_label('Negative Log P-Value')
cb.ax.set_yticklabels(['0', f'{max_pval:0.0f}'])
ax.barh(y_pos, tstats, color=[cm.jet(p/max_pval) for p in pvalues], align='center')
ax.set_yticks(y_pos)
ax.set_yticklabels(models)
ax.invert_yaxis() # labels read top-to-bottom
ax.set_xlabel('T–Statistic')
ax.xaxis.set_tick_params(which='both', labelbottom=True)
ax.set_title(f'{metric}\n n={lens[metric]}')

plt.tight_layout()



def angle_between(v1: np.ndarray, v2 : np.ndarray):
""" Returns the angle in radians between vectors 'v1' and 'v2'::
angle_between((1, 0, 0), (0, 1, 0))
90
angle_between((1, 0, 0), (1, 0, 0))
0.0
angle_between((1, 0, 0), (-1, 0, 0))
180
"""
v1_u = unit_vector(v1)
v2_u = unit_vector(v2)
return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) * 180 / 3.141592


def unit_vector(vector : np.ndarray):
""" Returns the unit vector of the vector. """
return vector / np.linalg.norm(vector)


def latent_space_dataframe(infer_hidden_tsv: str, explore_csv: str):
"""
Read raw data from a CSV file and generate a representation of the data in a latent space.
Args:
infer_hidden_tsv (str): Path to the TSV file containing the inferred hidden representations.
explore_csv (str): Path to the CSV file containing the data to be explored.
Returns:
pandas.DataFrame: Dataframe representing the data in the latent space.
"""
df = pd.read_csv(explore_csv)
df['sample_id'] = pd.to_numeric(df['fpath'], errors='coerce')
df2 = pd.read_csv(infer_hidden_tsv, sep='\t', engine='python')
df2['sample_id'] = pd.to_numeric(df2['sample_id'], errors='coerce')
latent_df = pd.merge(df, df2, on='sample_id', how='inner')
return latent_df


def confounder_vector(labels: pd.Series, space: np.ndarray):
"""
Compute the confounder vector based on labels and latent space.
Args:
labels (numpy.ndarray or list): The labels representing the dependent variable.
space (numpy.ndarray or list): The latent space representing the independent variable.
Returns:
cv and r2
"""
clf = make_pipeline(StandardScaler(with_mean=True), Ridge(solver='lsqr'))
clf.fit(space, labels)
train_score = clf.score(space, labels)
return clf[-1].coef_/clf[0].scale_, train_score


def confounder_matrix(adjust_cols: List[str], df: pd.DataFrame, space: np.ndarray):
"""
Compute the confounder matrix based on specified columns, a dataframe, and a latent space.
Args:
adjust_cols (list): List of column names to adjust for as confounders.
df (pandas.DataFrame): The dataframe containing the data.
space (numpy.ndarray): The latent space representing the independent variable.
Returns:
computed confounder matrix and scores.
"""
vectors = []
scores = {}
for col in adjust_cols:
cv, r2 = confounder_vector(df[col], space)
scores[col] = r2
vectors.append(cv)
return np.array(vectors), scores

def iterative_subspace_removal(adjust_cols: List[str], latent_df: pd.DataFrame, latent_cols: List[str],
r2_thresh: float = 0.01, fit_pca: bool = False):
"""
Perform iterative subspace removal based on specified columns, a latent dataframe,
and other parameters to remove confounder variables.
Args:
adjust_cols (List[str]): List of column names to adjust for as confounders.
latent_df (pd.DataFrame): The dataframe containing the latent data.
latent_cols (List[str]): List of column names representing the latent variables.
r2_thresh (float, optional): The threshold for the coefficient of determination (R-squared).
Default is 0.01.
fit_pca (bool, optional): Whether to fit Principal Component Analysis (PCA) on the latent data.
Default is False.
Returns:
pd.DataFrame: The latent dataframe after performing iterative subspace removal.
List: List of new columns
"""
new_cols = latent_cols
new_adjust_cols = adjust_cols
space = latent_df[latent_cols].to_numpy()

if fit_pca:
pca = PCA()
pca.fit(space)
space = pca.transform(space)

iteration = 0
while len(new_adjust_cols) > 0 and space.shape[-1] > len(new_adjust_cols):
cfm, scores = confounder_matrix(new_adjust_cols, latent_df, space)
u, s, vt = np.linalg.svd(cfm, full_matrices=True)
nspace = np.matmul(space, vt[:, len(new_adjust_cols):])
new_cols=[f'new_latent_{iteration}_{i}' for i in range(nspace.shape[-1])]
df2 = pd.DataFrame(nspace, columns=new_cols, index=latent_df.index)
latent_df = pd.concat([latent_df, df2], axis=1)

iteration += 1
space = nspace

new_adjust_cols = [col for col, score in scores.items() if score > r2_thresh]
keep_cols = new_cols + [c for c in latent_df.columns if 'latent' not in c]
latent_df = latent_df[keep_cols]
r_scores= {k:round(v,4) for k,v in scores.items()}
print(f'Scores were {r_scores}, remaining columns are {new_adjust_cols}')
print(f'After iteration {iteration} Space shape is: {space.shape}')
return new_cols, latent_df



def predictions_to_pngs(
predictions: np.ndarray, tensor_maps_in: List[TensorMap], tensor_maps_out: List[TensorMap], data: Dict[str, np.ndarray],
labels: Dict[str, np.ndarray], paths: List[str], folder: str,
Expand Down
Loading

0 comments on commit 69fb14d

Please sign in to comment.