Skip to content

Commit

Permalink
Merge pull request #299 from mims-harvard/geneformer_deploy
Browse files Browse the repository at this point in the history
Geneformer deploy
  • Loading branch information
amva13 authored Aug 6, 2024
2 parents 1a47468 + 585d0af commit d3e0b92
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 2 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ numpy>=1.26.4,<2.0.0
openpyxl>=3.0.10,<4.0.0
pandas>=2.1.4,<3.0.0
requests>=2.31.0,<3.0.0
# scikit-learn>=1.2.2,<2.0.0
scikit-learn==1.2.2
seaborn>=0.12.2,<1.0.0
tqdm>=4.65.0,<5.0.0
Expand Down
22 changes: 21 additions & 1 deletion tdc/tdc_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
'CYP3A4_Veith-AttentiveFP',
]

model_hub = ["Geneformer"]


class tdc_hf_interface:
'''
Expand All @@ -29,7 +31,10 @@ class tdc_hf_interface:

def __init__(self, repo_name):
self.repo_id = "tdc/" + repo_name
self.model_name = repo_name.split('-')[1]
try:
self.model_name = repo_name.split('-')[1]
except:
self.model_name = repo_name

def upload(self, folder_path):
create_repo(repo_id=self.repo_id)
Expand All @@ -47,6 +52,21 @@ def file_download(self, save_path, filename):
def repo_download(self, save_path):
snapshot_download(repo_id=self.repo_id, cache_dir=save_path)

def load(self):
if self.model_name not in model_hub:
raise Exception("this model is not in the TDC model hub GH repo.")
elif self.model_name == "Geneformer":
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
# tokenizer = AutoTokenizer.from_pretrained("ctheodoris/Geneformer")
model = AutoModelForMaskedLM.from_pretrained(
"ctheodoris/Geneformer")
# pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer)
# pipe = pipeline("fill-mask", model="ctheodoris/Geneformer")
# return pipe
return model
raise Exception("Not implemented yet!")

def load_deeppurpose(self, save_path):
if self.repo_id[4:] in deeppurpose_repo:
save_path = save_path + '/' + self.repo_id[4:]
Expand Down
66 changes: 66 additions & 0 deletions tdc/test/test_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-

from __future__ import division
from __future__ import print_function

import os
import sys

import unittest
import shutil
import pytest

# temporary solution for relative imports in case TDC is not installed
# if TDC is installed, no need to use the following line
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
# TODO: add verification for the generation other than simple integration


class TestHF(unittest.TestCase):

def setUp(self):
print(os.getcwd())
pass

@pytest.mark.skip(
reason="This test is skipped due to deeppurpose installation dependency"
)
@unittest.skip(reason="DeepPurpose")
def test_hf_load_predict(self):
from tdc.single_pred import Tox
data = Tox(name='herg_karim')

from tdc import tdc_hf_interface
tdc_hf = tdc_hf_interface("hERG_Karim-CNN")
# load deeppurpose model from this repo
dp_model = tdc_hf.load_deeppurpose('./data')
tdc_hf.predict_deeppurpose(dp_model, ['CC(=O)NC1=CC=C(O)C=C1'])

def test_hf_transformer(self):
from tdc import tdc_hf_interface
# from transformers import Pipeline
from transformers import BertForMaskedLM as BertModel
geneformer = tdc_hf_interface("Geneformer")
model = geneformer.load()
# assert isinstance(pipeline, Pipeline)
assert isinstance(model, BertModel), type(model)

# def test_hf_load_new_pytorch_standard(self):
# from tdc import tdc_hf_interface
# # from tdc.resource.dataloader import DataLoader
# # data = DataLoader(name="pinnacle_dti")
# tdc_hf = tdc_hf_interface("mli-PINNACLE")
# dp_model = tdc_hf.load()
# assert dp_model is not None

def tearDown(self):
try:
print(os.getcwd())
shutil.rmtree(os.path.join(os.getcwd(), "data"))
except:
pass


if __name__ == "__main__":
unittest.main()

0 comments on commit d3e0b92

Please sign in to comment.