From 35c2875335e52d2c355581db864bd3f26b1e9a73 Mon Sep 17 00:00:00 2001 From: narumi Date: Tue, 25 Jun 2024 01:46:59 +0800 Subject: [PATCH 1/2] support mypy --- .github/workflows/poetry.yml | 8 ++++++-- .github/workflows/python.yml | 6 ++++-- pyproject.toml | 3 +++ template/jobs/job.py | 4 ++-- template/jobs/mnist.py | 4 ++-- template/trainers/mnist.py | 30 +++++++++++++++--------------- template/trainers/trainer.py | 2 +- 7 files changed, 33 insertions(+), 24 deletions(-) diff --git a/.github/workflows/poetry.yml b/.github/workflows/poetry.yml index ee56b70..04c521f 100644 --- a/.github/workflows/poetry.yml +++ b/.github/workflows/poetry.yml @@ -3,9 +3,11 @@ name: Poetry on: workflow_dispatch: push: - branches: ["main"] + branches: + - main pull_request: - branches: ["main"] + branches: + - main jobs: poetry: @@ -27,6 +29,8 @@ jobs: run: poetry install - name: Lint run: poetry run ruff check . + - name: Type check + run: poetry run mypy --install-types --non-interactive . - name: Test run: poetry run pytest -v -s --cov=. --cov-report=xml tests - name: Upload coverage reports to Codecov diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 151ad52..7496904 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -3,9 +3,11 @@ name: Python on: workflow_dispatch: push: - branches: ["main"] + branches: + - main pull_request: - branches: ["main"] + branches: + - main jobs: python: diff --git a/pyproject.toml b/pyproject.toml index fa271f9..532f9f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,3 +55,6 @@ known-third-party = ["wandb"] [tool.pytest.ini_options] filterwarnings = ["ignore::DeprecationWarning"] + +[tool.mypy] +ignore_missing_imports = true diff --git a/template/jobs/job.py b/template/jobs/job.py index f769f65..e4d10ec 100644 --- a/template/jobs/job.py +++ b/template/jobs/job.py @@ -1,6 +1,6 @@ -from omegaconf import OmegaConf +from omegaconf import DictConfig class Job: - def run(self, config: OmegaConf, resume=None) -> None: + def run(self, config: DictConfig, resume: str | None = None) -> None: raise NotImplementedError diff --git a/template/jobs/mnist.py b/template/jobs/mnist.py index 7599456..57e774f 100644 --- a/template/jobs/mnist.py +++ b/template/jobs/mnist.py @@ -1,7 +1,7 @@ import torch from mlconfig import instantiate from mlconfig import register -from omegaconf import OmegaConf +from omegaconf import DictConfig from ..utils import manual_seed from .job import Job @@ -9,7 +9,7 @@ @register class MNISTTrainingJob(Job): - def run(self, config: OmegaConf, resume=None) -> None: + def run(self, config: DictConfig, resume: str | None = None) -> None: manual_seed() device = torch.device(config.device if torch.cuda.is_available() else "cpu") diff --git a/template/trainers/mnist.py b/template/trainers/mnist.py index dc922ae..115baca 100644 --- a/template/trainers/mnist.py +++ b/template/trainers/mnist.py @@ -6,8 +6,8 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader -from torchmetrics import Accuracy from torchmetrics import MeanMetric +from torchmetrics.classification import MulticlassAccuracy from tqdm import tqdm from tqdm import trange @@ -36,7 +36,7 @@ def __init__( self.num_epochs = num_epochs self.num_classes = num_classes - self.best_acc = 0 + self.best_acc = 0.0 self.state = {"epoch": 1} def fit(self) -> None: @@ -61,11 +61,11 @@ def fit(self) -> None: self.state["epoch"] = epoch - def train(self) -> None: + def train(self) -> tuple[float, float]: self.model.train() - loss_metric = MeanMetric().to(self.device) - acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes).to(self.device) + loss_metric = MeanMetric() + acc_metric = MulticlassAccuracy(num_classes=self.num_classes) for x, y in tqdm(self.train_loader): x = x.to(self.device) @@ -78,17 +78,17 @@ def train(self) -> None: loss.backward() self.optimizer.step() - loss_metric.update(loss, weight=x.size(0)) - acc_metric.update(output, y) + loss_metric.update(loss.item(), weight=x.size(0)) + acc_metric.update(output.cpu(), y.cpu()) - return loss_metric.compute().item(), acc_metric.compute().item() + return float(loss_metric.compute()), float(acc_metric.compute()) @torch.no_grad() - def evaluate(self) -> None: + def evaluate(self) -> tuple[float, float]: self.model.eval() - loss_metric = MeanMetric().to(self.device) - acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes).to(self.device) + loss_metric = MeanMetric() + acc_metric = MulticlassAccuracy(num_classes=self.num_classes) for x, y in tqdm(self.test_loader): x = x.to(self.device) @@ -97,15 +97,15 @@ def evaluate(self) -> None: output = self.model(x) loss = f.cross_entropy(output, y) - loss_metric.update(loss, weight=x.size(0)) - acc_metric.update(output, y) + loss_metric.update(loss.item(), weight=x.size(0)) + acc_metric.update(output.cpu(), y.cpu()) - test_acc = acc_metric.compute().item() + test_acc = float(acc_metric.compute()) if test_acc > self.best_acc: self.best_acc = test_acc self.save_checkpoint("best.pth") - return loss_metric.compute().item(), test_acc + return float(loss_metric.compute()), test_acc def save_checkpoint(self, f) -> None: self.model.eval() diff --git a/template/trainers/trainer.py b/template/trainers/trainer.py index ae1a555..609a710 100644 --- a/template/trainers/trainer.py +++ b/template/trainers/trainer.py @@ -1,3 +1,3 @@ class Trainer: - def train(self) -> None: + def train(self): raise NotImplementedError From d6ad95dbbfbd304d37d744da7e06ce4b9017eb4e Mon Sep 17 00:00:00 2001 From: narumi Date: Tue, 25 Jun 2024 01:54:03 +0800 Subject: [PATCH 2/2] delete pip workflow --- .github/workflows/poetry.yml | 39 ------------------------------------ .github/workflows/python.yml | 21 +++++++++---------- 2 files changed, 10 insertions(+), 50 deletions(-) delete mode 100644 .github/workflows/poetry.yml diff --git a/.github/workflows/poetry.yml b/.github/workflows/poetry.yml deleted file mode 100644 index 04c521f..0000000 --- a/.github/workflows/poetry.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: Poetry - -on: - workflow_dispatch: - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - poetry: - runs-on: ubuntu-latest - strategy: - max-parallel: 1 - matrix: - python-version: ["3.11"] - poetry-version: ["1.8.3"] - steps: - - uses: actions/checkout@v4 - - name: Install poetry - run: pipx install poetry==${{ matrix.poetry-version }} - - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: poetry - - name: Install dependencies - run: poetry install - - name: Lint - run: poetry run ruff check . - - name: Type check - run: poetry run mypy --install-types --non-interactive . - - name: Test - run: poetry run pytest -v -s --cov=. --cov-report=xml tests - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v4 - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 7496904..d1f0725 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -16,24 +16,23 @@ jobs: max-parallel: 1 matrix: python-version: ["3.11"] + poetry-version: ["1.8.3"] steps: - uses: actions/checkout@v4 + - name: Install poetry + run: pipx install poetry==${{ matrix.poetry-version }} - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - cache: pip - - name: Lint - run: | - pip install ruff - ruff check . + cache: poetry - name: Install dependencies - run: | - pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - pip install . + run: poetry install + - name: Lint + run: poetry run ruff check . + - name: Type check + run: poetry run mypy --install-types --non-interactive . - name: Test - run: | - pip install pytest pytest-cov - pytest -v -s --cov=. --cov-report=xml tests + run: poetry run pytest -v -s --cov=. --cov-report=xml tests - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4 env: