diff --git a/mace/__version__.py b/mace/__version__.py index e19434e2..334b8995 100644 --- a/mace/__version__.py +++ b/mace/__version__.py @@ -1 +1 @@ -__version__ = "0.3.3" +__version__ = "0.3.4" diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 31170d3d..edb91b14 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -175,7 +175,9 @@ def from_config( else None ) virials = ( - torch.tensor(config.virials, dtype=torch.get_default_dtype()).unsqueeze(0) + voigt_to_matrix( + torch.tensor(config.virials, dtype=torch.get_default_dtype()) + ).unsqueeze(0) if config.virials is not None else None ) diff --git a/mace/data/utils.py b/mace/data/utils.py index 0069550f..908fdc17 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -17,8 +17,8 @@ Vector = np.ndarray # [3,] Positions = np.ndarray # [..., 3] Forces = np.ndarray # [..., 3] -Stress = np.ndarray # [6, ] -Virials = np.ndarray # [3,3] +Stress = np.ndarray # [6, ], [3,3], [9, ] +Virials = np.ndarray # [6, ], [3,3], [9, ] Charges = np.ndarray # [..., 1] Cell = np.ndarray # [3,3] Pbc = tuple # (3,) diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py index e0c4d546..349f1e3b 100644 --- a/mace/tools/torch_tools.py +++ b/mace/tools/torch_tools.py @@ -107,7 +107,7 @@ def cartesian_to_spherical(t: torch.Tensor): def voigt_to_matrix(t: torch.Tensor): """ Convert voigt notation to matrix notation - :param t: (6,) tensor or (3, 3) tensor + :param t: (6,) tensor or (3, 3) tensor or (9,) tensor :return: (3, 3) tensor """ if t.shape == (3, 3): @@ -121,9 +121,11 @@ def voigt_to_matrix(t: torch.Tensor): ], dtype=t.dtype, ) + if t.shape == (9,): + return t.view(3, 3) raise ValueError( - f"Stress tensor must be of shape (6,) or (3, 3), but has shape {t.shape}" + f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" )