diff --git a/evaluation_code/pygcnveval/pygcnveval/callset.py b/evaluation_code/pygcnveval/pygcnveval/callset.py index 2ba6c31..2f4e40f 100644 --- a/evaluation_code/pygcnveval/pygcnveval/callset.py +++ b/evaluation_code/pygcnveval/pygcnveval/callset.py @@ -259,6 +259,9 @@ def _construct_sample_to_pyrange_map(callset_pyrange: pr.PyRanges, sample_set: F sample_to_events_list_map[sample].append(event) for sample in sample_set: + if not sample_to_events_list_map[sample]: + sample_to_pyrange_map[sample] = pr.PyRanges() + continue events_df = pd.DataFrame(sample_to_events_list_map[sample]) events_df = events_df.astype(Callset.CALLSET_COLUMN_TYPES) sample_to_pyrange_map[sample] = pr.PyRanges(events_df) @@ -355,7 +358,7 @@ def __init__(self, sample_to_pyrange_map: dict, joint_callset: pr.PyRanges, inte @classmethod def read_in_callset(cls, gcnv_segment_vcfs: List[str], gcnv_callset_tsv: Optional[str], gcnv_joint_vcf: Optional[str], interval_collection: IntervalCollection, max_events_allowed: int = 100, - sq_min_del: int =100, sq_min_dup: int = 50): + sq_min_del: int = 100, sq_min_dup: int = 50): sample_to_pyrange_map = {} @@ -412,7 +415,7 @@ def _get_attribute_list(call, num_exon, qs, sf, hg38_str): events_df = pd.DataFrame(events_df_lists, columns=Callset.CALLSET_COLUMNS) if len(events_df) > max_events_allowed: - continue + continue events_pr = pr.PyRanges(events_df) sample_to_pyrange_map[sample_name] = events_pr @@ -448,45 +451,54 @@ def _get_attribute_list(call, num_exon, qs, sf, hg38_str): number_overlaps = sum(sample_to_length_map[s] / length > 0.5 for s in sample_set) af = number_overlaps / num_samples event_frequencies.append(af) - #print(sample_to_pyrange_map[sample].Frequency[index]) sample_to_pyrange_map[sample].Frequency = pd.Series(event_frequencies) print("Done calculating AF") # Read in joint vcf joint_callset_pr = None if gcnv_joint_vcf is not None: + sample_to_pyrange_map = {} joint_vcf_reader = vcf.Reader(open(gcnv_joint_vcf, 'r')) joint_sample_list = list(joint_vcf_reader.samples) + sample_to_event_list = {s: [] for s in joint_sample_list} events_list = [] for record in joint_vcf_reader: - if not record.FILTER: + if record.FILTER: continue - del_call_samples = [] - dup_call_samples = [] if record.ALT[0] is None: continue + del_call_sample_to_qual = {} + dup_call_samples_to_qual = {} for sample in joint_sample_list: # TODO this does not work properly for allosomal contigs but we ignore them for now if record.genotype(sample)['CN']: # check that there is a value stored in the field if record.genotype(sample)['CN'] < 2: if record.genotype(sample)['QS'] >= sq_min_del: - del_call_samples.append(sample) + del_call_sample_to_qual[sample] = record.genotype(sample)['QS'] elif record.genotype(sample)['CN'] > 2: if record.genotype(sample)['QS'] >= sq_min_dup: - dup_call_samples.append(sample) + dup_call_samples_to_qual[sample] = record.genotype(sample)['QS'] chrom, start, end, num_bins = record.CHROM, int(record.POS), int(record.INFO['END']), int( record.genotype(joint_sample_list[0])['NP']) name_prefix = str(chrom) + "_" + str(start) + "_" + str(end) + "_" - af = (len(del_call_samples) + len(dup_call_samples)) / len(joint_sample_list) - if del_call_samples: - events_list.append([chrom, start, end, name_prefix + "DEL", EventType.DEL, frozenset(del_call_samples), af, num_bins]) - if dup_call_samples: - events_list.append([chrom, start, end, name_prefix + "DUP", EventType.DUP, frozenset(dup_call_samples), af, num_bins]) + af = (len(del_call_sample_to_qual) + len(dup_call_samples_to_qual)) / len(joint_sample_list) + if del_call_sample_to_qual: + for s, q in del_call_sample_to_qual.items(): + sample_to_event_list[s].append([chrom, start, end, EventType.DEL, num_bins, q, af]) + events_list.append([chrom, start, end, name_prefix + "DEL", EventType.DEL, frozenset(del_call_sample_to_qual.keys()), af, num_bins]) + if dup_call_samples_to_qual: + for s, q in dup_call_samples_to_qual.items(): + sample_to_event_list[s].append([chrom, start, end, EventType.DUP, num_bins, q, af]) + events_list.append([chrom, start, end, name_prefix + "DUP", EventType.DUP, frozenset(dup_call_samples_to_qual.keys()), af, num_bins]) joint_events_df = pd.DataFrame(events_list, columns=Callset.JOINT_CALLSET_COLUMNS) joint_events_df.astype(Callset.JOINT_CALLSET_COLUMN_TYPES) joint_callset_pr = pr.PyRanges(joint_events_df) - - sample_to_pyrange_map = Callset._construct_sample_to_pyrange_map(joint_callset_pr, frozenset(joint_sample_list)) + for s in joint_sample_list: + if len(sample_to_event_list[s]) > max_events_allowed: + continue + # TODO remove these samples from the joint callset as well + events_df = pd.DataFrame(sample_to_event_list[s], columns=Callset.CALLSET_COLUMNS) + sample_to_pyrange_map[s] = pr.PyRanges(events_df) return cls(sample_to_pyrange_map, joint_callset_pr, interval_collection) diff --git a/evaluation_code/pygcnveval/pygcnveval/evaluate_cnv_callset.py b/evaluation_code/pygcnveval/pygcnveval/evaluate_cnv_callset.py index dee0d04..130f81c 100644 --- a/evaluation_code/pygcnveval/pygcnveval/evaluate_cnv_callset.py +++ b/evaluation_code/pygcnveval/pygcnveval/evaluate_cnv_callset.py @@ -3,7 +3,7 @@ from callset import TruthCallset, GCNVCallset, XHMMCallset from interval_collection import IntervalCollection -from evaluator import PerEventEvaluator, PerBinEvaluator, PerSiteEvaluator +from evaluator import PerEventEvaluator, PerSiteEvaluator import plotting @@ -18,12 +18,14 @@ def evaluate_cnv_callsets_and_plot_results(analyzed_intervals: str, minimum_overlap: float, gcnv_sq_min_del: int, gcnv_sq_min_dup: int, + site_frequency_threshold: float, samples_to_evaluate_path: str): perform_per_site_eval = False print("Reading in interval list...", flush=True) interval_collection = IntervalCollection.read_interval_list(analyzed_intervals) callsets_to_evaluate = [] + gcnv_callset = None if gcnv_vcfs or gcnv_callset_tsv or gcnv_joint_vcf: print("Reading in gCNV callset...", flush=True) gcnv_callset = GCNVCallset.read_in_callset(gcnv_segment_vcfs=gcnv_vcfs, @@ -37,9 +39,10 @@ def evaluate_cnv_callsets_and_plot_results(analyzed_intervals: str, if xhmm_vcfs: print("Reading in XHMM callset", flush=True) + samples_to_keep = None if gcnv_callset else gcnv_callset.sample_set xhmm_callset = XHMMCallset.read_in_callset(xhmm_vcfs=xhmm_vcfs, interval_collection=interval_collection, - samples_to_keep=gcnv_callset.sample_set) + samples_to_keep=samples_to_keep) callsets_to_evaluate.append(xhmm_callset) if samples_to_evaluate_path: @@ -78,7 +81,8 @@ def evaluate_cnv_callsets_and_plot_results(analyzed_intervals: str, (len(per_event_evaluator.sample_list_to_eval), len(callset.sample_set))) per_event_evaluation_result = per_event_evaluator.evaluate_callset_against_the_truth(minimum_overlap=minimum_overlap, gcnv_sq_min_del=gcnv_sq_min_del, - gcnv_sq_min_dup=gcnv_sq_min_dup) + gcnv_sq_min_dup=gcnv_sq_min_dup, + site_frequency_threshold=site_frequency_threshold) per_event_evaluation_results.append(per_event_evaluation_result) per_event_evaluation_result.write_to_file(output_directory) @@ -125,6 +129,9 @@ def main(): parser.add_argument('--gcnv_min_sq_dup_threshold', metavar='MinimumSQDelThreshold', type=int, help='SQ threshold to filter gCNV duplication events on', required=True) + parser.add_argument('--site_frequency_threshold', metavar='SiteFrequencyThreshold', type=float, + help='Site frequency threshold for filtering variant sites', required=True) + parser.add_argument('--samples_to_evaluate_path', metavar='SamplesToEvaluate', type=str, help='A file containing the set of samples to evaluate, one sample per line.', required=False) @@ -141,14 +148,15 @@ def main(): min_required_overlap = args.min_required_overlap gcnv_sq_min_del = args.gcnv_min_sq_del_threshold gcnv_sq_min_dup = args.gcnv_min_sq_dup_threshold + site_frequency_threshold = args.site_frequency_threshold samples_to_evaluate_path = args.samples_to_evaluate_path assert (gcnv_segment_vcfs is None) ^ (gcnv_callset_tsv is None) ^ (gcnv_segment_vcfs is None), \ "Exactly one of the gCNV segment VCF list or gCNV TSV callset or joint gCNV VCF must be defined" - evaluate_cnv_callsets_and_plot_results(analyzed_intervals, truth_callset, gcnv_segment_vcfs, gcnv_callset_tsv, gcnv_max_event_number, - gcnv_joint_vcf, xhmm_vcfs, output_dir, min_required_overlap, - gcnv_sq_min_del, gcnv_sq_min_dup, samples_to_evaluate_path) + evaluate_cnv_callsets_and_plot_results(analyzed_intervals, truth_callset, gcnv_segment_vcfs, gcnv_callset_tsv, + gcnv_max_event_number, gcnv_joint_vcf, xhmm_vcfs, output_dir, min_required_overlap, + gcnv_sq_min_del, gcnv_sq_min_dup, site_frequency_threshold, samples_to_evaluate_path) if __name__ == '__main__': diff --git a/evaluation_code/pygcnveval/pygcnveval/evaluator.py b/evaluation_code/pygcnveval/pygcnveval/evaluator.py index 9a7e676..fd43131 100644 --- a/evaluation_code/pygcnveval/pygcnveval/evaluator.py +++ b/evaluation_code/pygcnveval/pygcnveval/evaluator.py @@ -16,22 +16,22 @@ def __init__(self, truth_callset: TruthCallset, callset: Callset, sample_list_to assert set(sample_list_to_evaluate).issubset(self.callset.sample_set) self.sample_list_to_eval = sample_list_to_evaluate - def evaluate_callset_against_the_truth(self, minimum_overlap: float = 0.2, - gcnv_sq_min_del: int = 100, gcnv_sq_min_dup: int = 50) -> PerEventEvaluationResult: + def evaluate_callset_against_the_truth(self, minimum_overlap: float = 0.2, gcnv_sq_min_del: int = 100, + gcnv_sq_min_dup: int = 50, site_frequency_threshold: float = 0.01) -> PerEventEvaluationResult: evaluation_result = PerEventEvaluationResult(self.callset.get_name()) # construct callset filter def gcnv_calls_filter(e): is_allosome = e.interval.chrom == "chrX" or e.interval.chrom == "chrY" \ or e.interval.chrom == "X" or e.interval.chrom == "Y" - return not is_allosome and (e.call_attributes['Frequency'] <= 0.02) \ + return not is_allosome and (e.call_attributes['Frequency'] <= site_frequency_threshold) \ and ((e.call_attributes['Quality'] >= gcnv_sq_min_del and e.event_type == EventType.DEL) or (e.call_attributes['Quality'] >= gcnv_sq_min_dup and e.event_type == EventType.DUP)) def xhmm_calls_filter(e): is_allosome = e.interval.chrom == "chrX" or e.interval.chrom == "chrY" \ or e.interval.chrom == "X" or e.interval.chrom == "Y" - return not is_allosome and (e.call_attributes['Frequency'] <= 0.02) and (e.call_attributes['Quality'] >= 60) + return not is_allosome and (e.call_attributes['Frequency'] <= site_frequency_threshold) and (e.call_attributes['Quality'] >= 60) calls_filter = None if self.callset.get_name() == "gCNV_callset": @@ -42,29 +42,12 @@ def xhmm_calls_filter(e): # Calculate precision for validated_event in self.callset.get_event_generator(self.sample_list_to_eval, calls_filter): - # if self.callset.get_name() == "gCNV_callset" and validated_event.call_attributes['Quality'] < 50 and validated_event.event_type == EventType.DUP: - # continue - # if self.callset.get_name() == "gCNV_callset" and validated_event.call_attributes['Quality'] < 100 and validated_event.event_type == EventType.DEL: - # continue - - # if self.callset.get_name() == "XHMM_callset" and validated_event.call_attributes['Quality'] < 60: - # continue - # - # if validated_event.interval.chrom == "chrX" or validated_event.interval.chrom == "chrY" or validated_event.interval.chrom == "X" or validated_event.interval.chrom == "Y": - # continue - # - # if self.callset.get_name() == "XHMM_callset" and validated_event.call_attributes['Frequency'] > 0.02: - # continue - # # check if event is common - # if self.callset.get_name() == "gCNV_callset" and validated_event.call_attributes['Frequency'] > 0.05: - # continue - overlapping_truth_events = self.truth_callset.get_overlapping_events_for_sample(validated_event.interval, validated_event.sample) overlapping_truth_event_best_match = validated_event.find_event_with_largest_overlap(overlapping_truth_events) if overlapping_truth_event_best_match: - if overlapping_truth_event_best_match.call_attributes['Frequency'] > 0.02: + if overlapping_truth_event_best_match.call_attributes['Frequency'] > site_frequency_threshold: evaluation_result.update_precision(validated_event.call_attributes['NumBins'], None, validated_event.event_type, validated_event.interval, "FILTERED_HIGH_TRUTH_AF", list([validated_event.sample])) @@ -82,7 +65,7 @@ def xhmm_calls_filter(e): for truth_event in self.truth_callset.get_event_generator(self.sample_list_to_eval): if truth_event.interval.chrom == "chrX" or truth_event.interval.chrom == "chrY" or truth_event.interval.chrom == "X" or truth_event.interval.chrom == "Y": continue - if truth_event.call_attributes['Frequency'] > 0.02: + if truth_event.call_attributes['Frequency'] > site_frequency_threshold: continue overlapping_gcnv_events = self.callset.get_overlapping_events_for_sample(truth_event.interval, truth_event.sample) @@ -105,7 +88,7 @@ def xhmm_calls_filter(e): "FILTERED_LOW_QUAL", list([truth_event.sample])) continue - if self.callset.get_name() == "XHMM_callset" and overlapping_gcnv_event_best_match.call_attributes['Frequency'] > 0.02: + if self.callset.get_name() == "XHMM_callset" and overlapping_gcnv_event_best_match.call_attributes['Frequency'] > site_frequency_threshold: evaluation_result.update_recall(truth_event.call_attributes['NumBins'], False, truth_event.event_type, truth_event.interval, "FILTERED_HIGH_AF", list([truth_event.sample])) # TODO refactor @@ -116,7 +99,7 @@ def xhmm_calls_filter(e): "FILTERED_LOW_QUAL", list([truth_event.sample])) # TODO refactor continue - if self.callset.get_name() == "gCNV_callset" and overlapping_gcnv_event_best_match.call_attributes['Frequency'] > 0.02: + if self.callset.get_name() == "gCNV_callset" and overlapping_gcnv_event_best_match.call_attributes['Frequency'] > site_frequency_threshold: evaluation_result.update_recall(truth_event.call_attributes['NumBins'], False, truth_event.event_type, truth_event.interval, "FILTERED_HIGH_AF", list([truth_event.sample])) continue @@ -151,7 +134,7 @@ def evaluate_callset_against_the_truth(self) -> PerEventEvaluationResult: overlapping_truth_alleles) if overlapping_truth_allele_best_match: - if overlapping_truth_allele_best_match.allele_attributes['Frequency'] > 0.01: + if overlapping_truth_allele_best_match.allele_attributes['Frequency'] > 0.02: evaluation_result.update_precision(validated_allele.allele_attributes['NumBins'], None, validated_allele.event_type, validated_allele.interval, "FILTERED_HIGH_TRUTH_AF", list(validated_allele.sample_set)) @@ -171,7 +154,7 @@ def evaluate_callset_against_the_truth(self) -> PerEventEvaluationResult: continue if truth_allele.interval.chrom == "chrX" or truth_allele.interval.chrom == "chrY": continue - if truth_allele.allele_attributes['Frequency'] > 0.01: + if truth_allele.allele_attributes['Frequency'] > 0.02: continue overlapping_gcnv_alleles = self.callset.get_overlapping_alleles(truth_allele.interval) diff --git a/evaluation_code/pygcnveval/pygcnveval/plotting.py b/evaluation_code/pygcnveval/pygcnveval/plotting.py index a9dc17d..2c12092 100644 --- a/evaluation_code/pygcnveval/pygcnveval/plotting.py +++ b/evaluation_code/pygcnveval/pygcnveval/plotting.py @@ -35,13 +35,13 @@ def sum_over_dict(d: dict, last_index: int, bins: list): for index, evaluation_result in enumerate(evaluation_result_list): tp = evaluation_result.precision_size_to_tp fp = evaluation_result.precision_size_to_fp - print(tp) - print(fp) + print("Number of true positives for events of the following bin sizes (precision calculation): {0}".format(tp)) + print("Number of false positives for events of the following bin sizes (precision calculation): {0}".format(fp)) precisions_for_bins.append([sum_over_dict(tp, i, bins) / max(1., (sum_over_dict(tp, i, bins)+sum_over_dict(fp, i, bins))) for i in bins]) tp = evaluation_result.recall_size_to_tp fn = evaluation_result.recall_size_to_fn - print(tp) - print(fn) + print("Number of true positives for events of the following bin sizes (recall calculation): {0}".format(tp)) + print("Number of false negatives for events of the following bin sizes (recall calculation): {0}".format(fn)) recall_for_bins.append([sum_over_dict(tp, i, bins) / max(1., (sum_over_dict(tp, i, bins) + sum_over_dict(fn, i, bins))) for i in bins]) f_1_for_bins.append([2 / ((1. / recall_for_bins[index][i]) + (1. / precisions_for_bins[index][i])) for i in bins])