Skip to content

Commit

Permalink
support mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
narumiruna committed Jun 24, 2024
1 parent c8ea8b2 commit 35c2875
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 24 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/poetry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ name: Poetry
on:
workflow_dispatch:
push:
branches: ["main"]
branches:
- main
pull_request:
branches: ["main"]
branches:
- main

jobs:
poetry:
Expand All @@ -27,6 +29,8 @@ jobs:
run: poetry install
- name: Lint
run: poetry run ruff check .
- name: Type check
run: poetry run mypy --install-types --non-interactive .
- name: Test
run: poetry run pytest -v -s --cov=. --cov-report=xml tests
- name: Upload coverage reports to Codecov
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ name: Python
on:
workflow_dispatch:
push:
branches: ["main"]
branches:
- main
pull_request:
branches: ["main"]
branches:
- main

jobs:
python:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ known-third-party = ["wandb"]

[tool.pytest.ini_options]
filterwarnings = ["ignore::DeprecationWarning"]

[tool.mypy]
ignore_missing_imports = true
4 changes: 2 additions & 2 deletions template/jobs/job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from omegaconf import OmegaConf
from omegaconf import DictConfig


class Job:
def run(self, config: OmegaConf, resume=None) -> None:
def run(self, config: DictConfig, resume: str | None = None) -> None:
raise NotImplementedError
4 changes: 2 additions & 2 deletions template/jobs/mnist.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch
from mlconfig import instantiate
from mlconfig import register
from omegaconf import OmegaConf
from omegaconf import DictConfig

from ..utils import manual_seed
from .job import Job


@register
class MNISTTrainingJob(Job):
def run(self, config: OmegaConf, resume=None) -> None:
def run(self, config: DictConfig, resume: str | None = None) -> None:
manual_seed()

device = torch.device(config.device if torch.cuda.is_available() else "cpu")
Expand Down
30 changes: 15 additions & 15 deletions template/trainers/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from tqdm import trange

Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
self.num_epochs = num_epochs
self.num_classes = num_classes

self.best_acc = 0
self.best_acc = 0.0

Check warning on line 39 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L39

Added line #L39 was not covered by tests
self.state = {"epoch": 1}

def fit(self) -> None:
Expand All @@ -61,11 +61,11 @@ def fit(self) -> None:

self.state["epoch"] = epoch

def train(self) -> None:
def train(self) -> tuple[float, float]:
self.model.train()

loss_metric = MeanMetric().to(self.device)
acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes).to(self.device)
loss_metric = MeanMetric()
acc_metric = MulticlassAccuracy(num_classes=self.num_classes)

Check warning on line 68 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L67-L68

Added lines #L67 - L68 were not covered by tests

for x, y in tqdm(self.train_loader):
x = x.to(self.device)
Expand All @@ -78,17 +78,17 @@ def train(self) -> None:
loss.backward()
self.optimizer.step()

loss_metric.update(loss, weight=x.size(0))
acc_metric.update(output, y)
loss_metric.update(loss.item(), weight=x.size(0))
acc_metric.update(output.cpu(), y.cpu())

Check warning on line 82 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L81-L82

Added lines #L81 - L82 were not covered by tests

return loss_metric.compute().item(), acc_metric.compute().item()
return float(loss_metric.compute()), float(acc_metric.compute())

Check warning on line 84 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L84

Added line #L84 was not covered by tests

@torch.no_grad()
def evaluate(self) -> None:
def evaluate(self) -> tuple[float, float]:
self.model.eval()

loss_metric = MeanMetric().to(self.device)
acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes).to(self.device)
loss_metric = MeanMetric()
acc_metric = MulticlassAccuracy(num_classes=self.num_classes)

Check warning on line 91 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L90-L91

Added lines #L90 - L91 were not covered by tests

for x, y in tqdm(self.test_loader):
x = x.to(self.device)
Expand All @@ -97,15 +97,15 @@ def evaluate(self) -> None:
output = self.model(x)
loss = f.cross_entropy(output, y)

loss_metric.update(loss, weight=x.size(0))
acc_metric.update(output, y)
loss_metric.update(loss.item(), weight=x.size(0))
acc_metric.update(output.cpu(), y.cpu())

Check warning on line 101 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L100-L101

Added lines #L100 - L101 were not covered by tests

test_acc = acc_metric.compute().item()
test_acc = float(acc_metric.compute())

Check warning on line 103 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L103

Added line #L103 was not covered by tests
if test_acc > self.best_acc:
self.best_acc = test_acc
self.save_checkpoint("best.pth")

return loss_metric.compute().item(), test_acc
return float(loss_metric.compute()), test_acc

Check warning on line 108 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L108

Added line #L108 was not covered by tests

def save_checkpoint(self, f) -> None:
self.model.eval()
Expand Down
2 changes: 1 addition & 1 deletion template/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class Trainer:
def train(self) -> None:
def train(self):
raise NotImplementedError

0 comments on commit 35c2875

Please sign in to comment.