Skip to content

Commit

Permalink
Smooth l1 loss function is added and the united tests are improved. (#…
Browse files Browse the repository at this point in the history
…289)

* improve TensorNet model coverage

* Update pyproject.toml

Signed-off-by: Tsz Wai Ko <[email protected]>

* Improve the unit test for SO(3) equivarance in TensorNet class

* improve SO3Net model class coverage and simplify TensorNet implementations

* improve the coverage in MLP_norm class

* Improve the implementation of three-body interactions

* fixed black

* Optimize the speed of _compute_3body class

* type checking is added for scheduler

* update M3GNet Potential training notebook for the demonstration of obtaining and using element offsets

* Downgrade sympy to avoid crash of SO3 operations

* Smooth l1 loss function is added and united tests are improved

---------

Signed-off-by: Tsz Wai Ko <[email protected]>
  • Loading branch information
kenko911 authored Jul 17, 2024
1 parent 6006ab4 commit 1cb40b5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
11 changes: 10 additions & 1 deletion src/matgl/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
data_mean: float = 0.0,
data_std: float = 1.0,
loss: str = "mse_loss",
loss_params: dict | None = None,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
lr: float = 0.001,
Expand All @@ -163,6 +164,7 @@ def __init__(
data_mean: average of training data
data_std: standard deviation of training data
loss: loss function used for training
loss_params: parameters for loss function
optimizer: optimizer for training
scheduler: scheduler for training
lr: learning rate for training
Expand All @@ -184,11 +186,16 @@ def __init__(
self.decay_alpha = decay_alpha
if loss == "mse_loss":
self.loss = F.mse_loss
elif loss == "huber_loss":
self.loss = F.huber_loss
elif loss == "smooth_l1_loss":
self.loss = F.smooth_l1_loss
else:
self.loss = F.l1_loss
self.optimizer = optimizer
self.scheduler = scheduler
self.sync_dist = sync_dist
self.loss_params = loss_params if loss_params is not None else {}
self.save_hyperparameters(ignore=["model"])

def forward(
Expand Down Expand Up @@ -243,7 +250,7 @@ def loss_fn(self, loss: nn.Module, labels: torch.Tensor, preds: torch.Tensor):
{"Total_Loss": total_loss, "MAE": mae, "RMSE": rmse}
"""
scaled_pred = torch.reshape(preds * self.data_std + self.data_mean, labels.size())
total_loss = loss(labels, scaled_pred)
total_loss = loss(labels, scaled_pred, **self.loss_params)
mae = self.mae(labels, scaled_pred)
rmse = self.rmse(labels, scaled_pred)
return {"Total_Loss": total_loss, "MAE": mae, "RMSE": rmse}
Expand Down Expand Up @@ -339,6 +346,8 @@ def __init__(
self.loss = F.mse_loss
elif loss == "huber_loss":
self.loss = F.huber_loss
elif loss == "smooth_l1_loss":
self.loss = F.smooth_l1_loss
else:
self.loss = F.l1_loss
self.loss_params = loss_params if loss_params is not None else {}
Expand Down
14 changes: 10 additions & 4 deletions tests/utils/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def test_so3net_training(self, LiFePO4, BaNiO3):
generator=torch.Generator(device=device),
)
model = SO3Net(element_types=element_types, lmax=2, is_intensive=False)
lit_model = PotentialLightningModule(model=model, stress_weight=0.0001)
lit_model = PotentialLightningModule(
model=model, stress_weight=0.0001, loss="huber_loss", loss_params={"delta": 1.0}
)
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=2, accelerator=device, inference_mode=False)

Expand Down Expand Up @@ -264,7 +266,9 @@ def test_tensornet_training(self, LiFePO4, BaNiO3):
generator=torch.Generator(device=device),
)
model = TensorNet(element_types=element_types, is_intensive=False)
lit_model = PotentialLightningModule(model=model, stress_weight=0.0001)
lit_model = PotentialLightningModule(
model=model, stress_weight=0.0001, loss="smooth_l1_loss", loss_params={"beta": 1.0}
)
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=2, accelerator=device, inference_mode=False)

Expand Down Expand Up @@ -387,7 +391,9 @@ def test_m3gnet_property_training(self, LiFePO4, BaNiO3):
is_intensive=True,
readout_type="set2set",
)
lit_model = ModelLightningModule(model=model, include_line_graph=True)
lit_model = ModelLightningModule(
model=model, include_line_graph=True, loss="huber_loss", loss_params={"delta": 1.0}
)
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=2, accelerator=device)

Expand Down Expand Up @@ -458,7 +464,7 @@ def test_so3net_property_training(self, LiFePO4, BaNiO3):
target_property="graph",
readout_type="set2set",
)
lit_model = ModelLightningModule(model=model)
lit_model = ModelLightningModule(model=model, loss="smooth_l1_loss", loss_params={"beta": 1.0})
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=2, accelerator=device)

Expand Down

0 comments on commit 1cb40b5

Please sign in to comment.