diff --git a/changelog.md b/changelog.md index 57d1fd3a7..bd3734d73 100644 --- a/changelog.md +++ b/changelog.md @@ -8,6 +8,7 @@ during the processing. Autocasting should result in a slight speedup, but may lead to numerical instability. - Use `torch.inference_mode` to disable view tracking and version counter bumps during inference. - Added a new NER pipeline for suicide attempt detection +- Added date cues (regular expression matches that contributed to a date being detected) under the extension `ent._.date_cues` ### Changed diff --git a/edsnlp/pipes/misc/dates/dates.py b/edsnlp/pipes/misc/dates/dates.py index 801bbdc3b..c184c4d21 100644 --- a/edsnlp/pipes/misc/dates/dates.py +++ b/edsnlp/pipes/misc/dates/dates.py @@ -1,4 +1,5 @@ """`eds.dates` pipeline.""" + import warnings from itertools import chain from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -174,6 +175,8 @@ class DatesMatcher(BaseNERComponent): Label to use for periods span_setter : SpanSetterArg How to set matches in the doc. + explain : bool + Whether to keep track of regex cues for each entity. Authors and citation -------------------- @@ -206,10 +209,12 @@ def __init__( "durations": ["duration"], "periods": ["period"], }, + explain: bool = False, ): self.date_label = date_label self.duration_label = duration_label self.period_label = period_label + self.explain = explain # Backward compatibility if as_ents is True: @@ -302,6 +307,9 @@ def set_extensions(self) -> None: if not Span.has_extension(self.period_label): Span.set_extension(self.period_label, default=None) + if not Span.has_extension("date_cues"): + Span.set_extension("date_cues", default=None) + def process(self, doc: Doc) -> List[Tuple[Span, Dict[str, str]]]: """ Find dates in doc. @@ -406,6 +414,9 @@ def parse( span.label_ = self.duration_label span._.duration = parsed + if self.explain: + span._.date_cues = groupdict + return [span for span, _ in matches] def process_periods(self, dates: List[Span]) -> List[Span]: diff --git a/tests/pipelines/misc/test_dates.py b/tests/pipelines/misc/test_dates.py index d36d84031..2ff4e98bc 100644 --- a/tests/pipelines/misc/test_dates.py +++ b/tests/pipelines/misc/test_dates.py @@ -72,7 +72,9 @@ @fixture(autouse=True) def add_date_pipeline(blank_nlp: PipelineProtocol): - blank_nlp.add_pipe("eds.dates", config=dict(detect_periods=True, as_ents=True)) + blank_nlp.add_pipe( + "eds.dates", config=dict(detect_periods=True, as_ents=True, explain=True) + ) def test_dates_component(blank_nlp: PipelineProtocol): @@ -89,6 +91,7 @@ def test_dates_component(blank_nlp: PipelineProtocol): for span, entity in zip(spans, entities): assert span.text == text[entity.start_char : entity.end_char] + assert bool(span._.date_cues) date = span._.date if span.label_ == "date" else span._.duration d = {modifier.key: modifier.value for modifier in entity.modifiers}