Skip to content

Commit

Permalink
test.yml set fail-fast: false, fix some doc strings missing Returns/R…
Browse files Browse the repository at this point in the history
…aises, change parse_vasp_dir error type FileNotFoundError->NotADirectoryError
  • Loading branch information
janosh committed Sep 10, 2024
1 parent ce1f29c commit 73d6096
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
tests:
strategy:
fail-fast: true
fail-fast: false
matrix:
os: [ubuntu-latest, macos-14, windows-latest]
version:
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.2
rev: v0.6.4
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -48,7 +48,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.9.0
rev: v9.10.0
hooks:
- id: eslint
types: [file]
Expand Down
41 changes: 37 additions & 4 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def __init__(
that are not included in the trainer_args. Default = None
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
Raises:
NotImplementedError: If the optimizer or scheduler is not implemented
ImportError: If wandb_path is specified but wandb is not installed
ValueError: If wandb_path is specified but not in the format
'project/run_name'
"""
# Store trainer args for reproducibility
self.trainer_args = {
Expand Down Expand Up @@ -271,6 +277,9 @@ def train(
wandb_log_freq ("epoch" | "batch"): Frequency of logging to wandb.
'epoch' logs once per epoch, 'batch' logs after every batch.
Default = "batch"
Raises:
ValueError: If model is not initialized
"""
if self.model is None:
raise ValueError("Model needs to be initialized")
Expand Down Expand Up @@ -579,7 +588,11 @@ def _validate(
return {k: round(mae_error.avg, 6) for k, mae_error in mae_errors.items()}

def get_best_model(self) -> CHGNet:
"""Get best model recorded in the trainer."""
"""Get best model recorded in the trainer.
Returns:
CHGNet: the model with lowest validation set energy error
"""
if self.best_model is None:
raise RuntimeError("the model needs to be trained first")
MAE = min(self.training_history["e"]["val"]) # noqa: N806
Expand Down Expand Up @@ -649,7 +662,14 @@ def save_checkpoint(self, epoch: int, mae_error: dict, save_dir: str) -> None:

@classmethod
def load(cls, path: str) -> Self:
"""Load trainer state_dict."""
"""Load trainer state_dict.
Args:
path (str): path to the saved model
Returns:
Trainer: the loaded trainer
"""
state = torch.load(path, map_location=torch.device("cpu"))
model = CHGNet.from_dict(state["model"])
print(f"Loaded model params = {sum(p.numel() for p in model.parameters()):,}")
Expand All @@ -664,8 +684,21 @@ def load(cls, path: str) -> Self:
return trainer

@staticmethod
def move_to(obj, device) -> Tensor | list[Tensor]:
"""Move object to device."""
def move_to(
obj: Tensor | list[Tensor], device: torch.device
) -> Tensor | list[Tensor]:
"""Move object to device.
Args:
obj (Tensor | list[Tensor]): object(s) to move to device
device (torch.device): device to move object to
Raises:
TypeError: if obj is not a tensor or list of tensors
Returns:
Tensor | list[Tensor]: moved object(s)
"""
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, list):
Expand Down
19 changes: 11 additions & 8 deletions chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def cuda_devices_sorted_by_free_mem() -> list[int]:
"""List available CUDA devices sorted by increasing available memory.
To get the device with the most free memory, use the last list item.
Returns:
list[int]: CUDA device numbers sorted by increasing free memory.
"""
if not torch.cuda.is_available():
return []
Expand Down Expand Up @@ -94,10 +97,10 @@ def mae(prediction: Tensor, target: Tensor) -> Tensor:


def read_json(filepath: str) -> dict:
"""Read the json file.
"""Read the JSON file.
Args:
filepath (str): file name of json to read.
filepath (str): file name of JSON to read.
Returns:
dict: data stored in filepath
Expand All @@ -107,27 +110,27 @@ def read_json(filepath: str) -> dict:


def write_json(dct: dict, filepath: str) -> dict:
"""Write the json file.
"""Write the JSON file.
Args:
dct (dict): dictionary to write
filepath (str): file name of json to write.
Returns:
written dictionary
filepath (str): file name of JSON to write.
"""

def handler(obj: object) -> int | object:
"""Convert numpy int64 to int.
Fixes TypeError: Object of type int64 is not JSON serializable
reported in https://github.com/CederGroupHub/chgnet/issues/168.
Returns:
int | object: object for serialization
"""
if isinstance(obj, np.integer):
return int(obj)
return obj

with open(filepath, "w") as file:
with open(filepath, mode="w") as file:
json.dump(dct, file, default=handler)


Expand Down
12 changes: 11 additions & 1 deletion chgnet/utils/vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,16 @@ def parse_vasp_dir(
Exception to VASP calculation that did not achieve electronic convergence.
Default = True
save_path (str): path to save the parsed VASP labels
Raises:
NotADirectoryError: if the base_dir is not a directory
Returns:
dict: a dictionary of lists with keys for structure, uncorrected_total_energy,
energy_per_atom, force, magmom, stress.
"""
if os.path.isdir(base_dir) is False:
raise FileNotFoundError(f"{base_dir=} is not a directory")
raise NotADirectoryError(f"{base_dir=} is not a directory")

oszicar_path = zpath(f"{base_dir}/OSZICAR")
vasprun_path = zpath(f"{base_dir}/vasprun.xml")
Expand Down Expand Up @@ -170,6 +177,9 @@ def solve_charge_by_mag(
(3.5, 4.2): 3,
(4.2, 5): 2
))
Returns:
Structure: pymatgen Structure with oxidation states assigned based on magmoms.
"""
out_structure = structure.copy()
out_structure.remove_oxidation_states()
Expand Down
18 changes: 16 additions & 2 deletions examples/make_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,22 @@ def make_graphs(
make_partition(labels, graph_dir, train_ratio, val_ratio)


def make_one_graph(mp_id: str, graph_id: str, data, graph_dir) -> dict | bool:
"""Convert a structure to a CrystalGraph and save it."""
def make_one_graph(
mp_id: str, graph_id: str, data: StructureJsonData, graph_dir: str
) -> dict | bool:
"""Convert a structure to a CrystalGraph and save it.
Args:
mp_id (str): The material id.
graph_id (str): The graph id.
data (StructureJsonData): The dataset. Warning: Dicts are popped from the data,
i.e. modifying the data in place.
graph_dir (str): The directory to save the graphs.
Returns:
dict | bool: The label dictionary if the graph is saved successfully, False
otherwise.
"""
dct = data.data[mp_id].pop(graph_id)
struct = Structure.from_dict(dct.pop("structure"))
try:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ docstring-code-format = true
"ANN201",
"D100",
"D103",
"DOC201", # doc string missing Return section
"FBT001",
"FBT002",
"INP001",
Expand Down
2 changes: 1 addition & 1 deletion site/make_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@
markdown = markdown.replace(
"\n**Global Variables**\n---------------\n- **TYPE_CHECKING**\n\n", ""
)
with open(path, "w") as file:
with open(path, mode="w") as file:
file.write(markdown)
26 changes: 10 additions & 16 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,17 @@
def structure_data() -> StructureData:
"""Create a graph with 3 nodes and 3 directed edges."""
random.seed(42)
structures, energies, forces, stresses, magmoms, structure_ids = (
[],
[],
[],
[],
[],
[],
)
structures, energies, forces = [], [], []
stresses, magmoms, structure_ids = [], [], []

for index in range(100):
struct = NaCl.copy()
struct.perturb(0.1)
structures.append(struct)
energies.append(np.random.random(1))
forces.append(np.random.random([2, 3]))
stresses.append(np.random.random([3, 3]))
magmoms.append(np.random.random([2, 1]))
structure_ids.append(index)
structures += [NaCl.copy().perturb(0.1)]
energies += [np.random.random(1)]
forces += [np.random.random([2, 3])]
stresses += [np.random.random([3, 3])]
magmoms += [np.random.random([2, 1])]
structure_ids += [index]

return StructureData(
structures=structures,
energies=energies,
Expand Down

0 comments on commit 73d6096

Please sign in to comment.