Skip to content

Commit

Permalink
feat: add date cues (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
cvinot authored Jul 17, 2024
1 parent da4321c commit 4995daa
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
during the processing. Autocasting should result in a slight speedup, but may lead to numerical instability.
- Use `torch.inference_mode` to disable view tracking and version counter bumps during inference.
- Added a new NER pipeline for suicide attempt detection
- Added date cues (regular expression matches that contributed to a date being detected) under the extension `ent._.date_cues`

### Changed

Expand Down
11 changes: 11 additions & 0 deletions edsnlp/pipes/misc/dates/dates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""`eds.dates` pipeline."""

import warnings
from itertools import chain
from typing import Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -174,6 +175,8 @@ class DatesMatcher(BaseNERComponent):
Label to use for periods
span_setter : SpanSetterArg
How to set matches in the doc.
explain : bool
Whether to keep track of regex cues for each entity.
Authors and citation
--------------------
Expand Down Expand Up @@ -206,10 +209,12 @@ def __init__(
"durations": ["duration"],
"periods": ["period"],
},
explain: bool = False,
):
self.date_label = date_label
self.duration_label = duration_label
self.period_label = period_label
self.explain = explain

# Backward compatibility
if as_ents is True:
Expand Down Expand Up @@ -302,6 +307,9 @@ def set_extensions(self) -> None:
if not Span.has_extension(self.period_label):
Span.set_extension(self.period_label, default=None)

if not Span.has_extension("date_cues"):
Span.set_extension("date_cues", default=None)

def process(self, doc: Doc) -> List[Tuple[Span, Dict[str, str]]]:
"""
Find dates in doc.
Expand Down Expand Up @@ -406,6 +414,9 @@ def parse(
span.label_ = self.duration_label
span._.duration = parsed

if self.explain:
span._.date_cues = groupdict

return [span for span, _ in matches]

def process_periods(self, dates: List[Span]) -> List[Span]:
Expand Down
5 changes: 4 additions & 1 deletion tests/pipelines/misc/test_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@

@fixture(autouse=True)
def add_date_pipeline(blank_nlp: PipelineProtocol):
blank_nlp.add_pipe("eds.dates", config=dict(detect_periods=True, as_ents=True))
blank_nlp.add_pipe(
"eds.dates", config=dict(detect_periods=True, as_ents=True, explain=True)
)


def test_dates_component(blank_nlp: PipelineProtocol):
Expand All @@ -89,6 +91,7 @@ def test_dates_component(blank_nlp: PipelineProtocol):

for span, entity in zip(spans, entities):
assert span.text == text[entity.start_char : entity.end_char]
assert bool(span._.date_cues)

date = span._.date if span.label_ == "date" else span._.duration
d = {modifier.key: modifier.value for modifier in entity.modifiers}
Expand Down

0 comments on commit 4995daa

Please sign in to comment.