Skip to content

Commit

Permalink
Db tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Dec 5, 2023
1 parent 9db6030 commit ea476d7
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,16 @@ jobs:
- name: Install dependencies
run: pip install -r dev_requirments.txt

- name: Set up SQLite and create test database
run: |
sudo apt-get install sqlite3
python -m sqlite3 test_db.db <<EOF
CREATE TABLE test_scores_table (
network TEXT,
checkpoint INTEGER,
threshold REAL,
scores TEXT
);
EOF
- name: Test with pytest
run: pytest tests
72 changes: 72 additions & 0 deletions tests/eval/db_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import unittest
import os
from autoseg.eval import Database


class TestDatabase(unittest.TestCase):
def setUp(self):
# Create a test database
self.db_name = "test_db"
self.table_name = "test_scores_table"
self.db = Database(self.db_name, self.table_name)

def tearDown(self):
# Remove the test database file after tests
if os.path.exists(f"{self.db_name}.db"):
os.remove(f"{self.db_name}.db")

def test_add_score(self):
# Test adding a score to the database
network = "TestNetwork"
checkpoint = 1
threshold = 0.5
scores_dict = {"metric": 0.8}
self.db.add_score(network, checkpoint, threshold, scores_dict)

# Verify that the score is added correctly
result = self.db.get_scores(networks=network)
self.assertEqual(len(result), 1)
self.assertEqual(result[0][0], network)
self.assertEqual(result[0][1], checkpoint)
self.assertEqual(result[0][2], threshold)
self.assertEqual(result[0][3], scores_dict)

def test_get_scores(self):
# Test retrieving scores from the database
network = "TestNetwork"
checkpoint = 1
threshold = 0.5
scores_dict = {"metric": 0.8}
self.db.add_score(network, checkpoint, threshold, scores_dict)

# Verify that the retrieved score matches the added score
result = self.db.get_scores(networks=network)
self.assertEqual(len(result), 1)
self.assertEqual(result[0][0], network)
self.assertEqual(result[0][1], checkpoint)
self.assertEqual(result[0][2], threshold)
self.assertEqual(result[0][3], scores_dict)

def test_get_scores_multiple_conditions(self):
# Test retrieving scores with multiple conditions
network = "TestNetwork"
checkpoint = 1
threshold = 0.5
scores_dict = {"metric": 0.8}
self.db.add_score(network, checkpoint, threshold, scores_dict)

# Add another score with different conditions
network2 = "TestNetwork2"
checkpoint2 = 2
threshold2 = 0.7
scores_dict2 = {"metric": 0.9}
self.db.add_score(network2, checkpoint2, threshold2, scores_dict2)

# Verify that the retrieved scores match the added scores with multiple conditions
result = self.db.get_scores(networks=[network, network2], checkpoints=[checkpoint2])
self.assertEqual(len(result), 1)
self.assertEqual(result[0][0], network2)
self.assertEqual(result[0][1], checkpoint2)
self.assertEqual(result[0][2], threshold2)
self.assertEqual(result[0][3], scores_dict2)

0 comments on commit ea476d7

Please sign in to comment.