Skip to content

Commit

Permalink
Work on adding hmm likelihood as support
Browse files Browse the repository at this point in the history
  • Loading branch information
jonperdomo committed Jun 3, 2024
1 parent 9b304f3 commit 1e19485
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 127 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,7 @@ options:
For release history, please visit [here](https://github.com/WGLab/ContextSV/releases).

## Getting help
Please refer to the [contextSV issue pages](https://github.com/WGLab/ContextSV/issues) for posting your issues. We will also respond your questions quickly. Your comments are critical to improve our tool and will benefit other users.
Please refer to the [contextSV issue
pages](https://github.com/WGLab/ContextSV/issues) for posting your issues, we
will respond quickly. Your comments will
benefit other users, and are crucial to improving this tool.
8 changes: 5 additions & 3 deletions include/cnv_caller.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class CNVCaller {

void updateSNPVectors(SNPData& snp_data, std::vector<int64_t>& pos, std::vector<double>& pfb, std::vector<double>& baf, std::vector<double>& log2_cov, std::vector<int>& state_sequence, std::vector<bool>& is_snp);

std::vector<int> runViterbi(CHMM hmm, SNPData &snp_data);
std::pair<std::vector<int>, double> runViterbi(CHMM hmm, SNPData &snp_data);

// Query a region for SNPs and return the SNP data
std::pair<SNPData, bool> querySNPRegion(std::string chr, int64_t start_pos, int64_t end_pos, SNPInfo &snp_info, std::unordered_map<uint64_t, int> &pos_depth_map, double mean_chr_cov);
Expand All @@ -105,9 +105,11 @@ class CNVCaller {
// Run copy number prediction for a chunk of SV candidates
SNPData runCopyNumberPredictionChunk(std::string chr, std::map<SVCandidate, SVInfo>& sv_candidates, std::vector<SVCandidate> sv_chunk, SNPInfo& snp_info, CHMM hmm, int window_size, double mean_chr_cov, std::unordered_map<uint64_t, int>& pos_depth_map);

void updateSVType(std::map<SVCandidate, SVInfo>& sv_candidates, SVCandidate key, int sv_type, std::string data_type);
void updateSVCopyNumber(std::map<SVCandidate, SVInfo>& sv_candidates, SVCandidate key, int sv_type_update, std::string data_type, std::string genotype, double hmm_likelihood);

void updateSVGenotype(std::map<SVCandidate, SVInfo>& sv_candidates, SVCandidate key, std::string genotype);
// void updateSVType(std::map<SVCandidate, SVInfo>& sv_candidates, SVCandidate key, int sv_type, std::string data_type);

// void updateSVGenotype(std::map<SVCandidate, SVInfo>& sv_candidates, SVCandidate key, std::string genotype);

void updateDPValue(std::map<SVCandidate, SVInfo>& sv_candidates, SVCandidate key, int dp_value);

Expand Down
4 changes: 2 additions & 2 deletions include/khmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ CHMM ReadCHMM (const char *filename);
// void FreeCHMM(CHMM *phmm);

/// Run the main HMM algorithm
std::vector<int> testVit_CHMM(CHMM hmm, int T, std::vector<double>& O1, std::vector<double>& O2, std::vector<double>& pfb);
std::pair<std::vector<int>, double> testVit_CHMM(CHMM hmm, int T, std::vector<double>& O1, std::vector<double>& O2, std::vector<double>& pfb);

/// Viterbi algorithm
std::vector<int> ViterbiLogNP_CHMM(CHMM phmm, int T, std::vector<double>& O1, std::vector<double>& O2, std::vector<double>& pfb, double **delta, int **psi, std::vector<double>& pprob);
std::pair<std::vector<int>, double> ViterbiLogNP_CHMM(CHMM phmm, int T, std::vector<double>& O1, std::vector<double>& O2, std::vector<double>& pfb, double **delta, int **psi, std::vector<double>& pprob);

/// O1 emission probability
double b1iot (int state, double *mean, double *sd, double uf, double o);
Expand Down
7 changes: 4 additions & 3 deletions include/sv_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ namespace sv_types {
std::set<std::string> data_type; // Alignment type used to call the SV
int sv_length;
std::string genotype = "./."; // Default genotype (no call)
double hmm_likelihood = 0.0; // HMM likelihood for the SV

SVInfo() :
sv_type(-1), read_support(0), read_depth(0), data_type({}), sv_length(0), genotype("./.") {}
sv_type(-1), read_support(0), read_depth(0), data_type({}), sv_length(0), genotype("./."), hmm_likelihood(0.0) {}

SVInfo(int sv_type, int read_support, int read_depth, std::string data_type, int sv_length, std::string genotype) :
sv_type(sv_type), read_support(read_support), read_depth(read_depth), data_type({data_type}), sv_length(sv_length), genotype(genotype) {}
SVInfo(int sv_type, int read_support, int read_depth, std::string data_type, int sv_length, std::string genotype, double hmm_likelihood) :
sv_type(sv_type), read_support(read_support), read_depth(read_depth), data_type({data_type}), sv_length(sv_length), genotype(genotype), hmm_likelihood(hmm_likelihood) {}
};

// SV (start, end, alt_allele)
Expand Down
173 changes: 129 additions & 44 deletions python/sv_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ def plot_dbscan(breakpoints, chosen_breakpoints, filename='dbscan_clustering.png
plt.savefig(filename)


def update_support(record, cluster_size):
"""
Set the SUPPORT field in the INFO column of a VCF record to the cluster size.
"""
# Get the INFO column
info = record['INFO']

# Parse the INFO columns
info_fields = info.split(';')

# Loop and update the SUPPORT field, while creating a new INFO string
support = 0
updated_info = ''
for field in info_fields:
if field.startswith('SUPPORT='):
support = int(field.split('=')[1])

# Set the SUPPORT field to the cluster size
# updated_info = f'SUPPORT={cluster_size};'
updated_info += f'SUPPORT={cluster_size};'

# # Increment the SUPPORT field by the cluster size
# support += cluster_size
# updated_info += f'SUPPORT={support};'
else:
updated_info += field + ';' # Append the field to the updated INFO

# Update the INFO column
record['INFO'] = updated_info

return record


def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.merged'):
"""
Use DBSCAN to merge SVs with the same breakpoint.
Expand All @@ -78,14 +111,18 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer
vcf_df = pd.read_csv(vcf_file_path, sep='\t', comment='#', header=None, usecols=[0, 1, 7], \
names=['CHROM', 'POS', 'INFO'], \
dtype={'CHROM': str, 'POS': np.int64, 'INFO': str})

# Add a column at the beginning with the index
vcf_df.insert(0, 'INDEX', range(0, len(vcf_df)))
logging.info("VCF file read into a pandas DataFrame.")

# Print total number of records
logging.info(f"Total number of records: {vcf_df.shape[0]}")

# Store a dataframe of records that will form the merged VCF file
merged_records = pd.DataFrame(columns=['CHROM', 'POS', 'INFO'])

# merged_records = pd.DataFrame(columns=['CHROM', 'POS', 'INFO'])
merged_records = pd.DataFrame(columns=['INDEX', 'CHROM', 'POS', 'INFO'])

# Create a set with each chromosome in the VCF file
chromosomes = set(vcf_df['CHROM'].values)

Expand Down Expand Up @@ -137,41 +174,26 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer
chr_ins_dup_clipped_bases = chr_ins_dup_df['INFO'].str.extract(r'CLIPSUP=(\d+)', expand=False).astype(np.int32)
chr_ins_dup_depth_scores = chr_ins_dup_support + chr_ins_dup_clipped_bases

# Cluster SV breakpoints using the specified mode
if mode == 'dbscan':
# Cluster SV breakpoints using DBSCAN
dbscan = DBSCAN(eps=eps, min_samples=min_samples)

# Cluster deletion breakpoints
if len(chr_del_breakpoints) > 0:
logging.info(f"Clustering deletion breakpoints using DBSCAN with eps={eps} and min_samples={min_samples}...")
del_labels = dbscan.fit_predict(chr_del_breakpoints)
logging.info(f"Deletion label counts: {len(np.unique(del_labels))}")
else:
del_labels = []

# Cluster insertion and duplication breakpoints together since
# insertions are a subset of duplications
if len(chr_ins_dup_breakpoints) > 0:
logging.info(f"Clustering insertion and duplication breakpoints using DBSCAN with eps={eps} and min_samples={min_samples}...")
ins_dup_labels = dbscan.fit_predict(chr_ins_dup_breakpoints)
logging.info(f"Insertion and duplication label counts: {len(np.unique(ins_dup_labels))}")
else:
ins_dup_labels = []

elif mode == 'agglomerative':
# Cluster SV breakpoints using agglomerative clustering
logging.info(f"Clustering deletion breakpoints using agglomerative clustering with eps={eps}...")
agglomerative = AgglomerativeClustering(n_clusters=None, distance_threshold=eps, compute_full_tree=True)

# Cluster deletion breakpoints
del_labels = agglomerative.fit_predict(chr_del_breakpoints)
# Cluster SV breakpoints using DBSCAN
dbscan = DBSCAN(eps=eps, min_samples=min_samples)

# Cluster deletion breakpoints
if len(chr_del_breakpoints) > 0:
logging.info(f"Clustering deletion breakpoints using DBSCAN with eps={eps} and min_samples={min_samples}...")
del_labels = dbscan.fit_predict(chr_del_breakpoints)
logging.info(f"Deletion label counts: {len(np.unique(del_labels))}")
else:
del_labels = []

# Cluster insertion and duplication breakpoints together since
# insertions are a subset of duplications
if len(chr_ins_dup_breakpoints) > 0:
logging.info(f"Clustering insertion and duplication breakpoints using DBSCAN with eps={eps} and min_samples={min_samples}...")
ins_dup_labels = dbscan.fit_predict(chr_ins_dup_breakpoints)
logging.info(f"Insertion and duplication label counts: {len(np.unique(ins_dup_labels))}")
else:
ins_dup_labels = []

# Cluster insertion breakpoints
logging.info(f"Clustering insertion and duplication breakpoints using agglomerative clustering with eps={eps}...")
ins_labels = agglomerative.fit_predict(chr_ins_dup_breakpoints)
logging.info(f"Insertion label counts: {len(np.unique(ins_labels))}")

# Get the unique deletion labels for the chromosome
unique_del_labels = np.unique(del_labels)
Expand All @@ -185,6 +207,17 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer

# Skip label -1 (outliers)
if label == -1:
# Print the positions if any are within a certain range
pos_min = 180915940
pos_max = 180950356
idx = del_labels == label
pos_values = chr_del_breakpoints[idx][:, 0]
if (np.any(pos_values >= pos_min) and np.any(pos_values <= pos_max)):
# Print all within range
pos_within_range = pos_values[(pos_values >= pos_min) & (pos_values <= pos_max)]
logging.info(f"Outlier deletion positions: {pos_within_range}")
# logging.info(f"Outlier deletion positions: {pos_values}")

continue

# Get the indices of SVs with the same label
Expand All @@ -199,9 +232,41 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer
# Get the VCF record with the highest depth score
max_del_record = chr_del_df.iloc[idx, :].iloc[max_depth_score_idx, :]

# Get the number of SVs in this cluster
cluster_size = np.sum(idx)
# logging.info("DEL Cluster size: %s", cluster_size)

# Update the SUPPORT field in the INFO column
max_del_record = update_support(max_del_record, cluster_size)

# Get all position values in the cluster
pos_values = chr_del_breakpoints[idx][:, 0]

# If the POS value is a certain value, plot the support
# target_pos = 180949294
pos_min = 180915940
pos_max = 180950356
# if max_del_record['POS'] == target_pos:
# if max_del_record['POS'] >= pos_min and max_del_record['POS'] <=
# pos_max:
# If any position in the cluster is within the range, log the
# information (or if the cluster size is greater than 10)
if (np.any(pos_values >= pos_min) and np.any(pos_values <= pos_max)) or cluster_size > 1000:
logging.info(f"Cluster size: {cluster_size}")
logging.info(f"Pos values: {pos_values}")
# logging.info(f"FOUND POS {target_pos} - Cluster size:
# {cluster_size}")
# logging.info(f"FOUND POS {max_del_record['POS']} - Cluster size: {cluster_size}")
# logging.info(f"INFO: {max_del_record['INFO']}")

# Append the chosen record to the dataframe of records that will
# form the merged VCF file
merged_records.loc[merged_records.shape[0]] = max_del_record

# Plot the DBSCAN clustering behavior if there are 10 < X < 20 SVs with the same label
plot_enabled = False
if plot_enabled:
logging.info(f"Plotting DBSCAN clustering behavior for label {label}...")
if len(chr_del_breakpoints[idx]) > 10 and len(chr_del_breakpoints[idx]) < 20 and num_plots < max_plots:

# Increment the number of plots
Expand All @@ -212,14 +277,10 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer
chosen_breakpoints = chr_del_breakpoints[chosen_idx]
plot_dbscan(chr_del_breakpoints[idx], chosen_breakpoints, filename=f"dbscan_clustering_{num_plots}.png")

# TEST: Return if the number of plots is reached
# Return if the number of plots is reached
if num_plots == max_plots:
return

# Append the chosen record to the dataframe of records that will
# form the merged VCF file
merged_records.loc[merged_records.shape[0]] = max_del_record

# Merge insertions and duplications with the same label
ins_dup_count = 0
for label in unique_ins_dup_labels:
Expand All @@ -228,7 +289,8 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer
if label == -1:
continue

# Get the indices of SVs with the same label
# Get the indices of SVs with the same label. IDX is a boolean
# array, where True indicates the SV has the same label
idx = ins_dup_labels == label

# Get the SV depth scores with the same label
Expand All @@ -240,6 +302,13 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer
# Get the VCF record with the highest depth score
max_ins_dup_record = chr_ins_dup_df.iloc[idx, :].iloc[max_depth_score_idx, :]

# Get the number of SVs in this cluster
cluster_size = np.sum(idx)
# logging.info("DUP Cluster size: %s", cluster_size)

# Update the SUPPORT field in the INFO column
max_ins_dup_record = update_support(max_ins_dup_record, cluster_size)

# Append the chosen record to the dataframe of records that will
# form the merged VCF file
merged_records.loc[merged_records.shape[0]] = max_ins_dup_record
Expand Down Expand Up @@ -278,6 +347,7 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer
logging.info(f"Writing {merged_records.shape[0]} records to merged VCF file...")

merge_count = 0
index_start = 0
with open(merged_vcf, 'w', encoding='utf-8') as merged_vcf_file:

# Write the VCF header to the merged VCF file
Expand All @@ -295,12 +365,27 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer
dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \
'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}, \
chunksize=1000):

# Add an INDEX column to the chunk
chunk.insert(0, 'INDEX', range(index_start, index_start + chunk.shape[0]))
index_start += chunk.shape[0]

# Merge on INDEX, and use all information from the original VCF file
# (chunk) but update the INFO field with the merged INFO field.
# This is done by dropping the INFO column from the chunk so that
# the INFO column from the merged_records dataframe is used.
# matching_records = pd.merge(chunk.drop(columns=['INFO']), merged_records, on=['INDEX'], how='inner')
matching_records = pd.merge(chunk.drop(columns=['INFO']), merged_records[['INDEX', 'INFO']], on=['INDEX'], how='inner')

# Print the columns of the matching records
matching_records = matching_records.drop_duplicates(subset=['INDEX']) # Drop duplicate records
matching_records = matching_records.drop(columns=['INDEX']) # Drop the INDEX column


# Get the matching records from the chunk by merging on CHROM, POS,
# and INFO, but only keep the records from the chunk since they
# contain the full VCF record
matching_records = pd.merge(chunk, merged_records, on=['CHROM', 'POS', 'INFO'], how='inner')
matching_records = matching_records.drop_duplicates(subset=['CHROM', 'POS', 'INFO']) # Drop duplicate records
# matching_records = pd.merge(chunk, merged_records, on=['CHROM', 'POS', 'INFO'], how='inner')
# matching_records = matching_records.drop_duplicates(subset=['CHROM', 'POS', 'INFO']) # Drop duplicate records

# Remove the matching records from the merged records dataframe
merged_records = merged_records[~merged_records.isin(matching_records)].dropna()
Expand Down
Loading

0 comments on commit 1e19485

Please sign in to comment.