Skip to content

Commit

Permalink
fixed lattice plane passing to contour and updated tests accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelWolloch committed May 16, 2024
1 parent e56f041 commit 5e476b7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
7 changes: 4 additions & 3 deletions src/py4vasp/_third_party/graph/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ class Contour(trace.Trace):
the dimensions should be the ones of the grid, if the data is 3d the first dimension
should be a 2 for a vector in the plane of the grid and the other two dimensions
should be the grid."""
lattice: Lattice
"""2 vectors spanning the plane in which the data is represented. Each vector should
have two components, so remove any element normal to the plane."""
lattice: Plane
"""Lattice plane in which the data is represented spanned by 2 vectors.
Each vector should have two components, so remove any element normal to
the plane. Can be generated with the 'plane' function in py4vasp._util.slicing."""
label: str
"Assign a label to the visualization that may be used to identify one among multiple plots."
isolevels: bool = False
Expand Down
26 changes: 19 additions & 7 deletions src/py4vasp/calculation/_partial_charge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from py4vasp._third_party.graph import Graph
from py4vasp._third_party.graph.contour import Contour
from py4vasp._util import import_, select
from py4vasp._util.slicing import plane
from py4vasp.calculation import _base, _structure

interpolate = import_.optional("scipy.interpolate")
Expand Down Expand Up @@ -201,15 +202,15 @@ def _constant_current_stm(self, smoothed_charge, current, spin):
spin_label = "both spin channels" if spin == "total" else f"spin {spin}"
topology = self._topology()
label = f"STM of {topology} for {spin_label} at constant current={current*1e9:.2f} nA"
return Contour(data=scan, lattice=self._in_plane_vectors(), label=label)
return Contour(data=scan, lattice=self._get_stm_plane(), label=label)

def _constant_height_stm(self, smoothed_charge, tip_height, spin):
zz = self._z_index_for_height(tip_height + self._get_highest_z_coord())
height_scan = smoothed_charge[:, :, zz] * self.stm_settings.enhancement_factor
spin_label = "both spin channels" if spin == "total" else f"spin {spin}"
topology = self._topology()
label = f"STM of {topology} for {spin_label} at constant height={float(tip_height):.2f} Angstrom"
return Contour(data=height_scan, lattice=self._in_plane_vectors(), label=label)
return Contour(data=height_scan, lattice=self._get_stm_plane(), label=label)

def _z_index_for_height(self, tip_height):
"""Return the z-index of the tip height in the charge density grid."""
Expand Down Expand Up @@ -267,11 +268,22 @@ def _smooth_stm_data(self, data):
data, sigma=sigma, truncate=self.stm_settings.truncate, mode="wrap"
)

def _in_plane_vectors(self):
"""Return the in-plane component of lattice vectors."""
lattice_vectors = self._structure.lattice_vectors()
_raise_error_if_vacuum_not_along_z(self._structure)
return lattice_vectors[:2, :2]
def _get_stm_plane(self):
"""Return lattice plane spanned by a and b vectors"""
# lv = self._structure.lattice_vectors()
# _raise_error_if_vacuum_not_along_z(self._structure)
# l0 = np.linalg.norm(lv[0]) / np.linalg.norm(lv[0,:2]) * lv[0,:2]
# l1 = np.linalg.norm(lv[1]) / np.linalg.norm(lv[1,:2]) * lv[1,:2]
# from py4vasp._util.slicing import Plane
# return Plane(vectors=np.vstack((l0, l1)),
# cell=lv,
# cut="c",
# )
return plane(
cell=self._structure.lattice_vectors(),
cut="c",
normal="z",
)

def _out_of_plane_vector(self):
"""Return out-of-plane component of lattice vectors."""
Expand Down
13 changes: 9 additions & 4 deletions tests/calculation/test_partial_charge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from py4vasp import calculation
from py4vasp._util.slicing import plane
from py4vasp.exception import IncorrectUsage, NoData, NotImplemented


Expand Down Expand Up @@ -89,7 +90,11 @@ def make_reference_partial_charge(raw_data, selection):
parchg = calculation.partial_charge.from_data(raw_partial_charge)
parchg.ref = types.SimpleNamespace()
parchg.ref.structure = calculation.structure.from_data(raw_partial_charge.structure)
parchg.ref.plane_vectors = parchg.ref.structure.lattice_vectors()[:2, :2]
parchg.ref.plane_vectors = plane(
cell=parchg.ref.structure.lattice_vectors(),
cut="c",
normal="z",
)
parchg.ref.partial_charge = raw_partial_charge.partial_charge
parchg.ref.bands = raw_partial_charge.bands
parchg.ref.kpoints = raw_partial_charge.kpoints
Expand Down Expand Up @@ -242,7 +247,7 @@ def test_to_stm_nonsplit_constant_height(
expected = PolarizedNonSplitPartialCharge.ref
assert type(actual.series.data) == np.ndarray
assert actual.series.data.shape == (expected.grid[0], expected.grid[1])
Assert.allclose(actual.series.lattice, expected.plane_vectors)
Assert.allclose(actual.series.lattice.vectors, expected.plane_vectors.vectors)
Assert.allclose(actual.series.supercell, np.asarray([supercell, supercell]))
# check different elements of the label
assert type(actual.series.label) is str
Expand All @@ -268,7 +273,7 @@ def test_to_stm_nonsplit_constant_current(
expected = PolarizedNonSplitPartialCharge.ref
assert type(actual.series.data) == np.ndarray
assert actual.series.data.shape == (expected.grid[0], expected.grid[1])
Assert.allclose(actual.series.lattice, expected.plane_vectors)
Assert.allclose(actual.series.lattice.vectors, expected.plane_vectors.vectors)
Assert.allclose(actual.series.supercell, supercell)
# check different elements of the label
assert type(actual.series.label) is str
Expand All @@ -294,7 +299,7 @@ def test_to_stm_nonsplit_constant_current_non_ortho(
expected = NonSplitPartialChargeCaAs3_110.ref
assert type(actual.series.data) == np.ndarray
assert actual.series.data.shape == (expected.grid[0], expected.grid[1])
Assert.allclose(actual.series.lattice, expected.plane_vectors)
Assert.allclose(actual.series.lattice.vectors, expected.plane_vectors.vectors)
Assert.allclose(actual.series.supercell, supercell)
# check different elements of the label
assert type(actual.series.label) is str
Expand Down

0 comments on commit 5e476b7

Please sign in to comment.