Skip to content

Commit

Permalink
Why should weights be None. A 0.0 weight is basically the same.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Aug 14, 2023
1 parent eb5fed1 commit 3a00612
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 22 deletions.
28 changes: 8 additions & 20 deletions matgl/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def __init__(
element_refs: np.ndarray | None = None,
energy_weight: float = 1.0,
force_weight: float = 1.0,
stress_weight: float | None = None,
site_wise_weight: float | None = None,
stress_weight: float = 0.0,
site_wise_weight: float = 0.0,
data_mean: float = 0.0,
data_std: float = 1.0,
calc_stress: bool = False,
Expand Down Expand Up @@ -308,7 +308,7 @@ def __init__(
self.decay_steps = decay_steps
self.decay_alpha = decay_alpha

calc_site_wise = site_wise_weight is not None
calc_site_wise = site_wise_weight != 0
self.model = Potential(
model=model,
element_refs=element_refs,
Expand Down Expand Up @@ -369,10 +369,6 @@ def step(self, batch: tuple):
loss=self.loss, # type: ignore
preds=preds,
labels=labels,
energy_weight=self.energy_weight,
force_weight=self.force_weight,
stress_weight=self.stress_weight,
site_wise_weight=self.site_wise_weight,
num_atoms=num_atoms,
)
batch_size = preds[0].numel()
Expand All @@ -384,10 +380,6 @@ def loss_fn(
loss: nn.Module,
labels: tuple,
preds: tuple,
energy_weight: float | None = None,
force_weight: float | None = None,
stress_weight: float | None = None,
site_wise_weight: float | None = None,
num_atoms: int | None = None,
):
"""Compute losses for EFS.
Expand All @@ -396,10 +388,6 @@ def loss_fn(
loss: Loss function.
labels: Labels.
preds: Predictions
energy_weight: Weight for energy loss.
force_weight: Weight for force loss.
stress_weight: Weight for stress loss.
site_wise_weight: Weight for site-wise loss.
num_atoms: Number of atoms.
Returns::
Expand Down Expand Up @@ -431,19 +419,19 @@ def loss_fn(
m_mae = torch.zeros(1)
m_rmse = torch.zeros(1)

total_loss = energy_weight * e_loss + force_weight * f_loss
total_loss = self.energy_weight * e_loss + self.force_weight * f_loss

if stress_weight is not None:
if self.stress_weight:
s_loss = loss(labels[2], preds[2])
s_mae = self.mae(labels[2], preds[2])
s_rmse = self.rmse(labels[2], preds[2])
total_loss = total_loss + stress_weight * s_loss
total_loss = total_loss + self.stress_weight * s_loss

if site_wise_weight is not None:
if self.site_wise_weight:
m_loss = loss(labels[3], preds[3])
m_mae = self.mae(labels[3], preds[3])
m_rmse = self.rmse(labels[3], preds[3])
total_loss = total_loss + site_wise_weight * m_loss
total_loss = total_loss + self.site_wise_weight * m_loss

return {
"Total_Loss": total_loss,
Expand Down
2 changes: 0 additions & 2 deletions tests/utils/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def test_m3gnet_training(self, LiFePO4, BaNiO3):

trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

model = model.to(torch.device(device))
pred_LFP_energy = model.predict_structure(LiFePO4)
pred_BNO_energy = model.predict_structure(BaNiO3)

Expand Down Expand Up @@ -167,7 +166,6 @@ def test_m3gnet_property_training(self, LiFePO4, BaNiO3):

trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

model = model.to(torch.device(device))
pred_LFP_energy = model.predict_structure(LiFePO4)
pred_BNO_energy = model.predict_structure(BaNiO3)

Expand Down

0 comments on commit 3a00612

Please sign in to comment.