Skip to content

Commit

Permalink
lda test
Browse files Browse the repository at this point in the history
  • Loading branch information
ColinDaglish committed Jul 13, 2023
1 parent d850567 commit 905fad7
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions tests/modules/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import sys
from itertools import repeat

import pytest
from pandas import DataFrame, Series
from scipy.sparse._csr import csr_matrix
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.feature_extraction.text import CountVectorizer

from src.modules.analysis import (
extract_feature_count,
get_total_feature_count,
latent_dirichlet_allocation,
retrieve_named_entities,
)

Expand All @@ -15,25 +17,32 @@ class TestExtractFeatureCount:
def test_feature_count(self):
data = Series(["My name is elf"])
expected = DataFrame([[1, 1, 1, 1]], columns=("elf", "is", "my", "name"))
actual = extract_feature_count(data)
actual = extract_feature_count(data)[1]
assert all(expected == actual), "Does not match expected output"

def test_remove_stopwords(self):
stopwords = ["is", "my"]
data = Series(["My name is elf"])
actual = extract_feature_count(data, stop_words=stopwords)
actual = extract_feature_count(data, stop_words=stopwords)[1]
expected = DataFrame([[1, 1]], columns=("elf", "name"))
assert all(expected == actual), "Does not remove stopwords"

def test_ngrams(self):
data = Series(["My name is elf"])
actual = extract_feature_count(data, ngram_range=(1, 2))
actual = extract_feature_count(data, ngram_range=(1, 2))[1]
expected = DataFrame(
[repeat(1, 7)],
columns=["elf", "is", "is elf", "my", "my name", "name", "name is"],
)
assert all(expected == actual), "Does not handle ngrams"

def test_get_fitted_vector(self):
data = Series(["My name is elf"])
actual = extract_feature_count(data)[0]
assert isinstance(
actual, csr_matrix
), "Does not return a csr_matrix object in position 0"


class TestGetTotalFeatureCount:
def test_get_total_feature_count(self):
Expand All @@ -49,7 +58,6 @@ def test_get_total_feature_count(self):


class TestRetrieveNamedEntities:
@pytest.mark.skipif(sys.platform.startswith("linux"), reason="Not sure")
def test_retrieve_named_entities(self):
test_data = Series(
[
Expand All @@ -60,4 +68,19 @@ def test_retrieve_named_entities(self):
)
actual = retrieve_named_entities(test_data)
expected = [["ONS", "the UK Government's"], [], ["Hollywood"]]
assert actual == expected, "Did not successfully retrieve named entities"
trimmed_actual = [component for component in actual if component != []]
trimmed_expected = [component for component in expected if component != []]
assert (
trimmed_actual == trimmed_expected
), "Did not successfully retrieve named entities"


class TestLatentDirichletAllocation:
def test_latent_dirichlet_allocation(self):
fitted = CountVectorizer().fit_transform(
Series(["My name is Elf and I like ignoble hats"])
)
lda = latent_dirichlet_allocation(10, 10, fitted)
assert isinstance(
lda, LatentDirichletAllocation
), "function did not return an latent dirichlet allocation object"

0 comments on commit 905fad7

Please sign in to comment.