diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 020416344..c17b8bda7 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -88,6 +88,7 @@ def __init__( distance_resolution: float = 0.02, show_timing_info: bool = False, resolution: int | None = None, + activation_checkpoint: bool | None = False, ) -> None: if mmax_list is None: mmax_list = [2] @@ -101,6 +102,7 @@ def __init__( logging.error("You need to install the e3nn library to use the SCN model") raise ImportError + self.activation_checkpoint = activation_checkpoint self.regress_forces = regress_forces self.use_pbc = use_pbc self.use_pbc_single = use_pbc_single @@ -287,22 +289,19 @@ def forward(self, data): ############################################################### for i in range(self.num_layers): - if i > 0: - x_message = self.layer_blocks[i]( + if self.activation_checkpoint: + x_message = torch.utils.checkpoint.checkpoint( + self.layer_blocks[i], x, atomic_numbers, graph.edge_distance, graph.edge_index, self.SO3_edge_rot, mappingReduced, + use_reentrant=not self.training, ) - - # Residual layer for all layers past the first - x.embedding = x.embedding + x_message.embedding - else: - # No residual for the first layer - x = self.layer_blocks[i]( + x_message = self.layer_blocks[i]( x, atomic_numbers, graph.edge_distance, @@ -311,6 +310,12 @@ def forward(self, data): mappingReduced, ) + if i > 0: + # Residual layer for all layers past the first + x.embedding = x.embedding + x_message.embedding + else: + x = x_message + # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. # These values are fed into the output blocks. x_pt = torch.tensor([], device=device) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index c78725db1..68e3c8f65 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -316,6 +316,52 @@ def test_train_and_predict( num_workers=0, ) + # test that both escn and equiv2 run with activation checkpointing + @pytest.mark.parametrize( + ("model_name"), + [ + ("escn_hydra"), + ("equiformer_v2_hydra"), + ], + ) + def test_train_and_predict_with_checkpointing( + self, + model_name, + configs, + tutorial_val_src, + ): + with tempfile.TemporaryDirectory() as tempdirname: + # first train a very simple model, checkpoint + train_rundir = Path(tempdirname) / "train" + train_rundir.mkdir() + checkpoint_path = str(train_rundir / "checkpoint.pt") + training_predictions_filename = str(train_rundir / "train_predictions.npz") + update_dict = { + "optim": { + "max_epochs": 2, + "eval_every": 8, + "batch_size": 5, + "num_workers": 1, + }, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + } + if "hydra" in model_name: + update_dict["model"] = {"backbone": {"activation_checkpoint": True}} + else: + update_dict["model"] = {"activation_checkpoint": True} + acc = _run_main( + rundir=str(train_rundir), + input_yaml=configs[model_name], + update_dict_with=update_dict, + save_checkpoint_to=checkpoint_path, + save_predictions_to=training_predictions_filename, + world_size=1, + ) + def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic): with tempfile.TemporaryDirectory() as tempdirname: tempdir = Path(tempdirname)