Skip to content

Commit

Permalink
format by black
Browse files Browse the repository at this point in the history
  • Loading branch information
narumiruna committed Aug 31, 2023
1 parent a8aacde commit 642cd43
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 42 deletions.
4 changes: 2 additions & 2 deletions template/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


@click.command()
@click.option('-c', '--config-file', type=click.STRING, default='configs/mnist.yaml')
@click.option('-r', '--resume', type=click.STRING, default=None)
@click.option("-c", "--config-file", type=click.STRING, default="configs/mnist.yaml")
@click.option("-r", "--resume", type=click.STRING, default=None)
def main(config_file, resume):
config = load_config(config_file)

Expand Down
17 changes: 10 additions & 7 deletions template/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@

@register
class MNISTDataLoader(data.DataLoader):

def __init__(self, root: str, train: bool, batch_size: int, **kwargs):
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
transform = transforms.Compose(
[
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)

dataset = datasets.MNIST(root, train=train, transform=transform, download=True)

super(MNISTDataLoader, self).__init__(dataset=dataset, batch_size=batch_size, shuffle=train, **kwargs)
super(MNISTDataLoader, self).__init__(
dataset=dataset, batch_size=batch_size, shuffle=train, **kwargs
)
1 change: 0 additions & 1 deletion template/jobs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@


class Job:

def run(self, config: OmegaConf, resume=None) -> None:
raise NotImplementedError
15 changes: 11 additions & 4 deletions template/jobs/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,29 @@

@register
class MNISTTrainingJob(Job):

def run(self, config: OmegaConf, resume=None) -> None:
mlflow.log_text(OmegaConf.to_yaml(config), artifact_file='config.yaml')
mlflow.log_text(OmegaConf.to_yaml(config), artifact_file="config.yaml")

mlflow.log_params(config.log_params)

manual_seed()

device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
device = torch.device(config.device if torch.cuda.is_available() else "cpu")
model = instantiate(config.model).to(device)
optimizer = instantiate(config.optimizer, model.parameters())
scheduler = instantiate(config.scheduler, optimizer)
train_loader = instantiate(config.dataset, train=True)
test_loader = instantiate(config.dataset, train=False)

trainer = instantiate(config.trainer, device, model, optimizer, scheduler, train_loader, test_loader)
trainer = instantiate(
config.trainer,
device,
model,
optimizer,
scheduler,
train_loader,
test_loader,
)

if resume is not None:
trainer.resume(resume)
Expand Down
2 changes: 0 additions & 2 deletions template/models/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class ConvBNReLU(nn.Sequential):

def __init__(self, in_channels, out_channels, kernel_size):
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, bias=False),
Expand All @@ -15,7 +14,6 @@ def __init__(self, in_channels, out_channels, kernel_size):

@register
class LeNet(nn.Module):

def __init__(self):
super(LeNet, self).__init__()
self.features = nn.Sequential(
Expand Down
46 changes: 28 additions & 18 deletions template/trainers/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

@register
class MNISTTrainer(Trainer):

def __init__(self, device, model, optimizer, scheduler, train_loader, test_loader, num_epochs):
def __init__(
self, device, model, optimizer, scheduler, train_loader, test_loader, num_epochs
):
self.device = device
self.model = model
self.optimizer = optimizer
Expand All @@ -31,13 +32,22 @@ def fit(self):
test_loss, test_acc = self.evaluate()
self.scheduler.step()

metrics = dict(train_loss=train_loss, train_acc=train_acc, test_loss=test_loss, test_acc=test_acc)
metrics = dict(
train_loss=train_loss,
train_acc=train_acc,
test_loss=test_loss,
test_acc=test_acc,
)
mlflow.log_metrics(metrics, step=self.epoch)

format_string = 'Epoch: {}/{}, '.format(self.epoch, self.num_epochs)
format_string += 'train loss: {:.4f}, train acc: {:.4f}, '.format(train_loss, train_acc)
format_string += 'test loss: {:.4f}, test acc: {:.4f}, '.format(test_loss, test_acc)
format_string += 'best test acc: {:.4f}.'.format(self.best_acc)
format_string = "Epoch: {}/{}, ".format(self.epoch, self.num_epochs)
format_string += "train loss: {:.4f}, train acc: {:.4f}, ".format(
train_loss, train_acc
)
format_string += "test loss: {:.4f}, test acc: {:.4f}, ".format(
test_loss, test_acc
)
format_string += "best test acc: {:.4f}.".format(self.best_acc)
tqdm.write(format_string)

def train(self):
Expand Down Expand Up @@ -82,19 +92,19 @@ def evaluate(self):
test_acc = acc_metric.compute().item()
if test_acc > self.best_acc:
self.best_acc = test_acc
self.save_checkpoint('best.pth')
self.save_checkpoint("best.pth")

return loss_metric.compute().item(), test_acc

def save_checkpoint(self, f):
self.model.eval()

checkpoint = {
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'epoch': self.epoch,
'best_acc': self.best_acc
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"epoch": self.epoch,
"best_acc": self.best_acc,
}

torch.save(checkpoint, f)
Expand All @@ -103,8 +113,8 @@ def save_checkpoint(self, f):
def resume(self, f):
checkpoint = torch.load(f, map_location=self.device)

self.model.load_state_dict(checkpoint['model'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
self.epoch = checkpoint['epoch'] + 1
self.best_acc = checkpoint['best_acc']
self.model.load_state_dict(checkpoint["model"])
self.optimizer.load_state_dict(checkpoint["optimizer"])
self.scheduler.load_state_dict(checkpoint["scheduler"])
self.epoch = checkpoint["epoch"] + 1
self.best_acc = checkpoint["best_acc"]
1 change: 0 additions & 1 deletion template/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
class Trainer:

def train(self):
raise NotImplementedError
6 changes: 3 additions & 3 deletions template/utils/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from omegaconf import OmegaConf

_REGISTRY = {}
_KEY_OF_FUNC_OR_CLS = 'name'
_KEY_OF_FUNC_OR_CLS = "name"


def load_config(f=None, obj=None) -> OmegaConf:
Expand All @@ -29,7 +29,7 @@ def load_config(f=None, obj=None) -> OmegaConf:

length = len(configs)
if length == 0:
raise ValueError('No configuration file or structured object provided.')
raise ValueError("No configuration file or structured object provided.")
elif length == 1:
return configs[0]

Expand All @@ -53,7 +53,7 @@ def _register(func_or_cls, name=None):
if name not in _REGISTRY:
_REGISTRY[name] = func_or_cls
else:
raise ValueError('duplicate name {} found'.format(name))
raise ValueError("duplicate name {} found".format(name))

return func_or_cls

Expand Down
8 changes: 4 additions & 4 deletions template/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@ def manual_seed(seed=0):


def load_yaml(f):
with open(f, 'r') as fp:
with open(f, "r") as fp:
return yaml.safe_load(fp)


def save_yaml(data, f, **kwargs):
Path(f).parent.mkdir(parents=True, exist_ok=True)
with open(f, 'w') as fp:
with open(f, "w") as fp:
yaml.safe_dump(data, fp, **kwargs)


def load_json(f):
data = None
with open(f, 'r') as fp:
with open(f, "r") as fp:
data = json.load(fp)
return data


def save_json(data, f, **kwargs):
Path(f).parent.mkdir(parents=True, exist_ok=True)
with open(f, 'w') as fp:
with open(f, "w") as fp:
json.dump(data, fp, **kwargs)

0 comments on commit 642cd43

Please sign in to comment.