Skip to content

Commit

Permalink
MTLSD model test
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Dec 4, 2023
1 parent 96a83d8 commit a4c8184
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
26 changes: 26 additions & 0 deletions tests/models/mtlsd_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import unittest
import torch
from autoseg.networks import setup_unet
from autoseg.utils import neighborhood
from autoseg.models import MTLSDModel


class TestMTLSDModel(unittest.TestCase):

def setUp(self):
# Set up any required data or configuration for your tests
unet = setup_unet()
self.mtlsd_model = MTLSDModel(unet, unet.out_channels)

def test_forward(self):
# Test the forward method of STELARRModel
input_tensor = torch.randn((1, 1, 100, 100, 100))
lsds, affs = self.mtlsd_model(input_tensor)

# Check if the output tensors have the correct shapes
self.assertEqual(lsds.shape, (1, 10, 8, 8, 8))
self.assertEqual(affs.shape, (1, len(neighborhood), 8, 8, 8))

# Check if the values are within a reasonable range (this can be adjusted based on your model)
self.assertTrue(torch.all(lsds >= 0) and torch.all(lsds <= 1))
self.assertTrue(torch.all(affs >= 0) and torch.all(affs <= 1))
6 changes: 3 additions & 3 deletions tests/models/stelarr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def test_forward(self):
lsds, affs, fake = self.stelarr_model(input_tensor)

# Check if the output tensors have the correct shapes
self.assertEqual(lsds.shape, (1, 10, 60, 60, 60)) # Replace with actual output shape
self.assertEqual(affs.shape, (1, len(neighborhood), 60, 60, 60)) # Replace with actual output shape
self.assertEqual(fake.shape, (1, 1, 60, 60, 60)) # Replace with actual output shape
self.assertEqual(lsds.shape, (1, 10, 60, 60, 60))
self.assertEqual(affs.shape, (1, len(neighborhood), 60, 60, 60))
self.assertEqual(fake.shape, (1, 1, 60, 60, 60))

# Check if the values are within a reasonable range (this can be adjusted based on your model)
self.assertTrue(torch.all(lsds >= 0) and torch.all(lsds <= 1))
Expand Down

0 comments on commit a4c8184

Please sign in to comment.