From 8caadf33051d414938975180a9818365d991be7f Mon Sep 17 00:00:00 2001 From: Daniel Levine Date: Wed, 24 Jul 2024 22:38:29 +0000 Subject: [PATCH] Use better function name and re-use fairchem_root function --- .../2023_neurips_challenge/challenge_eval.py | 2 +- src/fairchem/core/scripts/download_large_files.py | 15 ++++++++++----- src/fairchem/data/oc/core/bulk.py | 2 +- src/fairchem/data/oc/databases/update.py | 2 +- .../promising_mof_energies/energy.py | 2 +- tests/applications/cattsunami/tests/conftest.py | 4 ++-- tests/core/test_download_large_files.py | 2 +- 7 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py b/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py index 7de3f9bee..01c492bba 100644 --- a/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py +++ b/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py @@ -167,7 +167,7 @@ def main(): not Path(__file__).with_name("oc20dense_val_targets.pkl").exists() or not Path(__file__).with_name("ml_relaxed_dft_targets.pkl").exists() ): - download_large_files.main("adsorbml") + download_large_files.download_file_group("adsorbml") targets = pickle.load( open(Path(__file__).with_name("oc20dense_val_targets.pkl"), "rb") ) diff --git a/src/fairchem/core/scripts/download_large_files.py b/src/fairchem/core/scripts/download_large_files.py index da9780944..f79fa2156 100644 --- a/src/fairchem/core/scripts/download_large_files.py +++ b/src/fairchem/core/scripts/download_large_files.py @@ -4,7 +4,8 @@ from pathlib import Path from urllib.request import urlretrieve -FAIRCHEM_ROOT = Path(__file__).parents[4] +from fairchem.core.common.tutorial_utils import fairchem_root + S3_ROOT = "https://dl.fbaipublicfiles.com/opencatalystproject/data/large_files/" FILE_GROUPS = { @@ -51,7 +52,7 @@ def parse_args(): return parser.parse_args() -def main(file_group): +def download_file_group(file_group): if file_group in FILE_GROUPS: files_to_download = FILE_GROUPS[file_group] elif file_group == "ALL": @@ -61,11 +62,15 @@ def main(file_group): f'Requested file group {file_group} not recognized. Please select one of {["ALL", *list(FILE_GROUPS)]}' ) + fc_root = fairchem_root().parents[1] for file in files_to_download: - print(f"Downloading {file}...") - urlretrieve(S3_ROOT + file.name, FAIRCHEM_ROOT / file) + if not (fc_root / file).exists(): + print(f"Downloading {file}...") + urlretrieve(S3_ROOT + file.name, fc_root / file) + else: + print(f"{file} already exists") if __name__ == "__main__": args = parse_args() - main(args.file_group) + download_file_group(args.file_group) diff --git a/src/fairchem/data/oc/core/bulk.py b/src/fairchem/data/oc/core/bulk.py index 0a57ed3f8..6710b4388 100644 --- a/src/fairchem/data/oc/core/bulk.py +++ b/src/fairchem/data/oc/core/bulk.py @@ -54,7 +54,7 @@ def __init__( else: if bulk_db is None: if bulk_db_path == BULK_PKL_PATH and not os.path.exists(BULK_PKL_PATH): - download_large_files.main("oc") + download_large_files.download_file_group("oc") with open(bulk_db_path, "rb") as fp: bulk_db = pickle.load(fp) diff --git a/src/fairchem/data/oc/databases/update.py b/src/fairchem/data/oc/databases/update.py index a30aea4ec..bab75709c 100644 --- a/src/fairchem/data/oc/databases/update.py +++ b/src/fairchem/data/oc/databases/update.py @@ -47,7 +47,7 @@ def update_pkls(): pickle.dump(data, fp) if not Path("oc/databases/pkls/bulks.pkl").exists(): - download_large_files.main("oc") + download_large_files.download_file_group("oc") with open( "oc/databases/pkls/bulks.pkl", "rb", diff --git a/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py b/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py index e086a2dab..547806cc0 100644 --- a/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py +++ b/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py @@ -8,7 +8,7 @@ from fairchem.core.scripts import download_large_files if not os.path.exists("adsorption_energy.txt"): - download_large_files.main("odac") + download_large_files.download_file_group("odac") raw_ads_energy_data = pd.read_csv("adsorption_energy.txt", header=None, sep=" ") complete_data = pd.DataFrame( index=range(raw_ads_energy_data.shape[0]), diff --git a/tests/applications/cattsunami/tests/conftest.py b/tests/applications/cattsunami/tests/conftest.py index 96b5cd8e8..9afdc0a96 100644 --- a/tests/applications/cattsunami/tests/conftest.py +++ b/tests/applications/cattsunami/tests/conftest.py @@ -22,7 +22,7 @@ def desorption_inputs(request): def dissociation_inputs(request): pkl_path = Path(__file__).parent / "autoframe_inputs_dissociation.pkl" if not pkl_path.exists(): - download_large_files.main("cattsunami") + download_large_files.download_file_group("cattsunami") with open(pkl_path, "rb") as fp: request.cls.inputs = pickle.load(fp) @@ -31,6 +31,6 @@ def dissociation_inputs(request): def transfer_inputs(request): pkl_path = Path(__file__).parent / "autoframe_inputs_transfer.pkl" if not pkl_path.exists(): - download_large_files.main("cattsunami") + download_large_files.download_file_group("cattsunami") with open(pkl_path, "rb") as fp: request.cls.inputs = pickle.load(fp) diff --git a/tests/core/test_download_large_files.py b/tests/core/test_download_large_files.py index 2f12c9db2..991f8ce34 100644 --- a/tests/core/test_download_large_files.py +++ b/tests/core/test_download_large_files.py @@ -13,4 +13,4 @@ def urlretrieve_mock(x, y): ) url_mock.side_effect = urlretrieve_mock - dl_large.main("ALL") + dl_large.download_file_group("ALL")