Skip to content

Commit

Permalink
feat: model storage: trained model versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Stablum committed Nov 29, 2021
1 parent cd0f5c5 commit 4c51197
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 37 deletions.
4 changes: 2 additions & 2 deletions model_config/dspn_result_short.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ latent_dim: 2
activation_function: "ELU"
depth: 5
weight_decay: 0
max_epochs: 10
max_epochs: 2
rel_name: 'result'
batch_size: 1024
divide_output_layer: True
latent_l1_norm: 10.0
max_set_size: 200
dspn_iter: 4
dspn_lr: 800
cap_dataset: 1000
cap_dataset: 500
6 changes: 6 additions & 0 deletions models/dspn_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dspn.dspn
from models import diagnostics
from common import utils, config
from models import models_storage

class InvariantModel(torch.nn.Module): #FIXME: delete?
def __init__(self, phi, rho):
Expand All @@ -29,8 +30,13 @@ class DSPNAE(generic_model.GenericModel):
DSPNAE is an acronym for Deep Set Prediction Network AutoEncoder
"""

@classmethod
def storage(cls):
return models_storage.DSPNAEModelsStorage()

with_set_index = True


class CollateFn(object):
"""
CollateFn is being used to use the start-item-index
Expand Down
22 changes: 10 additions & 12 deletions models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,22 @@ class GenericModel(pl.LightningModule):

with_set_index = None # please set in subclass

@property
@classmethod
def storage(cls):
raise Exception("implement in subclass")

@property
def classname(self):
return self.__class__.__name__

@property
def name(self):
return f"{self.classname}_{self.rel.name}"
#@classmethod FIXME DELME
#def name(cls,rel):
# return f"{cls.__name__}_{rel.name}"

@property
def kwargs_filename(self):
os.path.join(
config.trained_models_dirpath,
f"{self.name}_kwargs.pickle"
)

def dump_kwargs(self):
with open(self.kwargs_filename, 'wb') as handle:
pickle.dump(self.kwargs, handle, protocol=pickle.HIGHEST_PROTOCOL)
def name(self):
return f"{self.__class__.__name__}_{self.rel.name}"

def make_train_loader(self,tsets):
raise Exception("implement in subclass")
Expand Down
132 changes: 120 additions & 12 deletions models/models_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import os
import sys
import pickle

import pytorch_lightning as pl
import glob
import re
project_root_dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__))+"/..")
sys.path.insert(0, project_root_dir)

from common import utils, relspecs
from common import utils, relspecs, config
from models import dspn_autoencoder

class DSPNAEModelsStorage(utils.Collection): # FIXME: abstraction?


class DSPNAEModelsStorage(utils.Collection): # FIXME: classname-parameterized?
"""
We need to store the DSPNAE models somewhere and to recall them
easily. This class offers a straightforward interface to load
Expand All @@ -21,25 +25,129 @@ def __init__(self):
model = self.load(rel)
self[rel.name] = model

def load(self, rel):
model_filename = os.path.join(
def create_write_callback(self, model):
"""
creates a model checkpoint dumping callback for the training
of the given model.
Also allows for model versioning, which is being taken care
by the pytorch_lightning library.
:param model: model to be saved
:return: the callback object to be used as callback=[callbacks]
parameter in the pytorch_lightning Trainer
"""
model_filename = model.name
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="train_loss",
dirpath=config.trained_models_dirpath,
filename=model_filename,
save_top_k=1,
save_last=False,
mode="min",
)
return checkpoint_callback

def generate_kwargs_filename(self, model):

# using the versioning of the last model filename
# as it is saved before the kwargs dump
version = self.last_version(model.rel)

ret = os.path.join(
config.trained_models_dirpath,
f"{model.name}-v{version}.kwargs.pickle"
)
return ret

def dump_kwargs(self,model):
kwargs_filename = self.generate_kwargs_filename(model)
with open(kwargs_filename, 'wb') as handle:
pickle.dump(model.kwargs, handle, protocol=pickle.HIGHEST_PROTOCOL)

def filenames(self,rel,extension):
filenames_glob = os.path.join(
"trained_models",
f"DSPNAE_{rel.name}.ckpt"
f"DSPNAE_{rel.name}*.{extension}"
)
print("filenames_glob",filenames_glob)
ret = {}
filenames = glob.glob(filenames_glob)
for curr in filenames:
m = re.match(f'.*-v(\d+).{extension}', curr)
if m:
# has version number
print(curr, m.groups())
version = int(m.groups()[0])
else:
# no version number in filename: this was the first
print(f'filename {curr} not matching versioned pattern')
version = 0
ret[version] = curr
print("ret",ret)
return ret

def kwargs_filenames(self,rel):
return self.filenames(rel,'kwargs.pickle')

def models_filenames(self,rel):
return self.filenames(rel,'ckpt')

def last_version(self, rel):
filenames = self.models_filenames(rel)
versions = filenames.keys()
if len(versions) == 0:
# there was no model stored
return None
last_version = max(versions)
return last_version

def most_recent_kwargs_filename(self, rel):
last_version = self.last_version(rel)
if last_version is None:
# there was no model stored
return None
filenames = self.kwargs_filenames(rel)
if last_version not in filenames.keys():
raise Exception(f"cannot find kwargs file for version {last_version}")
return filenames[last_version]

def most_recent_model_filename(self, rel):
filenames = self.models_filenames(rel)
last_version = self.last_version(rel)
if last_version is None:
# there was no model stored
return None
return filenames[last_version]

def load(self, rel):
"""
:param rel: the relation this model has been trained on
:return:
"""
filename = self.most_recent_model_filename(rel)

if filename is None or not os.path.exists(filename):
logging.warning(f"could not find model saved in {filename}")
return

# FIXME: duplicated from GenericModel
kwargs_filename = f"DSPNAE_{rel.name}_kwargs.pickle"
kwargs_filename = self.most_recent_kwargs_filename(rel)
with open(kwargs_filename) as f:
kwargs = pickle.load(f)
logging.info(f"loading {model_filename}..")
if not os.path.exists(model_filename):
logging.warning(f"could not find model saved in {model_filename}")
return
logging.info(f"loading {filename}..")

model = dspn_autoencoder.DSPNAE(**kwargs)
model.load_from_checkpoint(model_filename)
model.load_from_checkpoint(filename)
return model

def test(self, rel):
"""
testing that a loaded model is functioning correctly
:param rel: the relation
:return:
"""
model = self.load(rel)
# FIXME: TODO


def test():
ms = DSPNAEModelsStorage()
Expand Down
16 changes: 5 additions & 11 deletions models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import argparse
import pickle
from common import utils, relspecs, persistency, config
from models import diagnostics, measurements as ms
from models import diagnostics, measurements as ms, models_storage

def get_args():
args = {}
Expand Down Expand Up @@ -153,26 +153,20 @@ def run(Model,config_name, dynamic_config={}):
item_dim=tsets.item_dim,
**model_config
).to(device)
model.dump_kwargs()
storage = model.storage()
train_loader = model.make_train_loader(tsets)
test_loader = model.make_test_loader(tsets)
model_filename = model.name
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="train_loss",
dirpath=config.trained_models_dirpath,
filename=model_filename,
save_top_k=1,
mode="min",
)
model_write_callback = storage.create_write_callback(model)
callbacks = [
MeasurementsCallback(rel=rel,model=model),
checkpoint_callback
model_write_callback
]
trainer = pl.Trainer(
limit_train_batches=1.0,
callbacks=callbacks,
max_epochs=model_config['max_epochs']
)
trainer.fit(model, train_loader, test_loader)
storage.dump_kwargs(model)
print("current mlflow run:",mlflow.active_run().info.run_id, " - all done.")
#log_net_visualization(model,torch.zeros(model_config['batch_size'], tsets.item_dim))

0 comments on commit 4c51197

Please sign in to comment.