diff --git a/src/_normflowcore.py b/src/_normflowcore.py index a462b37..556e34d 100644 --- a/src/_normflowcore.py +++ b/src/_normflowcore.py @@ -18,9 +18,9 @@ used by other modules of this package. """ - import torch import time +import os import numpy as np @@ -28,7 +28,6 @@ from .lib.combo import estimate_logz, fmt_val_err from .device import ModelDeviceHandler - # ============================================================================= class Model: """The central high-level class of the package, which @@ -144,12 +143,13 @@ def __init__(self, model): print_stride=100, print_batch_size=1024, print_extra_func=None, - save_epochs=[], - save_fname_func=None + snapshot_path=None ) def __call__(self, n_epochs=1000, + save_every=None, + epochs_run=0, batch_size=64, optimizer_class=torch.optim.AdamW, scheduler=None, @@ -165,6 +165,9 @@ def __call__(self, n_epochs : int Number of epochs of training + save_every: int + save a model every epochs + batch_size : int Size of samples used at each epoch @@ -187,6 +190,18 @@ def __call__(self, self.hyperparam.update(hyperparam) self.checkpoint_dict.update(checkpoint_dict) + snapshot_path = self.checkpoint_dict['snapshot_path'] + + if os.path.exists(snapshot_path): + print(f"Trying to load snapshot from {snapshot_path}") + self._load_snapshot() + else: + print("Starting training from scratch") + + if save_every==None: + save_every = n_epochs // 10 + + self.loss_fn = Fitter.calc_kl_mean if loss_fn is None else loss_fn net_ = self._model.net_ @@ -198,9 +213,35 @@ def __call__(self, self.scheduler = None if scheduler is None else scheduler(self.optimizer) - return self.train(n_epochs, batch_size) - - def train(self, n_epochs, batch_size): + return self.train(n_epochs, batch_size, epochs_run, save_every) + + def _load_snapshot(self): + snapshot_path = self.checkpoint_dict['snapshot_path'] + if torch.cuda.is_available(): + gpu_id = int(os.environ["LOCAL_RANK"]) + loc = f"cuda:{gpu_id}" + print(f"GPU: Attempting to load saved model into {loc}") + else: + loc = None # cpu training + print("CPU: Attempting to load saved model") + snapshot = torch.load(snapshot_path, map_location=loc) + print(f"Snapshot found at \n {snapshot_path} \n") + self._model.net_.load_state_dict(snapshot["MODEL_STATE"]) + self.epochs_run = snapshot["EPOCHS_RUN"] + print(f"Resuming training from snapshot at Epoch {self.epochs_run}") + + def _save_snapshot(self, epoch): + """ Save snapshot of training for analysis and/or to continue + training at a later date. """ + + snapshot_path = self.checkpoint_dict['snapshot_path'] + snapshot = { + "MODEL_STATE": self._model.net_.state_dict(), + "EPOCHS_RUN": epoch } + torch.save(snapshot, snapshot_path) + print(f"Epoch {epoch} | Training snapshot saved at {snapshot_path}") + + def train(self, n_epochs, batch_size, epochs_run, save_every): """Train the model. Parameters @@ -208,6 +249,8 @@ def train(self, n_epochs, batch_size): n_epochs : int Number of epochs of training + epochs_run: int + batch_size : int Size of samples used at each epoch @@ -216,11 +259,10 @@ def train(self, n_epochs, batch_size): being called at least once. """ self.train_batch_size = batch_size - last_epoch = len(self.train_history["loss"]) + 1 T1 = time.time() - for epoch in range(last_epoch, last_epoch + n_epochs): + for epoch in range(epochs_run, epochs_run+n_epochs): loss, logqp = self.step() - self.checkpoint(epoch, loss) + self.checkpoint(epoch, loss, save_every) if self.scheduler is not None: self.scheduler.step() T2 = time.time() @@ -248,7 +290,7 @@ def step(self): return loss, logq - logp - def checkpoint(self, epoch, loss): + def checkpoint(self, epoch, loss, save_every): rank = self._model.device_handler.rank @@ -259,8 +301,6 @@ def checkpoint(self, epoch, loss): # For the rest print_stride = self.checkpoint_dict['print_stride'] print_batch_size = self.checkpoint_dict['print_batch_size'] - save_epochs = self.checkpoint_dict['save_epochs'] - save_fname_func = self.checkpoint_dict['save_fname_func'] print_batch_size = print_batch_size // self._model.device_handler.nranks @@ -275,16 +315,14 @@ def checkpoint(self, epoch, loss): loss_ = self.loss_fn(logq, logp) self._append_to_train_history(logq, logp) self.print_fit_status(epoch, loss=loss_) - - # if self.checkpoint_dict['display']: - # self.live_plot_handle.update(self.train_history) - - if rank == 0 and epoch in save_epochs: - torch.save(self._model.net_, save_fname_func(epoch)) + + if rank == 0 and (epoch % save_every == 0): + self._save_snapshot(epoch) @staticmethod def calc_kl_mean(logq, logp): - """Return Kullback-Leibler divergence estimated from logq and logp""" + """Return Kullback-Leibler divergence estimated + from logq and logp """ return (logq - logp).mean() # KL, assuming samples from q @staticmethod @@ -297,18 +335,12 @@ def calc_corrcoef(logq, logp): @staticmethod def calc_direct_kl_mean(logq, logp): - r"""Return *direct* KL mean, which is defined as - + """Return *direct* KL mean, which is defined as .. math:: - \frac{\sum \frac{p}{q} (\log(\frac{p}{q}) + logz)}{\sum \frac{p}{q}} - where - .. math:: - logz = \log( \sum(frac{p}{q}) / N) - wbere N is the number of samples. The direct KL means is invariant under scaling p and/or q. """ @@ -376,7 +408,7 @@ def print_fit_status(self, epoch, loss=None): if epoch == 1: print(f"\n>>> Training progress ({ess.device}) <<<\n") - print("Note: log(q/p) is esitamted with normalized p; " \ + print("Note: log(q/p) is estimated with normalized p; " \ + "mean & error are obtained from samples in a batch\n") str_ = f"Epoch: {epoch} | loss: {loss:g} | ess: {ess:g} | rho: {rho:g}"