Skip to content

Commit

Permalink
Use better function name and re-use fairchem_root function
Browse files Browse the repository at this point in the history
  • Loading branch information
levineds committed Jul 24, 2024
1 parent 0721ce0 commit 8caadf3
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def main():
not Path(__file__).with_name("oc20dense_val_targets.pkl").exists()
or not Path(__file__).with_name("ml_relaxed_dft_targets.pkl").exists()
):
download_large_files.main("adsorbml")
download_large_files.download_file_group("adsorbml")
targets = pickle.load(
open(Path(__file__).with_name("oc20dense_val_targets.pkl"), "rb")
)
Expand Down
15 changes: 10 additions & 5 deletions src/fairchem/core/scripts/download_large_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from pathlib import Path
from urllib.request import urlretrieve

FAIRCHEM_ROOT = Path(__file__).parents[4]
from fairchem.core.common.tutorial_utils import fairchem_root

S3_ROOT = "https://dl.fbaipublicfiles.com/opencatalystproject/data/large_files/"

FILE_GROUPS = {
Expand Down Expand Up @@ -51,7 +52,7 @@ def parse_args():
return parser.parse_args()


def main(file_group):
def download_file_group(file_group):
if file_group in FILE_GROUPS:
files_to_download = FILE_GROUPS[file_group]
elif file_group == "ALL":
Expand All @@ -61,11 +62,15 @@ def main(file_group):
f'Requested file group {file_group} not recognized. Please select one of {["ALL", *list(FILE_GROUPS)]}'
)

fc_root = fairchem_root().parents[1]
for file in files_to_download:
print(f"Downloading {file}...")
urlretrieve(S3_ROOT + file.name, FAIRCHEM_ROOT / file)
if not (fc_root / file).exists():
print(f"Downloading {file}...")
urlretrieve(S3_ROOT + file.name, fc_root / file)
else:
print(f"{file} already exists")


if __name__ == "__main__":
args = parse_args()
main(args.file_group)
download_file_group(args.file_group)
2 changes: 1 addition & 1 deletion src/fairchem/data/oc/core/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
else:
if bulk_db is None:
if bulk_db_path == BULK_PKL_PATH and not os.path.exists(BULK_PKL_PATH):
download_large_files.main("oc")
download_large_files.download_file_group("oc")
with open(bulk_db_path, "rb") as fp:
bulk_db = pickle.load(fp)

Expand Down
2 changes: 1 addition & 1 deletion src/fairchem/data/oc/databases/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def update_pkls():
pickle.dump(data, fp)

if not Path("oc/databases/pkls/bulks.pkl").exists():
download_large_files.main("oc")
download_large_files.download_file_group("oc")
with open(
"oc/databases/pkls/bulks.pkl",
"rb",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fairchem.core.scripts import download_large_files

if not os.path.exists("adsorption_energy.txt"):
download_large_files.main("odac")
download_large_files.download_file_group("odac")
raw_ads_energy_data = pd.read_csv("adsorption_energy.txt", header=None, sep=" ")
complete_data = pd.DataFrame(
index=range(raw_ads_energy_data.shape[0]),
Expand Down
4 changes: 2 additions & 2 deletions tests/applications/cattsunami/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def desorption_inputs(request):
def dissociation_inputs(request):
pkl_path = Path(__file__).parent / "autoframe_inputs_dissociation.pkl"
if not pkl_path.exists():
download_large_files.main("cattsunami")
download_large_files.download_file_group("cattsunami")
with open(pkl_path, "rb") as fp:
request.cls.inputs = pickle.load(fp)

Expand All @@ -31,6 +31,6 @@ def dissociation_inputs(request):
def transfer_inputs(request):
pkl_path = Path(__file__).parent / "autoframe_inputs_transfer.pkl"
if not pkl_path.exists():
download_large_files.main("cattsunami")
download_large_files.download_file_group("cattsunami")
with open(pkl_path, "rb") as fp:
request.cls.inputs = pickle.load(fp)
2 changes: 1 addition & 1 deletion tests/core/test_download_large_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ def urlretrieve_mock(x, y):
)

url_mock.side_effect = urlretrieve_mock
dl_large.main("ALL")
dl_large.download_file_group("ALL")

0 comments on commit 8caadf3

Please sign in to comment.