Skip to content

Commit

Permalink
feat: using Airflow to schedule the training of the autoencoders of a…
Browse files Browse the repository at this point in the history
…ll activity relation fields
  • Loading branch information
Francesco Stablum committed Nov 18, 2021
1 parent 0f2e5fc commit cd1fb6e
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 7 deletions.
5 changes: 4 additions & 1 deletion airflow/add_dag_bags.py.m4
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
""" add additional DAGs folders """
import os
from airflow.models import DagBag
dags_dirs = ['LEARNING_SETS_DIR/preprocess/']
dags_dirs = [
'LEARNING_SETS_DIR/preprocess/',
'LEARNING_SETS_DIR/models/'
]

for d in dags_dirs:
print(f"creating DagBag with path {d}")
Expand Down
8 changes: 7 additions & 1 deletion common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def str_shapes(stuff):
return str(stuff.shape)


def load_model_config(config_name):
def load_model_config(config_name, dynamic_config=None):
if os.path.exists(config_name):
# a filename is given
filename = config_name
Expand All @@ -206,6 +206,12 @@ def load_model_config(config_name):

ret['config_name'] = config_name
ret['config_filename'] = filename

# dynamic config generation will override the yaml file config
if dynamic_config is not None:
for k,v in dynamic_config.items():
logging.info(f"configuration item {k} dynamically set at {v}")
ret[k] = v
return ret


Expand Down
6 changes: 5 additions & 1 deletion config/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ vm_host: somehost
pg_password: somepassword
airflow_user: someuser
airflow_password: somepassword
airflow_email: [email protected]
airflow_email: [email protected]

data_loader_num_workers: 4
models_dag_config_name: dspn_deepnarrow
models_dag_days_interval: 2
73 changes: 73 additions & 0 deletions models/dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.bash import BashOperator
from airflow.utils import timezone
import datetime
import os
import sys

import models.dspn_autoencoder

project_root_path = os.path.abspath(os.path.dirname(os.path.abspath(__file__))+"/..")
sys.path = [project_root_path]+sys.path

from common import relspecs, config
from models import run

config_name = config.models_dag_config_name

def in_days(n):
"""
Get a datetime object representing `n` days ago. By default the time is
set to midnight.
"""
today = timezone.utcnow()
return today + datetime.timedelta(days=n)

def train_model(rel,ti):
dynamic_config = {'rel_name':rel.name}
run.run(
models.dspn_autoencoder.DSPNAE,
config_name,
dynamic_config=dynamic_config
)


project_root_dir = os.path.abspath(os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'..' # parent directory of models/
))
os.chdir(project_root_dir)

default_args = {
'retries': 2,
'retry_delay': datetime.timedelta(minutes=5),
'schedule_interval': None
}
with DAG(
'train_dspn_models',
description='trains DSPN models',
tags=['train', 'dspn', 'sets', 'models'],
default_args=default_args,
schedule_interval=None
) as dag:
days_interval = config.models_dag_days_interval
for rel_i,rel in enumerate(relspecs.rels):

train_cmd = f"cd {project_root_dir}; python3 models/dspn_autoencoder.py {config.models_dag_config_name} --rel_name={rel.name}"

t_train_model = BashOperator(
task_id=f"train_dsp_model_{rel.name}",
depends_on_past=False,
bash_command=train_cmd,
start_date=in_days((rel_i-1)*days_interval),
dag=dag
)

### PythonOperator version:
#t_train_model = PythonOperator(
# task_id=f"train_dsp_model_{rel.name}",
# python_callable=train_model,
# start_date=in_days((rel_i-1)*days_interval),
# op_kwargs={'rel':rel}
#)
4 changes: 2 additions & 2 deletions models/dspn_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import dspn.model
import dspn.dspn
from models import diagnostics
from common import utils
from common import utils, config

class InvariantModel(torch.nn.Module): #FIXME: delete?
def __init__(self, phi, rho):
Expand Down Expand Up @@ -71,7 +71,7 @@ def make_train_loader(self, tsets):
train_loader = torch.utils.data.DataLoader(
tsets.sets_intervals('train'),
shuffle=True,
num_workers=4,
num_workers=config.data_loader_num_workers,
pin_memory=False,
collate_fn=self.CollateFn(tsets.train_scaled)
)
Expand Down
31 changes: 29 additions & 2 deletions models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@
import numpy as np
import logging
import os
import sys
import argparse

from common import utils, relspecs, persistency
from models import diagnostics, measurements as ms

def get_args():
args = {}
for arg in sys.argv:
if arg.startswith(("--")):
k = arg.split('=')[0][2:]
v = arg.split('=')[1]
args[k] = v
return args

class MeasurementsCallback(pl.callbacks.Callback):
rel = None
Expand Down Expand Up @@ -90,10 +100,27 @@ def teardown(self, trainer, lm, stage=None):
type_=m.plot_type
)

def run(Model,config_name):
def run(Model,config_name, dynamic_config={}):

# need to make sure that logs/* and mlruns/* are generated
# in the correct project root directory, as well as
# config files are loaded from model_config/
project_root_dir= os.path.abspath(os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'..' # parent directory of models/
))
os.chdir(project_root_dir)

# gets args from command line that end up in the model run's configuration
# and overrides the eventual given dynamic_config with the passed arguments
# as in --rel_name=activity_date for example
args = get_args()
for arg, val in args.items():
dynamic_config[arg] = val

log_filename = os.path.join("logs",utils.strnow_compact()+'.log')
logging.basicConfig(filename=log_filename, filemode='w', level=logging.DEBUG)
model_config = utils.load_model_config(config_name)
model_config = utils.load_model_config(config_name, dynamic_config=dynamic_config)
mlflow.set_experiment(model_config['experiment_name'])
mlflow.pytorch.autolog()
with mlflow.start_run(run_name=model_config['config_name']):
Expand Down

0 comments on commit cd1fb6e

Please sign in to comment.