Skip to content

Commit

Permalink
resolved ruff errors and mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuld committed Oct 24, 2024
1 parent 78c494d commit 75ca32c
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 171 deletions.
135 changes: 64 additions & 71 deletions tests/datasets/test_substation_seg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
import os
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pytest
import torch
from pytest import MonkeyPatch
from numpy.typing import NDArray

from torchgeo.datasets import DatasetNotFoundError, SubstationDataset


@dataclass
class Args:
data_dir: Path
in_channels: int
use_timepoints: bool
normalizing_type: str
normalizing_factor: NDArray[np.float64]
means: NDArray[np.float64]
stds: NDArray[np.float64]
mask_2d: bool
model_type: str


class TestSubstationDataset:
@pytest.fixture(
params=[
Expand All @@ -21,23 +35,21 @@ class TestSubstationDataset:
}
]
)
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path, request: pytest.FixtureRequest) -> SubstationDataset:
def dataset(self, tmp_path: Path, request: pytest.FixtureRequest) -> SubstationDataset:
"""
Fixture to create a mock dataset with specified parameters.
"""
class Args:
pass

args = Args()
args.data_dir = tmp_path
args.in_channels = 4
args.use_timepoints = False
args.normalizing_type = 'zscore'
args.normalizing_factor = np.array([1.0])
args.means = np.array([0.5])
args.stds = np.array([0.1])
args.mask_2d = True
args.model_type = 'segmentation'
args = Args(
data_dir=tmp_path,
in_channels=4,
use_timepoints=False,
normalizing_type='zscore',
normalizing_factor=np.array([1.0]),
means=np.array([0.5]),
stds=np.array([0.1]),
mask_2d=True,
model_type='segmentation',
)

# Creating mock image and mask files
for filename in request.param['image_files']:
Expand All @@ -46,91 +58,72 @@ class Args:
np.savez_compressed(os.path.join(tmp_path, 'image_stack', filename), arr_0=np.random.rand(4, 128, 128))
np.savez_compressed(os.path.join(tmp_path, 'mask', filename), arr_0=np.random.randint(0, 4, (128, 128)))

image_files = request.param['image_files']
geo_transforms = request.param['geo_transforms']
color_transforms = request.param['color_transforms']
image_resize = request.param['image_resize']
mask_resize = request.param['mask_resize']

return SubstationDataset(
args,
image_files=image_files,
geo_transforms=geo_transforms,
color_transforms=color_transforms,
image_resize=image_resize,
mask_resize=mask_resize,
image_files=request.param['image_files'],
geo_transforms=request.param['geo_transforms'],
color_transforms=request.param['color_transforms'],
image_resize=request.param['image_resize'],
mask_resize=request.param['mask_resize'],
)

def test_getitem(self, dataset: SubstationDataset) -> None:
image, mask = dataset[0]
"""Test that __getitem__ returns a valid image and mask tensor."""
data = dataset[0]
image = data["image"]
mask = data["mask"]
assert isinstance(image, torch.Tensor)
assert isinstance(mask, torch.Tensor)
assert image.shape[0] == 4 # Checking number of channels
assert mask.shape == (1, 128, 128)

def test_len(self, dataset: SubstationDataset) -> None:
"""Test that __len__ returns the correct length of the dataset."""
assert len(dataset) == 2

def test_already_downloaded(self, tmp_path: Path) -> None:
# Test to ensure dataset initialization doesn't download if data already exists
class Args:
pass

args = Args()
args.data_dir = tmp_path
args.in_channels = 4
args.use_timepoints = False
args.normalizing_type = 'zscore'
args.normalizing_factor = np.array([1.0])
args.means = np.array([0.5])
args.stds = np.array([0.1])
args.mask_2d = True
args.model_type = 'segmentation'

"""Test dataset initialization when data is already downloaded."""

os.makedirs(os.path.join(tmp_path, 'image_stack'))
os.makedirs(os.path.join(tmp_path, 'mask'))

# No need to assign `dataset` variable, just assert
SubstationDataset(args, image_files=[])
assert os.path.exists(os.path.join(tmp_path, 'image_stack'))
assert os.path.exists(os.path.join(tmp_path, 'mask'))

def test_not_downloaded(self, tmp_path: Path) -> None:
class Args:
pass

args = Args()
args.data_dir = tmp_path
args.in_channels = 4
args.use_timepoints = False
args.normalizing_type = 'zscore'
args.normalizing_factor = np.array([1.0])
args.means = np.array([0.5])
args.stds = np.array([0.1])
args.mask_2d = True
args.model_type = 'segmentation'
"""Test dataset initialization when data is not downloaded, expecting DatasetNotFoundError."""
args = Args(
data_dir=tmp_path,
in_channels=4,
use_timepoints=False,
normalizing_type='zscore',
normalizing_factor=np.array([1.0]),
means=np.array([0.5]),
stds=np.array([0.1]),
mask_2d=True,
model_type='segmentation',
)

with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SubstationDataset(args, image_files=[])

def test_plot(self, dataset: SubstationDataset) -> None:
"""Test that the plot function runs without throwing exceptions."""
dataset.plot()
# No assertion, just ensuring that the plotting does not throw any exceptions.

def test_corrupted(self, tmp_path: Path) -> None:
class Args:
pass

args = Args()
args.data_dir = tmp_path
args.in_channels = 4
args.use_timepoints = False
args.normalizing_type = 'zscore'
args.normalizing_factor = np.array([1.0])
args.means = np.array([0.5])
args.stds = np.array([0.1])
args.mask_2d = True
args.model_type = 'segmentation'
"""Test dataset loading with corrupted files."""
args = Args(
data_dir=tmp_path,
in_channels=4,
use_timepoints=False,
normalizing_type='zscore',
normalizing_factor=np.array([1.0]),
means=np.array([0.5]),
stds=np.array([0.1]),
mask_2d=True,
model_type='segmentation',
)

# Creating corrupted files
os.makedirs(os.path.join(tmp_path, 'image_stack'))
Expand Down
Loading

0 comments on commit 75ca32c

Please sign in to comment.