Skip to content

Commit

Permalink
clean up setting defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Sep 20, 2024
1 parent 5fcbd42 commit 7c6244f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
22 changes: 12 additions & 10 deletions src/fairchem/core/common/relaxation/ml_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING

import torch
from torch_geometric.data import Batch
Expand All @@ -20,10 +21,13 @@
from .optimizers.lbfgs_torch import LBFGS
from .optimizers.optimizable import OptimizableBatch, OptimizableUnitCellBatch

if TYPE_CHECKING:
from fairchem.core.trainers import BaseTrainer


def ml_relax(
batch,
model,
batch: Batch,
model: BaseTrainer,
steps: int,
fmax: float,
relax_opt: dict[str],
Expand All @@ -38,7 +42,7 @@ def ml_relax(
Args:
batch: a data batch object.
model: a trainer object with model.q
model: a trainer object with model.
steps: Max number of steps in the structure relaxation.
fmax: Structure relaxation terminates when the max force of the system is no bigger than fmax.
relax_opt: Optimizer parameters to be used for structure relaxations.
Expand Down Expand Up @@ -82,16 +86,14 @@ def ml_relax(

# Run ML-based relaxation
traj_dir = relax_opt.get("traj_dir")
relax_opt.update({"traj_dir": Path(traj_dir) if traj_dir is not None else None})

optimizer = LBFGS(
optimizable_batch=optimizable,
maxstep=relax_opt.get("maxstep", 0.2),
memory=relax_opt["memory"],
damping=relax_opt.get("damping", 1.2),
alpha=relax_opt.get("alpha", 80.0),
device=device,
save_full_traj=save_full_traj,
traj_dir=Path(traj_dir) if traj_dir is not None else None,
traj_names=ids,
**relax_opt,
)

e: RuntimeError | None = None
Expand Down Expand Up @@ -126,8 +128,8 @@ def ml_relax(

# Batch.from_data_list is not intended to be used with a list of batches, so when sid is a list of str
# it will be incorrectly collated as a list of lists for each batch.
# but we can not use to_data_list in the relaxed batches (since the have been changed, see linked comment above).
# So instead just manually fix it
# but we can not use to_data_list in the relaxed batches (since they have been changed, see linked comment above).
# So instead just manually fix it for now. Remove this once pyg dependency is removed
if isinstance(relaxed_batch.sid, list):
relaxed_batch.sid = [sid for sid_list in relaxed_batch.sid for sid in sid_list]

Expand Down
4 changes: 2 additions & 2 deletions src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class LBFGS:
def __init__(
self,
optimizable_batch: OptimizableBatch,
maxstep: float = 0.01,
maxstep: float = 0.02,
memory: int = 100,
damping: float = 0.25,
damping: float = 1.2,
alpha: float = 100.0,
device: str = "cuda:0",
save_full_traj: bool = True,
Expand Down

0 comments on commit 7c6244f

Please sign in to comment.