diff --git a/config/example.yaml b/config/example.yaml index 449b0c1..8667448 100644 --- a/config/example.yaml +++ b/config/example.yaml @@ -23,3 +23,5 @@ download_max_set_size: 50 data_loader_num_workers: 4 models_dag_config_name: dspn_deepnarrow models_dag_days_interval: 2 + +trained_models_dirpath: trained_models/ diff --git a/models/generic_model.py b/models/generic_model.py index 3d758e6..cd27fc6 100644 --- a/models/generic_model.py +++ b/models/generic_model.py @@ -1,6 +1,13 @@ import pytorch_lightning as pl import torch +import os +import sys +import pickle +project_root_dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__))+"/..") +sys.path.insert(0, project_root_dir) + +from common import config class AEModule(torch.nn.Module): """ @@ -14,9 +21,6 @@ def __init__(self,**kwargs): self.rel = kwargs.get('rel', None) super().__init__() - @property - def name(self): - return f"{self.classname}_{self.rel.name}" def activate(self, x): return self.activation_function()(x) @@ -152,7 +156,18 @@ def classname(self): @property def name(self): - raise Exception("implement in subclass") + return f"{self.classname}_{self.rel.name}" + + @property + def kwargs_filename(self): + os.path.join( + config.trained_models_dirpath, + f"{self.name}_kwargs.pickle" + ) + + def dump_kwargs(self): + with open(self.kwargs_filename, 'wb') as handle: + pickle.dump(self.kwargs, handle, protocol=pickle.HIGHEST_PROTOCOL) def make_train_loader(self,tsets): raise Exception("implement in subclass") diff --git a/models/models_storage.py b/models/models_storage.py index ee1572a..7958d47 100644 --- a/models/models_storage.py +++ b/models/models_storage.py @@ -1,4 +1,10 @@ +import logging import os +import sys +import pickle + +project_root_dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__))+"/..") +sys.path.insert(0, project_root_dir) from common import utils, relspecs from models import dspn_autoencoder @@ -10,11 +16,35 @@ class DSPNAEModelsStorage(utils.Collection): # FIXME: abstraction? the models. """ def __init__(self): - + os.chdir(project_root_dir) for rel in relspecs.rels: - model_filename = os.path.join( - "trained_models", - f"DSPNAE_{rel.name}" - ) - model = dspn_autoencoder.DSPNAE.load_from_checkpoint(model_filename) + model = self.load(rel) self[rel.name] = model + + def load(self, rel): + model_filename = os.path.join( + "trained_models", + f"DSPNAE_{rel.name}.ckpt" + ) + + # FIXME: duplicated from GenericModel + kwargs_filename = f"DSPNAE_{rel.name}_kwargs.pickle" + with open(kwargs_filename) as f: + kwargs = pickle.load(f) + logging.info(f"loading {model_filename}..") + if not os.path.exists(model_filename): + logging.warning(f"could not find model saved in {model_filename}") + return + + model = dspn_autoencoder.DSPNAE(**kwargs) + model.load_from_checkpoint(model_filename) + return model + + +def test(): + ms = DSPNAEModelsStorage() + print("ms",ms) + print("done.") + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/models/run.py b/models/run.py index dee648e..318a786 100644 --- a/models/run.py +++ b/models/run.py @@ -6,8 +6,8 @@ import os import sys import argparse - -from common import utils, relspecs, persistency +import pickle +from common import utils, relspecs, persistency, config from models import diagnostics, measurements as ms def get_args(): @@ -153,13 +153,13 @@ def run(Model,config_name, dynamic_config={}): item_dim=tsets.item_dim, **model_config ).to(device) - + model.dump_kwargs() train_loader = model.make_train_loader(tsets) test_loader = model.make_test_loader(tsets) model_filename = model.name checkpoint_callback = pl.callbacks.ModelCheckpoint( monitor="train_loss", - dirpath="trained_models/", + dirpath=config.trained_models_dirpath, filename=model_filename, save_top_k=1, mode="min",