diff --git a/src/py4vasp/_calculation/magnetism.py b/src/py4vasp/_calculation/magnetism.py index bef02311..1964cea9 100644 --- a/src/py4vasp/_calculation/magnetism.py +++ b/src/py4vasp/_calculation/magnetism.py @@ -228,11 +228,8 @@ def _collinear_moments(self): return self._raw_data.spin_moments[self._steps, 1] def _noncollinear_moments(self, selection): - spin_moments = self._raw_data.spin_moments[self._steps, 1:] - if self._has_orbital_moments: - orbital_moments = self._raw_data.orbital_moments[self._steps] - else: - orbital_moments = np.zeros_like(spin_moments) + spin_moments = self._spin_moments() + orbital_moments = self._orbital_moments(spin_moments) if selection == "orbital": moments = orbital_moments elif selection == "spin": @@ -242,11 +239,21 @@ def _noncollinear_moments(self, selection): direction_axis = 1 if moments.ndim == 4 else 0 return np.moveaxis(moments, direction_axis, -1) + def _spin_moments(self): + return self._raw_data.spin_moments[self._steps, 1:] + + def _orbital_moments(self, spin_moments): + if not self._has_orbital_moments: + return np.zeros_like(spin_moments) + zero_s_moments = np.zeros((*spin_moments.shape[:-1], 1)) + orbital_moments = self._raw_data.orbital_moments[self._steps] + return np.concatenate((zero_s_moments, orbital_moments), axis=-1) + def _add_spin_and_orbital_moments(self): if not self._has_orbital_moments: return {} - spin_moments = self._raw_data.spin_moments[self._steps, 1:] - orbital_moments = self._raw_data.orbital_moments[self._steps] + spin_moments = self._spin_moments() + orbital_moments = self._orbital_moments(spin_moments) direction_axis = 1 if spin_moments.ndim == 4 else 0 return { "spin_moments": np.moveaxis(spin_moments, direction_axis, -1), diff --git a/tests/calculation/test_magnetism.py b/tests/calculation/test_magnetism.py index c0e8427e..fa8fd477 100644 --- a/tests/calculation/test_magnetism.py +++ b/tests/calculation/test_magnetism.py @@ -63,7 +63,8 @@ def __getitem__(self, step): reference.moments = np.moveaxis(raw_magnetism.spin_moments[:, 1:4], 1, 3) else: spin_moments = np.moveaxis(raw_magnetism.spin_moments[:, 1:4], 1, 3) - orbital_moments = np.moveaxis(raw_magnetism.orbital_moments, 1, 3) + orbital_moments = np.zeros_like(spin_moments).astype(np.float64) + orbital_moments[:, :, 1:] += np.moveaxis(raw_magnetism.orbital_moments, 1, 3) reference.moments = spin_moments + orbital_moments reference.spin_moments = spin_moments reference.orbital_moments = orbital_moments diff --git a/tests/conftest.py b/tests/conftest.py index 4ec23f79..89148030 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -543,10 +543,8 @@ def _magnetism(selection): spin_moments=_make_data(np.arange(np.prod(shape)).reshape(shape)), ) if selection == "orbital_moments": - remove_charge_component = magnetism.spin_moments[:, 1:] - magnetism.orbital_moments = _make_data(np.sqrt(remove_charge_component)) - print("spin_moments", magnetism.spin_moments.shape) - print("orb_moments", magnetism.orbital_moments.shape) + remove_charge_and_s_component = magnetism.spin_moments[:, 1:, :, 1:] + magnetism.orbital_moments = _make_data(np.sqrt(remove_charge_and_s_component)) return magnetism