Skip to content

Commit

Permalink
added two functions, _save_snapshot and _load_snapshot for model anal…
Browse files Browse the repository at this point in the history
…ysis and stopping/restarting training
  • Loading branch information
Gaurav Ray committed Apr 9, 2024
1 parent 7bafaf0 commit ff993b7
Showing 1 changed file with 60 additions and 28 deletions.
88 changes: 60 additions & 28 deletions src/_normflowcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@
used by other modules of this package.
"""


import torch
import time
import os

import numpy as np

from .mcmc import MCMCSampler, BlockedMCMCSampler
from .lib.combo import estimate_logz, fmt_val_err
from .device import ModelDeviceHandler


# =============================================================================
class Model:
"""The central high-level class of the package, which
Expand Down Expand Up @@ -144,12 +143,13 @@ def __init__(self, model):
print_stride=100,
print_batch_size=1024,
print_extra_func=None,
save_epochs=[],
save_fname_func=None
snapshot_path=None
)

def __call__(self,
n_epochs=1000,
save_every=None,
epochs_run=0,
batch_size=64,
optimizer_class=torch.optim.AdamW,
scheduler=None,
Expand All @@ -165,6 +165,9 @@ def __call__(self,
n_epochs : int
Number of epochs of training
save_every: int
save a model every <save_every> epochs
batch_size : int
Size of samples used at each epoch
Expand All @@ -187,6 +190,18 @@ def __call__(self,
self.hyperparam.update(hyperparam)
self.checkpoint_dict.update(checkpoint_dict)

snapshot_path = self.checkpoint_dict['snapshot_path']

if os.path.exists(snapshot_path):
print(f"Trying to load snapshot from {snapshot_path}")
self._load_snapshot()
else:
print("Starting training from scratch")

if save_every==None:
save_every = n_epochs // 10


self.loss_fn = Fitter.calc_kl_mean if loss_fn is None else loss_fn

net_ = self._model.net_
Expand All @@ -198,16 +213,44 @@ def __call__(self,

self.scheduler = None if scheduler is None else scheduler(self.optimizer)

return self.train(n_epochs, batch_size)

def train(self, n_epochs, batch_size):
return self.train(n_epochs, batch_size, epochs_run, save_every)

def _load_snapshot(self):
snapshot_path = self.checkpoint_dict['snapshot_path']
if torch.cuda.is_available():
gpu_id = int(os.environ["LOCAL_RANK"])
loc = f"cuda:{gpu_id}"
print(f"GPU: Attempting to load saved model into {loc}")
else:
loc = None # cpu training
print("CPU: Attempting to load saved model")
snapshot = torch.load(snapshot_path, map_location=loc)
print(f"Snapshot found at \n {snapshot_path} \n")
self._model.net_.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

def _save_snapshot(self, epoch):
""" Save snapshot of training for analysis and/or to continue
training at a later date. """

snapshot_path = self.checkpoint_dict['snapshot_path']
snapshot = {
"MODEL_STATE": self._model.net_.state_dict(),
"EPOCHS_RUN": epoch }
torch.save(snapshot, snapshot_path)
print(f"Epoch {epoch} | Training snapshot saved at {snapshot_path}")

def train(self, n_epochs, batch_size, epochs_run, save_every):
"""Train the model.
Parameters
----------
n_epochs : int
Number of epochs of training
epochs_run: int
batch_size : int
Size of samples used at each epoch
Expand All @@ -216,11 +259,10 @@ def train(self, n_epochs, batch_size):
being called at least once.
"""
self.train_batch_size = batch_size
last_epoch = len(self.train_history["loss"]) + 1
T1 = time.time()
for epoch in range(last_epoch, last_epoch + n_epochs):
for epoch in range(epochs_run, epochs_run+n_epochs):
loss, logqp = self.step()
self.checkpoint(epoch, loss)
self.checkpoint(epoch, loss, save_every)
if self.scheduler is not None:
self.scheduler.step()
T2 = time.time()
Expand Down Expand Up @@ -248,7 +290,7 @@ def step(self):

return loss, logq - logp

def checkpoint(self, epoch, loss):
def checkpoint(self, epoch, loss, save_every):

rank = self._model.device_handler.rank

Expand All @@ -259,8 +301,6 @@ def checkpoint(self, epoch, loss):
# For the rest
print_stride = self.checkpoint_dict['print_stride']
print_batch_size = self.checkpoint_dict['print_batch_size']
save_epochs = self.checkpoint_dict['save_epochs']
save_fname_func = self.checkpoint_dict['save_fname_func']

print_batch_size = print_batch_size // self._model.device_handler.nranks

Expand All @@ -275,16 +315,14 @@ def checkpoint(self, epoch, loss):
loss_ = self.loss_fn(logq, logp)
self._append_to_train_history(logq, logp)
self.print_fit_status(epoch, loss=loss_)

# if self.checkpoint_dict['display']:
# self.live_plot_handle.update(self.train_history)

if rank == 0 and epoch in save_epochs:
torch.save(self._model.net_, save_fname_func(epoch))

if rank == 0 and (epoch % save_every == 0):
self._save_snapshot(epoch)

@staticmethod
def calc_kl_mean(logq, logp):
"""Return Kullback-Leibler divergence estimated from logq and logp"""
"""Return Kullback-Leibler divergence estimated
from logq and logp """
return (logq - logp).mean() # KL, assuming samples from q

@staticmethod
Expand All @@ -297,18 +335,12 @@ def calc_corrcoef(logq, logp):

@staticmethod
def calc_direct_kl_mean(logq, logp):
r"""Return *direct* KL mean, which is defined as
"""Return *direct* KL mean, which is defined as
.. math::
\frac{\sum \frac{p}{q} (\log(\frac{p}{q}) + logz)}{\sum \frac{p}{q}}
where
.. math::
logz = \log( \sum(frac{p}{q}) / N)
wbere N is the number of samples. The direct KL means is invariant
under scaling p and/or q.
"""
Expand Down Expand Up @@ -376,7 +408,7 @@ def print_fit_status(self, epoch, loss=None):

if epoch == 1:
print(f"\n>>> Training progress ({ess.device}) <<<\n")
print("Note: log(q/p) is esitamted with normalized p; " \
print("Note: log(q/p) is estimated with normalized p; " \
+ "mean & error are obtained from samples in a batch\n")

str_ = f"Epoch: {epoch} | loss: {loss:g} | ess: {ess:g} | rho: {rho:g}"
Expand Down

0 comments on commit ff993b7

Please sign in to comment.