From 642cd43157c272201e1cb18729395ee62b70f8ee Mon Sep 17 00:00:00 2001 From: narumi Date: Thu, 31 Aug 2023 17:35:35 +0800 Subject: [PATCH] format by black --- template/cli.py | 4 ++-- template/datasets/mnist.py | 17 +++++++------ template/jobs/job.py | 1 - template/jobs/mnist.py | 15 ++++++++---- template/models/lenet.py | 2 -- template/trainers/mnist.py | 46 ++++++++++++++++++++++-------------- template/trainers/trainer.py | 1 - template/utils/conf.py | 6 ++--- template/utils/utils.py | 8 +++---- 9 files changed, 58 insertions(+), 42 deletions(-) diff --git a/template/cli.py b/template/cli.py index c59513f..de216d1 100644 --- a/template/cli.py +++ b/template/cli.py @@ -5,8 +5,8 @@ @click.command() -@click.option('-c', '--config-file', type=click.STRING, default='configs/mnist.yaml') -@click.option('-r', '--resume', type=click.STRING, default=None) +@click.option("-c", "--config-file", type=click.STRING, default="configs/mnist.yaml") +@click.option("-r", "--resume", type=click.STRING, default=None) def main(config_file, resume): config = load_config(config_file) diff --git a/template/datasets/mnist.py b/template/datasets/mnist.py index 9177f69..366cb36 100755 --- a/template/datasets/mnist.py +++ b/template/datasets/mnist.py @@ -7,14 +7,17 @@ @register class MNISTDataLoader(data.DataLoader): - def __init__(self, root: str, train: bool, batch_size: int, **kwargs): - transform = transforms.Compose([ - transforms.Resize(32), - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)), - ]) + transform = transforms.Compose( + [ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ) dataset = datasets.MNIST(root, train=train, transform=transform, download=True) - super(MNISTDataLoader, self).__init__(dataset=dataset, batch_size=batch_size, shuffle=train, **kwargs) + super(MNISTDataLoader, self).__init__( + dataset=dataset, batch_size=batch_size, shuffle=train, **kwargs + ) diff --git a/template/jobs/job.py b/template/jobs/job.py index a4955ba..f769f65 100644 --- a/template/jobs/job.py +++ b/template/jobs/job.py @@ -2,6 +2,5 @@ class Job: - def run(self, config: OmegaConf, resume=None) -> None: raise NotImplementedError diff --git a/template/jobs/mnist.py b/template/jobs/mnist.py index f739256..e5d8f6d 100644 --- a/template/jobs/mnist.py +++ b/template/jobs/mnist.py @@ -10,22 +10,29 @@ @register class MNISTTrainingJob(Job): - def run(self, config: OmegaConf, resume=None) -> None: - mlflow.log_text(OmegaConf.to_yaml(config), artifact_file='config.yaml') + mlflow.log_text(OmegaConf.to_yaml(config), artifact_file="config.yaml") mlflow.log_params(config.log_params) manual_seed() - device = torch.device(config.device if torch.cuda.is_available() else 'cpu') + device = torch.device(config.device if torch.cuda.is_available() else "cpu") model = instantiate(config.model).to(device) optimizer = instantiate(config.optimizer, model.parameters()) scheduler = instantiate(config.scheduler, optimizer) train_loader = instantiate(config.dataset, train=True) test_loader = instantiate(config.dataset, train=False) - trainer = instantiate(config.trainer, device, model, optimizer, scheduler, train_loader, test_loader) + trainer = instantiate( + config.trainer, + device, + model, + optimizer, + scheduler, + train_loader, + test_loader, + ) if resume is not None: trainer.resume(resume) diff --git a/template/models/lenet.py b/template/models/lenet.py index ca20b84..1f5da56 100644 --- a/template/models/lenet.py +++ b/template/models/lenet.py @@ -4,7 +4,6 @@ class ConvBNReLU(nn.Sequential): - def __init__(self, in_channels, out_channels, kernel_size): super(ConvBNReLU, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size, bias=False), @@ -15,7 +14,6 @@ def __init__(self, in_channels, out_channels, kernel_size): @register class LeNet(nn.Module): - def __init__(self): super(LeNet, self).__init__() self.features = nn.Sequential( diff --git a/template/trainers/mnist.py b/template/trainers/mnist.py index eee36f9..1f9df23 100644 --- a/template/trainers/mnist.py +++ b/template/trainers/mnist.py @@ -12,8 +12,9 @@ @register class MNISTTrainer(Trainer): - - def __init__(self, device, model, optimizer, scheduler, train_loader, test_loader, num_epochs): + def __init__( + self, device, model, optimizer, scheduler, train_loader, test_loader, num_epochs + ): self.device = device self.model = model self.optimizer = optimizer @@ -31,13 +32,22 @@ def fit(self): test_loss, test_acc = self.evaluate() self.scheduler.step() - metrics = dict(train_loss=train_loss, train_acc=train_acc, test_loss=test_loss, test_acc=test_acc) + metrics = dict( + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + ) mlflow.log_metrics(metrics, step=self.epoch) - format_string = 'Epoch: {}/{}, '.format(self.epoch, self.num_epochs) - format_string += 'train loss: {:.4f}, train acc: {:.4f}, '.format(train_loss, train_acc) - format_string += 'test loss: {:.4f}, test acc: {:.4f}, '.format(test_loss, test_acc) - format_string += 'best test acc: {:.4f}.'.format(self.best_acc) + format_string = "Epoch: {}/{}, ".format(self.epoch, self.num_epochs) + format_string += "train loss: {:.4f}, train acc: {:.4f}, ".format( + train_loss, train_acc + ) + format_string += "test loss: {:.4f}, test acc: {:.4f}, ".format( + test_loss, test_acc + ) + format_string += "best test acc: {:.4f}.".format(self.best_acc) tqdm.write(format_string) def train(self): @@ -82,7 +92,7 @@ def evaluate(self): test_acc = acc_metric.compute().item() if test_acc > self.best_acc: self.best_acc = test_acc - self.save_checkpoint('best.pth') + self.save_checkpoint("best.pth") return loss_metric.compute().item(), test_acc @@ -90,11 +100,11 @@ def save_checkpoint(self, f): self.model.eval() checkpoint = { - 'model': self.model.state_dict(), - 'optimizer': self.optimizer.state_dict(), - 'scheduler': self.scheduler.state_dict(), - 'epoch': self.epoch, - 'best_acc': self.best_acc + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler.state_dict(), + "epoch": self.epoch, + "best_acc": self.best_acc, } torch.save(checkpoint, f) @@ -103,8 +113,8 @@ def save_checkpoint(self, f): def resume(self, f): checkpoint = torch.load(f, map_location=self.device) - self.model.load_state_dict(checkpoint['model']) - self.optimizer.load_state_dict(checkpoint['optimizer']) - self.scheduler.load_state_dict(checkpoint['scheduler']) - self.epoch = checkpoint['epoch'] + 1 - self.best_acc = checkpoint['best_acc'] + self.model.load_state_dict(checkpoint["model"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + self.scheduler.load_state_dict(checkpoint["scheduler"]) + self.epoch = checkpoint["epoch"] + 1 + self.best_acc = checkpoint["best_acc"] diff --git a/template/trainers/trainer.py b/template/trainers/trainer.py index ece8eb7..609a710 100644 --- a/template/trainers/trainer.py +++ b/template/trainers/trainer.py @@ -1,4 +1,3 @@ class Trainer: - def train(self): raise NotImplementedError diff --git a/template/utils/conf.py b/template/utils/conf.py index cc80d9f..d93ada1 100644 --- a/template/utils/conf.py +++ b/template/utils/conf.py @@ -4,7 +4,7 @@ from omegaconf import OmegaConf _REGISTRY = {} -_KEY_OF_FUNC_OR_CLS = 'name' +_KEY_OF_FUNC_OR_CLS = "name" def load_config(f=None, obj=None) -> OmegaConf: @@ -29,7 +29,7 @@ def load_config(f=None, obj=None) -> OmegaConf: length = len(configs) if length == 0: - raise ValueError('No configuration file or structured object provided.') + raise ValueError("No configuration file or structured object provided.") elif length == 1: return configs[0] @@ -53,7 +53,7 @@ def _register(func_or_cls, name=None): if name not in _REGISTRY: _REGISTRY[name] = func_or_cls else: - raise ValueError('duplicate name {} found'.format(name)) + raise ValueError("duplicate name {} found".format(name)) return func_or_cls diff --git a/template/utils/utils.py b/template/utils/utils.py index e4f0aaa..fc7bd66 100644 --- a/template/utils/utils.py +++ b/template/utils/utils.py @@ -14,24 +14,24 @@ def manual_seed(seed=0): def load_yaml(f): - with open(f, 'r') as fp: + with open(f, "r") as fp: return yaml.safe_load(fp) def save_yaml(data, f, **kwargs): Path(f).parent.mkdir(parents=True, exist_ok=True) - with open(f, 'w') as fp: + with open(f, "w") as fp: yaml.safe_dump(data, fp, **kwargs) def load_json(f): data = None - with open(f, 'r') as fp: + with open(f, "r") as fp: data = json.load(fp) return data def save_json(data, f, **kwargs): Path(f).parent.mkdir(parents=True, exist_ok=True) - with open(f, 'w') as fp: + with open(f, "w") as fp: json.dump(data, fp, **kwargs)