Skip to content

Commit

Permalink
Merge branch 'main' into add_activation_checkpointing_to_escn
Browse files Browse the repository at this point in the history
  • Loading branch information
misko authored Sep 23, 2024
2 parents a9cdee7 + 83fd9d2 commit a8bb2bf
Show file tree
Hide file tree
Showing 37 changed files with 585 additions and 361 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
python-version: '3.12'

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.11
python-version: 3.12

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
max-parallel: 10
matrix:
python_version: ['3.9', '3.11']
python_version: ['3.9', '3.12']

steps:
- uses: actions/checkout@v4
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.11
python-version: 3.12
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -29,4 +29,5 @@ jobs:
- name: ruff
run: |
ruff --version
ruff check src
ruff check src # tests has a lot of issues , TODO
ruff format --check src # tests
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
max-parallel: 10
matrix:
python_version: ['3.9', '3.11']
python_version: ['3.9', '3.12']

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -56,7 +56,7 @@ jobs:
run: |
pytest tests -vv --ignore=tests/demo/ocpapi/tests/integration/ --cov-report=xml --cov=fairchem -c ./packages/fairchem-core/pyproject.toml
- if: ${{ matrix.python_version == '3.11' }}
- if: ${{ matrix.python_version == '3.12' }}
name: codecov-report
uses: codecov/codecov-action@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion docs/core/model_checkpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ This page summarizes all the pretrained models released as part of the [Open Cat
| SCN-t4-b2-S2EF-OC20-2M | SCN-t4-b2 | 2M |[checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/scn_t4_b2_s2ef_2M.pt) \| [config](https://github.com/FAIR-Chem/fairchem/blob/main/configs/s2ef/2M/scn/scn-t4-b2.yml) |0.0193 |2.68% |
| SCN-S2EF-OC20-All+MD | SCN | All+MD |[checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/scn_all_md_s2ef.pt) \| [config](https://github.com/FAIR-Chem/fairchem/blob/main/configs/s2ef/all/scn/scn-all-md.yml) |0.0160 |5.08% |
| eSCN-L4-M2-Lay12-S2EF-OC20-2M | eSCN-L4-M2-Lay12 | 2M |[checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l4_m2_lay12_2M_s2ef.pt) \| [config](https://github.com/FAIR-Chem/fairchem/blob/main/configs/s2ef/2M/escn/eSCN-L4-M2-Lay12.yml) |0.0191 |2.55% |
| eSCN-L6-M2-Lay12-S2EF-OC20-2M | eSCN-L6-M2-Lay12 | 2M |[checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m2_lay12_2M_s2ef.pt) \| [config](https://github.com/FAIR-Chem/fairchem/blob/main/configs/s2ef/2M/escn/eSCN-L6-M2-Lay12.yml) |0.0186 |2.66% |
| eSCN-L6-M2-Lay12-S2EF-OC20-2M | eSCN-L6-M2-Lay12 | 2M |[checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m2_lay12_2M_s2ef.pt) \| [config](https://github.com/FAIR-Chem/fairchem/blob/main/configs/s2ef/2M/escn/eSCN-L6-M2-Lay12.yml) \| [exported](https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m2_lay12_2M_s2ef_export_cuda_9182024.pt2) |0.0186 |2.66% |
| eSCN-L6-M2-Lay12-S2EF-OC20-All+MD | eSCN-L6-M2-Lay12 | All+MD |[checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m2_lay12_all_md_s2ef.pt) \| [config](https://github.com/FAIR-Chem/fairchem/blob/main/configs/s2ef/all/escn/eSCN-L6-M2-Lay12-All-MD.yml) |0.0161 |4.28% |
| eSCN-L6-M3-Lay20-S2EF-OC20-All+MD | eSCN-L6-M3-Lay20 | All+MD |[checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m3_lay20_all_md_s2ef.pt) \| [config](https://github.com/FAIR-Chem/fairchem/blob/main/configs/s2ef/all/escn/eSCN-L6-M3-Lay20-All-MD.yml) |0.0139 |6.64% |
| EquiformerV2-83M-S2EF-OC20-2M | EquiformerV2 (83M) | 2M |[checkpoint](https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_83M_2M.pt) \| [config](https://github.com/FAIR-Chem/fairchem/blob/main/configs/s2ef/2M/equiformer_v2/equiformer_v2_N@12_L@[email protected]) |0.0167 |4.26% |
Expand Down
17 changes: 10 additions & 7 deletions packages/env.cpu.yml
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
channels:
- pytorch
- pyg
- conda-forge
- defaults
dependencies:
- cpuonly
- pytorch>=2
- pyg
- pytorch-scatter
- pytorch-sparse
- pytorch-cluster
- pytorch>=2.4
- ase
- e3nn>=0.5
- numpy >=1.25.0,<2.0.0
- numpy >=1.26.0,<2.0.0
- pymatgen>=2023.10.3
- numba
- orjson
- pip
- pip:
- --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
- torch_cluster==1.6.3+pt24cpu
- torch_geometric==2.5.3
- pyg-lib==0.4.0+pt24cpu
- torch_scatter==2.1.2+pt24cpu
- torch_sparse==0.6.18+pt24cpu
- torch_spline_conv==1.2.2+pt24cpu
- pyyaml
- tqdm
- python-lmdb
Expand Down
19 changes: 11 additions & 8 deletions packages/env.gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@ channels:
- pytorch
- nvidia
- conda-forge
- pyg
- defaults
dependencies:
- pytorch-cuda=11.8
- pytorch>=2
- pytorch-scatter
- pytorch-sparse
- pytorch-cluster
- pyg
- pytorch-cuda=12.1
- pytorch>=2.4
- ase
- e3nn>=0.5
- numpy >=1.25.0,<2.0.0
- numpy >=1.26.0,<2.0.0
- pymatgen>=2023.10.3
- numba
- orjson
- pip
- pip:
- --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
- torch_cluster==1.6.3+pt24cu121
- torch_geometric==2.5.3
- pyg-lib==0.4.0+pt24cu121
- torch_scatter==2.1.2+pt24cu121
- torch_sparse==0.6.18+pt24cu121
- torch_spline_conv==1.2.2+pt24cu121
- pyyaml
- tqdm
- python-lmdb
Expand Down
2 changes: 1 addition & 1 deletion packages/fairchem-applications-cattsunami/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dynamic = ["version", "readme"]
description = "Accelerating Transition State Energy Calculations with Pre-trained Graph Neural Networks"
license = {text = "MIT License"}
dependencies = [
"torch>=2.2",
"torch>=2.4",
"scipy",
"ase",
"networkx",
Expand Down
4 changes: 2 additions & 2 deletions packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ license = {text = "MIT License"}
dynamic = ["version", "readme"]
requires-python = ">=3.9, <3.13"
dependencies = [
"torch>=2.2",
"numpy >=1.25.0, <2.0.0",
"torch>=2.4",
"numpy >=1.26.0, <2.0.0",
"lmdb",
"ase",
"pymatgen>=2023.10.3",
Expand Down
2 changes: 1 addition & 1 deletion packages/fairchem-data-oc/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dynamic = ["version", "readme"]
description = "Code for generating adsorbate-catalyst input configurations"
license = {text = "MIT License"}
dependencies = [
"numpy >=1.25.0, <2.0.0",
"numpy >=1.26.0, <2.0.0",
"scipy",
"matplotlib",
"ase", # this was pinned to 3.22.1
Expand Down
2 changes: 1 addition & 1 deletion packages/fairchem-demo-ocpapi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
dependencies = [
"dataclasses-json == 0.6.0",
"inquirer == 3.1.3",
"requests == 2.32.0",
"requests == 2.32.3",
"responses == 0.23.2",
"tenacity == 8.2.3",
"tqdm == 4.66.1",
Expand Down
4 changes: 2 additions & 2 deletions packages/requirements-optional.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch_geometric==2.3.0
-f https://data.pyg.org/whl/torch-2.2.0+cpu.html
torch_geometric==2.6.0
-f https://data.pyg.org/whl/torch-2.4.0+cpu.html
torch_scatter==2.1.2
torch_sparse==0.6.18
torch_cluster==1.6.3
4 changes: 2 additions & 2 deletions packages/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch==2.2.0
numpy==1.23.5
torch==2.4.1
numpy==1.26.4
ase==3.23.0
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include = ["src/fairchem/core/**/*.py", "src/fairchem/data/oc/**/*.py"]
include = ["src/fairchem/core/**/*.py", "src/fairchem/data/oc/**/*.py", "tests/**/*.py"]
line-length = 88

[lint]
Expand Down
20 changes: 13 additions & 7 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,19 @@ def checkpoint(self, *args, **kwargs):
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
logging.info(f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}')
logging.info(
f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}'
)
return DelayedSubmission(new_runner, self.config)


def runner_wrapper(config: dict):
Runner()(config)


def main(args: argparse.Namespace | None = None, override_args: list[str] | None = None):
def main(
args: argparse.Namespace | None = None, override_args: list[str] | None = None
):
"""Run the main fairchem program."""
setup_logging()

Expand All @@ -66,7 +70,9 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None
args, override_args = parser.parse_known_args()

# TODO: rename num_gpus -> num_ranks everywhere
assert args.num_gpus > 0, "num_gpus is used to determine number ranks, so it must be at least 1"
assert (
args.num_gpus > 0
), "num_gpus is used to determine number ranks, so it must be at least 1"
config = build_config(args, override_args)

if args.submit: # Run on cluster
Expand Down Expand Up @@ -98,9 +104,7 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None

else: # Run locally on a single node, n-processes
if args.num_gpus > 1:
logging.info(
f"Running in local mode with {args.num_gpus} ranks"
)
logging.info(f"Running in local mode with {args.num_gpus} ranks")
# HACK to disable multiprocess dataloading in local mode
# there is an open issue where LMDB's environment cannot be pickled and used
# during torch multiprocessing https://github.com/pytorch/examples/issues/526
Expand All @@ -119,7 +123,9 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None
)
elastic_launch(launch_config, runner_wrapper)(config)
else:
logging.info("Running in local mode without elastic launch (single gpu only)")
logging.info(
"Running in local mode without elastic launch (single gpu only)"
)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
Expand Down
16 changes: 13 additions & 3 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
T = TypeVar("T")
DISTRIBUTED_PORT = 13356


def os_environ_get_or_throw(x: str) -> str:
if x not in os.environ:
raise RuntimeError(f"Could not find {x} in ENV variables")
Expand Down Expand Up @@ -68,7 +69,9 @@ def setup(config) -> None:
)

# ensures GPU0 does not have extra context/higher peak memory
logging.info(f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}")
logging.info(
f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}"
)
torch.cuda.set_device(config["local_rank"])

dist.init_process_group(
Expand Down Expand Up @@ -104,13 +107,20 @@ def setup(config) -> None:
)
else:
if not os.environ.get("MASTER_ADDR"):
assert config["world_size"] == 1, "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup"
assert (
config["world_size"] == 1
), "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
config["local_rank"] = int(os.environ.get("LOCAL_RANK"))
dist.init_process_group(backend=config["distributed_backend"], rank=int(os.environ.get("RANK")), world_size=config["world_size"], timeout=timeout)
dist.init_process_group(
backend=config["distributed_backend"],
rank=int(os.environ.get("RANK")),
world_size=config["world_size"],
timeout=timeout,
)


def cleanup() -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def log_summary(self, summary_dict: dict[str, Any]) -> None:
def log_artifact(self, name: str, type: str, file_location: str) -> None:
pass


@registry.register_logger("wandb")
class WandBLogger(Logger):
def __init__(self, config) -> None:
Expand Down Expand Up @@ -115,6 +116,7 @@ def log_artifact(self, name: str, type: str, file_location: str) -> None:
art.add_file(file_location)
art.save()


@registry.register_logger("tensorboard")
class TensorboardLogger(Logger):
def __init__(self, config) -> None:
Expand Down
7 changes: 6 additions & 1 deletion src/fairchem/core/common/profiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
if TYPE_CHECKING:
from fairchem.core.common.logger import Logger


def get_default_profiler_handler(run_id: str, output_dir: str, logger: Logger):
"""Get a standard callback handle for the pytorch profiler"""

Expand All @@ -20,9 +21,13 @@ def trace_handler(p):
print(f"Saving trace in {output_path}")
p.export_chrome_trace(output_path)
if logger:
logger.log_artifact(name=trace_name, type="profile", file_location=output_path)
logger.log_artifact(
name=trace_name, type="profile", file_location=output_path
)

return trace_handler


def get_profile_schedule(wait: int = 5, warmup: int = 5, active: int = 2):
"""Get a profile schedule and total number of steps to run
check pytorch docs on the meaning of these paramters:
Expand Down
4 changes: 3 additions & 1 deletion src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def load_checkpoint(
Path to trained model
"""
try:
self.trainer.load_checkpoint(checkpoint_path, checkpoint)
self.trainer.load_checkpoint(
checkpoint_path, checkpoint, inference_only=True
)
except NotImplementedError:
logging.warning("Unable to load checkpoint!")

Expand Down
8 changes: 6 additions & 2 deletions src/fairchem/core/common/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from submitit.core.utils import JobPaths


def add_timestamp_id_to_submission_pickle(slurm_folder: str, slurm_job_id: str, timestamp_id: str):
def add_timestamp_id_to_submission_pickle(
slurm_folder: str, slurm_job_id: str, timestamp_id: str
):
# Try to put the timestamp-id into the original submission pickle's config
# so that if the node crashes, it can be pick up the correct run to resume
#
# we need to do this after the job has started because the timestamp-id is generated at runtime
# instead a-priori before the submission starts (ie: if we had a db to store a global job unique job)
submission_pickle_path = JobPaths(folder=slurm_folder, job_id=slurm_job_id).submitted_pickle
submission_pickle_path = JobPaths(
folder=slurm_folder, job_id=slurm_job_id
).submitted_pickle
try:
with open(str(submission_pickle_path), "rb") as f:
pkl = pickle.load(f)
Expand Down
1 change: 1 addition & 0 deletions src/fairchem/core/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def spawn_multi_process(

return [mp_output_dict[i] for i in range(config.world_size)]


def init_local_distributed_process_group(backend="nccl"):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
Expand Down
4 changes: 3 additions & 1 deletion src/fairchem/core/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _metadata(self) -> DatasetMetadata:
for field in DatasetMetadata._fields
}
)

assert np.issubdtype(
metadata.natoms.dtype, np.integer
), f"Metadata natoms must be an integer type! not {metadata.natoms.dtype}"
assert metadata.natoms.shape[0] == len(
self
), "Loaded metadata and dataset size mismatch."
Expand Down
4 changes: 3 additions & 1 deletion src/fairchem/core/datasets/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def sample_property_metadata(self, num_samples: int = 100):
}


def data_list_collater(data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False) -> BaseData | dict[str, torch.Tensor]:
def data_list_collater(
data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False
) -> BaseData | dict[str, torch.Tensor]:
batch = Batch.from_data_list(data_list)

if not otf_graph:
Expand Down
Loading

0 comments on commit a8bb2bf

Please sign in to comment.