Skip to content

Commit

Permalink
#28 add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
HanmeiTang committed Sep 26, 2019
1 parent f9da1f7 commit dc1afd8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 33 deletions.
50 changes: 29 additions & 21 deletions pymatgen_diffusion/aimd/tests/test_van_hove.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,39 @@

import unittest
import os
import json

import numpy as np
import matplotlib
from copy import deepcopy

matplotlib.use("pdf")

from monty.serialization import loadfn
from pymatgen.analysis.diffusion_analyzer import DiffusionAnalyzer
from pymatgen_diffusion.aimd.van_hove import VanHoveAnalysis, \
RadialDistributionFunction, EvolutionAnalyzer
from pymatgen.analysis.diffusion_analyzer import DiffusionAnalyzer

matplotlib.use("pdf")
tests_dir = os.path.dirname(os.path.abspath(__file__))


class VanHoveTest(unittest.TestCase):
def test_van_hove(self):
data_file = os.path.join(tests_dir, "cNa3PS4_pda.json")
with open(data_file, "r") as j:
data = json.load(j)
obj = DiffusionAnalyzer.from_dict(data)
# Parse the DiffusionAnalyzer object from json file directly
obj = loadfn(os.path.join(tests_dir, "cNa3PS4_pda.json"))

vh = VanHoveAnalysis(diffusion_analyzer=obj, avg_nsteps=5, ngrid=101,
rmax=10.0,
step_skip=5, sigma=0.1, species=["Li", "Na"])
rmax=10.0, step_skip=5, sigma=0.1,
species=["Li", "Na"])

check = np.shape(vh.gsrt) == (20, 101) and np.shape(vh.gdrt) == (
20, 101)
check = np.shape(vh.gsrt) == (20, 101) and \
np.shape(vh.gdrt) == (20, 101)
self.assertTrue(check)
self.assertAlmostEqual(vh.gsrt[0, 0], 3.98942280401, 10)
self.assertAlmostEqual(vh.gdrt[10, 0], 9.68574868168, 10)


class RDFTest(unittest.TestCase):
def test_rdf(self):
data_file = os.path.join(tests_dir, "cNa3PS4_pda.json")
with open(data_file, "r") as j:
data = json.load(j)
obj = DiffusionAnalyzer.from_dict(data)
# Parse the DiffusionAnalyzer object from json file directly
obj = loadfn(os.path.join(tests_dir, "cNa3PS4_pda.json"))

structure_list = []
for i, s in enumerate(obj.get_drift_corrected_structures()):
Expand Down Expand Up @@ -75,6 +70,21 @@ def test_rdf(self):
self.assertTrue(check)
self.assertAlmostEqual(obj.rdf.max(), 1.634448, 4)

# Test init using structures w/ different lattices
s0 = deepcopy(structure_list[0])
sl_1 = [s0, s0, s0 * [1, 2, 1]]
sl_2 = [s0 * [2, 1, 1], s0, s0]

obj_1 = RadialDistributionFunction(
structures=sl_1, ngrid=101, rmax=10.0, cell_range=1,
sigma=0.1, indices=indices, reference_indices=indices)
obj_2 = RadialDistributionFunction(
structures=sl_2, ngrid=101, rmax=10.0, cell_range=1,
sigma=0.1, indices=indices, reference_indices=indices)

self.assertEqual(obj_1.rho, obj_2.rho)
self.assertEqual(obj_1.rdf[0], obj_2.rdf[0])

def test_rdf_coordination_number(self):
# create a simple cubic lattice
coords = np.array([[0.5, 0.5, 0.5]])
Expand Down Expand Up @@ -131,10 +141,8 @@ def test_raises_ValueError_if_reference_species_not_in_structure(self):

class EvolutionAnalyzerTest(unittest.TestCase):
def test_get_df(self):
data_file = os.path.join(tests_dir, "cNa3PS4_pda.json")
with open(data_file, "r") as j:
data = json.load(j)
obj = DiffusionAnalyzer.from_dict(data)
# Parse the DiffusionAnalyzer object from json file directly
obj = loadfn(os.path.join(tests_dir, "cNa3PS4_pda.json"))

structure_list = []
for i, s in enumerate(obj.get_drift_corrected_structures()):
Expand Down
32 changes: 20 additions & 12 deletions pymatgen_diffusion/aimd/van_hove.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def __init__(self, diffusion_analyzer: DiffusionAnalyzer,
ngrid: int = 101, rmax: float = 10.0,
step_skip: int = 50,
sigma: float = 0.1, cell_range: int = 1,
species: Tuple = ("Li", "Na"),
reference_species: Tuple = None, indices: List = None):
species: Union[Tuple, List] = ("Li", "Na"),
reference_species: Union[Tuple, List] = None,
indices: List = None):
"""
Initiation.
Expand Down Expand Up @@ -314,7 +315,10 @@ def __init__(self, structures: List, indices: List, reference_indices: List,
raise ValueError("ngrid should be greater than 1!")
if sigma <= 0:
raise ValueError("sigma should be a positive number!")


if len(indices) < 1:
raise ValueError("Given species are not in the structure!")

lattices, rhos, fcoords_list, ref_fcoords_list = [], [], [], []

dr = rmax / (ngrid - 1)
Expand Down Expand Up @@ -346,9 +350,11 @@ def __init__(self, structures: List, indices: List, reference_indices: List,
ref_fcoords_list.append(all_fcoords[reference_indices, :])

rho = sum(rhos) / len(rhos) # The average density
self.rhos = rhos
self.rho = rho # This is the average density

for fcoords, ref_fcoords, latt, rho in zip(
fcoords_list, ref_fcoords_list, lattices, rhos):
for fcoords, ref_fcoords, latt in zip(
fcoords_list, ref_fcoords_list, lattices):
dcf = fcoords[:, None, None, :] + \
images[None, None, :, :] - ref_fcoords[None, :, None, :]
dcc = latt.get_cartesian_coords(dcf)
Expand All @@ -368,16 +374,17 @@ def __init__(self, structures: List, indices: List, reference_indices: List,
# Volume of the thin shell
ff = 4.0 / 3.0 * np.pi * \
(interval[indx + 1] ** 3 - interval[indx] ** 3)

rdf[:] += (norm.pdf(interval, interval[indx], sigma) * dn /
float(len(reference_indices)) / ff / rho / len(
fcoords_list) * dr)
# print(norm.pdf(interval, interval[indx], sigma) * dn /
# float(len(reference_indices)) / ff / rho / len(
# fcoords_list) * dr)
rdf[:] += norm.pdf(interval, interval[indx], sigma) * dn / \
float(len(reference_indices)) / ff / rho / \
len(fcoords_list) * dr

# additional dr factor renormalises overlapping gaussians.
raw_rdf[indx] += dn / float(
len(reference_indices)) / ff / rho / len(fcoords_list)

self.rho = rho # This is the average density
self.structures = structures
self.cell_range = cell_range
self.rmax = rmax
Expand All @@ -400,8 +407,9 @@ def __init__(self, structures: List, indices: List, reference_indices: List,
@classmethod
def from_species(cls, structures: List, ngrid: int = 101,
rmax: float = 10.0, cell_range: int = 1,
sigma: float = 0.1, species: tuple = ("Li", "Na"),
reference_species: tuple = None):
sigma: float = 0.1,
species: Union[Tuple, List] = ("Li", "Na"),
reference_species: Union[Tuple, List] = None):
"""
Initialize using species.
Expand Down

0 comments on commit dc1afd8

Please sign in to comment.