Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tf-idf fix: punctuation removal + lowercasing #339

Merged
merged 14 commits into from
Sep 24, 2024
13 changes: 12 additions & 1 deletion fastembed/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from itertools import islice
from pathlib import Path
from typing import Generator, Iterable, Optional, Union

import unicodedata
import sys
import numpy as np


Expand Down Expand Up @@ -41,3 +42,13 @@ def define_cache_dir(cache_dir: Optional[str] = None) -> Path:
cache_path.mkdir(parents=True, exist_ok=True)

return cache_path


def get_all_punctuation():
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved
return set(
chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")
)


def replace_punctuation(text, punctuation):
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved
return "".join(" " if char in punctuation else char for char in text)
18 changes: 14 additions & 4 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import string
from collections import defaultdict
from multiprocessing import get_all_start_methods
from pathlib import Path
Expand All @@ -9,7 +8,12 @@
import numpy as np
from snowballstemmer import stemmer as get_stemmer

from fastembed.common.utils import define_cache_dir, iter_batch
from fastembed.common.utils import (
define_cache_dir,
iter_batch,
get_all_punctuation,
replace_punctuation,
)
from fastembed.parallel_processor import ParallelWorkerPool, Worker
from fastembed.sparse.sparse_embedding_base import (
SparseEmbedding,
Expand Down Expand Up @@ -120,11 +124,15 @@ def __init__(
model_description, self.cache_dir, local_files_only=self._local_files_only
)

self.punctuation = set(string.punctuation)
self.punctuation = set(get_all_punctuation())
self.stopwords = set(self._load_stopwords(model_dir, self.language))

self.stemmer = get_stemmer(language)
self.tokenizer = WordTokenizer

self.num_terms: int = 0
self.token_accum: int = 0
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""Lists the supported models.
Expand Down Expand Up @@ -222,7 +230,7 @@ def _stem(self, tokens: List[str]) -> List[str]:
if token.lower() in self.stopwords:
continue

stemmed_token = self.stemmer.stemWord(token)
stemmed_token = self.stemmer.stemWord(token.lower())

if stemmed_token:
stemmed_tokens.append(stemmed_token)
Expand All @@ -234,6 +242,7 @@ def raw_embed(
) -> List[SparseEmbedding]:
embeddings = []
for document in documents:
document = replace_punctuation(document, self.punctuation)
tokens = self.tokenizer.tokenize(document)
stemmed_tokens = self._stem(tokens)
token_id2value = self._term_frequency(stemmed_tokens)
Expand Down Expand Up @@ -282,6 +291,7 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[Sp
query = [query]

for text in query:
text = replace_punctuation(text, self.punctuation)
tokens = self.tokenizer.tokenize(text)
stemmed_tokens = self._stem(tokens)
token_ids = np.array(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ numpy = [
pillow = "^10.3.0"
snowballstemmer = "^2.2.0"
PyStemmer = "^2.2.0"
mmh3 = "^4.0"
mmh3 = "^4.1.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.2"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_attention_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def test_multilanguage(model_name):

model = SparseTextEmbedding(model_name=model_name, language="english")
embeddings = list(model.embed(docs))[:2]
assert embeddings[0].values.shape == (4,)
assert embeddings[0].indices.shape == (4,)
assert embeddings[0].values.shape == (5,)
assert embeddings[0].indices.shape == (5,)

assert embeddings[1].values.shape == (4,)
assert embeddings[1].indices.shape == (4,)
2 changes: 1 addition & 1 deletion tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,5 @@ def test_stem_case_insensitive_stopwords(bm25_instance):
result = bm25_instance._stem(tokens)

# Assert
expected = ["Quick", "Brown", "Fox", "Test", "Sentenc"]
expected = ["quick", "brown", "fox", "test", "sentenc"]
assert result == expected, f"Expected {expected}, but got {result}"
Loading