Skip to content

Commit

Permalink
Merge pull request #12 from datasciencecampus/feature/unit-tests
Browse files Browse the repository at this point in the history
Add unit tests
  • Loading branch information
ColinDaglish authored Jul 17, 2023
2 parents 8951a4c + 38fcf59 commit fe6d915
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ feature_count: #dict
max_features: null #null converts to None, or int value
lowercase: True #whether to convert all words to lowercase
lda: #dict
n_topics: 5 #int
n_topics: 5 #int greater than 0
n_top_words: 10 #int
max_iter: 25 #int
title: "Topic Summary" #str
Expand Down
4 changes: 2 additions & 2 deletions src/modules/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _generate_topic_labels(n_topics: int, topic_labels: list = None) -> list:
list of topic labels
"""
if topic_labels is None:
topic_labels = [f"Topic_{n}" for n in range(1, n_topics)]
topic_labels = [f"Topic_{n}" for n in range(1, n_topics + 1)]
else:
if len(topic_labels) != n_topics:
raise AttributeError("len(topic_labels) does not equal n_topics")
Expand All @@ -121,7 +121,7 @@ def _get_n_columns_and_n_rows(n_topics: int) -> int:
Parameters
----------
n_topics: int
number of topics
number of topics (must be integer greater than 0)
Returns
-------
int
Expand Down
102 changes: 102 additions & 0 deletions tests/modules/test_visualisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import sys
import unittest
from datetime import datetime as dt

import matplotlib.pyplot as plt
import pytest

from src.modules.visualisation import (
_generate_topic_labels,
_get_factors,
_get_fig_size,
_get_n_columns_and_n_rows,
save_figure,
)


class TestSaveFigure:
@pytest.mark.skipif(
sys.platform.startswith("linux"), reason="Unknown error during CI"
)
def test_file_created(self):
figure = plt.figure()
name = "test"
datestamp = dt.strftime(dt.now(), "%Y%m%d")
filepath = f"data/outputs/{datestamp}_{name}.jpeg"
save_figure(name, figure)
assert os.path.isfile(filepath), "Did not save a file with correct filename"
os.remove(filepath)


class TestGenerateTopicLabels(unittest.TestCase):
def test_topic_labels_is_none(self):
topic_labels = None
n_topics = 2
actual = _generate_topic_labels(n_topics, topic_labels)
expected = ["Topic_1", "Topic_2"]
assert actual == expected, "Topic labels did not match expected"

def test_topic_labels_preset(self):
topic_labels = ["My Topic", "Your Topic"]
n_topics = 2
actual = _generate_topic_labels(n_topics, topic_labels)
expected = ["My Topic", "Your Topic"]
assert actual == expected, "Topic labels did not match expected"

def test_raise_attribute_error(self):
topic_labels = ["One"]
n_topics = 2
with (self.assertRaises(Exception)) as context:
_generate_topic_labels(n_topics, topic_labels)
self.assertTrue("Does not raise an AttributeError", context.exception)


class TestGetNColumnsAndNRows(unittest.TestCase):
def test_raise_value_error(self):
n_topics = 0
with (self.assertRaises(Exception)) as context:
_get_n_columns_and_n_rows(n_topics)
self.assertTrue("Does not raise a value error", context.exception)

def test_n_topics_5_or_less(self):
n_topics = 4
actual = _get_n_columns_and_n_rows(n_topics)
expected = (1, 4)
assert actual == expected, "Did not produce the correct number of rows/columns"

def test_n_topics_above_5_with_factor(self):
n_topics = 24
actual = _get_n_columns_and_n_rows(n_topics)
expected = (6, 4)
assert actual == expected, "Did not produce the correct number of rows/columns"

def test_n_topics_above_5_without_factor(self):
n_topics = 23
actual = _get_n_columns_and_n_rows(n_topics)
expected = (6, 4)
assert actual == expected, "Did not produce the correct number of rows/columns"


class TestGetFactors:
def test_get_factors_of_4(self):
test_input = 4
actual = _get_factors(test_input)
expected = [1, 2, 4]
assert actual == expected, "Did not return the correct factors of 4"

def test_get_factors_of_5(self):
actual = _get_factors(5)
expected = [1, 5]
assert actual == expected, "Did not return the correct factors of 5"


class TestGetFigSize:
def test_get_fig_size(self):
columns = 2
rows = 4
actual = _get_fig_size(columns, rows)
expected_width = 12
expected_height = 27
expected_result = (expected_width, expected_height)
assert actual == expected_result, "expected height or width is not correct"

0 comments on commit fe6d915

Please sign in to comment.