Skip to content

Commit

Permalink
Fetch models from https://modeldb.science/ (#88)
Browse files Browse the repository at this point in the history
* getmodels: use modeldb.science website

* CI: drop monthly cache invalidation

* CI: tweak cache key so it misses

* Add a better error message

* suggestion from Robert

* change download url
  • Loading branch information
olupton authored Jun 21, 2023
1 parent 9fdae92 commit 98587d9
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 91 deletions.
9 changes: 2 additions & 7 deletions .github/workflows/nrn-modeldb-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,17 @@ jobs:
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
#Install project in editable mode
python -m pip install -e .
echo "date=$(date -u "+%Y%m")" >> $GITHUB_OUTPUT
- name: Cache ModelDB models
id: cache-models
uses: actions/cache@v3
with:
path: |
cache
modeldb/modeldb-meta.yaml
key: models-${{steps.install-deps.outputs.date}}
key: dynamic-models

- name: Get ModelDB models
if: steps.cache-models.outputs.cache-hit != 'true'
run: getmodels
run: getmodels $MODELS_TO_RUN

- name: Run Models with NEURON V1 -> ${{ env.NEURON_V1 }}
run: |
Expand All @@ -162,7 +159,6 @@ jobs:
python -m pip install $NEURON_V1
fi
nrn_ver=`python -c "from neuron import __version__ as nrn_ver; print(nrn_ver)"`
ps uxf # debug
runmodels --gout --workdir=$nrn_ver $MODELS_TO_RUN
# Filter out the gout data before generating HTML reports. The HTML
# diff uses the original gout files on disk anyway. Compress the large
Expand All @@ -189,7 +185,6 @@ jobs:
python -m pip install $NEURON_V2
fi
nrn_ver=`python -c "from neuron import __version__ as nrn_ver; print(nrn_ver)"`
ps uxf # debug
runmodels --gout --workdir=$nrn_ver $MODELS_TO_RUN
# Filter out the gout data before generating HTML reports. The HTML
# diff uses the original gout files on disk anyway. Compress the large
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ The following commands are now available:
| MODELDB_RUN_FILE | yaml file containing run instructions for models (required for `runmodels`) |
| MODELDB_METADATA_FILE | yaml file containing model info for those downloaded with `getmodels` |
| MODELS_ZIP_DIR | location of cache folder for models populated via `getmodels` |
| MDB_NEURON_MODELS_URL | url used to get list of all NEURON model ids (necessary for `getmodels`) |
| MDB_NEURON_MODELS_URL | url template used to get NEURON model IDs and last-updated timestamps (needed for `getmodels`) |
| MDB_MODEL_METADATA_URL | url template used to get metadata about a single NEURON model (needed for `getmodels`) |
| MDB_MODEL_DOWNLOAD_URL | url template used for model downloading (cf `{model_id}`) |

## Model Run
Expand Down
10 changes: 4 additions & 6 deletions modeldb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@

import os

# MDB_NEURON_MODELS_URL = "https://senselab.med.yale.edu/_site/webapi/object.json/?cl=19&oid=1882"
MDB_NEURON_MODELS_URL = (
"http://modeldb.science/api/v1/models?modeling_application=NEURON"
)
MDB_MODEL_DOWNLOAD_URL = (
"https://senselab.med.yale.edu/_site/webapi/object.json/{model_id}"
"http://modeldb.science/api/v1/models/{model_field}?modeling_application=NEURON"
)
MDB_MODEL_METADATA_URL = "https://modeldb.science/api/v1/models/{model_id}"
MDB_MODEL_DOWNLOAD_URL = "https://modeldb.science/download/{model_id}"

ROOT_DIR = os.path.abspath(__file__ + "/../../")

MODELS_ZIP_DIR = "%s/cache" % ROOT_DIR
MODELDB_ROOT_DIR = "%s/modeldb" % ROOT_DIR
MODELDB_METADATA_FILE = "%s/modeldb-meta.yaml" % MODELDB_ROOT_DIR
MODELDB_RUN_FILE = "%s/modeldb-run.yaml" % MODELDB_ROOT_DIR
MODELDB_RUN_FILE = "%s/modeldb-run.yaml" % MODELDB_ROOT_DIR
217 changes: 140 additions & 77 deletions modeldb/modeldb.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,99 @@
from . import config
from .config import *
from .data import Model
import logging
import multiprocessing
import base64
import os
from pprint import pformat
from .progressbar import ProgressBar
import requests
import time
from .progressbar import ProgressBar
import yaml
from .data import Model
from .config import *
import traceback
from pprint import pformat


def download_model(arg_tuple):
model_id, model_run_info = arg_tuple
model_id, model_run_info, expected_ver_date = arg_tuple
try:
model_json = requests.get(MDB_MODEL_DOWNLOAD_URL.format(model_id=model_id)).json()
# Fetch the model metadata from ModelDB.
model_json = requests.get(
MDB_MODEL_METADATA_URL.format(model_id=model_id)
).json()
# Check that the timestamp matches our expectations.
assert model_json["ver_date"] == expected_ver_date
# Assemble a Model object from the JSON metadata just fetched
model = Model(
*(
model_json[key]
for key in ("object_id", "object_name", "object_created", "object_ver_date")
)
)
url = None
for att in model_json["object_attribute_values"]:
if att["attribute_id"] == 23:
url = att["value"]
break
# print(model.id)
model_zip_uri = os.path.join(
MODELS_ZIP_DIR, "{model_id}.zip".format(model_id=model.id)
model_json["id"],
model_json["name"],
model_json["created"],
model_json["ver_date"],
)
with open(model_zip_uri, "wb+") as zipfile:
zipfile.write(base64.standard_b64decode(url["file_content"]))

# Now fetch the actual model data .zip file. By default this also comes
# from ModelDB, but it can be overriden to come from GitHub instead.
if "github" in model_run_info:
# This means we should try to replace the version of the model that
# we downloaded from the ModelDB API just above with a version from
# GitHub
# This means we should try to download the model content from
# GitHub instead of from ModelDB.
github = model_run_info["github"]
organisation = "ModelDBRepository"
suffix = "" # default branch
suffix = "" # default branch
if github == "default":
# Using
# github: "default"
# in modeldb-run.yaml implies that we fetch the HEAD of the
# default branch from ModelDBRepository on GitHub. In general
# this should be the same thing as fetching from ModelDB.
pass
elif github.startswith("pull/"):
# Using
# github: "pull/4"
# in modeldb-run.yaml implies that we use the branch from pull
# request #4 to ModelDBRepository/{model_id} on GitHub. This is
# used if you want to test updates to models.
pr_number = int(github[5:])
suffix = "/pull/{}/head".format(pr_number)
elif github.startswith('/'):
# /org implies that we use the default branch from org/model_id
elif github.startswith("/"):
# Using
# github: "/myname"
# in modeldb-run.yaml implies that we fetch the HEAD of the
# default branch of myname/{model_id} on GitHub. This is useful
# if you need to test changes to a model that does not exist on
# GitHub under the ModelDBRepository organisation.
organisation = github[1:]
else:
raise Exception("Invalid value for github key: {}".format(github))
github_url = "https://api.github.com/repos/{organisation}/{model_id}/zipball{suffix}".format(
url = "https://api.github.com/repos/{organisation}/{model_id}/zipball{suffix}".format(
model_id=model_id, organisation=organisation, suffix=suffix
)
# Replace the local file `model_zip_uri` with the zip file we
# downloaded from `github_url`
num_attempts = 3
status_codes = []
for _ in range(num_attempts):
github_response = requests.get(github_url)
status_codes.append(github_response.status_code)
if github_response.status_code == requests.codes.ok:
break
time.sleep(5)
else:
raise Exception(
"Failed to download {} with status codes {}".format(
github_url, status_codes
)
)
with open(model_zip_uri, "wb+") as zipfile:
zipfile.write(github_response.content)
else:
# Get the .zip file from ModelDB, not from GitHub.
url = MDB_MODEL_DOWNLOAD_URL.format(model_id=model_id)

# Construct the path we want to save the .zip at locally.
model_zip_uri = os.path.join(
MODELS_ZIP_DIR, "{model_id}.zip".format(model_id=model.id)
)

# Download the model data from `url`. Retry a few times on failure.
num_attempts = 3
status_codes = []
for _ in range(num_attempts):
model_download_response = requests.get(url)
status_codes.append(model_download_response.status_code)
if model_download_response.status_code == requests.codes.ok:
break
time.sleep(5)
else:
raise Exception(
"Failed to download {} with status codes {}".format(url, status_codes)
)
with open(model_zip_uri, "wb+") as zipfile:
zipfile.write(model_download_response.content)
except Exception as e: # noqa
model = e

return model_id, model


class ModelDB(object):
logger = None
metadata = property(lambda self: self._metadata)
run_instr = property(lambda self: self._run_instr)

Expand All @@ -87,38 +102,82 @@ def __init__(self):
self._run_instr = {}

self._load_run_instructions()
self._setup_logging()
try:
self._load_metadata()
except FileNotFoundError:
logging.warning(
"{} not found!".format(MODELDB_METADATA_FILE)
)
ModelDB.logger.warning("{} not found!".format(MODELDB_METADATA_FILE))
except yaml.YAMLError as y:
logging.error("Error loading {}: {}".format(MODELDB_METADATA_FILE, y))
ModelDB.logger.error(
"Error loading {}: {}".format(MODELDB_METADATA_FILE, y)
)
raise y
except Exception as e:
raise e

def _download_models(self, model_list=None):
def download_models(self, model_list=None):
if not os.path.isdir(MODELS_ZIP_DIR):
logging.info("Creating cache directory: {}".format(MODELS_ZIP_DIR))
ModelDB.logger.info("Creating cache directory: {}".format(MODELS_ZIP_DIR))
os.mkdir(MODELS_ZIP_DIR)
models = requests.get(MDB_NEURON_MODELS_URL).json() if model_list is None else model_list
pool = multiprocessing.Pool()
processed_models = pool.imap_unordered(
download_model,
[(model_id, self._run_instr.get(model_id, {})) for model_id in models],
)
# Fetch the list of NEURON model IDs, and a list of timestamps for
# those models. We do this even if `model_list` is not None to build
# the model ID -> timestamp mapping.
def query(field):
return requests.get(MDB_NEURON_MODELS_URL.format(model_field=field)).json()

all_model_ids = query("id")
all_model_timestamps = query("ver_date")
metadata = {
model_id: timestamp
for model_id, timestamp in zip(all_model_ids, all_model_timestamps)
}
# If we were passed a non-None `model_list`, restrict those models now.
if model_list is not None:
missing_ids = set(model_list) - set(metadata.keys())
if missing_ids:
raise Exception(
"Model IDs {} were explicitly requested, but are not known NEURON models.".format(
missing_ids
)
)
metadata = {model_id: metadata[model_id] for model_id in model_list}
# For each model in `metadata`, check if a cached entry exists and is
# up to date. If not, download it.
models_to_download = []
for model_id, new_ver_date in metadata.items():
if model_id in self._metadata:
cached_ver_date = self._metadata[model_id]._ver_date
if cached_ver_date == new_ver_date:
ModelDB.logger.debug(
"Model {} cache up to date ({})".format(model_id, new_ver_date)
)
continue
else:
ModelDB.logger.debug(
"Model {} cache out of date (cached: {}, new: {})".format(
model_id, cached_ver_date, new_ver_date
)
)
else:
ModelDB.logger.debug("Model {} not found in cache".format(model_id))
models_to_download.append(
(model_id, self._run_instr.get(model_id, {}), new_ver_date)
)
# Download the missing or out of date models in parallel
pool = multiprocessing.Pool(8)
processed_models = pool.imap_unordered(download_model, models_to_download)
download_err = {}
for model_id, model in ProgressBar.iter(processed_models, len(models)):
for model_id, model in ProgressBar.iter(
processed_models, len(models_to_download)
):
if not isinstance(model, Exception):
self._metadata[model_id] = model
else:
download_err[model_id] = model

if download_err:
logging.error("Error downloading models:")
logging.error(pformat(download_err))
ModelDB.logger.error("Error downloading models:")
ModelDB.logger.error(pformat(download_err))

self._save_metadata()

Expand All @@ -134,15 +193,19 @@ def _save_metadata(self):
with open(MODELDB_METADATA_FILE, "w+") as meta_file:
yaml.dump(self._metadata, meta_file, sort_keys=True)

def download_models(self, model_list=None):
if model_list is None:
try:
os.remove(MODELDB_METADATA_FILE)
except OSError:
pass
self._download_models(model_list)

# TODO -> check/update models
def update_models(self):
pass

def _setup_logging(self):
if ModelDB.logger is not None:
return
formatter = logging.Formatter(
fmt="%(asctime)s :: %(levelname)-8s :: %(message)s"
)
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(formatter)
consoleHandler.setLevel(logging.INFO)
fileHandler = logging.FileHandler("modeldb.log")
fileHandler.setFormatter(formatter)
fileHandler.setLevel(logging.DEBUG)
ModelDB.logger = logging.getLogger("modeldb")
ModelDB.logger.setLevel(logging.DEBUG)
ModelDB.logger.addHandler(consoleHandler)
ModelDB.logger.addHandler(fileHandler)

0 comments on commit 98587d9

Please sign in to comment.