Skip to content

Commit

Permalink
typing, docstring, cleanup geometry
Browse files Browse the repository at this point in the history
  • Loading branch information
mshuaibii committed Jul 24, 2024
1 parent d225119 commit e1c7916
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 122 deletions.
37 changes: 31 additions & 6 deletions src/fairchem/data/oc/core/interface_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from fairchem.data.oc.core.slab import Slab
from fairchem.data.oc.core.solvent import Solvent

from fairchem.data.oc.utils.geometry import BoxGeometry, PlaneBoundTriclinicGeometry
from fairchem.data.oc.utils.geometry import (
BoxGeometry,
Geometry,
PlaneBoundTriclinicGeometry,
)

# Code adapted from https://github.com/henriasv/molecular-builder/tree/master

Expand Down Expand Up @@ -120,13 +124,20 @@ def __init__(
self.pbc_shift = pbc_shift
self.packmol_tolerance = packmol_tolerance

self.n_mol_per_volume = solvent.get_molecules_per_volume
self.n_mol_per_volume = solvent.molecules_per_volume

self.atoms_list, self.metadata_list = self.create_interface_on_sites(
self.atoms_list, self.metadata_list
)

def create_interface_on_sites(self, atoms_list, metadata_list):
def create_interface_on_sites(
self, atoms_list: list[ase.Atoms], metadata_list: list[dict]
):
"""
Given adsorbate+slab configurations generated from
(Multi)AdsorbateSlabConfig and its corresponding metadata, create the
solvent/ion interface on top of the provided atoms objects.
"""
atoms_interface_list = []
metadata_interface_list = []

Expand Down Expand Up @@ -177,7 +188,15 @@ def create_interface_on_sites(self, atoms_list, metadata_list):

return atoms_interface_list, metadata_interface_list

def create_packmol_atoms(self, geometry, n_solvent_mols):
def create_packmol_atoms(self, geometry: Geometry, n_solvent_mols: int):
"""
Pack solvent molecules in a provided unit cell volume. Packmol is used
to randomly pack solvent molecules in the desired volume.
Arguments:
geometry (Geometry): Geometry object corresponding to the desired cell.
n_solvent_mols (int): Number of solvent molecules to pack in the volume.
"""
cell = geometry.cell
with tempfile.TemporaryDirectory() as tmp_dir:
output_path = os.path.join(tmp_dir, "out.pdb")
Expand Down Expand Up @@ -222,7 +241,10 @@ def create_packmol_atoms(self, geometry, n_solvent_mols):

return solvent_ions_atoms

def run_packmol(self, packmol_input):
def run_packmol(self, packmol_input: str):
"""
Run packmol.
"""
packmol_cmd = which("packmol")
if not packmol_cmd:
raise OSError("packmol not found.")
Expand All @@ -237,7 +259,10 @@ def run_packmol(self, packmol_input):
if err:
raise OSError(err.decode("utf-8"))

def randomize_coords(self, atoms):
def randomize_coords(self, atoms: ase.Atoms):
"""
Randomly place the atoms in its unit cell.
"""
cell_weights = np.random.rand(3)
cell_weights /= np.sum(cell_weights)
xyz = np.dot(cell_weights, atoms.cell)
Expand Down
2 changes: 1 addition & 1 deletion src/fairchem/data/oc/core/solvent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _load_solvent(self, solvent: dict) -> None:
)

@property
def get_molecules_per_volume(self):
def molecules_per_volume(self):
"""
Convert the solvent density in g/cm3 to the number of molecules per
angstrom cubed of volume.
Expand Down
207 changes: 93 additions & 114 deletions src/fairchem/data/oc/utils/geometry.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,67 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import numpy as np

# Code adapted from https://github.com/henriasv/molecular-builder/tree/master
if TYPE_CHECKING:
from ase.cell import Cell

# Code adapted from https://github.com/henriasv/molecular-builder/tree/master

class Geometry:
"""Base class for geometries.

:param periodic_boundary_condition: self-explanatory
:type periodic_boundary_condition: array_like
:param minimum_image_convention: use the minimum image convention for
bookkeeping how the particles interact
:type minimum_image_convention: bool
class Geometry(ABC):
"""
Base class for geometries
"""

def __init__(
self,
periodic_boundary_condition=(False, False, False),
minimum_image_convention=True,
):
self.minimum_image_convention = minimum_image_convention
self.periodic_boundary_condition = periodic_boundary_condition

def __call__(self, atoms):
"""The empty geometry. False because we define no particle to be
in the dummy geometry.
:param atoms: atoms object from ase.Atom that is being modified
:type atoms: ase.Atom obj
:returns: ndarray of bools telling which atoms to remove
:rtype: ndarray of bool
"""
return np.zeros(len(atoms), dtype=np.bool)
@abstractmethod
def __init__(self):
pass

@staticmethod
def distance_point_line(vec, point_line, point_ext):
"""Returns the (shortest) distance between a line parallel to
a normal vector 'vec' through point 'point_line' and an external
point 'point_ext'.
:param vec: unit vector parallel to line
:type vec: ndarray
:param point_line: point on line
:type point_line: ndarray
:param point_ext: external points
:type point_ext: ndarray
:return: distance between line and external point(s)
:rtype: ndarray
def distance_point_plane(vec: np.array, point_plane: np.array, point_ext: np.array):
"""
return np.linalg.norm(np.cross(vec, point_ext - point_line), axis=1)

@staticmethod
def distance_point_plane(vec, point_plane, point_ext):
"""Returns the (shortest) distance between a plane with normal vector
Returns the (shortest) distance between a plane with normal vector
'vec' through point 'point_plane' and a point 'point_ext'.
:param vec: normal vector of plane
:type vec: ndarray
:param point_plane: point on line
:type point_plane: ndarray
:param point_ext: external point(s)
:type point_ext: ndarray
:return: distance between plane and external point(s)
:rtype: ndarray
Args:
vec (np.array): normal vector of plane
point_plane (np.array): point on line
point_ext (np.array): external point(s)
Returns:
(np.array) Distance between plane and external point(s)
"""
vec = np.atleast_2d(vec) # Ensure n is 2d
return np.abs(np.einsum("ik,jk->ij", point_ext - point_plane, vec))

@staticmethod
def vec_and_point_to_plane(vec, point):
"""Returns the (unique) plane, given a normal vector 'vec' and a
def vec_and_point_to_plane(vec: np.array, point: np.array):
"""
Returns the (unique) plane, given a normal vector 'vec' and a
point 'point' in the plane. ax + by + cz - d = 0
:param vec: normal vector of plane
:type vec: ndarray
:param point: point in plane
:type point: ndarray
:returns: parameterization of plane
:rtype: ndarray
Args:
vec (np.array): normal vector of plane
point (np.array): point in plane
Returns:
(np.array) Paramterization of plane
"""
return np.array((*vec, np.dot(vec, point)))

@staticmethod
def cell2planes(cell, pbc):
"""Get the parameterization of the sizes of a ase.Atom cell
def cell2planes(cell: Cell, pbc: float):
"""
Get the parameterization of the sizes of a ase.Atom cell
:param cell: ase.Atom cell
:type cell: obj
:param pbc: shift of boundaries to be used with periodic boundary condition
:type pbc: float
:returns: parameterization of cell plane sides
:rtype: list of ndarray
cell: ase.cell.Cell
pbc (float): shift of boundaries to be used with periodic boundary condition
Return
(List[np.array]) Parameterization of cell plane sides
3 planes intersect the origin by ase design.
"""
Expand All @@ -116,8 +86,11 @@ def cell2planes(cell, pbc):
return [plane1, plane2, plane3, plane4, plane5, plane6]

@staticmethod
def extract_box_properties(center, length, lo_corner, hi_corner):
"""Given two of the properties 'center', 'length', 'lo_corner',
def extract_box_properties(
center: np.array, length: np.array, lo_corner: np.array, hi_corner: np.array
):
"""
Given two of the properties 'center', 'length', 'lo_corner',
'hi_corner', return all the properties. The properties that
are not given are expected to be 'None'.
"""
Expand Down Expand Up @@ -170,36 +143,23 @@ def extract_box_properties(center, length, lo_corner, hi_corner):
del relation_list[i]
return relation_list

@abstractmethod
def packmol_structure(self, filename, number, side):
"""Make structure to be used in PACKMOL input script
:param number: number of solvent molecules
:type number: int
:param side: pack solvent inside/outside of geometry
:type side: str
:returns: string with information about the structure
:rtype: str
"""
structure = ""
structure += f"structure {filename}\n"
structure += f" number {number}\n"
structure += f" {side} {self.__repr__()} "
for param in self.params:
structure += f"{param} "
structure += "\nend structure\n"
return structure
How to write packmol input file. To be defined by inherited class.
"""


class PlaneBoundTriclinicGeometry(Geometry):
"""Triclinic crystal geometry based on ase.Atom cell
:param cell: ase.Atom cell
:type cell: obj
:param pbc: shift of boundaries to be used with periodic boundary condition
:type pbc: float
"""
Triclinic crystal geometry based on ase.Atom cell
"""

def __init__(self, cell, pbc=0.0):
def __init__(self, cell: Cell, pbc: float = 0.0):
"""
cell (ase.cell.Cell)
pbc (float): shift of boundaries to be used with periodic boundary condition
"""
self.planes = self.cell2planes(cell, pbc)
self.cell = cell
self.ll_corner = [0, 0, 0]
Expand All @@ -209,7 +169,18 @@ def __init__(self, cell, pbc=0.0):
self.ur_corner = a + b + c

def packmol_structure(self, filename, number, side):
"""Make structure to be used in PACKMOL input script"""
"""
Make file structure to be used in packmol input script
Args:
filename (str): output filename to save structure
number (int): number of solvent molecules
side (str): pack solvent inside/outside of geometry
Returns:
String with information about the structure
"""

structure = ""

if side == "inside":
Expand All @@ -226,27 +197,22 @@ def packmol_structure(self, filename, number, side):
structure += "end structure\n"
return structure

def __call__(self, position):
raise NotImplementedError


class BoxGeometry(Geometry):
"""Box geometry.
:param center: geometric center of box
:type center: array_like
:param length: length of box in all directions
:type length: array_like
:param lo_corner: lower corner
:type lo_corner: array_like
:param hi_corner: higher corner
:type hi_corner: array_like
"""
Box geometry for orthorhombic cells.
"""

def __init__(
self, center=None, length=None, lo_corner=None, hi_corner=None, **kwargs
):
super().__init__(**kwargs)
"""
Args:
center (np.array): geometric center of box
length (np.array): length of box in all directions
lo_corner (np.array): lower corner
hi_corner (np.array): higher corner
"""
props = self.extract_box_properties(center, length, lo_corner, hi_corner)
self.ll_corner, self.ur_corner, self.length_half, self.center = props
self.params = list(self.ll_corner) + list(self.ur_corner)
Expand All @@ -255,10 +221,23 @@ def __init__(
def __repr__(self):
return "box"

def __call__(self, atoms):
positions = atoms.get_positions()
dist = self.distance_point_plane(np.eye(3), self.center, positions)
return np.all((np.abs(dist) <= self.length_half), axis=1)
def packmol_structure(self, filename: str, number: int, side: str):
"""
Make file structure to be used in packmol input script
Args:
filename (str): output filename to save structure
number (int): number of solvent molecules
side (str): pack solvent inside/outside of geometry
def volume(self):
return np.prod(self.length)
Returns:
String with information about the structure
"""
structure = ""
structure += f"structure {filename}\n"
structure += f" number {number}\n"
structure += f" {side} {self.__repr__()} "
for param in self.params:
structure += f"{param} "
structure += "\nend structure\n"
return structure
2 changes: 1 addition & 1 deletion tests/data/oc/tests/test_interface_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_solvent_density(self):

for atoms, metadata in zip(adslabs.atoms_list, adslabs.metadata_list):
volume = metadata["solvent_volume"]
n_solvent_mols = int(volume * self.solvent.get_molecules_per_volume)
n_solvent_mols = int(volume * self.solvent.molecules_per_volume)
n_solvent_atoms = n_solvent_mols * len(self.solvent.atoms)
n_ions = len(self.ions)

Expand Down

0 comments on commit e1c7916

Please sign in to comment.