Skip to content

Commit

Permalink
Merge pull request #64 from narumiruna/mypy
Browse files Browse the repository at this point in the history
support mypy
  • Loading branch information
narumiruna authored Jun 24, 2024
2 parents c8ea8b2 + d6ad95d commit c19e8b2
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 68 deletions.
35 changes: 0 additions & 35 deletions .github/workflows/poetry.yml

This file was deleted.

27 changes: 14 additions & 13 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ name: Python
on:
workflow_dispatch:
push:
branches: ["main"]
branches:
- main
pull_request:
branches: ["main"]
branches:
- main

jobs:
python:
Expand All @@ -14,24 +16,23 @@ jobs:
max-parallel: 1
matrix:
python-version: ["3.11"]
poetry-version: ["1.8.3"]
steps:
- uses: actions/checkout@v4
- name: Install poetry
run: pipx install poetry==${{ matrix.poetry-version }}
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
- name: Lint
run: |
pip install ruff
ruff check .
cache: poetry
- name: Install dependencies
run: |
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install .
run: poetry install
- name: Lint
run: poetry run ruff check .
- name: Type check
run: poetry run mypy --install-types --non-interactive .
- name: Test
run: |
pip install pytest pytest-cov
pytest -v -s --cov=. --cov-report=xml tests
run: poetry run pytest -v -s --cov=. --cov-report=xml tests
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4
env:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ known-third-party = ["wandb"]

[tool.pytest.ini_options]
filterwarnings = ["ignore::DeprecationWarning"]

[tool.mypy]
ignore_missing_imports = true
4 changes: 2 additions & 2 deletions template/jobs/job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from omegaconf import OmegaConf
from omegaconf import DictConfig


class Job:
def run(self, config: OmegaConf, resume=None) -> None:
def run(self, config: DictConfig, resume: str | None = None) -> None:
raise NotImplementedError
4 changes: 2 additions & 2 deletions template/jobs/mnist.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch
from mlconfig import instantiate
from mlconfig import register
from omegaconf import OmegaConf
from omegaconf import DictConfig

from ..utils import manual_seed
from .job import Job


@register
class MNISTTrainingJob(Job):
def run(self, config: OmegaConf, resume=None) -> None:
def run(self, config: DictConfig, resume: str | None = None) -> None:
manual_seed()

device = torch.device(config.device if torch.cuda.is_available() else "cpu")
Expand Down
30 changes: 15 additions & 15 deletions template/trainers/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from tqdm import trange

Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
self.num_epochs = num_epochs
self.num_classes = num_classes

self.best_acc = 0
self.best_acc = 0.0
self.state = {"epoch": 1}

def fit(self) -> None:
Expand All @@ -61,11 +61,11 @@ def fit(self) -> None:

self.state["epoch"] = epoch

def train(self) -> None:
def train(self) -> tuple[float, float]:
self.model.train()

loss_metric = MeanMetric().to(self.device)
acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes).to(self.device)
loss_metric = MeanMetric()
acc_metric = MulticlassAccuracy(num_classes=self.num_classes)

for x, y in tqdm(self.train_loader):
x = x.to(self.device)
Expand All @@ -78,17 +78,17 @@ def train(self) -> None:
loss.backward()
self.optimizer.step()

loss_metric.update(loss, weight=x.size(0))
acc_metric.update(output, y)
loss_metric.update(loss.item(), weight=x.size(0))
acc_metric.update(output.cpu(), y.cpu())

return loss_metric.compute().item(), acc_metric.compute().item()
return float(loss_metric.compute()), float(acc_metric.compute())

@torch.no_grad()
def evaluate(self) -> None:
def evaluate(self) -> tuple[float, float]:
self.model.eval()

loss_metric = MeanMetric().to(self.device)
acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes).to(self.device)
loss_metric = MeanMetric()
acc_metric = MulticlassAccuracy(num_classes=self.num_classes)

for x, y in tqdm(self.test_loader):
x = x.to(self.device)
Expand All @@ -97,15 +97,15 @@ def evaluate(self) -> None:
output = self.model(x)
loss = f.cross_entropy(output, y)

loss_metric.update(loss, weight=x.size(0))
acc_metric.update(output, y)
loss_metric.update(loss.item(), weight=x.size(0))
acc_metric.update(output.cpu(), y.cpu())

test_acc = acc_metric.compute().item()
test_acc = float(acc_metric.compute())
if test_acc > self.best_acc:
self.best_acc = test_acc
self.save_checkpoint("best.pth")

return loss_metric.compute().item(), test_acc
return float(loss_metric.compute()), test_acc

def save_checkpoint(self, f) -> None:
self.model.eval()
Expand Down
2 changes: 1 addition & 1 deletion template/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class Trainer:
def train(self) -> None:
def train(self):
raise NotImplementedError

0 comments on commit c19e8b2

Please sign in to comment.