Skip to content

Commit

Permalink
Check shapes of arrays is consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-schlipf committed Jun 18, 2024
1 parent 0228cf2 commit 2a17d36
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/py4vasp/_third_party/view/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _verify(self):
self._raise_error_if_present_on_multiple_steps(self.grid_scalars)
self._raise_error_if_present_on_multiple_steps(self.ion_arrows)
self._raise_error_if_number_steps_inconsistent()
self._raise_error_if_any_shape_is_incorrect()

def _raise_error_if_present_on_multiple_steps(self, attributes):
if not attributes:
Expand All @@ -171,6 +172,17 @@ def _raise_error_if_number_steps_inconsistent(self):
"steps."
)

def _raise_error_if_any_shape_is_incorrect(self):
number_elements = len(self.elements[0])
_, number_positions, vector_size = np.shape(self.positions)
if number_elements != number_positions:
raise exception.IncorrectUsage(f"Number of elements ({number_elements}) inconsistent with number of positions ({number_positions}).")
if vector_size != 3:
raise exception.IncorrectUsage(f"Positions must have three components and not {vector_size}.")
cell_shape = np.shape(self.lattice_vectors)[1:]
if any(length != 3 for length in cell_shape):
raise exception.IncorrectUsage(f"Lattice vectors must be a 3x3 unit cell but have the shape {cell_shape}.")

def _create_atoms(self, step):
symbols = "".join(self.elements[step])
atoms = ase.Atoms(symbols)
Expand Down
11 changes: 11 additions & 0 deletions tests/third_party/view/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,14 @@ def test_different_number_of_steps_raises_error(view):
broken_view = copy.copy(view)
broken_view.positions = too_many_positions
broken_view.to_ngl()

def test_incorrect_shape_raises_error(view):
different_number_atoms = np.zeros((len(view.positions), 7, 3))
with pytest.raises(exception.IncorrectUsage):
View(view.elements, view.lattice_vectors, different_number_atoms)
not_a_three_component_vector = np.array(view.positions)[:,:,:2]
with pytest.raises(exception.IncorrectUsage):
View(view.elements, view.lattice_vectors, not_a_three_component_vector)
incorrect_unit_cell = np.zeros((len(view.lattice_vectors), 2, 4))
with pytest.raises(exception.IncorrectUsage):
View(view.elements, incorrect_unit_cell, view.positions)

0 comments on commit 2a17d36

Please sign in to comment.