Skip to content

Commit

Permalink
Change default CHGnet.load(check_cuda_mem: bool) to False (#164)
Browse files Browse the repository at this point in the history
* tweak docs examples/QueryMPtrj.md

* change check_cuda_mem default to False

* update documented default True->False

* assert check_cuda_mem defaults to False in test_model_load_version_params
  • Loading branch information
janosh authored Jun 11, 2024
1 parent d9019b9 commit d3f1b30
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.4.8
hooks:
- id: ruff
args: [--fix]
Expand Down
4 changes: 2 additions & 2 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
model: CHGNet | None = None,
*,
use_device: str | None = None,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
stress_weight: float | None = 1 / 160.21766208,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
**kwargs,
Expand All @@ -73,7 +73,7 @@ def __init__(
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
Default = False
stress_weight (float): the conversion factor to convert GPa to eV/A^3.
Default = 1/160.21
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
Expand Down
4 changes: 2 additions & 2 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def load(
*,
model_name: str = "0.3.0",
use_device: str | None = None,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
verbose: bool = True,
) -> CHGNet:
"""Load pretrained CHGNet model.
Expand All @@ -692,7 +692,7 @@ def load(
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
Default = False
verbose (bool): whether to print model device information
Default = True
Raises:
Expand Down
4 changes: 2 additions & 2 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
torch_seed: int | None = None,
data_seed: int | None = None,
use_device: str | None = None,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
**kwargs,
) -> None:
"""Initialize all hyper-parameters for trainer.
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
Default = False
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
"""
# Store trainer args for reproducibility
Expand Down
3 changes: 2 additions & 1 deletion chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
def determine_device(
use_device: str | None = None,
*,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
) -> str:
"""Determine the device to use for torch model.
Args:
use_device (str): User specify device name
check_cuda_mem (bool): Whether to return cuda with most available memory
Default = False
Returns:
device (str): device name to be passed to model.to(device)
Expand Down
48 changes: 24 additions & 24 deletions examples/QueryMPtrj.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ opt_task_types = [
optimization_task_ids = {}
for doc in material_ids:
material_id = doc.material_id
tmp = mpr.materials.get_data_by_id(material_id)
mp_doc = mpr.materials.get_data_by_id(material_id)

for task_id, task_type in tmp.calc_types.items():
for task_id, task_type in mp_doc.calc_types.items():
if task_type in opt_task_types:
optimization_task_ids[material_id.string].append(task_id)
```

### Query Materials Project Thermodoc entry and the relaxation tasks
### Query Materials Project `ThermoDoc` entry and the relaxation tasks

The thermodoc entry is the entry you normally see on the MP website
The `ThermoDoc` entry is the entry you normally see on the MP website

```python
# ThermoDoc: Query MP main entries
main_entry = mpr.get_entry_by_material_id(material_id=material_id)[0]
# Query one relaxation task
taskdoc = mpr.tasks.get_data_by_id(task_id, fields=["input", "output", "calcs_reversed", 'task_id', "run_type"])
task_doc = mpr.tasks.get_data_by_id(task_id, fields=["input", "output", "calcs_reversed", 'task_id', "run_type"])
```

## Filtering the data
Expand All @@ -61,58 +61,58 @@ This is done in two steps:
Check whether a task is compatible to Materials Project main entry, by comparing its DFT settings
and converged results with MP main entry.

- Note this step can no longer work for the current MP data, since a lot of `thermodoc` entry (main entry) have changed to `r2SCAN`
- Note this step no longer works for the current MP data, since a lot of `ThermoDoc` entries (main entry) have changed to `r2SCAN`

```python
def calc_type_equal(
taskdoc,
task_doc,
main_entry,
trjdata
trj_data
) -> bool:
# Check the LDAU of task
try:
is_hubbard = taskdoc.calcs_reversed[0].input['parameters']['LDAU']
is_hubbard = task_doc.calcs_reversed[0].input['parameters']['LDAU']
except:
is_hubbard = taskdoc.calcs_reversed[0].input['incar']['LDAU']
is_hubbard = task_doc.calcs_reversed[0].input['incar']['LDAU']

# Make sure we don't include both GGA and GGA+U for the same mp_id
if main_entry.parameters['is_hubbard'] != is_hubbard:
print(f'{main_entry.entry_id}, {taskdoc.task_id} is_hubbard= {is_hubbard}')
trjdata.exception[taskdoc.task_id] = f'is_hubbard inconsistent task is_hubbard={is_hubbard}'
print(f'{main_entry.entry_id}, {task_doc.task_id} is_hubbard= {is_hubbard}')
trj_data.exception[task_doc.task_id] = f'is_hubbard inconsistent task is_hubbard={is_hubbard}'
return False
elif is_hubbard == True:
# If the task is calculated with GGA+U
# Make sure the +U values are the same for each element
composition = taskdoc.output.structure.composition
composition = task_doc.output.structure.composition
hubbards = {element.symbol: U for element, U in
zip(composition.elements,
taskdoc.calcs_reversed[0].input['incar']['LDAUU'])}
task_doc.calcs_reversed[0].input['incar']['LDAUU'])}
if main_entry.parameters['hubbards'] != hubbards:
thermo_hubbards = main_entry.parameters['hubbards']
trjdata.exception[taskdoc.task_id] = f'hubbards inconsistent task hubbards={hubbards}, thermo hubbards={thermo_hubbards}'
trj_data.exception[task_doc.task_id] = f'hubbards inconsistent task hubbards={hubbards}, thermo hubbards={thermo_hubbards}'
return False
else:
# Check the energy convergence of the task wrt. the main entry
return check_energy_convergence(
taskdoc,
task_doc,
main_entry.uncorrected_energy_per_atom,
trjdata=trjdata
trj_data=trj_data
)
else:
# Check energy convergence for pure GGA tasks
check_energy_convergence(
taskdoc,
task_doc,
main_entry.uncorrected_energy_per_atom,
trjdata=trjdata
trj_data=trj_data
)

def check_energy_convergence(
taskdoc,
task_doc,
relaxed_entry_uncorrected_energy_per_atom,
trjdata
trj_data
) -> bool:
task_energy = taskdoc.calcs_reversed[0].output['ionic_steps'][-1]['e_fr_energy']
n_atom = taskdoc.calcs_reversed[0].output['ionic_steps'][-1][
task_energy = task_doc.calcs_reversed[0].output['ionic_steps'][-1]['e_fr_energy']
n_atom = task_doc.calcs_reversed[0].output['ionic_steps'][-1][
'structure'].composition.num_atoms
e_per_atom = task_energy / n_atom
# This is the energy difference of the last frame of the task vs main_entry energy
Expand All @@ -125,7 +125,7 @@ def check_energy_convergence(
# The task is falsely relaxed, we will discard the whole task
# This step will filter out tasks that relaxed into different spin states
# that caused large energy discrepancies
trjdata.exception[taskdoc.task_id] =
trj_data.exception[task_doc.task_id] =
f'e_diff is too large, '
f'task last step energy_per_atom = {e_per_atom}, '
f'relaxed_entry_uncorrected_e_per_atom = {relaxed_entry_uncorrected_energy_per_atom}'
Expand Down
14 changes: 7 additions & 7 deletions examples/crystaltoolkit_relax_viewer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
" # https://github.com/materialsproject/crystaltoolkit\n",
" # (only needed on Google Colab or if you didn't install these packages yet)\n",
" !git clone --depth 1 https://github.com/CederGroupHub/chgnet\n",
" !pip install './chgnet[examples]'\n"
" !pip install './chgnet[examples]'"
]
},
{
Expand All @@ -47,7 +47,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"from pymatgen.core import Structure\n"
"from pymatgen.core import Structure"
]
},
{
Expand All @@ -66,7 +66,7 @@
"\n",
" url = \"https://github.com/CederGroupHub/chgnet/raw/-/examples/mp-18767-LiMnO2.cif\"\n",
" cif = urlopen(url).read().decode(\"utf-8\")\n",
" structure = Structure.from_str(cif, fmt=\"cif\")\n"
" structure = Structure.from_str(cif, fmt=\"cif\")"
]
},
{
Expand Down Expand Up @@ -94,7 +94,7 @@
"# stretch the cell by a small amount\n",
"structure.scale_lattice(structure.volume * 1.1)\n",
"\n",
"print(f\"perturbed: {structure.get_space_group_info()}\")\n"
"print(f\"perturbed: {structure.get_space_group_info()}\")"
]
},
{
Expand Down Expand Up @@ -212,7 +212,7 @@
"\n",
"from chgnet.model import StructOptimizer\n",
"\n",
"trajectory = StructOptimizer().relax(structure)[\"trajectory\"]\n"
"trajectory = StructOptimizer().relax(structure)[\"trajectory\"]"
]
},
{
Expand All @@ -229,7 +229,7 @@
" np.linalg.norm(force, axis=1).mean() # mean of norm of force on each atom\n",
" for force in trajectory.forces\n",
"]\n",
"df_traj.index.name = \"step\"\n"
"df_traj.index.name = \"step\""
]
},
{
Expand All @@ -250,7 +250,7 @@
"mp_id = \"mp-18767\"\n",
"\n",
"dft_energy = -59.09\n",
"print(f\"{dft_energy=:.2f} eV (see https://materialsproject.org/materials/{mp_id})\")\n"
"print(f\"{dft_energy=:.2f} eV (see https://materialsproject.org/materials/{mp_id})\")"
]
},
{
Expand Down
17 changes: 12 additions & 5 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import inspect

import numpy as np
import pytest
from pymatgen.core import Structure
Expand Down Expand Up @@ -249,10 +251,15 @@ def test_model_load_version_params(
with pytest.raises(ValueError, match=f"Unknown {model_name=}"):
CHGNet.load(model_name=model_name)

# # set CHGNET_DEVICE to "cuda" and test
monkeypatch.setenv("CHGNET_DEVICE", env_device := "foobar")
with pytest.raises(
RuntimeError,
match=f"Expected one of cpu, .+type at start of device string: {env_device}",
bad_env_device = "foobar"
err_msg = f"Expected one of cpu, .+type at start of device string: {bad_env_device}"
with ( # noqa: PT012
monkeypatch.context() as ctx,
pytest.raises(RuntimeError, match=err_msg),
):
ctx.setenv("CHGNET_DEVICE", bad_env_device)
CHGNet.load()

# check check_cuda_mem defaults to False
inspect_signature = inspect.signature(CHGNet.load)
assert inspect_signature.parameters["check_cuda_mem"].default is False

0 comments on commit d3f1b30

Please sign in to comment.