Skip to content

Commit

Permalink
Add inference_only option to trainers (#850)
Browse files Browse the repository at this point in the history
* deprecate ocp-collater

* add type hints to trainers

* fix deprecation warning

* add inference_only key to avoid loading unnecessary things

* no need to remove relax dataset anymore

* fix typo

* remove deprecated decorator

* fix normalizer loading

* fix typo
  • Loading branch information
lbluque authored Sep 13, 2024
1 parent 6c54ad9 commit 6ded0d3
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 87 deletions.
5 changes: 4 additions & 1 deletion src/fairchem/core/common/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
import torch.distributed
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
from typing_extensions import override
from typing_extensions import deprecated, override

from fairchem.core.common import distutils, gp_utils
from fairchem.core.datasets import data_list_collater
Expand All @@ -29,6 +29,9 @@
from torch_geometric.data import Batch, Data


@deprecated(
"OCPColatter is deprecated. Please use data_list_collater optionally with functools.partial to set defaults"
)
class OCPCollater:
def __init__(self, otf_graph: bool = False) -> None:
self.otf_graph = otf_graph
Expand Down
6 changes: 1 addition & 5 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,6 @@ def __init__(
config["model_attributes"]["name"] = config.pop("model")
config["model"] = config["model_attributes"]

# for checkpoints with relaxation datasets defined, remove to avoid
# unnecesarily trying to load that dataset
if "relax_dataset" in config.get("task", {}):
del config["task"]["relax_dataset"]

# Calculate the edge indices on the fly
config["model"]["otf_graph"] = True

Expand All @@ -189,6 +184,7 @@ def __init__(
is_debug=config.get("is_debug", True),
cpu=cpu,
amp=config.get("amp", False),
inference_only=True,
)

if checkpoint_path is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
pass



from .edge_rot_mat import init_edge_rot_mat
from .gaussian_rbf import GaussianRadialBasisLayer
from .input_block import EdgeDegreeEmbedding
Expand Down Expand Up @@ -155,7 +154,7 @@ def __init__(
load_energy_lin_ref: bool | None = False,
):
logging.warning(
"equiformer_v2 (EquiformerV2) class is deprecaed in favor of equiformer_v2_backbone_and_heads (EquiformerV2BackboneAndHeads)"
"equiformer_v2 (EquiformerV2) class is deprecated in favor of equiformer_v2_backbone_and_heads (EquiformerV2BackboneAndHeads)"
)
if mmax_list is None:
mmax_list = [2]
Expand Down
131 changes: 78 additions & 53 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import random
import sys
from abc import ABC, abstractmethod
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
import numpy.typing as npt
Expand All @@ -29,7 +30,7 @@

from fairchem.core import __version__
from fairchem.core.common import distutils, gp_utils
from fairchem.core.common.data_parallel import BalancedBatchSampler, OCPCollater
from fairchem.core.common.data_parallel import BalancedBatchSampler
from fairchem.core.common.registry import registry
from fairchem.core.common.slurm import (
add_timestamp_id_to_submission_pickle,
Expand All @@ -44,15 +45,19 @@
save_checkpoint,
update_config,
)
from fairchem.core.datasets import data_list_collater
from fairchem.core.datasets.base_dataset import create_dataset
from fairchem.core.modules.evaluator import Evaluator
from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage
from fairchem.core.modules.loss import DDPLoss
from fairchem.core.modules.normalization.element_references import (
LinearReferences,
create_element_references,
load_references_from_config,
)
from fairchem.core.modules.normalization.normalizer import load_normalizers_from_config
from fairchem.core.modules.normalization.normalizer import (
create_normalizer,
load_normalizers_from_config,
)
from fairchem.core.modules.scaling.compat import load_scales_compat
from fairchem.core.modules.scaling.util import ensure_fitted
from fairchem.core.modules.scheduler import LRScheduler
Expand All @@ -65,13 +70,13 @@
class BaseTrainer(ABC):
def __init__(
self,
task,
model,
outputs,
dataset,
optimizer,
loss_functions,
evaluation_metrics,
task: dict[str, str | Any],
model: dict[str, Any],
outputs: dict[str, str | int],
dataset: dict[str, str | float],
optimizer: dict[str, str | float],
loss_functions: dict[str, str | float],
evaluation_metrics: dict[str, str],
identifier: str,
# TODO: dealing with local rank is dangerous
# T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use
Expand All @@ -87,6 +92,7 @@ def __init__(
name: str = "ocp",
slurm=None,
gp_gpus: int | None = None,
inference_only: bool = False,
) -> None:
if slurm is None:
slurm = {}
Expand Down Expand Up @@ -200,12 +206,16 @@ def __init__(
if distutils.is_master():
logging.info(yaml.dump(self.config, default_flow_style=False))

# define attributes for readability
self.elementrefs = {}
self.normalizers = {}
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.load()
self.best_val_metric = None
self.primary_metric = None

self.load(inference_only)

@abstractmethod
def train(self, disable_eval_tqdm: bool = False) -> None:
Expand All @@ -224,20 +234,24 @@ def _get_timestamp(device: torch.device, suffix: str | None) -> str:
timestamp_str += "-" + suffix
return timestamp_str

def load(self) -> None:
def load(self, inference_only: bool) -> None:
self.load_seed_from_config()
self.load_logger()
self.load_datasets()
self.load_references_and_normalizers()
self.load_task()
self.load_model()
self.load_loss()
self.load_optimizer()
self.load_extras()

if inference_only is False:
self.load_datasets()
self.load_references_and_normalizers()
self.load_loss()
self.load_optimizer()
self.load_extras()

if self.config["optim"].get("load_datasets_and_model_then_exit", False):
sys.exit(0)

def set_seed(self, seed) -> None:
@staticmethod
def set_seed(seed) -> None:
# https://pytorch.org/docs/stable/notes/randomness.html
random.seed(seed)
np.random.seed(seed)
Expand Down Expand Up @@ -298,14 +312,16 @@ def get_sampler(
def get_dataloader(self, dataset, sampler) -> DataLoader:
return DataLoader(
dataset,
collate_fn=self.ocp_collater,
collate_fn=self.collater,
num_workers=self.config["optim"]["num_workers"],
pin_memory=True,
batch_sampler=sampler,
)

def load_datasets(self) -> None:
self.ocp_collater = OCPCollater(self.config["model"].get("otf_graph", False))
self.collater = partial(
data_list_collater, otf_graph=self.config["model"].get("otf_graph", False)
)
self.train_loader = None
self.val_loader = None
self.test_loader = None
Expand Down Expand Up @@ -345,7 +361,7 @@ def convert_settings_to_split_settings(config, split_name):
or "sample_n" in self.config["dataset"]
or "max_atom" in self.config["dataset"]
):
logging.warn(
logging.warning(
"Dataset attributes (first_n/sample_n/max_atom) passed to all datasets! Please don't do this, its dangerous!\n"
+ "Add them under each dataset 'train_split_settings'/'val_split_settings'/'test_split_settings'"
)
Expand Down Expand Up @@ -498,15 +514,15 @@ def load_task(self):
][target_name].get("level", "system")
if "train_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget]["train_on_free_atoms"] = (
self.config["outputs"][target_name].get(
"train_on_free_atoms", True
)
self.config[
"outputs"
][target_name].get("train_on_free_atoms", True)
)
if "eval_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget]["eval_on_free_atoms"] = (
self.config["outputs"][target_name].get(
"eval_on_free_atoms", True
)
self.config[
"outputs"
][target_name].get("eval_on_free_atoms", True)
)

# TODO: Assert that all targets, loss fn, metrics defined are consistent
Expand Down Expand Up @@ -559,7 +575,10 @@ def _unwrapped_model(self):
return module

def load_checkpoint(
self, checkpoint_path: str, checkpoint: dict | None = None
self,
checkpoint_path: str,
checkpoint: dict | None = None,
inference_only: bool | None = None,
) -> None:
map_location = torch.device("cpu") if self.cpu else self.device
if checkpoint is None:
Expand All @@ -570,24 +589,28 @@ def load_checkpoint(
logging.info(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=map_location)

self.epoch = checkpoint.get("epoch", 0)
self.step = checkpoint.get("step", 0)
self.best_val_metric = checkpoint.get("best_val_metric", None)
self.primary_metric = checkpoint.get("primary_metric", None)
# attributes that are necessary for training and validation
inference_only = self.train_dataset is None or inference_only
if inference_only is False:
self.epoch = checkpoint.get("epoch", 0)
self.step = checkpoint.get("step", 0)
self.best_val_metric = checkpoint.get("best_val_metric", None)
self.primary_metric = checkpoint.get("primary_metric", None)

new_dict = match_state_dict(self.model.state_dict(), checkpoint["state_dict"])
strict = self.config.get("task", {}).get("strict_load", True)
load_state_dict(self.model, new_dict, strict=strict)
if "optimizer" in checkpoint:
self.optimizer.load_state_dict(checkpoint["optimizer"])
if "scheduler" in checkpoint and checkpoint["scheduler"] is not None:
self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"])

if "optimizer" in checkpoint:
self.optimizer.load_state_dict(checkpoint["optimizer"])
if "scheduler" in checkpoint and checkpoint["scheduler"] is not None:
self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"])
if "ema" in checkpoint and checkpoint["ema"] is not None:
self.ema.load_state_dict(checkpoint["ema"])
else:
self.ema = None

new_dict = match_state_dict(self.model.state_dict(), checkpoint["state_dict"])
strict = self.config.get("task", {}).get("strict_load", True)
load_state_dict(self.model, new_dict, strict=strict)

scale_dict = checkpoint.get("scale_dict", None)
if scale_dict:
logging.info(
Expand All @@ -597,7 +620,7 @@ def load_checkpoint(
)
load_scales_compat(self._unwrapped_model, scale_dict)

for key in checkpoint["normalizers"]:
for key, state_dict in checkpoint["normalizers"].items():
### Convert old normalizer keys to new target keys
if key == "target":
target_key = "energy"
Expand All @@ -606,22 +629,24 @@ def load_checkpoint(
else:
target_key = key

if target_key in self.normalizers:
mkeys = self.normalizers[target_key].load_state_dict(
checkpoint["normalizers"][key]
)
self.normalizers[target_key].to(map_location)
if target_key not in self.normalizers:
self.normalizers[target_key] = create_normalizer(state_dict=state_dict)
else:
mkeys = self.normalizers[target_key].load_state_dict(state_dict)
assert len(mkeys.missing_keys) == 0
assert len(mkeys.unexpected_keys) == 0

self.normalizers[target_key].to(map_location)

for key, state_dict in checkpoint.get("elementrefs", {}).items():
elementrefs = LinearReferences(
max_num_elements=len(state_dict["element_references"]) - 1
).to(map_location)
mkeys = elementrefs.load_state_dict(state_dict)
self.elementrefs[key] = elementrefs
assert len(mkeys.missing_keys) == 0
assert len(mkeys.unexpected_keys) == 0
if key not in self.elementrefs:
self.elementrefs[key] = create_element_references(state_dict=state_dict)
else:
mkeys = self.elementrefs[key].load_state_dict(state_dict)
assert len(mkeys.missing_keys) == 0
assert len(mkeys.unexpected_keys) == 0

self.elementrefs[key].to(map_location)

if self.scaler and checkpoint["amp"]:
self.scaler.load_state_dict(checkpoint["amp"])
Expand Down
Loading

0 comments on commit 6ded0d3

Please sign in to comment.