Skip to content

Commit

Permalink
Fixed parsing of the joint gCNV callset
Browse files Browse the repository at this point in the history
  • Loading branch information
asmirnov239 committed Jul 13, 2022
1 parent 186b09a commit 6c22e41
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 52 deletions.
42 changes: 27 additions & 15 deletions evaluation_code/pygcnveval/pygcnveval/callset.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ def _construct_sample_to_pyrange_map(callset_pyrange: pr.PyRanges, sample_set: F
sample_to_events_list_map[sample].append(event)

for sample in sample_set:
if not sample_to_events_list_map[sample]:
sample_to_pyrange_map[sample] = pr.PyRanges()
continue
events_df = pd.DataFrame(sample_to_events_list_map[sample])
events_df = events_df.astype(Callset.CALLSET_COLUMN_TYPES)
sample_to_pyrange_map[sample] = pr.PyRanges(events_df)
Expand Down Expand Up @@ -355,7 +358,7 @@ def __init__(self, sample_to_pyrange_map: dict, joint_callset: pr.PyRanges, inte
@classmethod
def read_in_callset(cls, gcnv_segment_vcfs: List[str], gcnv_callset_tsv: Optional[str], gcnv_joint_vcf: Optional[str],
interval_collection: IntervalCollection, max_events_allowed: int = 100,
sq_min_del: int =100, sq_min_dup: int = 50):
sq_min_del: int = 100, sq_min_dup: int = 50):

sample_to_pyrange_map = {}

Expand Down Expand Up @@ -412,7 +415,7 @@ def _get_attribute_list(call, num_exon, qs, sf, hg38_str):

events_df = pd.DataFrame(events_df_lists, columns=Callset.CALLSET_COLUMNS)
if len(events_df) > max_events_allowed:
continue
continue
events_pr = pr.PyRanges(events_df)
sample_to_pyrange_map[sample_name] = events_pr

Expand Down Expand Up @@ -448,45 +451,54 @@ def _get_attribute_list(call, num_exon, qs, sf, hg38_str):
number_overlaps = sum(sample_to_length_map[s] / length > 0.5 for s in sample_set)
af = number_overlaps / num_samples
event_frequencies.append(af)
#print(sample_to_pyrange_map[sample].Frequency[index])
sample_to_pyrange_map[sample].Frequency = pd.Series(event_frequencies)
print("Done calculating AF")

# Read in joint vcf
joint_callset_pr = None
if gcnv_joint_vcf is not None:
sample_to_pyrange_map = {}
joint_vcf_reader = vcf.Reader(open(gcnv_joint_vcf, 'r'))
joint_sample_list = list(joint_vcf_reader.samples)
sample_to_event_list = {s: [] for s in joint_sample_list}
events_list = []
for record in joint_vcf_reader:
if not record.FILTER:
if record.FILTER:
continue
del_call_samples = []
dup_call_samples = []
if record.ALT[0] is None:
continue
del_call_sample_to_qual = {}
dup_call_samples_to_qual = {}
for sample in joint_sample_list:
# TODO this does not work properly for allosomal contigs but we ignore them for now
if record.genotype(sample)['CN']: # check that there is a value stored in the field
if record.genotype(sample)['CN'] < 2:
if record.genotype(sample)['QS'] >= sq_min_del:
del_call_samples.append(sample)
del_call_sample_to_qual[sample] = record.genotype(sample)['QS']
elif record.genotype(sample)['CN'] > 2:
if record.genotype(sample)['QS'] >= sq_min_dup:
dup_call_samples.append(sample)
dup_call_samples_to_qual[sample] = record.genotype(sample)['QS']
chrom, start, end, num_bins = record.CHROM, int(record.POS), int(record.INFO['END']), int(
record.genotype(joint_sample_list[0])['NP'])
name_prefix = str(chrom) + "_" + str(start) + "_" + str(end) + "_"
af = (len(del_call_samples) + len(dup_call_samples)) / len(joint_sample_list)
if del_call_samples:
events_list.append([chrom, start, end, name_prefix + "DEL", EventType.DEL, frozenset(del_call_samples), af, num_bins])
if dup_call_samples:
events_list.append([chrom, start, end, name_prefix + "DUP", EventType.DUP, frozenset(dup_call_samples), af, num_bins])
af = (len(del_call_sample_to_qual) + len(dup_call_samples_to_qual)) / len(joint_sample_list)
if del_call_sample_to_qual:
for s, q in del_call_sample_to_qual.items():
sample_to_event_list[s].append([chrom, start, end, EventType.DEL, num_bins, q, af])
events_list.append([chrom, start, end, name_prefix + "DEL", EventType.DEL, frozenset(del_call_sample_to_qual.keys()), af, num_bins])
if dup_call_samples_to_qual:
for s, q in dup_call_samples_to_qual.items():
sample_to_event_list[s].append([chrom, start, end, EventType.DUP, num_bins, q, af])
events_list.append([chrom, start, end, name_prefix + "DUP", EventType.DUP, frozenset(dup_call_samples_to_qual.keys()), af, num_bins])
joint_events_df = pd.DataFrame(events_list, columns=Callset.JOINT_CALLSET_COLUMNS)
joint_events_df.astype(Callset.JOINT_CALLSET_COLUMN_TYPES)
joint_callset_pr = pr.PyRanges(joint_events_df)

sample_to_pyrange_map = Callset._construct_sample_to_pyrange_map(joint_callset_pr, frozenset(joint_sample_list))
for s in joint_sample_list:
if len(sample_to_event_list[s]) > max_events_allowed:
continue
# TODO remove these samples from the joint callset as well
events_df = pd.DataFrame(sample_to_event_list[s], columns=Callset.CALLSET_COLUMNS)
sample_to_pyrange_map[s] = pr.PyRanges(events_df)

return cls(sample_to_pyrange_map, joint_callset_pr, interval_collection)

Expand Down
20 changes: 14 additions & 6 deletions evaluation_code/pygcnveval/pygcnveval/evaluate_cnv_callset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from callset import TruthCallset, GCNVCallset, XHMMCallset
from interval_collection import IntervalCollection
from evaluator import PerEventEvaluator, PerBinEvaluator, PerSiteEvaluator
from evaluator import PerEventEvaluator, PerSiteEvaluator
import plotting


Expand All @@ -18,12 +18,14 @@ def evaluate_cnv_callsets_and_plot_results(analyzed_intervals: str,
minimum_overlap: float,
gcnv_sq_min_del: int,
gcnv_sq_min_dup: int,
site_frequency_threshold: float,
samples_to_evaluate_path: str):
perform_per_site_eval = False

print("Reading in interval list...", flush=True)
interval_collection = IntervalCollection.read_interval_list(analyzed_intervals)
callsets_to_evaluate = []
gcnv_callset = None
if gcnv_vcfs or gcnv_callset_tsv or gcnv_joint_vcf:
print("Reading in gCNV callset...", flush=True)
gcnv_callset = GCNVCallset.read_in_callset(gcnv_segment_vcfs=gcnv_vcfs,
Expand All @@ -37,9 +39,10 @@ def evaluate_cnv_callsets_and_plot_results(analyzed_intervals: str,

if xhmm_vcfs:
print("Reading in XHMM callset", flush=True)
samples_to_keep = None if gcnv_callset else gcnv_callset.sample_set
xhmm_callset = XHMMCallset.read_in_callset(xhmm_vcfs=xhmm_vcfs,
interval_collection=interval_collection,
samples_to_keep=gcnv_callset.sample_set)
samples_to_keep=samples_to_keep)
callsets_to_evaluate.append(xhmm_callset)

if samples_to_evaluate_path:
Expand Down Expand Up @@ -78,7 +81,8 @@ def evaluate_cnv_callsets_and_plot_results(analyzed_intervals: str,
(len(per_event_evaluator.sample_list_to_eval), len(callset.sample_set)))
per_event_evaluation_result = per_event_evaluator.evaluate_callset_against_the_truth(minimum_overlap=minimum_overlap,
gcnv_sq_min_del=gcnv_sq_min_del,
gcnv_sq_min_dup=gcnv_sq_min_dup)
gcnv_sq_min_dup=gcnv_sq_min_dup,
site_frequency_threshold=site_frequency_threshold)
per_event_evaluation_results.append(per_event_evaluation_result)

per_event_evaluation_result.write_to_file(output_directory)
Expand Down Expand Up @@ -125,6 +129,9 @@ def main():
parser.add_argument('--gcnv_min_sq_dup_threshold', metavar='MinimumSQDelThreshold', type=int,
help='SQ threshold to filter gCNV duplication events on', required=True)

parser.add_argument('--site_frequency_threshold', metavar='SiteFrequencyThreshold', type=float,
help='Site frequency threshold for filtering variant sites', required=True)

parser.add_argument('--samples_to_evaluate_path', metavar='SamplesToEvaluate', type=str,
help='A file containing the set of samples to evaluate, one sample per line.',
required=False)
Expand All @@ -141,14 +148,15 @@ def main():
min_required_overlap = args.min_required_overlap
gcnv_sq_min_del = args.gcnv_min_sq_del_threshold
gcnv_sq_min_dup = args.gcnv_min_sq_dup_threshold
site_frequency_threshold = args.site_frequency_threshold
samples_to_evaluate_path = args.samples_to_evaluate_path

assert (gcnv_segment_vcfs is None) ^ (gcnv_callset_tsv is None) ^ (gcnv_segment_vcfs is None), \
"Exactly one of the gCNV segment VCF list or gCNV TSV callset or joint gCNV VCF must be defined"

evaluate_cnv_callsets_and_plot_results(analyzed_intervals, truth_callset, gcnv_segment_vcfs, gcnv_callset_tsv, gcnv_max_event_number,
gcnv_joint_vcf, xhmm_vcfs, output_dir, min_required_overlap,
gcnv_sq_min_del, gcnv_sq_min_dup, samples_to_evaluate_path)
evaluate_cnv_callsets_and_plot_results(analyzed_intervals, truth_callset, gcnv_segment_vcfs, gcnv_callset_tsv,
gcnv_max_event_number, gcnv_joint_vcf, xhmm_vcfs, output_dir, min_required_overlap,
gcnv_sq_min_del, gcnv_sq_min_dup, site_frequency_threshold, samples_to_evaluate_path)


if __name__ == '__main__':
Expand Down
37 changes: 10 additions & 27 deletions evaluation_code/pygcnveval/pygcnveval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@ def __init__(self, truth_callset: TruthCallset, callset: Callset, sample_list_to
assert set(sample_list_to_evaluate).issubset(self.callset.sample_set)
self.sample_list_to_eval = sample_list_to_evaluate

def evaluate_callset_against_the_truth(self, minimum_overlap: float = 0.2,
gcnv_sq_min_del: int = 100, gcnv_sq_min_dup: int = 50) -> PerEventEvaluationResult:
def evaluate_callset_against_the_truth(self, minimum_overlap: float = 0.2, gcnv_sq_min_del: int = 100,
gcnv_sq_min_dup: int = 50, site_frequency_threshold: float = 0.01) -> PerEventEvaluationResult:
evaluation_result = PerEventEvaluationResult(self.callset.get_name())

# construct callset filter
def gcnv_calls_filter(e):
is_allosome = e.interval.chrom == "chrX" or e.interval.chrom == "chrY" \
or e.interval.chrom == "X" or e.interval.chrom == "Y"
return not is_allosome and (e.call_attributes['Frequency'] <= 0.02) \
return not is_allosome and (e.call_attributes['Frequency'] <= site_frequency_threshold) \
and ((e.call_attributes['Quality'] >= gcnv_sq_min_del and e.event_type == EventType.DEL)
or (e.call_attributes['Quality'] >= gcnv_sq_min_dup and e.event_type == EventType.DUP))

def xhmm_calls_filter(e):
is_allosome = e.interval.chrom == "chrX" or e.interval.chrom == "chrY" \
or e.interval.chrom == "X" or e.interval.chrom == "Y"
return not is_allosome and (e.call_attributes['Frequency'] <= 0.02) and (e.call_attributes['Quality'] >= 60)
return not is_allosome and (e.call_attributes['Frequency'] <= site_frequency_threshold) and (e.call_attributes['Quality'] >= 60)

calls_filter = None
if self.callset.get_name() == "gCNV_callset":
Expand All @@ -42,29 +42,12 @@ def xhmm_calls_filter(e):
# Calculate precision
for validated_event in self.callset.get_event_generator(self.sample_list_to_eval, calls_filter):

# if self.callset.get_name() == "gCNV_callset" and validated_event.call_attributes['Quality'] < 50 and validated_event.event_type == EventType.DUP:
# continue
# if self.callset.get_name() == "gCNV_callset" and validated_event.call_attributes['Quality'] < 100 and validated_event.event_type == EventType.DEL:
# continue

# if self.callset.get_name() == "XHMM_callset" and validated_event.call_attributes['Quality'] < 60:
# continue
#
# if validated_event.interval.chrom == "chrX" or validated_event.interval.chrom == "chrY" or validated_event.interval.chrom == "X" or validated_event.interval.chrom == "Y":
# continue
#
# if self.callset.get_name() == "XHMM_callset" and validated_event.call_attributes['Frequency'] > 0.02:
# continue
# # check if event is common
# if self.callset.get_name() == "gCNV_callset" and validated_event.call_attributes['Frequency'] > 0.05:
# continue

overlapping_truth_events = self.truth_callset.get_overlapping_events_for_sample(validated_event.interval,
validated_event.sample)
overlapping_truth_event_best_match = validated_event.find_event_with_largest_overlap(overlapping_truth_events)

if overlapping_truth_event_best_match:
if overlapping_truth_event_best_match.call_attributes['Frequency'] > 0.02:
if overlapping_truth_event_best_match.call_attributes['Frequency'] > site_frequency_threshold:
evaluation_result.update_precision(validated_event.call_attributes['NumBins'], None,
validated_event.event_type, validated_event.interval,
"FILTERED_HIGH_TRUTH_AF", list([validated_event.sample]))
Expand All @@ -82,7 +65,7 @@ def xhmm_calls_filter(e):
for truth_event in self.truth_callset.get_event_generator(self.sample_list_to_eval):
if truth_event.interval.chrom == "chrX" or truth_event.interval.chrom == "chrY" or truth_event.interval.chrom == "X" or truth_event.interval.chrom == "Y":
continue
if truth_event.call_attributes['Frequency'] > 0.02:
if truth_event.call_attributes['Frequency'] > site_frequency_threshold:
continue
overlapping_gcnv_events = self.callset.get_overlapping_events_for_sample(truth_event.interval, truth_event.sample)

Expand All @@ -105,7 +88,7 @@ def xhmm_calls_filter(e):
"FILTERED_LOW_QUAL", list([truth_event.sample]))
continue

if self.callset.get_name() == "XHMM_callset" and overlapping_gcnv_event_best_match.call_attributes['Frequency'] > 0.02:
if self.callset.get_name() == "XHMM_callset" and overlapping_gcnv_event_best_match.call_attributes['Frequency'] > site_frequency_threshold:
evaluation_result.update_recall(truth_event.call_attributes['NumBins'], False,
truth_event.event_type, truth_event.interval,
"FILTERED_HIGH_AF", list([truth_event.sample])) # TODO refactor
Expand All @@ -116,7 +99,7 @@ def xhmm_calls_filter(e):
"FILTERED_LOW_QUAL", list([truth_event.sample])) # TODO refactor
continue

if self.callset.get_name() == "gCNV_callset" and overlapping_gcnv_event_best_match.call_attributes['Frequency'] > 0.02:
if self.callset.get_name() == "gCNV_callset" and overlapping_gcnv_event_best_match.call_attributes['Frequency'] > site_frequency_threshold:
evaluation_result.update_recall(truth_event.call_attributes['NumBins'], False, truth_event.event_type,
truth_event.interval, "FILTERED_HIGH_AF", list([truth_event.sample]))
continue
Expand Down Expand Up @@ -151,7 +134,7 @@ def evaluate_callset_against_the_truth(self) -> PerEventEvaluationResult:
overlapping_truth_alleles)

if overlapping_truth_allele_best_match:
if overlapping_truth_allele_best_match.allele_attributes['Frequency'] > 0.01:
if overlapping_truth_allele_best_match.allele_attributes['Frequency'] > 0.02:
evaluation_result.update_precision(validated_allele.allele_attributes['NumBins'], None,
validated_allele.event_type, validated_allele.interval,
"FILTERED_HIGH_TRUTH_AF", list(validated_allele.sample_set))
Expand All @@ -171,7 +154,7 @@ def evaluate_callset_against_the_truth(self) -> PerEventEvaluationResult:
continue
if truth_allele.interval.chrom == "chrX" or truth_allele.interval.chrom == "chrY":
continue
if truth_allele.allele_attributes['Frequency'] > 0.01:
if truth_allele.allele_attributes['Frequency'] > 0.02:
continue

overlapping_gcnv_alleles = self.callset.get_overlapping_alleles(truth_allele.interval)
Expand Down
8 changes: 4 additions & 4 deletions evaluation_code/pygcnveval/pygcnveval/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def sum_over_dict(d: dict, last_index: int, bins: list):
for index, evaluation_result in enumerate(evaluation_result_list):
tp = evaluation_result.precision_size_to_tp
fp = evaluation_result.precision_size_to_fp
print(tp)
print(fp)
print("Number of true positives for events of the following bin sizes (precision calculation): {0}".format(tp))
print("Number of false positives for events of the following bin sizes (precision calculation): {0}".format(fp))
precisions_for_bins.append([sum_over_dict(tp, i, bins) / max(1., (sum_over_dict(tp, i, bins)+sum_over_dict(fp, i, bins))) for i in bins])
tp = evaluation_result.recall_size_to_tp
fn = evaluation_result.recall_size_to_fn
print(tp)
print(fn)
print("Number of true positives for events of the following bin sizes (recall calculation): {0}".format(tp))
print("Number of false negatives for events of the following bin sizes (recall calculation): {0}".format(fn))
recall_for_bins.append([sum_over_dict(tp, i, bins) / max(1., (sum_over_dict(tp, i, bins) + sum_over_dict(fn, i, bins))) for i in bins])
f_1_for_bins.append([2 / ((1. / recall_for_bins[index][i]) + (1. / precisions_for_bins[index][i])) for i in bins])

Expand Down

0 comments on commit 6c22e41

Please sign in to comment.