Skip to content

Commit

Permalink
First step towards better selection logic (#138)
Browse files Browse the repository at this point in the history
* Replace - with : in selection logic
* Replace x2-y2 with dx2y2
* Make the 'all' option a bit harder to reach
  • Loading branch information
martin-schlipf authored Jan 31, 2022
1 parent f083850 commit 0801601
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 84 deletions.
6 changes: 5 additions & 1 deletion src/py4vasp/_util/selection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
range_separator = ":"
all = "__all__"


class SelectionTree:
def __init__(self, parent=None):
self._new_child = True
Expand All @@ -21,7 +25,7 @@ def nodes(self):
def parse_character(self, character):
if character in (" ", ","):
return self._parse_separator()
elif character == "-":
elif character == range_separator:
return self._parse_range(character)
elif character == "(":
return self._children[-1]
Expand Down
2 changes: 0 additions & 2 deletions src/py4vasp/data/_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,3 @@ class Selection(NamedTuple):
"Indices from which the specified quantity is read."
label: str = ""
"Label identifying the quantity."
default = "*"
"Character identifying the default selection"
4 changes: 2 additions & 2 deletions src/py4vasp/data/dielectric_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _parse_selection(selection, data):
class _Choice(typing.NamedTuple):
component: str
direction: str = "isotropic"
real_or_imag: str = _Selection.default
real_or_imag: str = _selection.all


def _default_choice(data):
Expand Down Expand Up @@ -142,7 +142,7 @@ def _update_choice(current_choice, part):


def _setup_component_choices(choice):
if choice.real_or_imag == _Selection.default:
if choice.real_or_imag == _selection.all:
yield choice._replace(real_or_imag="real")
yield choice._replace(real_or_imag="imag")
else:
Expand Down
7 changes: 4 additions & 3 deletions src/py4vasp/data/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import py4vasp._util.documentation as _documentation
import py4vasp._util.convert as _convert
import py4vasp._util.sanity_check as _check
import py4vasp._util.selection as _selection


_energy_docs = f"""
Expand Down Expand Up @@ -82,7 +83,7 @@ def _step_string(self):
{_trajectory.trajectory_examples("energy", "read")}"""
)
def _to_dict(self, selection=_Selection.default):
def _to_dict(self, selection=_selection.all):
return {
label: self._raw_data.values[self._steps, index]
for label, index in self._parse_selection(selection)
Expand Down Expand Up @@ -136,7 +137,7 @@ def _to_numpy(self, selection="TOTEN"):
)
return _unpack_if_only_one_element(result)

def _labels(self, selection=_Selection.default):
def _labels(self, selection=_selection.all):
"Return the labels corresponding to a particular selection defaulting to all labels."
return [label for label, _ in self._parse_selection(selection)]

Expand All @@ -149,7 +150,7 @@ def _parse_selection(self, selection):
yield get_label(index).strip(), index

def _find_selection_indices(self, selection):
if selection == _Selection.default:
if selection == _selection.all:
return range(len(self._raw_data.labels))
else:
selection_parts = _split_selection_in_parts(selection)
Expand Down
44 changes: 23 additions & 21 deletions src/py4vasp/data/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@


_spin_not_set = "not set"
_range_separator = "-"
_range = re.compile(r"^(\d+)" + re.escape(_range_separator) + r"(\d+)$")
_range = re.compile(r"^(\d+)" + re.escape(_selection.range_separator) + r"(\d+)$")

_selection_doc = r"""
selection : str
Expand All @@ -24,13 +23,12 @@
- To specify the **atom**, you can either use its element name (Si, Al, ...)
or its index as given in the input file (1, 2, ...). For the latter
option it is also possible to specify ranges (e.g. 1-4).
option it is also possible to specify ranges (e.g. 1:4).
- To select a particular **orbital** you can give a string (s, px, dxz, ...)
or select multiple orbitals by their angular momentum (s, p, d, f).
- For the **spin**, you have the options up, down, or total.
For all of these options a wildcard \* exists, which selects all elements. You
separate multiple selections by commas or whitespace and can nest them using
You separate multiple selections by commas or whitespace and can nest them using
parenthesis, e.g. `Sr(s, p)` or `s(up), p(down)`. The order of the selections
does not matter, but it is case sensitive to distinguish p (angular momentum
l = 1) from P (phosphorus).
Expand All @@ -47,7 +45,7 @@ def _selection_examples(instance_name, function_name):
>>> calc.{instance_name}.{function_name}("d(Mn, Co, Fe)")
Select the spin-up contribution of the first three atoms combined
>>> calc.{instance_name}.{function_name}("up(1-3)")
>>> calc.{instance_name}.{function_name}("up(1{_selection.range_separator}3)")
"""


Expand Down Expand Up @@ -120,15 +118,14 @@ def _to_dict(self, selection=None, projections=None):
f"""Map selection strings onto corresponding Selection objects.
With the selection strings, you specify which atom, orbital, and spin component
you are interested in. *Note* that for all parameters you can pass "*" to
default to all (atoms, orbitals, or spins).
you are interested in.
Parameters
----------
atom : str
Element name or index of the atom in the input file of Vasp. If a
range is specified (e.g. 1-3) a pointer to multiple indices will be
created.
range is specified (e.g. 1{_selection.range_separator}3) a pointer to
multiple indices will be created.
orbital : str
Character identifying the angular momentum of the orbital. You may
select a specific one (e.g. px) or all of the same character (e.g. d).
Expand All @@ -146,9 +143,9 @@ def _to_dict(self, selection=None, projections=None):
)
def _select(
self,
atom=_Selection.default,
orbital=_Selection.default,
spin=_Selection.default,
atom=_selection.all,
orbital=_selection.all,
spin=_selection.all,
):
dicts = self._init_dicts()
_raise_error_if_not_found_in_dict(orbital, dicts["orbital"])
Expand Down Expand Up @@ -180,8 +177,8 @@ def _select(
def _parse_selection(self, selection):
dicts = self._init_dicts()
default_index = Projector.Index(
atom=_Selection.default,
orbital=_Selection.default,
atom=_selection.all,
orbital=_selection.all,
spin=_spin_not_set,
)
tree = _selection.SelectionTree.from_selection(selection)
Expand All @@ -203,7 +200,7 @@ def _init_atom_dict(self):
def _init_orbital_dict(self):
num_orbitals = len(self._raw_data.orbital_types)
all_orbitals = _Selection(indices=slice(num_orbitals))
orbital_dict = {_Selection.default: all_orbitals}
orbital_dict = {_selection.all: all_orbitals}
for i, orbital in enumerate(self._orbital_types()):
orbital_dict[orbital] = _Selection(indices=slice(i, i + 1), label=orbital)
if "px" in orbital_dict:
Expand All @@ -213,8 +210,13 @@ def _init_orbital_dict(self):
return orbital_dict

def _orbital_types(self):
clean_string = lambda ion_type: _convert.text_to_string(ion_type).strip()
return (clean_string(orbital) for orbital in self._raw_data.orbital_types)
clean_string = lambda orbital: _convert.text_to_string(orbital).strip()
for orbital in self._raw_data.orbital_types:
orbital = clean_string(orbital)
if orbital == "x2-y2":
yield "dx2y2"
else:
yield orbital

def _init_spin_dict(self):
num_spins = self._raw_data.number_spins
Expand All @@ -223,7 +225,7 @@ def _init_spin_dict(self):
"up": _Selection(indices=slice(1), label="up"),
"down": _Selection(indices=slice(1, 2), label="down"),
"total": _Selection(indices=slice(num_spins), label="total"),
_Selection.default: _Selection(indices=slice(num_spins)),
_selection.all: _Selection(indices=slice(num_spins)),
}

def _get_indices(self, selection):
Expand Down Expand Up @@ -293,7 +295,7 @@ def _parse_recursive(dicts, tree, current_index):

def _update_index(dicts, index, part):
part = part.strip()
if part == _Selection.default:
if part == _selection.all:
pass
elif part in dicts["atom"] or _range.match(part):
index = index._replace(atom=part)
Expand All @@ -314,7 +316,7 @@ def _setup_spin_indices(index, spin_polarized):
if index.spin != _spin_not_set:
yield index
elif not spin_polarized:
yield index._replace(spin=_Selection.default)
yield index._replace(spin=_selection.all)
else:
for key in ("up", "down"):
yield index._replace(spin=key)
Expand Down
3 changes: 2 additions & 1 deletion src/py4vasp/data/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import py4vasp.exceptions as exception
import py4vasp._util.sanity_check as _check
import py4vasp._util.convert as _convert
import py4vasp._util.selection as _selection
import numpy as np
import pandas as pd
import functools
Expand Down Expand Up @@ -126,7 +127,7 @@ def _create_repr(self, number_suffix):

def _default_selection(self):
num_atoms = self._number_atoms()
return {_Selection.default: _Selection(indices=slice(num_atoms))}
return {_selection.all: _Selection(indices=slice(num_atoms))}

def _specific_selection(self):
start = 0
Expand Down
6 changes: 3 additions & 3 deletions tests/_util/test_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def test_no_whitespace():


def test_ranges():
selection = "foo(1 - 3) 2 - 6"
selection = "foo(1 : 3) 2 : 6"
tree = SelectionTree.from_selection(selection)
level1 = tree.nodes[0]
assert str(level1) == "foo"
assert str(level1.nodes[0]) == "1-3"
assert str(tree.nodes[1]) == "2-6"
assert str(level1.nodes[0]) == "1:3"
assert str(tree.nodes[1]) == "2:6"
10 changes: 6 additions & 4 deletions tests/data/test_dos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from py4vasp.data import Dos
from unittest.mock import patch
import py4vasp.exceptions as exception
import py4vasp._util.selection as _selection
import pytest
import numpy as np
import types
Expand Down Expand Up @@ -117,14 +118,15 @@ def test_Fe3O4_to_frame(Fe3O4, Assert):


def test_Sr2TiO4_projectors_to_frame(Sr2TiO4_projectors, Assert):
all = _selection.all
equivalent_selections = [
"s Sr(d) Ti O(px,dxy) 2(p) 4 3(dz2) 1-2(p)",
"2( p), dz2(3) Sr(d) p(1-2), *(s), 4 Ti(*) px(O) O(dxy)",
"s Sr(d) Ti O(px,dxy) 2(p) 4 3(dz2) 1:2(p)",
f"2( p), dz2(3) Sr(d) p(1:2), {all}(s), 4 Ti({all}) px(O) O(dxy)",
]
for selection in equivalent_selections:
actual = Sr2TiO4_projectors.to_frame(selection)
Assert.allclose(actual.s, Sr2TiO4_projectors.ref.s)
Assert.allclose(actual["1-2_p"], Sr2TiO4_projectors.ref.Sr_p)
Assert.allclose(actual["1:2_p"], Sr2TiO4_projectors.ref.Sr_p)
Assert.allclose(actual.Sr_d, Sr2TiO4_projectors.ref.Sr_d)
Assert.allclose(actual.Sr_2_p, Sr2TiO4_projectors.ref.Sr_2_p)
Assert.allclose(actual.Ti, Sr2TiO4_projectors.ref.Ti)
Expand Down Expand Up @@ -224,7 +226,7 @@ def test_Sr2TiO4_projectors_print(Sr2TiO4_projectors, format_):
energies: [-1.00, 3.00] 50 points
projectors:
atoms: Sr, Ti, O
orbitals: s, py, pz, px, dxy, dyz, dz2, dxz, x2-y2, fy3x2, fxyz, fyz2, fz3, fxz2, fzx2, fx3
orbitals: s, py, pz, px, dxy, dyz, dz2, dxz, dx2y2, fy3x2, fxyz, fyz2, fz3, fxz2, fzx2, fx3
""".strip()
assert actual == {"text/plain": reference}

Expand Down
3 changes: 2 additions & 1 deletion tests/data/test_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import types
import py4vasp.exceptions as exception
import py4vasp._util.selection as selection


@pytest.fixture
Expand Down Expand Up @@ -103,7 +104,7 @@ def test_plot_all(energy, Assert):


def check_plot_all(energy, steps, Assert):
fig = energy[steps].plot("*")
fig = energy[steps].plot(selection.all)
assert fig.layout.yaxis.title.text == "Energy (eV)"
assert fig.layout.yaxis2.title.text == "Temperature (K)"
Assert.allclose(fig.data[0].y, energy.ref.total_energy[steps])
Expand Down
Loading

0 comments on commit 0801601

Please sign in to comment.