Skip to content

Commit

Permalink
fix few comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 13, 2024
1 parent ba23d7d commit a6abf92
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
12 changes: 5 additions & 7 deletions src/fairchem/core/models/finetune_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def load_hydra_model(checkpoint_path: str) -> HydraInterface:
)
logging.info(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
config_copy = copy.deepcopy(checkpoint["config"]["model"])
name = config_copy.pop("name")
hydra_model = registry.get_model_class(name)(**config_copy)
config = checkpoint["config"]["model"]
name = config.pop("name")
hydra_model = registry.get_model_class(name)(**config)
assert isinstance(
hydra_model, HydraInterface
), "Can only load models with the HydraInterface"
Expand Down Expand Up @@ -85,13 +85,13 @@ def load_model(self) -> nn.Module:
config_copy = copy.deepcopy(self.config[FTConfig.STARTING_MODEL])
name = config_copy.pop("name")
hydra_model = registry.get_model_class(name)(**config_copy)
assert isinstance(hydra_model, HydraInterface)
# if provided a checkpoint to start then load the model and weights from the given checkpoint
# this happens used in the beginning of a finetuning run
elif FTConfig.STARTING_CHECKPOINT in self.config:
hydra_model: HydraInterface = load_hydra_model(
self.config[FTConfig.STARTING_CHECKPOINT]
)
assert isinstance(hydra_model, HydraInterface)

num_params = sum(p.numel() for p in hydra_model.parameters())
logging.info(f"Loaded Original hydra model with {num_params} params")
Expand All @@ -115,9 +115,7 @@ def get_standalone_config(self) -> dict:
)
)
standalone_config[FTConfig.FT_CONFIG_NAME] = new_config
return standalone_config
else:
return standalone_config
return standalone_config

@property
def mode(self) -> FineTuneMode:
Expand Down
2 changes: 1 addition & 1 deletion src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def get_dataloader(self, dataset, sampler) -> DataLoader:
)

def load_datasets(self) -> None:
self.ocp_collater = OCPCollater(self.config["model"].get("otf_graph", True))
self.ocp_collater = OCPCollater(self.config["model"].get("otf_graph", False))
self.train_loader = None
self.val_loader = None
self.test_loader = None
Expand Down

0 comments on commit a6abf92

Please sign in to comment.