Skip to content

Commit

Permalink
fix: fixed models storage
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Stablum committed Nov 25, 2021
1 parent 039f13f commit cd0f5c5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 14 deletions.
2 changes: 2 additions & 0 deletions config/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ download_max_set_size: 50
data_loader_num_workers: 4
models_dag_config_name: dspn_deepnarrow
models_dag_days_interval: 2

trained_models_dirpath: trained_models/
23 changes: 19 additions & 4 deletions models/generic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import pytorch_lightning as pl
import torch
import os
import sys
import pickle

project_root_dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__))+"/..")
sys.path.insert(0, project_root_dir)

from common import config

class AEModule(torch.nn.Module):
"""
Expand All @@ -14,9 +21,6 @@ def __init__(self,**kwargs):
self.rel = kwargs.get('rel', None)
super().__init__()

@property
def name(self):
return f"{self.classname}_{self.rel.name}"

def activate(self, x):
return self.activation_function()(x)
Expand Down Expand Up @@ -152,7 +156,18 @@ def classname(self):

@property
def name(self):
raise Exception("implement in subclass")
return f"{self.classname}_{self.rel.name}"

@property
def kwargs_filename(self):
os.path.join(
config.trained_models_dirpath,
f"{self.name}_kwargs.pickle"
)

def dump_kwargs(self):
with open(self.kwargs_filename, 'wb') as handle:
pickle.dump(self.kwargs, handle, protocol=pickle.HIGHEST_PROTOCOL)

def make_train_loader(self,tsets):
raise Exception("implement in subclass")
Expand Down
42 changes: 36 additions & 6 deletions models/models_storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import logging
import os
import sys
import pickle

project_root_dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__))+"/..")
sys.path.insert(0, project_root_dir)

from common import utils, relspecs
from models import dspn_autoencoder
Expand All @@ -10,11 +16,35 @@ class DSPNAEModelsStorage(utils.Collection): # FIXME: abstraction?
the models.
"""
def __init__(self):

os.chdir(project_root_dir)
for rel in relspecs.rels:
model_filename = os.path.join(
"trained_models",
f"DSPNAE_{rel.name}"
)
model = dspn_autoencoder.DSPNAE.load_from_checkpoint(model_filename)
model = self.load(rel)
self[rel.name] = model

def load(self, rel):
model_filename = os.path.join(
"trained_models",
f"DSPNAE_{rel.name}.ckpt"
)

# FIXME: duplicated from GenericModel
kwargs_filename = f"DSPNAE_{rel.name}_kwargs.pickle"
with open(kwargs_filename) as f:
kwargs = pickle.load(f)
logging.info(f"loading {model_filename}..")
if not os.path.exists(model_filename):
logging.warning(f"could not find model saved in {model_filename}")
return

model = dspn_autoencoder.DSPNAE(**kwargs)
model.load_from_checkpoint(model_filename)
return model


def test():
ms = DSPNAEModelsStorage()
print("ms",ms)
print("done.")

if __name__ == "__main__":
test()
8 changes: 4 additions & 4 deletions models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import os
import sys
import argparse

from common import utils, relspecs, persistency
import pickle
from common import utils, relspecs, persistency, config
from models import diagnostics, measurements as ms

def get_args():
Expand Down Expand Up @@ -153,13 +153,13 @@ def run(Model,config_name, dynamic_config={}):
item_dim=tsets.item_dim,
**model_config
).to(device)

model.dump_kwargs()
train_loader = model.make_train_loader(tsets)
test_loader = model.make_test_loader(tsets)
model_filename = model.name
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="train_loss",
dirpath="trained_models/",
dirpath=config.trained_models_dirpath,
filename=model_filename,
save_top_k=1,
mode="min",
Expand Down

0 comments on commit cd0f5c5

Please sign in to comment.