Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add intersection of input, reference to overlap #1817

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions scripts/data_overlap/compute_data_overlap_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

PART_INPUT: str = "input"
PART_REF: str = "references"
PART_INTERSECT: str = "intersect"


# type alias for overlap-related data structures
Expand Down Expand Up @@ -106,6 +107,24 @@ def create_ngram_index(
ngram_index[n][reference_ngram].add(
EntryDataOverlapKey(stats_key=stats_key, instance_id=id, part=PART_REF)
)

# concatenate the last n-1 tokens of input and the first n-1 tokens
# of reference and compute n-grams on this "interesection token sequence"
# for instance: input = ["is 2+2 4 true or false"] reference = ["true"]
# the intersection is the 5-gram ["4 true or false true"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is good (as we discussed), but (i) do we have to get providers to rerun with this new code (non-trivial cost) and (ii) I wonder how often the question and answer will be juxtaposed. If there's any token that separates the Q and A, then we won't detect overlap.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think for now we can just run this for the pile and see what sort of insights we get; this is simply an additional metric that doesn't affect existing metrics. If we're concerned for a token gap, we can allow for a skip token budget, esp for this metric, though it may be premature to add at this point

# (which is formed from the input 4-gram [4 true or false] and the reference 1-gram [true])
input_end_tokens = input_tokens[-(n - 1) :]
for reference in instance.references:
reference_unigrams = tokenizer.tokenize(reference)
reference_start_tokens = reference_unigrams[: n - 1]
intersection_tokens = input_end_tokens + reference_start_tokens
for intersection_ngram in ngrams(intersection_tokens, n):
if intersection_ngram not in ngram_index[n]:
ngram_index[n][intersection_ngram] = set()
ngram_index[n][intersection_ngram].add(
EntryDataOverlapKey(stats_key=stats_key, instance_id=id, part=PART_INTERSECT)
)

return ngram_index


Expand All @@ -116,6 +135,7 @@ def compute_all_data_overlap(
tokenizer: LightTokenizer,
stats_key_to_input_ids: DefaultDict[DataOverlapStatsKey, Set[str]],
stats_key_to_reference_ids: DefaultDict[DataOverlapStatsKey, Set[str]],
stats_key_to_intersection_ids: DefaultDict[DataOverlapStatsKey, Set[str]],
entry_overlap_key_to_ngram_counts: DefaultDict[EntryDataOverlapKey, DefaultDict[str, int]],
output_ngrams: bool,
) -> None:
Expand All @@ -140,6 +160,7 @@ def compute_all_data_overlap(
tokenizer=tokenizer,
stats_key_to_input_ids=stats_key_to_input_ids,
stats_key_to_reference_ids=stats_key_to_reference_ids,
stats_key_to_intersection_ids=stats_key_to_intersection_ids,
entry_overlap_key_to_ngram_counts=entry_overlap_key_to_ngram_counts,
output_ngrams=output_ngrams,
)
Expand All @@ -151,6 +172,7 @@ def compute_document_data_overlap(
tokenizer: LightTokenizer,
stats_key_to_input_ids: DefaultDict[DataOverlapStatsKey, Set[str]],
stats_key_to_reference_ids: DefaultDict[DataOverlapStatsKey, Set[str]],
stats_key_to_intersection_ids: DefaultDict[DataOverlapStatsKey, Set[str]],
entry_overlap_key_to_ngram_counts: DefaultDict[EntryDataOverlapKey, DefaultDict[str, int]],
output_ngrams: bool,
) -> None:
Expand Down Expand Up @@ -182,6 +204,8 @@ def compute_document_data_overlap(
stats_key_to_input_ids[entry_overlap_key.stats_key].add(id)
elif part == PART_REF:
stats_key_to_reference_ids[entry_overlap_key.stats_key].add(id)
elif part == PART_INTERSECT:
stats_key_to_intersection_ids[entry_overlap_key.stats_key].add(id)
if output_ngrams:
entry_overlap_key_to_ngram_counts[entry_overlap_key][document_ngram] += 1

Expand Down Expand Up @@ -214,6 +238,7 @@ def compute_document_data_overlap(
# DataOverlapStatsKey -> Set[str] for ids
stats_key_to_input_ids: DefaultDict[DataOverlapStatsKey, Set] = defaultdict(set)
stats_key_to_reference_ids: DefaultDict[DataOverlapStatsKey, Set] = defaultdict(set)
stats_key_to_intersection_ids: DefaultDict[DataOverlapStatsKey, Set] = defaultdict(set)

entry_overlap_key_to_ngram_counts: DefaultDict[EntryDataOverlapKey, DefaultDict[str, int]] = defaultdict(
lambda: defaultdict(int)
Expand All @@ -232,6 +257,7 @@ def compute_document_data_overlap(
tokenizer=tokenizer,
stats_key_to_input_ids=stats_key_to_input_ids,
stats_key_to_reference_ids=stats_key_to_reference_ids,
stats_key_to_intersection_ids=stats_key_to_intersection_ids,
entry_overlap_key_to_ngram_counts=entry_overlap_key_to_ngram_counts,
output_ngrams=not args.no_output_ngrams,
)
Expand All @@ -255,6 +281,7 @@ def compute_document_data_overlap(
data_overlap_stats_key=stats_key,
instance_ids_with_overlapping_input=sorted(stats_key_to_input_ids[stats_key]),
instance_ids_with_overlapping_reference=sorted(stats_key_to_reference_ids[stats_key]),
instance_ids_with_overlapping_intersection=sorted(stats_key_to_intersection_ids[stats_key]),
num_instances=count,
)
all_data_overlap_stats.append(data_overlap_stats)
Expand Down
2 changes: 2 additions & 0 deletions scripts/data_overlap/data_overlap_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class DataOverlapStats:

instance_ids_with_overlapping_reference: List[str]

instance_ids_with_overlapping_intersection: List[str]


@dataclass(frozen=True)
class EntryDataOverlapKey:
Expand Down
Loading