Skip to content

Commit

Permalink
Merge pull request #181 from MAIF/feature/improve_regex
Browse files Browse the repository at this point in the history
❇️ Add hooks to MelusineRegex
  • Loading branch information
HugoPerrier authored Oct 8, 2024
2 parents a265c56 + 5028d89 commit 29815b5
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 8 deletions.
69 changes: 68 additions & 1 deletion melusine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import re
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, TypeVar, Union

import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
Expand Down Expand Up @@ -308,12 +308,16 @@ class MissingFieldError(Exception):
"""


MatchData = Dict[str, List[Dict[str, Any]]]


class MelusineRegex(ABC):
"""
Class to standardise text pattern detection using regex.
"""

REGEX_FLAGS: re.RegexFlag = re.IGNORECASE | re.MULTILINE
PAIRED_MATCHING_PREFIX: str = "_"

# Match fields
MATCH_RESULT: str = "match_result"
Expand Down Expand Up @@ -483,6 +487,9 @@ def __call__(self, text: str) -> dict[str, Any]:
Returns:
_: Regex match results.
"""
# Apply pre match hook
text = self.pre_match_hook(text)

match_dict = {
self.MATCH_RESULT: False,
self.NEUTRAL_MATCH_FIELD: {},
Expand All @@ -509,6 +516,9 @@ def __call__(self, text: str) -> dict[str, Any]:

match_dict[self.MATCH_RESULT] = positive_match and not negative_match

# Apply post match hook
match_dict = self.post_match_hook(match_dict)

return match_dict

def describe(self, text: str, position: bool = False) -> None:
Expand Down Expand Up @@ -563,6 +573,63 @@ def _describe_match_field(match_field_data: dict[str, list[dict[str, Any]]]) ->
print("The following text matched positively:")
_describe_match_field(match_data[self.POSITIVE_MATCH_FIELD])

def apply_paired_matching(self, negative_match_data: MatchData, positive_match_data: MatchData) -> bool:
"""
Check if negative match is effective in the case of paired matching.
Args:
negative_match_data: negative_match_data
positive_match_data: positive_match_data
Returns:
effective_negative_match: negative_match adapted for paired matching
"""
effective_negative_match = False
if positive_match_data and negative_match_data:
positive_match_keys = set(positive_match_data.keys())

for key in negative_match_data:
if key.startswith(self.PAIRED_MATCHING_PREFIX):
if key[1:] in positive_match_keys:
effective_negative_match = True
else:
effective_negative_match = True

return effective_negative_match

def pre_match_hook(self, text: str) -> str:
"""
Hook to run before the Melusine regex match.
Args:
text: input text.
Returns:
_: Modified text.
"""
return text

def post_match_hook(self, match_dict: dict[str, Any]) -> dict[str, Any]:
"""
Hook to run after the Melusine regex match.
Args:
match_dict: Match results.
Returns:
_: Modified match results.
"""

# Paired matching
negative_match = self.apply_paired_matching(
match_dict[self.NEGATIVE_MATCH_FIELD], match_dict[self.POSITIVE_MATCH_FIELD]
)
positive_match = bool(match_dict[self.POSITIVE_MATCH_FIELD])

match_dict[self.MATCH_RESULT] = positive_match and not negative_match

return match_dict

def test(self) -> None:
"""
Test the MelusineRegex on the match_list and no_match_list.
Expand Down
92 changes: 85 additions & 7 deletions tests/base/test_melusine_regex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import pytest

Expand Down Expand Up @@ -46,13 +46,12 @@ def no_match_list(self) -> List[str]:

def test_erroneous_substitution_pattern():
with pytest.raises(ValueError):
regex = VirusRegex(substitution_pattern="12345")
_ = VirusRegex(substitution_pattern="12345")


def test_method_test():
regex = VirusRegex()
regex.test()
assert True


def test_match_method():
Expand Down Expand Up @@ -94,7 +93,7 @@ def test_describe_method(capfd):

# Negative match on bug (group NEGATIVE_BUG) and ignore ladybug and corona virus
regex.describe("The computer virus in the ladybug software caused a bug in the corona virus dashboard")
out, err = capfd.readouterr()
out, _ = capfd.readouterr()
assert "NEGATIVE_BUG" in out
assert "start" not in out

Expand All @@ -103,18 +102,18 @@ def test_describe_method(capfd):
"The computer virus in the ladybug software caused a bug in the corona virus dashboard",
position=True,
)
out, err = capfd.readouterr()
out, _ = capfd.readouterr()
assert "match result is : NEGATIVE" in out
assert "NEGATIVE_BUG" in out
assert "start" in out

regex.describe("This is a dangerous virus")
out, err = capfd.readouterr()
out, _ = capfd.readouterr()
assert "match result is : POSITIVE" in out
assert "start" not in out

regex.describe("Nada")
out, err = capfd.readouterr()
out, _ = capfd.readouterr()
assert "The input text did not match anything" in out


Expand Down Expand Up @@ -151,3 +150,82 @@ def no_match_list(self):
regex = SomeRegex()
assert regex.neutral is None
assert regex.negative is None


class PreMatchHookVirusRegex(VirusRegex):
def pre_match_hook(self, text: str) -> str:
text = text.replace("virrrrus", "virus")
return text


def test_pre_match_hook():
reg = PreMatchHookVirusRegex()

bool_match_result = reg.get_match_result("I see a virrrrus !")

assert bool_match_result is True


class PostMatchHookVirusRegex(VirusRegex):
def post_match_hook(self, match_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Test custom post processing of match data"""
if (
match_dict[self.MATCH_RESULT] is True
and "NEUTRAL_MEDICAL_VIRUS" in match_dict[self.NEUTRAL_MATCH_FIELD]
and "NEUTRAL_INSECT" in match_dict[self.NEUTRAL_MATCH_FIELD]
):
match_dict[self.MATCH_RESULT] = False

return match_dict


def test_post_match_hook():
reg = PostMatchHookVirusRegex()

bool_match_result = reg.get_match_result("I see a virus, a corona virus and a ladybug")
assert bool_match_result is False

bool_match_result = reg.get_match_result("I see a virus and a ladybug")
assert bool_match_result is True


class PairedMatchRegex(MelusineRegex):
"""
Test paired matching.
"""

@property
def positive(self) -> Union[str, Dict[str, str]]:
return {
"test_1": r"pos_pattern_1",
"test_2": r"pos_pattern_2",
}

@property
def negative(self) -> Optional[Union[str, Dict[str, str]]]:
return {
"_test_1": r"neg_pattern_1",
"generic": r"neg_pattern_2",
}

@property
def match_list(self) -> List[str]:
return [
"Test pos_pattern_1",
"pos_pattern_2",
"pos_pattern_2 and neg_pattern_1",
]

@property
def no_match_list(self) -> List[str]:
return [
"test",
"Test pos_pattern_1 and neg_pattern_1",
"pos_pattern_2 and neg_pattern_2",
"pos_pattern_1 and neg_pattern_2",
]


def test_paired_matching_test():
regex = PairedMatchRegex()
regex.test()

0 comments on commit 29815b5

Please sign in to comment.