Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZBL short-range potential #335

Merged
merged 23 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ build:
pre_build:
- set -e && cd examples/ase && bash train.sh
- set -e && cd examples/programmatic/llpr && bash train.sh
- set -e && cd examples/zbl && bash train.sh

# Build documentation in the docs/ directory with Sphinx
sphinx:
Expand Down
12 changes: 10 additions & 2 deletions docs/generate_examples/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,16 @@
sphinx_gallery_conf = {
"filename_pattern": "/*",
"copyfile_regex": r".*\.(pt|sh|xyz|yaml)",
"examples_dirs": [os.path.join(ROOT, "examples", "ase"), os.path.join(ROOT, "examples", "programmatic", "llpr")],
"gallery_dirs": [os.path.join(ROOT, "docs", "src", "examples", "ase"), os.path.join(ROOT, "docs", "src", "examples", "programmatic", "llpr")],
"examples_dirs": [
os.path.join(ROOT, "examples", "ase"),
os.path.join(ROOT, "examples", "programmatic", "llpr"),
os.path.join(ROOT, "examples", "zbl")
],
"gallery_dirs": [
os.path.join(ROOT, "docs", "src", "examples", "ase"),
os.path.join(ROOT, "docs", "src", "examples", "programmatic", "llpr"),
os.path.join(ROOT, "docs", "src", "examples", "zbl")
],
"min_reported_time": 5,
"matplotlib_animations": True,
}
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/additive/composition.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Composition model
#################

.. automodule:: metatrain.utils.additive.composition
:members:
:undoc-members:
:show-inheritance:
12 changes: 12 additions & 0 deletions docs/src/dev-docs/utils/additive/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Data
====
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Title seems wrong for this section


API for handling additive models in ``metatrain``. These are models that
can be added to one or more architectures.

.. toctree::
:maxdepth: 1

remove_additive
composition
zbl
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/additive/remove_additive.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Removing additive contributions
###############################

.. automodule:: metatrain.utils.additive.remove
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/additive/zbl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
ZBL short-range potential
#########################

.. automodule:: metatrain.utils.additive.zbl
:members:
:undoc-members:
:show-inheritance:
7 changes: 0 additions & 7 deletions docs/src/dev-docs/utils/composition.rst

This file was deleted.

2 changes: 1 addition & 1 deletion docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ This is the API for the ``utils`` module of ``metatrain``.
.. toctree::
:maxdepth: 1

additive/index
data/index
architectures
composition
devices
dtype
errors
Expand Down
1 change: 1 addition & 0 deletions docs/src/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ This sections includes some more advanced tutorials on the usage of the

../examples/ase/run_ase
../examples/programmatic/llpr/llpr
../examples/zbl/dimers
5 changes: 0 additions & 5 deletions examples/ase/run_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@

# %%
#
# .. note::
# We have to import ``rascaline.torch`` even though it is not used explicitly in this
# tutorial. The SOAP-BPNN model contains compiled extensions and therefore the import
# is required.
#
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well spotted, thanks!

# Setting up the simulation
# -------------------------
#
Expand Down
7 changes: 5 additions & 2 deletions examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@
# how to create a Dataset object from them.

from metatrain.utils.data import Dataset, read_systems, read_targets # noqa: E402
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists # noqa: E402
from metatrain.utils.neighbor_lists import ( # noqa: E402
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)


qm9_systems = read_systems("qm9_reduced_100.xyz")
Expand All @@ -67,7 +70,7 @@
}
targets, _ = read_targets(target_config)

requested_neighbor_lists = model.requested_neighbor_lists()
requested_neighbor_lists = get_requested_neighbor_lists(model)
qm9_systems = [
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in qm9_systems
Expand Down
2 changes: 2 additions & 0 deletions examples/zbl/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Running molecular dynamics with ASE
===================================
148 changes: 148 additions & 0 deletions examples/zbl/dimers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Training a model with ZBL corrections
=====================================

This tutorial demonstrates how to train a model with ZBL corrections.

The training set for this example consists of a
subset of the ethanol moleculs from the `rMD17 dataset
<https://iopscience.iop.org/article/10.1088/2632-2153/abba6f/meta>`_.

The models are trained using the following training options, respectively:

.. literalinclude:: options_no_zbl.yaml
:language: yaml

.. literalinclude:: options_zbl.yaml
:language: yaml

As you can see, they are identical, except for the ``zbl`` key in the
``model`` section.
You can train the same models yourself with

.. literalinclude:: train.sh
:language: bash

A detailed step-by-step introduction on how to train a model is provided in
the :ref:`label_basic_usage` tutorial.
"""

# %%
#
# First, we start by importing the necessary libraries, including the integration of ASE
# calculators for metatensor atomistic models.

import ase
import matplotlib.pyplot as plt
import numpy as np
import torch
from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator


# %%
#
# Setting up the dimers
# ---------------------
#
# We set up a series of dimers with different atom pairs and distances. We will
# calculate the energies of these dimers using the models trained with and without ZBL
# corrections.

distances = np.linspace(0.5, 6.0, 200)
pairs = {}
for pair in [("H", "H"), ("H", "C"), ("C", "C"), ("C", "O"), ("O", "O"), ("H", "O")]:
structures = []
for distance in distances:
atoms = ase.Atoms(
symbols=[pair[0], pair[1]],
positions=[[0, 0, 0], [0, 0, distance]],
)
structures.append(atoms)
pairs[pair] = structures

# %%
#
# We now load the two exported models, one with and one without ZBL corrections

calc_no_zbl = MetatensorCalculator(
"model_no_zbl.pt", extensions_directory="extensions/"
)
calc_zbl = MetatensorCalculator("model_zbl.pt", extensions_directory="extensions/")


# %%
#
# Calculate and plot energies without ZBL
# ---------------------------------------
#
# We calculate the energies of the dimer curves for each pair of atoms and
# plot the results, using the non-ZBL-corrected model.

for pair, structures_for_pair in pairs.items():
energies = []
for atoms in structures_for_pair:
atoms.set_calculator(calc_no_zbl)
with torch.jit.optimized_execution(False):
energies.append(atoms.get_potential_energy())
energies = np.array(energies) - energies[-1]
plt.plot(distances, energies, label=f"{pair[0]}-{pair[1]}")
plt.title("Dimer curves - no ZBL")
plt.xlabel("Distance (Å)")
plt.ylabel("Energy (eV)")
plt.legend()
plt.tight_layout()
plt.show()

# %%
#
# Calculate and plot energies from the ZBL-corrected model
# --------------------------------------------------------
#
# We repeat the same procedure as above, but this time with the ZBL-corrected model.

for pair, structures_for_pair in pairs.items():
energies = []
for atoms in structures_for_pair:
atoms.set_calculator(calc_zbl)
with torch.jit.optimized_execution(False):
energies.append(atoms.get_potential_energy())
energies = np.array(energies) - energies[-1]
plt.plot(distances, energies, label=f"{pair[0]}-{pair[1]}")
plt.title("Dimer curves - with ZBL")
plt.xlabel("Distance (Å)")
plt.ylabel("Energy (eV)")
plt.legend()
plt.tight_layout()
plt.show()

# %%
#
# It can be seen that all the dimer curves include a strong repulsion
# at short distances, which is due to the ZBL contribution. Even the H-H dimer,
# whose ZBL correction is very weak due to the small covalent radii of hydrogen,
# would show a strong repulsion closer to the origin (here, we only plotted
# starting from a distance of 0.5 Å). Let's zoom in on the H-H dimer to see
# this effect more clearly.

new_distances = np.linspace(0.1, 2.0, 200)

structures = []
for distance in new_distances:
atoms = ase.Atoms(
symbols=["H", "H"],
positions=[[0, 0, 0], [0, 0, distance]],
)
structures.append(atoms)

for atoms in structures:
atoms.set_calculator(calc_zbl)
with torch.jit.optimized_execution(False):
energies = [atoms.get_potential_energy() for atoms in structures]
energies = np.array(energies) - energies[-1]
plt.plot(new_distances, energies, label="H-H")
plt.title("Dimer curve - H-H with ZBL")
plt.xlabel("Distance (Å)")
plt.ylabel("Energy (eV)")
plt.legend()
plt.tight_layout()
plt.show()
1 change: 1 addition & 0 deletions examples/zbl/ethanol_reduced_100.xyz
21 changes: 21 additions & 0 deletions examples/zbl/options_no_zbl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
seed: 42

architecture:
name: experimental.soap_bpnn
model:
zbl: false
training:
num_epochs: 10

# training set section
training_set:
systems:
read_from: ethanol_reduced_100.xyz
length_unit: angstrom
targets:
energy:
key: "energy"
unit: "eV" # very important to run simulations

validation_set: 0.1
test_set: 0.0
21 changes: 21 additions & 0 deletions examples/zbl/options_zbl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
seed: 42

architecture:
name: experimental.soap_bpnn
model:
zbl: true
training:
num_epochs: 10

# training set section
training_set:
systems:
read_from: ethanol_reduced_100.xyz
length_unit: angstrom
targets:
energy:
key: "energy"
unit: "eV" # very important to run simulations

validation_set: 0.1
test_set: 0.0
4 changes: 4 additions & 0 deletions examples/zbl/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

mtt train options_no_zbl.yaml -o model_no_zbl.pt
mtt train options_zbl.yaml -o model_zbl.pt
7 changes: 5 additions & 2 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from ..utils.evaluate_model import evaluate_model
from ..utils.logging import MetricLogger
from ..utils.metrics import RMSEAccumulator
from ..utils.neighbor_lists import get_system_with_neighbor_lists
from ..utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from ..utils.omegaconf import expand_dataset_config
from ..utils.per_atom import average_by_num_atoms
from .formatter import CustomHelpFormatter
Expand Down Expand Up @@ -177,7 +180,7 @@ def _eval_targets(
# if already present (e.g. if this function is called after training)
for sample in dataset:
system = sample["system"]
get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
get_system_with_neighbor_lists(system, get_requested_neighbor_lists(model))

# Infer the device and dtype from the model
model_tensor = next(itertools.chain(model.parameters(), model.buffers()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ model:
bpnn:
hidden_sizes: [32, 32]
output_size: 1
zbl: false

training:
batch_size: 8
Expand Down
Loading