Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fix-gh-168
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 21, 2024
2 parents 9484896 + 9717a32 commit a5f8b48
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
python setup.py build_ext --inplace
uv pip install -e .[test] --system --resolution=${{ matrix.version.resolution }}
uv pip install -e .[test,logging] --system --resolution=${{ matrix.version.resolution }}
- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ coverage.xml
.ipynb_checkpoints
bond_graph_error.cif
test.py

# training logs
wandb
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.4.8
hooks:
- id: ruff
args: [--fix]
Expand Down
4 changes: 2 additions & 2 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
model: CHGNet | None = None,
*,
use_device: str | None = None,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
stress_weight: float | None = 1 / 160.21766208,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
**kwargs,
Expand All @@ -73,7 +73,7 @@ def __init__(
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
Default = False
stress_weight (float): the conversion factor to convert GPa to eV/A^3.
Default = 1/160.21
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
Expand Down
4 changes: 2 additions & 2 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def load(
*,
model_name: str = "0.3.0",
use_device: str | None = None,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
verbose: bool = True,
) -> CHGNet:
"""Load pretrained CHGNet model.
Expand All @@ -692,7 +692,7 @@ def load(
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
Default = False
verbose (bool): whether to print model device information
Default = True
Raises:
Expand Down
56 changes: 52 additions & 4 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
from chgnet.model.model import CHGNet
from chgnet.utils import AverageMeter, determine_device, mae, write_json

try:
import wandb
except ImportError:
wandb = None


if TYPE_CHECKING:
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -49,7 +55,10 @@ def __init__(
torch_seed: int | None = None,
data_seed: int | None = None,
use_device: str | None = None,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
wandb_path: str | None = None,
wandb_init_kwargs: dict | None = None,
extra_run_config: dict | None = None,
**kwargs,
) -> None:
"""Initialize all hyper-parameters for trainer.
Expand Down Expand Up @@ -87,16 +96,23 @@ def __init__(
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
Default = False
wandb_path (str | None): The project and run name separated by a slash:
"project/run_name". If None, wandb logging is not used.
Default = None
wandb_init_kwargs (dict): Additional kwargs to pass to wandb.init.
Default = None
extra_run_config (dict): Additional hyper-params to be recorded by wandb
that are not included in the trainer_args. Default = None
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
"""
# Store trainer args for reproducibility
self.trainer_args = {
k: v
for k, v in locals().items()
if k not in {"self", "__class__", "model", "kwargs"}
}
self.trainer_args.update(kwargs)
} | kwargs

self.model = model
self.targets = targets
Expand Down Expand Up @@ -199,6 +215,27 @@ def __init__(
] = {key: {"train": [], "val": [], "test": []} for key in self.targets}
self.best_model = None

# Initialize wandb if project/run specified
if wandb_path:
if wandb is None:
raise ImportError(
"Weights and Biases not installed. pip install wandb to use "
"wandb logging."
)
if wandb_path.count("/") == 1:
project, run_name = wandb_path.split("/")
else:
raise ValueError(
f"{wandb_path=} should be in the format 'project/run_name' "
"(no extra slashes)"
)
wandb.init(
project=project,
name=run_name,
config=self.trainer_args | (extra_run_config or {}),
**(wandb_init_kwargs or {}),
)

def train(
self,
train_loader: DataLoader,
Expand Down Expand Up @@ -261,6 +298,13 @@ def train(

self.save_checkpoint(epoch, val_mae, save_dir=save_dir)

# Log train/val metrics to wandb
if wandb is not None and self.trainer_args.get("wandb_path"):
wandb.log(
{f"train_{k}_mae": v for k, v in train_mae.items()}
| {f"val_{k}_mae": v for k, v in val_mae.items()}
)

if test_loader is not None:
# test best model
print("---------Evaluate Model on Test Set---------------")
Expand All @@ -283,6 +327,10 @@ def train(
self.training_history[key]["test"] = test_mae[key]
self.save(filename=os.path.join(save_dir, test_file))

# Log test metrics to wandb
if wandb is not None and self.trainer_args.get("wandb_path"):
wandb.log({f"test_{k}_mae": v for k, v in test_mae.items()})

def _train(self, train_loader: DataLoader, current_epoch: int) -> dict:
"""Train all data for one epoch.
Expand Down
3 changes: 2 additions & 1 deletion chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
def determine_device(
use_device: str | None = None,
*,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
) -> str:
"""Determine the device to use for torch model.
Args:
use_device (str): User specify device name
check_cuda_mem (bool): Whether to return cuda with most available memory
Default = False
Returns:
device (str): device name to be passed to model.to(device)
Expand Down
48 changes: 24 additions & 24 deletions examples/QueryMPtrj.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ opt_task_types = [
optimization_task_ids = {}
for doc in material_ids:
material_id = doc.material_id
tmp = mpr.materials.get_data_by_id(material_id)
mp_doc = mpr.materials.get_data_by_id(material_id)

for task_id, task_type in tmp.calc_types.items():
for task_id, task_type in mp_doc.calc_types.items():
if task_type in opt_task_types:
optimization_task_ids[material_id.string].append(task_id)
```

### Query Materials Project Thermodoc entry and the relaxation tasks
### Query Materials Project `ThermoDoc` entry and the relaxation tasks

The thermodoc entry is the entry you normally see on the MP website
The `ThermoDoc` entry is the entry you normally see on the MP website

```python
# ThermoDoc: Query MP main entries
main_entry = mpr.get_entry_by_material_id(material_id=material_id)[0]
# Query one relaxation task
taskdoc = mpr.tasks.get_data_by_id(task_id, fields=["input", "output", "calcs_reversed", 'task_id', "run_type"])
task_doc = mpr.tasks.get_data_by_id(task_id, fields=["input", "output", "calcs_reversed", 'task_id', "run_type"])
```

## Filtering the data
Expand All @@ -61,58 +61,58 @@ This is done in two steps:
Check whether a task is compatible to Materials Project main entry, by comparing its DFT settings
and converged results with MP main entry.

- Note this step can no longer work for the current MP data, since a lot of `thermodoc` entry (main entry) have changed to `r2SCAN`
- Note this step no longer works for the current MP data, since a lot of `ThermoDoc` entries (main entry) have changed to `r2SCAN`

```python
def calc_type_equal(
taskdoc,
task_doc,
main_entry,
trjdata
trj_data
) -> bool:
# Check the LDAU of task
try:
is_hubbard = taskdoc.calcs_reversed[0].input['parameters']['LDAU']
is_hubbard = task_doc.calcs_reversed[0].input['parameters']['LDAU']
except:
is_hubbard = taskdoc.calcs_reversed[0].input['incar']['LDAU']
is_hubbard = task_doc.calcs_reversed[0].input['incar']['LDAU']

# Make sure we don't include both GGA and GGA+U for the same mp_id
if main_entry.parameters['is_hubbard'] != is_hubbard:
print(f'{main_entry.entry_id}, {taskdoc.task_id} is_hubbard= {is_hubbard}')
trjdata.exception[taskdoc.task_id] = f'is_hubbard inconsistent task is_hubbard={is_hubbard}'
print(f'{main_entry.entry_id}, {task_doc.task_id} is_hubbard= {is_hubbard}')
trj_data.exception[task_doc.task_id] = f'is_hubbard inconsistent task is_hubbard={is_hubbard}'
return False
elif is_hubbard == True:
# If the task is calculated with GGA+U
# Make sure the +U values are the same for each element
composition = taskdoc.output.structure.composition
composition = task_doc.output.structure.composition
hubbards = {element.symbol: U for element, U in
zip(composition.elements,
taskdoc.calcs_reversed[0].input['incar']['LDAUU'])}
task_doc.calcs_reversed[0].input['incar']['LDAUU'])}
if main_entry.parameters['hubbards'] != hubbards:
thermo_hubbards = main_entry.parameters['hubbards']
trjdata.exception[taskdoc.task_id] = f'hubbards inconsistent task hubbards={hubbards}, thermo hubbards={thermo_hubbards}'
trj_data.exception[task_doc.task_id] = f'hubbards inconsistent task hubbards={hubbards}, thermo hubbards={thermo_hubbards}'
return False
else:
# Check the energy convergence of the task wrt. the main entry
return check_energy_convergence(
taskdoc,
task_doc,
main_entry.uncorrected_energy_per_atom,
trjdata=trjdata
trj_data=trj_data
)
else:
# Check energy convergence for pure GGA tasks
check_energy_convergence(
taskdoc,
task_doc,
main_entry.uncorrected_energy_per_atom,
trjdata=trjdata
trj_data=trj_data
)

def check_energy_convergence(
taskdoc,
task_doc,
relaxed_entry_uncorrected_energy_per_atom,
trjdata
trj_data
) -> bool:
task_energy = taskdoc.calcs_reversed[0].output['ionic_steps'][-1]['e_fr_energy']
n_atom = taskdoc.calcs_reversed[0].output['ionic_steps'][-1][
task_energy = task_doc.calcs_reversed[0].output['ionic_steps'][-1]['e_fr_energy']
n_atom = task_doc.calcs_reversed[0].output['ionic_steps'][-1][
'structure'].composition.num_atoms
e_per_atom = task_energy / n_atom
# This is the energy difference of the last frame of the task vs main_entry energy
Expand All @@ -125,7 +125,7 @@ def check_energy_convergence(
# The task is falsely relaxed, we will discard the whole task
# This step will filter out tasks that relaxed into different spin states
# that caused large energy discrepancies
trjdata.exception[taskdoc.task_id] =
trj_data.exception[task_doc.task_id] =
f'e_diff is too large, '
f'task last step energy_per_atom = {e_per_atom}, '
f'relaxed_entry_uncorrected_e_per_atom = {relaxed_entry_uncorrected_energy_per_atom}'
Expand Down
14 changes: 7 additions & 7 deletions examples/crystaltoolkit_relax_viewer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
" # https://github.com/materialsproject/crystaltoolkit\n",
" # (only needed on Google Colab or if you didn't install these packages yet)\n",
" !git clone --depth 1 https://github.com/CederGroupHub/chgnet\n",
" !pip install './chgnet[examples]'\n"
" !pip install './chgnet[examples]'"
]
},
{
Expand All @@ -47,7 +47,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"from pymatgen.core import Structure\n"
"from pymatgen.core import Structure"
]
},
{
Expand All @@ -66,7 +66,7 @@
"\n",
" url = \"https://github.com/CederGroupHub/chgnet/raw/-/examples/mp-18767-LiMnO2.cif\"\n",
" cif = urlopen(url).read().decode(\"utf-8\")\n",
" structure = Structure.from_str(cif, fmt=\"cif\")\n"
" structure = Structure.from_str(cif, fmt=\"cif\")"
]
},
{
Expand Down Expand Up @@ -94,7 +94,7 @@
"# stretch the cell by a small amount\n",
"structure.scale_lattice(structure.volume * 1.1)\n",
"\n",
"print(f\"perturbed: {structure.get_space_group_info()}\")\n"
"print(f\"perturbed: {structure.get_space_group_info()}\")"
]
},
{
Expand Down Expand Up @@ -212,7 +212,7 @@
"\n",
"from chgnet.model import StructOptimizer\n",
"\n",
"trajectory = StructOptimizer().relax(structure)[\"trajectory\"]\n"
"trajectory = StructOptimizer().relax(structure)[\"trajectory\"]"
]
},
{
Expand All @@ -229,7 +229,7 @@
" np.linalg.norm(force, axis=1).mean() # mean of norm of force on each atom\n",
" for force in trajectory.forces\n",
"]\n",
"df_traj.index.name = \"step\"\n"
"df_traj.index.name = \"step\""
]
},
{
Expand All @@ -250,7 +250,7 @@
"mp_id = \"mp-18767\"\n",
"\n",
"dft_energy = -59.09\n",
"print(f\"{dft_energy=:.2f} eV (see https://materialsproject.org/materials/{mp_id})\")\n"
"print(f\"{dft_energy=:.2f} eV (see https://materialsproject.org/materials/{mp_id})\")"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ test = ["pytest-cov>=4", "pytest>=8"]
# needed to run interactive example notebooks
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
docs = ["lazydocs>=0.4"]
logging = ["wandb>=0.17"]

[project.urls]
Source = "https://github.com/CederGroupHub/chgnet"
Expand Down Expand Up @@ -89,6 +90,7 @@ ignore = [
pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]
isort.split-on-trailing-comma = false
isort.known-third-party = ["wandb"]

[tool.ruff.format]
docstring-code-format = true
Expand Down
Loading

0 comments on commit a5f8b48

Please sign in to comment.