From 2e8578acbe9009699092ee8d21e69056cbcfa2eb Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Fri, 12 Jul 2024 20:20:36 +0300 Subject: [PATCH] [BE] Single ruff config (#751) * ruff toml * ruff fix * fix and format * update ruff * remove ruff config in pyproject * lint workflow use ruff.toml * unpin ruff * fix circular imports * clean up imports * revert to 3.9 annotations, dataclasses_json does not work with __future__ annotations * pin ruff --- .github/workflows/lint.yml | 5 +- .pre-commit-config.yaml | 4 +- packages/fairchem-core/pyproject.toml | 72 +- packages/fairchem-data-oc/pyproject.toml | 5 - ruff.toml | 76 + .../2023_neurips_challenge/challenge_eval.py | 6 +- .../AdsorbML/adsorbml/scripts/dense_eval.py | 2 + .../AdsorbML/adsorbml/scripts/process_mlrs.py | 7 +- .../AdsorbML/adsorbml/scripts/utils.py | 4 +- .../adsorbml/scripts/write_top_k_vasp.py | 2 + .../applications/cattsunami/core/__init__.py | 4 +- .../applications/cattsunami/core/autoframe.py | 50 +- .../applications/cattsunami/core/ocpneb.py | 16 +- .../applications/cattsunami/core/reaction.py | 10 +- .../cattsunami/databases/__init__.py | 2 + .../run_validation/run_validation.py | 26 +- src/fairchem/core/_cli.py | 23 +- src/fairchem/core/common/data_parallel.py | 2 +- src/fairchem/core/common/logger.py | 2 +- src/fairchem/core/common/test_utils.py | 8 +- src/fairchem/core/common/tutorial_utils.py | 2 +- .../equiformer_v2/equiformer_v2_oc20.py | 2 +- .../core/models/equiformer_v2/input_block.py | 4 +- .../core/models/gemnet_oc/gemnet_oc.py | 4 +- src/fairchem/core/trainers/base_trainer.py | 4 +- src/fairchem/data/oc/core/__init__.py | 6 +- src/fairchem/data/oc/core/adsorbate.py | 63 +- .../data/oc/core/adsorbate_slab_config.py | 24 +- src/fairchem/data/oc/core/bulk.py | 61 +- .../oc/core/multi_adsorbate_slab_config.py | 20 +- src/fairchem/data/oc/core/slab.py | 91 +- .../data/oc/databases/pkls/__init__.py | 2 + src/fairchem/data/oc/databases/update.py | 30 +- .../scripts/precompute_sample_structures.py | 8 +- src/fairchem/data/oc/structure_generator.py | 9 +- src/fairchem/data/oc/utils/__init__.py | 2 + src/fairchem/data/oc/utils/flag_anomaly.py | 5 +- src/fairchem/data/oc/utils/vasp.py | 7 +- .../data/odac/force_field/FF_analysis.py | 36 +- .../promising_mof_energies/energy.py | 1239 +++++++++++------ src/fairchem/data/odac/setup_vasp.py | 32 +- .../om/biomolecules/geom/sample_geom_drugs.py | 10 +- .../geom/write_geom_drugs_structures.py | 7 +- src/fairchem/data/om/omdata/orca/calc.py | 6 +- src/fairchem/data/om/omdata/orca/recipes.py | 11 +- src/fairchem/demo/ocpapi/client/client.py | 20 +- src/fairchem/demo/ocpapi/client/models.py | 1 + src/fairchem/demo/ocpapi/client/ui.py | 8 +- src/fairchem/demo/ocpapi/version.py | 2 + .../demo/ocpapi/workflows/adsorbates.py | 77 +- src/fairchem/demo/ocpapi/workflows/context.py | 8 +- src/fairchem/demo/ocpapi/workflows/filter.py | 29 +- src/fairchem/demo/ocpapi/workflows/log.py | 2 + src/fairchem/demo/ocpapi/workflows/retry.py | 27 +- .../ocpapi/tests/unit/client/test_models.py | 9 +- 55 files changed, 1356 insertions(+), 838 deletions(-) create mode 100644 ruff.toml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 13faf8c65..2e5be1d36 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -25,7 +25,4 @@ jobs: - name: ruff run: | ruff --version - ruff check --statistics --config packages/fairchem-core/pyproject.toml src/fairchem/core/ - ruff check --statistics --config packages/fairchem-data-oc/pyproject.toml src/fairchem/data/oc/ - #ruff check --statistics --config packages/fairchem-data-om/pyproject.toml src/fairchem/data/om/ - #ruff check --statistics --config packages/fairchem-demo-ocpapi/pyproject.toml src/fairchem/demo/ocpapi/ + ruff check src diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 298c37b89..33b0e7999 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,13 +5,13 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.4 + rev: v0.5.1 hooks: - id: ruff args: [ --fix ] - id: ruff-format - repo: https://github.com/adamchainz/blacken-docs - rev: 1.16.0 + rev: 1.18.0 hooks: - id: blacken-docs - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/packages/fairchem-core/pyproject.toml b/packages/fairchem-core/pyproject.toml index ffa4d43de..5116e5a5b 100644 --- a/packages/fairchem-core/pyproject.toml +++ b/packages/fairchem-core/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ ] [project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev] -dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.4.10"] +dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"] docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi", "umap-learn", "vdict"] adsorbml = ["dscribe","x3dase","scikit-image"] @@ -74,73 +74,3 @@ testpaths = ["tests"] [tool.coverage.run] source = ["fairchem.core"] - -[tool.ruff] -line-length = 88 -lint.select = [ - "B", # flake8-bugbear - "C4", # flake8-comprehensions - "E", # pycodestyle error - "EXE", # flake8-executable - "F", # pyflakes - "FA", # flake8-future-annotations - "FBT003", # boolean-positional-value-in-call - "FLY", # flynt - "I", # isort - "ICN", # flake8-import-conventions - "PD", # pandas-vet - "PERF", # perflint - "PIE", # flake8-pie - "PL", # pylint - "PT", # flake8-pytest-style - "PYI", # flakes8-pyi - "Q", # flake8-quotes - "RET", # flake8-return - "RSE", # flake8-raise - "RUF", # Ruff-specific rules - "SIM", # flake8-simplify - "SLOT", # flake8-slots - "TCH", # flake8-type-checking - "TID", # tidy imports - "TID", # flake8-tidy-imports - "UP", # pyupgrade - "W", # pycodestyle warning - "YTT", # flake8-2020 -] -lint.ignore = [ - "PLR", # Design related pylint codes - "E501", # Line too long - "B028", # No explicit stacklevel - "EM101", # Exception must not use a string literal - "EM102", # Exception must not use an f-string literal - "G004", # f-string in Logging statement - "RUF015", # Prefer next(iter()) - "RET505", # Unnecessary `elif` after `return` - "PT004", # Fixture does not return anthing - "B017", # pytest.raises - "PT011", # pytest.raises - "PT012", # pytest.raises" - "E741", # ambigous variable naming, i.e. one letter - "FBT003", # boolean positional variable in function call - "PERF203", # `try`-`except` within a loop incurs performance overhead (no overhead in Py 3.11+) - "EXE002", # The file is executable but no shebang is present (not sure why some files come up as this) -] - -lint.typing-modules = ["mypackage._compat.typing"] -src = ["src"] -lint.unfixable = [ - "T20", # Removes print statements - "F841", # Removes unused variables -] -lint.pydocstyle.convention = "google" -lint.isort.known-first-party = ["fairchem.core"] -lint.isort.required-imports = ["from __future__ import annotations"] - -[tool.ruff.lint.per-file-ignores] -# Ignore `E402` (import violations) in all `__init__.py` files, and in `path/to/file.py`. -"src/fairchem/core/__init__.py" = ["I002"] -"src/fairchem/core/conf.py" = ["I002"] -"src/fairchem/core/common/*" = ["PLW0603"] # Using the global statement to update [] is discouraged -"src/fairchem/core/scripts/*" = ["PLW0603"] # Using the global statement to update [] is discouraged -"src/fairchem/core/models/*" = ["PERF401"] # Use a list comprehension to create a transformed list -"src/fairchem/core/models/gemnet*" = ["B023"] # Function definition does not bind loop variable `first_sph` diff --git a/packages/fairchem-data-oc/pyproject.toml b/packages/fairchem-data-oc/pyproject.toml index b41b3be4d..76f55cdb0 100644 --- a/packages/fairchem-data-oc/pyproject.toml +++ b/packages/fairchem-data-oc/pyproject.toml @@ -42,8 +42,3 @@ content-type = "text/markdown" fragments = [ { path = "src/fairchem/data/oc/README.md" } ] - -[tool.ruff.lint.per-file-ignores] -# Ignore `E402` (import violations) in all `__init__.py` files, and in `path/to/file.py`. -"src/fairchem/data/oc/core/__init__.py" = ["F401"] -"src/fairchem/data/oc/utils/__init__.py" = ["F401"] diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..1a3a4eeab --- /dev/null +++ b/ruff.toml @@ -0,0 +1,76 @@ +include = ["src/fairchem/core/**/*.py", "src/fairchem/data/oc/**/*.py"] +line-length = 88 + +[lint] +select = [ + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "E", # pycodestyle error + "EXE", # flake8-executable + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT003", # boolean-positional-value-in-call + "FLY", # flynt + "I", # isort + "ICN", # flake8-import-conventions + "PD", # pandas-vet + "PERF", # perflint + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PYI", # flakes8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # Ruff-specific rules + "SIM", # flake8-simplify + "SLOT", # flake8-slots + "TCH", # flake8-type-checking + "TID", # tidy imports + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warning + "YTT", # flake8-2020 +] +ignore = [ + "PLR", # Design related pylint codes + "E501", # Line too long + "B028", # No explicit stacklevel + "EM101", # Exception must not use a string literal + "EM102", # Exception must not use an f-string literal + "G004", # f-string in Logging statement + "RUF015", # Prefer next(iter()) + "RET505", # Unnecessary `elif` after `return` + "PT004", # Fixture does not return anthing + "B017", # pytest.raises + "PT011", # pytest.raises + "PT012", # pytest.raises" + "E741", # ambigous variable naming, i.e. one letter + "FBT003", # boolean positional variable in function call + "PERF203", # `try`-`except` within a loop incurs performance overhead (no overhead in Py 3.11+) + "EXE002", # The file is executable but no shebang is present (not sure why some files come up as this) + "PLC2401", # non ASCII characters +] +unfixable = [ + "T20", # Removes print statements + "F841", # Removes unused variables +] + +[lint.isort] +known-first-party = ["fairchem.core"] +required-imports = ["from __future__ import annotations"] + +[lint.pydocstyle] +convention = "google" + +[lint.per-file-ignores] +# Ignore `E402` (import violations) in all `__init__.py` files, and in `path/to/file.py`. +"src/fairchem/core/__init__.py" = ["I002"] +"src/fairchem/core/conf.py" = ["I002"] +"src/fairchem/core/common/*" = ["PLW0603"] # Using the global statement to update [] is discouraged +"src/fairchem/core/scripts/*" = ["PLW0603"] # Using the global statement to update [] is discouraged +"src/fairchem/core/models/*" = ["PERF401"] # Use a list comprehension to create a transformed list +"src/fairchem/core/models/gemnet*" = ["B023"] # Function definition does not bind loop variable `first_sph` +# Ignore `E402` (import violations) in all `__init__.py` files, and in `path/to/file.py`. +"src/fairchem/data/oc/core/__init__.py" = ["F401"] +"src/fairchem/data/oc/utils/__init__.py" = ["F401"] diff --git a/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py b/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py index cbc721447..d7e801fe0 100644 --- a/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py +++ b/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import pickle from collections import defaultdict @@ -19,9 +21,7 @@ def is_successful(best_pred_energy, best_dft_energy, SUCCESS_THRESHOLD=0.1): # Given best ML and DFT energy, compute various success metrics: # success_parity: base success metric (ML - DFT <= SUCCESS_THRESHOLD) diff = best_pred_energy - best_dft_energy - success_parity = diff <= SUCCESS_THRESHOLD - - return success_parity + return diff <= SUCCESS_THRESHOLD def compute_valid_ml_success(ml_data, dft_data): diff --git a/src/fairchem/applications/AdsorbML/adsorbml/scripts/dense_eval.py b/src/fairchem/applications/AdsorbML/adsorbml/scripts/dense_eval.py index b4f783e24..83d68659c 100644 --- a/src/fairchem/applications/AdsorbML/adsorbml/scripts/dense_eval.py +++ b/src/fairchem/applications/AdsorbML/adsorbml/scripts/dense_eval.py @@ -67,6 +67,8 @@ } """ +from __future__ import annotations + import argparse import pickle from collections import defaultdict diff --git a/src/fairchem/applications/AdsorbML/adsorbml/scripts/process_mlrs.py b/src/fairchem/applications/AdsorbML/adsorbml/scripts/process_mlrs.py index 43f308c08..ff7e2a856 100644 --- a/src/fairchem/applications/AdsorbML/adsorbml/scripts/process_mlrs.py +++ b/src/fairchem/applications/AdsorbML/adsorbml/scripts/process_mlrs.py @@ -15,6 +15,8 @@ - errors_by_sid.pkl: any errors that occurred """ +from __future__ import annotations + import argparse import multiprocessing as mp import os @@ -52,8 +54,7 @@ def parse_args(): ) parser.add_argument("--surface-dir", type=str, help="Path to surface DFT outputs") - args = parser.parse_args() - return args + return parser.parse_args() def min_diff(atoms_init, atoms_final): @@ -159,7 +160,7 @@ def process_mlrs(arg): continue anomalies[sid] = anomaly if not anomaly: - grouped_configs[system].append(tuple([adslab_idx, predE, mlrs])) + grouped_configs[system].append((adslab_idx, predE, mlrs)) # group configs by system and sort sorted_grouped_configs = {} diff --git a/src/fairchem/applications/AdsorbML/adsorbml/scripts/utils.py b/src/fairchem/applications/AdsorbML/adsorbml/scripts/utils.py index be431e5f7..aaeedd66e 100644 --- a/src/fairchem/applications/AdsorbML/adsorbml/scripts/utils.py +++ b/src/fairchem/applications/AdsorbML/adsorbml/scripts/utils.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import glob import os -from pymatgen.io.vasp.outputs import * +from pymatgen.io.vasp.outputs import Oszicar """ This script provides utility functions that can be useful for trying to diff --git a/src/fairchem/applications/AdsorbML/adsorbml/scripts/write_top_k_vasp.py b/src/fairchem/applications/AdsorbML/adsorbml/scripts/write_top_k_vasp.py index 0f957cbf7..1fb846c9f 100644 --- a/src/fairchem/applications/AdsorbML/adsorbml/scripts/write_top_k_vasp.py +++ b/src/fairchem/applications/AdsorbML/adsorbml/scripts/write_top_k_vasp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import os import pickle diff --git a/src/fairchem/applications/cattsunami/core/__init__.py b/src/fairchem/applications/cattsunami/core/__init__.py index e97618b9b..6060a0491 100644 --- a/src/fairchem/applications/cattsunami/core/__init__.py +++ b/src/fairchem/applications/cattsunami/core/__init__.py @@ -1,4 +1,6 @@ -from .reaction import Reaction +from __future__ import annotations + from .ocpneb import OCPNEB +from .reaction import Reaction __all__ = ["Reaction", "OCPNEB"] diff --git a/src/fairchem/applications/cattsunami/core/autoframe.py b/src/fairchem/applications/cattsunami/core/autoframe.py index 157bc0315..326d5df61 100644 --- a/src/fairchem/applications/cattsunami/core/autoframe.py +++ b/src/fairchem/applications/cattsunami/core/autoframe.py @@ -3,18 +3,24 @@ and final frames for NEB calculations. """ -import numpy as np +from __future__ import annotations + +import copy +from copy import deepcopy +from itertools import combinations, product +from typing import TYPE_CHECKING + import ase +import networkx as nx +import numpy as np +import torch from ase.data import atomic_numbers, covalent_radii -from scipy.spatial.distance import euclidean -from itertools import combinations, product from ase.optimize import BFGS -import copy -import torch from fairchem.data.oc.utils import DetectTrajAnomaly -import networkx as nx -from copy import deepcopy -from fairchem.applications.cattsunami.core import Reaction +from scipy.spatial.distance import euclidean + +if TYPE_CHECKING: + from fairchem.applications.cattsunami.core import Reaction class AutoFrame: @@ -76,13 +82,11 @@ def only_keep_unique_systems(self, systems, energies): # Iterate over the systems and see where there are matches (systems where every adsorbate atom is overlapping) for idx in range(len(systems)): if not any( - [ - self.are_all_adsorbate_atoms_overlapping( - adsorbates_stripped_out[unique_system], - adsorbates_stripped_out[idx], - ) - for unique_system in unique_systems - ] + self.are_all_adsorbate_atoms_overlapping( + adsorbates_stripped_out[unique_system], + adsorbates_stripped_out[idx], + ) + for unique_system in unique_systems ): unique_systems.append(idx) unique_energies.append(energies[idx]) @@ -162,9 +166,9 @@ def __init__( product1_energies: list, product2_systems: list, product2_energies: list, - r_product1_max: float = None, - r_product2_max: float = None, - r_product2_min: float = None, + r_product1_max: float | None = None, + r_product2_max: float | None = None, + r_product2_min: float | None = None, ): """ Initialize class to handle the automatic generation of NEB frames for dissociation. @@ -1071,8 +1075,7 @@ def interpolate_and_correct_frames( reaction, map_idx, ) - images = interpolate(initial, final, n_frames) - return images + return interpolate(initial, final, n_frames) def get_shortest_path( @@ -1122,7 +1125,7 @@ def get_shortest_path( # atoms (2) the bound atom of product 1 (3) the atom of product 2 which formed a new bond shortest_path_final_positions = [] equivalent_idx_factors = len(initial) * np.array(list(range(9))) - for idx, atom in enumerate(initial): + for idx, _atom in enumerate(initial): equivalent_indices = equivalent_idx_factors + idx final_distances = [ euclidean(initial.positions[idx], new_atoms_final.positions[i]) @@ -1421,8 +1424,7 @@ def get_product2_idx( edge for edge in reaction.edge_list_initial if edge not in edge_list_final ][0] flat_nodes = [item for sublist in traversal_rxt1_final for item in sublist] - product2_binding_idx = [idx for idx in broken_edge if idx not in flat_nodes][0] - return product2_binding_idx + return [idx for idx in broken_edge if idx not in flat_nodes][0] def traverse_adsorbate_general( @@ -1556,7 +1558,7 @@ def interpolate(initial_frame: ase.Atoms, final_frame: ase.Atoms, num_frames: in atoms_frames.append(atoms_now) # Iteratively update positions to avoid overlap - for i in range(100): + for _i in range(100): rate = 0.1 frame_dist = [] diff --git a/src/fairchem/applications/cattsunami/core/ocpneb.py b/src/fairchem/applications/cattsunami/core/ocpneb.py index c4607272a..9a6391ca6 100644 --- a/src/fairchem/applications/cattsunami/core/ocpneb.py +++ b/src/fairchem/applications/cattsunami/core/ocpneb.py @@ -1,14 +1,16 @@ +from __future__ import annotations + import logging import numpy as np import torch - from ase.optimize.precon import Precon, PreconImages +from torch.utils.data import DataLoader + from fairchem.core.common.registry import registry from fairchem.core.common.utils import setup_imports, setup_logging from fairchem.core.datasets import data_list_collater from fairchem.core.preprocessing import AtomsToGraphs -from torch.utils.data import DataLoader try: from ase.neb import DyNEB, NEBState @@ -178,7 +180,7 @@ def get_forces(self): fixed_atoms = np.array( [idx for idx, tag in enumerate(self.images[0].get_tags()) if tag == 0] ) - for i in range(0, self.nimages - 2): + for i in range(self.nimages - 2): for fixed_atom in fixed_atoms: forces[fixed_atom + len(images[0]) * i] = [0, 0, 0] @@ -244,14 +246,10 @@ def set_positions(self, positions): image.set_positions(positions[n1:n2]) n1 = n2 self.cached = False + return None def get_precon_forces(self, forces, energies, images): - if ( - self.precon is None - or isinstance(self.precon, str) - or isinstance(self.precon, Precon) - or isinstance(self.precon, list) - ): + if self.precon is None or isinstance(self.precon, (str, Precon, list)): self.precon = PreconImages(self.precon, images) # apply preconditioners to transform forces diff --git a/src/fairchem/applications/cattsunami/core/reaction.py b/src/fairchem/applications/cattsunami/core/reaction.py index 916c69ee2..2c699b0f3 100644 --- a/src/fairchem/applications/cattsunami/core/reaction.py +++ b/src/fairchem/applications/cattsunami/core/reaction.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle import random @@ -11,9 +13,9 @@ def __init__( self, reaction_db_path: str, adsorbate_db_path: str, - reaction_id_from_db: int = None, - reaction_str_from_db: str = None, - reaction_type: str = None, + reaction_id_from_db: int | None = None, + reaction_str_from_db: str | None = None, + reaction_type: str | None = None, ): self.reaction_db_path = reaction_db_path reaction_db = pickle.load(open(reaction_db_path, "rb")) @@ -92,6 +94,6 @@ def get_desorption_mapping(self, reactant): Get mapping for desorption reaction """ mapping = {} - for idx, atom in enumerate(reactant): + for idx, _atom in enumerate(reactant): mapping[idx] = idx return [mapping] diff --git a/src/fairchem/applications/cattsunami/databases/__init__.py b/src/fairchem/applications/cattsunami/databases/__init__.py index b7553b56d..6dc29dd6a 100644 --- a/src/fairchem/applications/cattsunami/databases/__init__.py +++ b/src/fairchem/applications/cattsunami/databases/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os DISSOCIATION_REACTION_DB_PATH = os.path.join( diff --git a/src/fairchem/applications/cattsunami/run_validation/run_validation.py b/src/fairchem/applications/cattsunami/run_validation/run_validation.py index 92e427adc..ec5683f74 100644 --- a/src/fairchem/applications/cattsunami/run_validation/run_validation.py +++ b/src/fairchem/applications/cattsunami/run_validation/run_validation.py @@ -5,17 +5,18 @@ from __future__ import annotations +import argparse +import os from typing import TYPE_CHECKING +import numpy as np +import pandas as pd +import torch from ase.io import read from ase.optimize import BFGS -import torch -import argparse -from fairchem.core.common.relaxation.ase_utils import OCPCalculator from fairchem.applications.cattsunami.core.ocpneb import OCPNEB -import os -import pandas as pd -import numpy as np + +from fairchem.core.common.relaxation.ase_utils import OCPCalculator if TYPE_CHECKING: import ase @@ -141,9 +142,14 @@ def all_converged(row, ml=True): Returns: bool: whether the system is converged """ - if row.converged_ml and row.converged and ml: - return True - elif row.converged_ml and row.converged and (not np.isnan(row.E_TS_SP)): + if ( + row.converged_ml + and row.converged + and ml + or row.converged_ml + and row.converged + and (not np.isnan(row.E_TS_SP)) + ): return True return False @@ -340,8 +346,8 @@ def get_single_point( # If single points are to be performed, perform them if args.get_ts_sp: - from vasp_interactive import VaspInteractive from fairchem.data.oc.utils.vasp import calculate_surface_k_points + from vasp_interactive import VaspInteractive os.makedirs( f"{args.output_file_path}/{model_id}/vasp_files/{neb_id}", diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index 047cade3c..69204a8ce 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -94,20 +94,33 @@ def main(): else: # Run locally on a single node, n-processes if args.distributed: - logging.info(f"Running in distributed local mode with {args.num_gpus} ranks") + logging.info( + f"Running in distributed local mode with {args.num_gpus} ranks" + ) # HACK to disable multiprocess dataloading in local mode # there is an open issue where LMDB's environment cannot be pickled and used # during torch multiprocessing https://github.com/pytorch/examples/issues/526 if "optim" in config and "num_workers" in config["optim"]: config["optim"]["num_workers"] = 0 - logging.info("WARNING: running in local mode, setting dataloading num_workers to 0, see https://github.com/pytorch/examples/issues/526") - - launch_config = LaunchConfig(min_nodes=1, max_nodes=1, nproc_per_node=args.num_gpus, rdzv_backend="c10d", max_restarts=0) + logging.info( + "WARNING: running in local mode, setting dataloading num_workers to 0, see https://github.com/pytorch/examples/issues/526" + ) + + launch_config = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=args.num_gpus, + rdzv_backend="c10d", + max_restarts=0, + ) elastic_launch(launch_config, runner_wrapper)(args.distributed, config) else: logging.info("Running in non-distributed local mode") - assert args.num_gpus == 1, "Can only run with a single gpu in non distributed local mode, use --distributed flag instead if using >1 gpu" + assert ( + args.num_gpus == 1 + ), "Can only run with a single gpu in non distributed local mode, use --distributed flag instead if using >1 gpu" runner_wrapper(args.distributed, config) + if __name__ == "__main__": main() diff --git a/src/fairchem/core/common/data_parallel.py b/src/fairchem/core/common/data_parallel.py index 99e2cd9e9..4d5836b78 100644 --- a/src/fairchem/core/common/data_parallel.py +++ b/src/fairchem/core/common/data_parallel.py @@ -160,7 +160,7 @@ def __init__( shuffle=shuffle, drop_last=drop_last, batch_size=batch_size, - seed=seed + seed=seed, ) self.batch_sampler = BatchSampler( self.single_sampler, diff --git a/src/fairchem/core/common/logger.py b/src/fairchem/core/common/logger.py index c1d501920..639d83c29 100644 --- a/src/fairchem/core/common/logger.py +++ b/src/fairchem/core/common/logger.py @@ -86,7 +86,7 @@ def __init__(self, config) -> None: ) def watch(self, model, log_freq: int = 1000) -> None: - wandb.watch(model, log_freq = log_freq) + wandb.watch(model, log_freq=log_freq) def log(self, update_dict, step: int, split: str = "") -> None: update_dict = super().log(update_dict, step, split) diff --git a/src/fairchem/core/common/test_utils.py b/src/fairchem/core/common/test_utils.py index b58f2f934..8aaf82210 100644 --- a/src/fairchem/core/common/test_utils.py +++ b/src/fairchem/core/common/test_utils.py @@ -24,6 +24,7 @@ class ForkedPdb(pdb.Pdb): from fairchem.core.common.test_utils import ForkedPdb ForkedPdb().set_trace() """ + def interaction(self, *args, **kwargs): _stdin = sys.stdin try: @@ -33,6 +34,7 @@ def interaction(self, *args, **kwargs): finally: sys.stdin = _stdin + @dataclass class PGConfig: backend: str @@ -41,6 +43,7 @@ class PGConfig: port: str = "12345" use_gp: bool = True + def spawn_multi_process( config: PGConfig, test_method: callable, @@ -104,6 +107,9 @@ def _init_pg_and_rank_and_launch_test( ) # setup gp if pg_setup_params.use_gp: - config = {"gp_gpus": pg_setup_params.gp_group_size, "distributed_backend": pg_setup_params.backend} + config = { + "gp_gpus": pg_setup_params.gp_group_size, + "distributed_backend": pg_setup_params.backend, + } setup_gp(config) mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme diff --git a/src/fairchem/core/common/tutorial_utils.py b/src/fairchem/core/common/tutorial_utils.py index 7d1a5bf1e..616537996 100644 --- a/src/fairchem/core/common/tutorial_utils.py +++ b/src/fairchem/core/common/tutorial_utils.py @@ -175,7 +175,7 @@ def nested_set(dic, keys, value): nested_set(config, keys, update[_key]) # TODO : Do not rename keys in utils.py when reading a config - config["evaluation_metrics"]=config["eval_metrics"] + config["evaluation_metrics"] = config["eval_metrics"] config.pop("eval_metrics") out = dump(config) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py index e82b5d6e0..8edf81319 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py @@ -402,7 +402,7 @@ def _init_gp_partitions( edge_distance, edge_distance_vec, ): - """ Graph Parallel + """Graph Parallel This creates the required partial tensors for each rank given the full tensors. The tensors are split on the dimension along the node index using node_partition. """ diff --git a/src/fairchem/core/models/equiformer_v2/input_block.py b/src/fairchem/core/models/equiformer_v2/input_block.py index bb03e2b5f..371672521 100644 --- a/src/fairchem/core/models/equiformer_v2/input_block.py +++ b/src/fairchem/core/models/equiformer_v2/input_block.py @@ -79,7 +79,9 @@ def __init__( self.rescale_factor = rescale_factor - def forward(self, atomic_numbers, edge_distance, edge_index, num_nodes, node_offset=0): + def forward( + self, atomic_numbers, edge_distance, edge_index, num_nodes, node_offset=0 + ): if self.use_atom_edge_embedding: source_element = atomic_numbers[edge_index[0]] # Source atom atomic number target_element = atomic_numbers[edge_index[1]] # Target atom atomic number diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index fcf9e0177..e1176d00c 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -864,9 +864,7 @@ def subselect_edges( empty_image = subgraph["num_neighbors"] == 0 if torch.any(empty_image): - raise ValueError( - f"An image has no neighbors: sid={data.sid[empty_image]}" - ) + raise ValueError(f"An image has no neighbors: sid={data.sid[empty_image]}") return subgraph def generate_graph_dict(self, data, cutoff, max_neighbors): diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 92952b805..1daa5007f 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -440,7 +440,9 @@ def load_model(self) -> None: # only "watch" model if user specify watch: True because logging gradients # spews too much data into W&B and makes the UI slow to respond if "watch" in self.config["logger"]: - self.logger.watch(self.model, log_freq = int(self.config["logger"]["watch"])) + self.logger.watch( + self.model, log_freq=int(self.config["logger"]["watch"]) + ) self.logger.log_summary({"num_params": self.model.num_params}) if distutils.initialized() and not self.config["noddp"]: diff --git a/src/fairchem/data/oc/core/__init__.py b/src/fairchem/data/oc/core/__init__.py index 16ed8977c..1cbc64c78 100644 --- a/src/fairchem/data/oc/core/__init__.py +++ b/src/fairchem/data/oc/core/__init__.py @@ -1,5 +1,7 @@ -from .bulk import Bulk -from .slab import Slab +from __future__ import annotations + from .adsorbate import Adsorbate from .adsorbate_slab_config import AdsorbateSlabConfig +from .bulk import Bulk from .multi_adsorbate_slab_config import MultipleAdsorbateSlabConfig +from .slab import Slab diff --git a/src/fairchem/data/oc/core/adsorbate.py b/src/fairchem/data/oc/core/adsorbate.py index 9b388bd9c..afb93dc17 100644 --- a/src/fairchem/data/oc/core/adsorbate.py +++ b/src/fairchem/data/oc/core/adsorbate.py @@ -1,12 +1,15 @@ +from __future__ import annotations + import pickle import warnings -from typing import Any, Dict, Tuple +from typing import TYPE_CHECKING, Any -import ase import numpy as np - from fairchem.data.oc.databases.pkls import ADSORBATE_PKL_PATH +if TYPE_CHECKING: + import ase + class Adsorbate: """ @@ -34,11 +37,11 @@ class Adsorbate: def __init__( self, adsorbate_atoms: ase.Atoms = None, - adsorbate_id_from_db: int = None, - adsorbate_smiles_from_db: str = None, + adsorbate_id_from_db: int | None = None, + adsorbate_smiles_from_db: str | None = None, adsorbate_db_path: str = ADSORBATE_PKL_PATH, - adsorbate_db: Dict[int, Tuple[Any, ...]] = None, - adsorbate_binding_indices: list = None, + adsorbate_db: dict[int, tuple[Any, ...]] | None = None, + adsorbate_binding_indices: list | None = None, ): self.adsorbate_id_from_db = adsorbate_id_from_db self.adsorbate_db_path = adsorbate_db_path @@ -62,27 +65,29 @@ def __init__( ) else: self.binding_indices = adsorbate_binding_indices - elif adsorbate_id_from_db is not None: - adsorbate_db = adsorbate_db or pickle.load(open(adsorbate_db_path, "rb")) - self._load_adsorbate(adsorbate_db[adsorbate_id_from_db]) - elif adsorbate_smiles_from_db is not None: - adsorbate_db = adsorbate_db or pickle.load(open(adsorbate_db_path, "rb")) - adsorbate_obj_tuple = [ - (idx, adsorbate_info) - for idx, adsorbate_info in adsorbate_db.items() - if adsorbate_info[1] == adsorbate_smiles_from_db - ] - if len(adsorbate_obj_tuple) < 1: - warnings.warn( - "An adsorbate with that SMILES string was not found. Choosing one at random instead." - ) - self._get_adsorbate_from_random(adsorbate_db) - else: - self._load_adsorbate(adsorbate_obj_tuple[0][1]) - self.adsorbate_id_from_db = adsorbate_obj_tuple[0][0] else: - adsorbate_db = adsorbate_db or pickle.load(open(adsorbate_db_path, "rb")) - self._get_adsorbate_from_random(adsorbate_db) + if adsorbate_db is None: + with open(adsorbate_db_path, "rb") as fp: + adsorbate_db = pickle.load(fp) + + if adsorbate_id_from_db is not None: + self._load_adsorbate(adsorbate_db[adsorbate_id_from_db]) + elif adsorbate_smiles_from_db is not None: + adsorbate_obj_tuple = [ + (idx, adsorbate_info) + for idx, adsorbate_info in adsorbate_db.items() + if adsorbate_info[1] == adsorbate_smiles_from_db + ] + if len(adsorbate_obj_tuple) < 1: + warnings.warn( + "An adsorbate with that SMILES string was not found. Choosing one at random instead." + ) + self._get_adsorbate_from_random(adsorbate_db) + else: + self._load_adsorbate(adsorbate_obj_tuple[0][1]) + self.adsorbate_id_from_db = adsorbate_obj_tuple[0][0] + else: + self._get_adsorbate_from_random(adsorbate_db) def __len__(self): return len(self.atoms) @@ -100,7 +105,7 @@ def _get_adsorbate_from_random(self, adsorbate_db): self.adsorbate_id_from_db = np.random.randint(len(adsorbate_db)) self._load_adsorbate(adsorbate_db[self.adsorbate_id_from_db]) - def _load_adsorbate(self, adsorbate: Tuple[Any, ...]) -> None: + def _load_adsorbate(self, adsorbate: tuple[Any, ...]) -> None: """ Saves the fields from an adsorbate stored in a database. Fields added after the first revision are conditionally added for backwards @@ -114,7 +119,7 @@ def _load_adsorbate(self, adsorbate: Tuple[Any, ...]) -> None: def randomly_rotate_adsorbate( - adsorbate_atoms: ase.Atoms, mode: str = "random", binding_idx: int = None + adsorbate_atoms: ase.Atoms, mode: str = "random", binding_idx: int | None = None ): assert mode in ["random", "heuristic", "random_site_heuristic_placement"] atoms = adsorbate_atoms.copy() diff --git a/src/fairchem/data/oc/core/adsorbate_slab_config.py b/src/fairchem/data/oc/core/adsorbate_slab_config.py index 2d065e780..f3118cef0 100644 --- a/src/fairchem/data/oc/core/adsorbate_slab_config.py +++ b/src/fairchem/data/oc/core/adsorbate_slab_config.py @@ -1,18 +1,22 @@ +from __future__ import annotations + import copy import logging from itertools import product +from typing import TYPE_CHECKING -import ase import numpy as np import scipy from ase.data import atomic_numbers, covalent_radii from ase.geometry import wrap_positions +from fairchem.data.oc.core.adsorbate import randomly_rotate_adsorbate from pymatgen.analysis.adsorption import AdsorbateSiteFinder from pymatgen.io.ase import AseAtomsAdaptor from scipy.optimize import fsolve -from fairchem.data.oc.core import Adsorbate, Slab -from fairchem.data.oc.core.adsorbate import randomly_rotate_adsorbate +if TYPE_CHECKING: + import ase + from fairchem.data.oc.core.slab import Adsorbate, Slab # warnings.filterwarnings("ignore", "The iteration is not making good progress") @@ -69,7 +73,8 @@ def __init__( mode: str = "random", ): assert mode in ["random", "heuristic", "random_site_heuristic_placement"] - assert interstitial_gap < 5 and interstitial_gap >= 0 + assert interstitial_gap < 5 + assert interstitial_gap >= 0 self.slab = slab self.adsorbate = adsorbate @@ -119,15 +124,16 @@ def get_binding_sites(self, num_sites: int): simplices = dt.simplices # Only keep triangles with at least one vertex in central cell. - pruned_simplices = [] - for tri in simplices: + pruned_simplices = [ + tri + for tri in simplices if np.any( [ tiled_surface_atoms_idx[ver] in unit_surface_atoms_idx for ver in tri ] - ): - pruned_simplices.append(tri) + ) + ] simplices = np.array(pruned_simplices) # Uniformly sample sites on each triangle. @@ -452,7 +458,7 @@ def get_random_sites_on_triangle( + r1_sqrt * (1 - r2) * vertices[1] + r1_sqrt * r2 * vertices[2] ) - return [i for i in sites] + return list(sites) def custom_tile_atoms(atoms: ase.Atoms): diff --git a/src/fairchem/data/oc/core/bulk.py b/src/fairchem/data/oc/core/bulk.py index 4bb10fecb..9568ad362 100644 --- a/src/fairchem/data/oc/core/bulk.py +++ b/src/fairchem/data/oc/core/bulk.py @@ -1,14 +1,17 @@ +from __future__ import annotations + import os import pickle import warnings -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any -import ase import numpy as np - from fairchem.data.oc.core.slab import Slab from fairchem.data.oc.databases.pkls import BULK_PKL_PATH +if TYPE_CHECKING: + import ase + class Bulk: """ @@ -35,10 +38,10 @@ class Bulk: def __init__( self, bulk_atoms: ase.Atoms = None, - bulk_id_from_db: int = None, - bulk_src_id_from_db: str = None, + bulk_id_from_db: int | None = None, + bulk_src_id_from_db: str | None = None, bulk_db_path: str = BULK_PKL_PATH, - bulk_db: List[Dict[str, Any]] = None, + bulk_db: list[dict[str, Any]] | None = None, ): self.bulk_id_from_db = bulk_id_from_db self.bulk_db_path = bulk_db_path @@ -46,29 +49,31 @@ def __init__( if bulk_atoms is not None: self.atoms = bulk_atoms.copy() self.src_id = None - elif bulk_id_from_db is not None: - bulk_db = bulk_db or pickle.load(open(bulk_db_path, "rb")) - bulk_obj = bulk_db[bulk_id_from_db] - self.atoms, self.src_id = bulk_obj["atoms"], bulk_obj["src_id"] - elif bulk_src_id_from_db is not None: - bulk_db = bulk_db or pickle.load(open(bulk_db_path, "rb")) - bulk_obj_tuple = [ - (idx, bulk) - for idx, bulk in enumerate(bulk_db) - if bulk["src_id"] == bulk_src_id_from_db - ] - if len(bulk_obj_tuple) < 1: - warnings.warn( - "A bulk with that src id was not found. Choosing one at random instead" - ) - self._get_bulk_from_random(bulk_db) - else: - bulk_obj = bulk_obj_tuple[0][1] - self.bulk_id_from_db = bulk_obj_tuple[0][0] - self.atoms, self.src_id = bulk_obj["atoms"], bulk_obj["src_id"] else: - bulk_db = bulk_db or pickle.load(open(bulk_db_path, "rb")) - self._get_bulk_from_random(bulk_db) + if bulk_db is None: + with open(bulk_db_path, "rb") as fp: + bulk_db = pickle.load(fp) + + if bulk_id_from_db is not None: + bulk_obj = bulk_db[bulk_id_from_db] + self.atoms, self.src_id = bulk_obj["atoms"], bulk_obj["src_id"] + elif bulk_src_id_from_db is not None: + bulk_obj_tuple = [ + (idx, bulk) + for idx, bulk in enumerate(bulk_db) + if bulk["src_id"] == bulk_src_id_from_db + ] + if len(bulk_obj_tuple) < 1: + warnings.warn( + "A bulk with that src id was not found. Choosing one at random instead" + ) + self._get_bulk_from_random(bulk_db) + else: + bulk_obj = bulk_obj_tuple[0][1] + self.bulk_id_from_db = bulk_obj_tuple[0][0] + self.atoms, self.src_id = bulk_obj["atoms"], bulk_obj["src_id"] + else: + self._get_bulk_from_random(bulk_db) def _get_bulk_from_random(self, bulk_db): self.bulk_id_from_db = np.random.randint(len(bulk_db)) diff --git a/src/fairchem/data/oc/core/multi_adsorbate_slab_config.py b/src/fairchem/data/oc/core/multi_adsorbate_slab_config.py index 64ff33ccb..1b05fbed7 100644 --- a/src/fairchem/data/oc/core/multi_adsorbate_slab_config.py +++ b/src/fairchem/data/oc/core/multi_adsorbate_slab_config.py @@ -1,10 +1,15 @@ -from typing import List +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np from ase import Atoms from ase.data import covalent_radii +from fairchem.data.oc.core.adsorbate_slab_config import AdsorbateSlabConfig -from fairchem.data.oc.core import Adsorbate, AdsorbateSlabConfig, Slab +if TYPE_CHECKING: + from fairchem.data.oc.core.adsorbate import Adsorbate + from fairchem.data.oc.core.slab import Slab class MultipleAdsorbateSlabConfig(AdsorbateSlabConfig): @@ -57,14 +62,15 @@ class MultipleAdsorbateSlabConfig(AdsorbateSlabConfig): def __init__( self, slab: Slab, - adsorbates: List[Adsorbate], + adsorbates: list[Adsorbate], num_sites: int = 100, num_configurations: int = 1, interstitial_gap: float = 0.1, mode: str = "random_site_heuristic_placement", ): assert mode in ["random", "heuristic", "random_site_heuristic_placement"] - assert interstitial_gap < 5 and interstitial_gap >= 0 + assert interstitial_gap < 5 + assert interstitial_gap >= 0 self.slab = slab self.adsorbates = adsorbates @@ -142,7 +148,7 @@ def place_adsorbates_on_sites( pseudo_atoms, ) - for idx, adsorbate in enumerate(self.adsorbates[1:]): + for _idx, adsorbate in enumerate(self.adsorbates[1:]): binding_idx = adsorbate.binding_indices[0] binding_atom = adsorbate.atoms.get_atomic_numbers()[binding_idx] covalent_radius = covalent_radii[binding_atom] @@ -226,6 +232,4 @@ def update_distance_map(prev_distance_map, site_idx, adsorbate, pseudo_atoms): # update previous distance mapping by taking the minimum per-element distance between # the new distance mapping for the placed site and the previous mapping. - updated_distance_map = np.minimum(prev_distance_map, new_site_distances) - - return updated_distance_map + return np.minimum(prev_distance_map, new_site_distances) diff --git a/src/fairchem/data/oc/core/slab.py b/src/fairchem/data/oc/core/slab.py index 913ec4bd4..3ee382838 100644 --- a/src/fairchem/data/oc/core/slab.py +++ b/src/fairchem/data/oc/core/slab.py @@ -1,14 +1,15 @@ +from __future__ import annotations + import math import os import pickle from collections import defaultdict +from typing import TYPE_CHECKING -import ase import numpy as np from ase.constraints import FixAtoms from pymatgen.analysis.local_env import VoronoiNN from pymatgen.core.composition import Composition -from pymatgen.core.structure import Structure from pymatgen.core.surface import ( SlabGenerator, get_symmetrically_distinct_miller_indices, @@ -16,6 +17,10 @@ from pymatgen.io.ase import AseAtomsAdaptor from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +if TYPE_CHECKING: + import ase + from pymatgen.core.structure import Structure + class Slab: """ @@ -45,9 +50,9 @@ def __init__( self, bulk=None, slab_atoms: ase.Atoms = None, - millers: tuple = None, - shift: float = None, - top: bool = None, + millers: tuple | None = None, + shift: float | None = None, + top: bool | None = None, oriented_bulk: Structure = None, min_ab: float = 0.8, ): @@ -64,10 +69,8 @@ def __init__( Composition(self.atoms.get_chemical_formula()).reduced_formula == Composition(bulk.atoms.get_chemical_formula()).reduced_formula ), "Mismatched bulk and surface" - assert ( - np.linalg.norm(self.atoms.cell[0]) >= min_ab - and np.linalg.norm(self.atoms.cell[1]) >= min_ab - ), "Slab not tiled" + assert np.linalg.norm(self.atoms.cell[0]) >= min_ab, "Slab not tiled" + assert np.linalg.norm(self.atoms.cell[1]) >= min_ab, "Slab not tiled" assert self.has_surface_tagged(), "Slab not tagged" assert len(self.atoms.constraints) > 0, "Sub-surface atoms not constrained" @@ -100,17 +103,14 @@ def from_bulk_get_random_slab( def from_bulk_get_specific_millers( cls, specific_millers, bulk=None, min_ab=8.0, save_path=None ): - assert type(specific_millers) == tuple + assert isinstance(specific_millers, tuple) assert len(specific_millers) == 3 if save_path is not None: all_slabs = Slab.from_bulk_get_all_slabs( bulk, max(np.abs(specific_millers)), min_ab, save_path ) - slabs_with_millers = [ - slab for slab in all_slabs if slab.millers == specific_millers - ] - return slabs_with_millers + return [slab for slab in all_slabs if slab.millers == specific_millers] else: # If we're not saving all slabs, just tile and tag those with correct millers assert bulk is not None @@ -119,18 +119,16 @@ def from_bulk_get_specific_millers( max_miller=max(np.abs(specific_millers)), specific_millers=[specific_millers], ) - slabs = [] - for s in untiled_slabs: - slabs.append( - ( - tile_and_tag_atoms(s[0], bulk.atoms, min_ab=min_ab), - s[1], - s[2], - s[3], - s[4], - ) + slabs = [ + ( + tile_and_tag_atoms(s[0], bulk.atoms, min_ab=min_ab), + s[1], + s[2], + s[3], + s[4], ) - + for s in untiled_slabs + ] return [cls(bulk, s[0], s[1], s[2], s[3], s[4]) for s in slabs] @classmethod @@ -143,17 +141,16 @@ def from_bulk_get_all_slabs( bulk.atoms, max_miller=max_miller, ) - slabs = [] - for s in untiled_slabs: - slabs.append( - ( - tile_and_tag_atoms(s[0], bulk.atoms, min_ab=min_ab), - s[1], - s[2], - s[3], - s[4], - ) + slabs = [ + ( + tile_and_tag_atoms(s[0], bulk.atoms, min_ab=min_ab), + s[1], + s[2], + s[3], + s[4], ) + for s in untiled_slabs + ] # if path is provided, save out the pkl if save_path is not None: @@ -172,11 +169,11 @@ def from_precomputed_slabs_pkl( min_ab=8.0, ): assert bulk is not None - assert precomputed_slabs_pkl is not None and os.path.exists( - precomputed_slabs_pkl - ) + assert precomputed_slabs_pkl is not None + assert os.path.exists(precomputed_slabs_pkl) - slabs = pickle.load(open(precomputed_slabs_pkl, "rb")) + with open(precomputed_slabs_pkl, "rb") as fp: + slabs = pickle.load(fp) is_slab_obj = np.all([isinstance(s, Slab) for s in slabs]) if is_slab_obj: @@ -280,7 +277,7 @@ def set_fixed_atom_constraints(atoms): # list should contain a `True` if we want an atom to be constrained, and # `False` otherwise. atoms = atoms.copy() - mask = [True if atom.tag == 0 else False for atom in atoms] + mask = [atom.tag == 0 for atom in atoms] atoms.constraints += [FixAtoms(mask=mask)] return atoms @@ -347,8 +344,7 @@ def tile_atoms(atoms: ase.Atoms, min_ab: float = 8): na = int(math.ceil(min_ab / a_length)) nb = int(math.ceil(min_ab / b_length)) n_abc = (na, nb, 1) - atoms_tiled = atoms.repeat(n_abc) - return atoms_tiled + return atoms.repeat(n_abc) def find_surface_atoms_by_height(surface_atoms): @@ -377,11 +373,10 @@ def find_surface_atoms_by_height(surface_atoms): scaled_max_height = max(scaled_position[2] for scaled_position in scaled_positions) scaled_threshold = scaled_max_height - 2.0 / unit_cell_height - tags = [ + return [ 0 if scaled_position[2] < scaled_threshold else 1 for scaled_position in scaled_positions ] - return tags def find_surface_atoms_with_voronoi_given_height(bulk_atoms, slab_atoms, height_tags): @@ -442,8 +437,7 @@ def calculate_center_of_mass(struct): Calculates the center of mass of the slab. """ weights = [site.species.weight for site in struct] - center_of_mass = np.average(struct.frac_coords, weights=weights, axis=0) - return center_of_mass + return np.average(struct.frac_coords, weights=weights, axis=0) def calculate_coordination_of_bulk_atoms(bulk_atoms): @@ -485,7 +479,7 @@ def calculate_coordination_of_bulk_atoms(bulk_atoms): def compute_slabs( bulk_atoms: ase.Atoms = None, max_miller: int = 2, - specific_millers: list = None, + specific_millers: list | None = None, ): """ Enumerates all the symmetrically distinct slabs of a bulk structure. @@ -641,5 +635,4 @@ def standardize_bulk(atoms: ase.Atoms): """ struct = AseAtomsAdaptor.get_structure(atoms) sga = SpacegroupAnalyzer(struct, symprec=0.1) - standardized_struct = sga.get_conventional_standard_structure() - return standardized_struct + return sga.get_conventional_standard_structure() diff --git a/src/fairchem/data/oc/databases/pkls/__init__.py b/src/fairchem/data/oc/databases/pkls/__init__.py index affb2c698..b58ecbdb8 100644 --- a/src/fairchem/data/oc/databases/pkls/__init__.py +++ b/src/fairchem/data/oc/databases/pkls/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os BULK_PKL_PATH = os.path.join(__path__[0], "bulks.pkl") diff --git a/src/fairchem/data/oc/databases/update.py b/src/fairchem/data/oc/databases/update.py index ed2622d44..f9ca1f645 100644 --- a/src/fairchem/data/oc/databases/update.py +++ b/src/fairchem/data/oc/databases/update.py @@ -2,6 +2,9 @@ Script for updating ase pkl and db files from v3.19 to v3.21. Run it with ase v3.19. """ + +from __future__ import annotations + import pickle import ase.io @@ -25,27 +28,26 @@ def set_pbc_patch(self, pbc): def update_pkls(): - data = pickle.load( - open( - "ocdata/databases/pkls/adsorbates.pkl", - "rb", - ) - ) + with open( + "ocdata/databases/pkls/adsorbates.pkl", + "rb", + ) as fp: + data = pickle.load(fp) + for idx in data: pbc = data[idx][0].cell._pbc data[idx][0]._pbc = pbc with open( "ocdata/databases/pkls/adsorbates_new.pkl", "wb", - ) as f: - pickle.dump(data, f) + ) as fp: + pickle.dump(data, fp) - data = pickle.load( - open( - "ocdata/databases/pkls/bulks.pkl", - "rb", - ) - ) + with open( + "ocdata/databases/pkls/bulks.pkl", + "rb", + ) as fp: + data = pickle.load(fp) bulks = [] for info in tqdm(data): diff --git a/src/fairchem/data/oc/scripts/precompute_sample_structures.py b/src/fairchem/data/oc/scripts/precompute_sample_structures.py index 490d0f845..cd5b7acf6 100644 --- a/src/fairchem/data/oc/scripts/precompute_sample_structures.py +++ b/src/fairchem/data/oc/scripts/precompute_sample_structures.py @@ -6,6 +6,8 @@ [GASpy](https://github.com/ulissigroup/GASpy) with permission of author. """ +from __future__ import annotations + __authors__ = ["Kevin Tran", "Aini Palizhati", "Siddharth Goyal", "Zachary Ulissi"] __email__ = ["ktran@andrew.cmu.edu"] @@ -96,8 +98,7 @@ def standardize_bulk(atoms): """ struct = AseAtomsAdaptor.get_structure(atoms) sga = SpacegroupAnalyzer(struct, symprec=0.1) - standardized_struct = sga.get_conventional_standard_structure() - return standardized_struct + return sga.get_conventional_standard_structure() def is_structure_invertible(structure): @@ -151,8 +152,7 @@ def flip_struct(struct): atoms.cell[1] = -atoms.cell[1] atoms.wrap() - flipped_struct = AseAtomsAdaptor.get_structure(atoms) - return flipped_struct + return AseAtomsAdaptor.get_structure(atoms) def precompute_enumerate_surface(bulk_database, bulk_index, opfile): diff --git a/src/fairchem/data/oc/structure_generator.py b/src/fairchem/data/oc/structure_generator.py index f2638ebdf..bc6eea11d 100644 --- a/src/fairchem/data/oc/structure_generator.py +++ b/src/fairchem/data/oc/structure_generator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import logging import multiprocessing as mp @@ -7,10 +9,9 @@ import traceback import numpy as np -from tqdm import tqdm - from fairchem.data.oc.core import Adsorbate, AdsorbateSlabConfig, Bulk from fairchem.data.oc.utils.vasp import write_vasp_input_files +from tqdm import tqdm class StructureGenerator: @@ -395,7 +396,7 @@ def run_placements(inputs): args = parse_args() if args.indices_file: - with open(args.indices_file, "r") as f: + with open(args.indices_file) as f: all_indices = f.read().splitlines() chunks = np.array_split(all_indices, args.chunks) inds_to_run = chunks[args.chunk_index] @@ -416,7 +417,7 @@ def run_placements(inputs): print("Placements successfully generated!") elif args.bulk_indices_file: - with open(args.bulk_indices_file, "r") as f: + with open(args.bulk_indices_file) as f: all_indices = f.read().splitlines() chunks = np.array_split(all_indices, args.chunks) diff --git a/src/fairchem/data/oc/utils/__init__.py b/src/fairchem/data/oc/utils/__init__.py index c30147489..a76c10974 100644 --- a/src/fairchem/data/oc/utils/__init__.py +++ b/src/fairchem/data/oc/utils/__init__.py @@ -1 +1,3 @@ +from __future__ import annotations + from .flag_anomaly import DetectTrajAnomaly diff --git a/src/fairchem/data/oc/utils/flag_anomaly.py b/src/fairchem/data/oc/utils/flag_anomaly.py index c4adaec75..f8d67a604 100644 --- a/src/fairchem/data/oc/utils/flag_anomaly.py +++ b/src/fairchem/data/oc/utils/flag_anomaly.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from ase import neighborlist from ase.neighborlist import natural_cutoffs @@ -116,8 +118,7 @@ def _get_connectivity(self, atoms, cutoff_multiplier=1.0): cutoff, self_interaction=False, bothways=True ) ase_neighbor_list.update(atoms) - matrix = neighborlist.get_connectivity_matrix(ase_neighbor_list.nl).toarray() - return matrix + return neighborlist.get_connectivity_matrix(ase_neighbor_list.nl).toarray() def is_adsorbate_intercalated(self): """ diff --git a/src/fairchem/data/oc/utils/vasp.py b/src/fairchem/data/oc/utils/vasp.py index ea18cc619..2ceafbee3 100644 --- a/src/fairchem/data/oc/utils/vasp.py +++ b/src/fairchem/data/oc/utils/vasp.py @@ -5,6 +5,8 @@ [GASpy](https://github.com/ulissigroup/GASpy) with permission of authors. """ +from __future__ import annotations + __author__ = "Kevin Tran" __email__ = "ktran@andrew.cmu.edu" @@ -70,7 +72,7 @@ def _clean_up_inputs(atoms, vasp_flags): atoms.set_cell(atoms.cell[[1, 0, 2], :]) # Calculate and set the k points - if "kpts" not in vasp_flags.keys(): + if "kpts" not in vasp_flags: k_pts = calculate_surface_k_points(atoms) vasp_flags["kpts"] = k_pts @@ -92,12 +94,11 @@ def calculate_surface_k_points(atoms): a0 = np.linalg.norm(cell[0], ord=order) b0 = np.linalg.norm(cell[1], ord=order) multiplier = 40 - k_pts = ( + return ( max(1, int(round(multiplier / a0))), max(1, int(round(multiplier / b0))), 1, ) - return k_pts def write_vasp_input_files(atoms, outdir=".", vasp_flags=None): diff --git a/src/fairchem/data/odac/force_field/FF_analysis.py b/src/fairchem/data/odac/force_field/FF_analysis.py index 5b7720cf8..f792ce123 100644 --- a/src/fairchem/data/odac/force_field/FF_analysis.py +++ b/src/fairchem/data/odac/force_field/FF_analysis.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import numpy as np @@ -58,21 +60,15 @@ def binned_average(DFT_ads, pred_err, bins): bin0 = -1000 avgs = [] for i, bin in enumerate(bins): - if i == 0: - left = bin0 - else: - left = bins[i - 1] + left = bin0 if i == 0 else bins[i - 1] bin_errs = [] for DFT, pred in zip( DFT_ads, pred_err ): # this is a horribly inefficient way to do this... - if DFT > left and DFT < bin: + if left < DFT and bin > DFT: bin_errs.append(pred) - if bin_errs: - bin_avg = np.mean(bin_errs) - else: - bin_avg = 0 + bin_avg = np.mean(bin_errs) if bin_errs else 0 avgs.append(bin_avg) return avgs @@ -290,20 +286,16 @@ def chem_err(DFT, FF): get_Fig4c(DFT_CO2, err_CO2) get_Fig4d(DFT_H2O, err_H2O) - print("Overall MAE: {} eV".format(np.mean(err_CO2 + err_H2O))) - print("CO2 error: {} eV".format(np.mean(err_CO2))) - print("H2O error: {} eV".format(np.mean(err_H2O))) + print(f"Overall MAE: {np.mean(err_CO2 + err_H2O)} eV") + print(f"CO2 error: {np.mean(err_CO2)} eV") + print(f"H2O error: {np.mean(err_H2O)} eV") print( - "Overall physisorption error: {} eV".format( - phys_err(DFT_CO2 + DFT_H2O, FF_CO2 + FF_H2O) - ) + f"Overall physisorption error: {phys_err(DFT_CO2 + DFT_H2O, FF_CO2 + FF_H2O)} eV" ) - print("CO2 physisorption error: {} eV".format(phys_err(DFT_CO2, FF_CO2))) - print("H2O physisorption error: {} eV".format(phys_err(DFT_H2O, FF_H2O))) + print(f"CO2 physisorption error: {phys_err(DFT_CO2, FF_CO2)} eV") + print(f"H2O physisorption error: {phys_err(DFT_H2O, FF_H2O)} eV") print( - "Overall chemisorption error: {} eV".format( - chem_err(DFT_CO2 + DFT_H2O, FF_CO2 + FF_H2O) - ) + f"Overall chemisorption error: {chem_err(DFT_CO2 + DFT_H2O, FF_CO2 + FF_H2O)} eV" ) - print("CO2 chemisorption error: {} eV".format(chem_err(DFT_CO2, FF_CO2))) - print("H2O chemisorption error: {} eV".format(chem_err(DFT_H2O, FF_H2O))) + print(f"CO2 chemisorption error: {chem_err(DFT_CO2, FF_CO2)} eV") + print(f"H2O chemisorption error: {chem_err(DFT_H2O, FF_H2O)} eV") diff --git a/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py b/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py index 3069d15f9..6a9d37924 100644 --- a/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py +++ b/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py @@ -1,451 +1,898 @@ -import pandas as pd -import numpy as np -import os -import matplotlib.pyploat as plt -from pymatgen.core import Lattice, Structure, Molecule, Element -import seaborn as sns -import filecmp +from __future__ import annotations +import matplotlib.pyploat as plt +import pandas as pd -raw_ads_energy_data=pd.read_csv('adsorption_energy.txt',header=None,sep=' ') -complete_data=pd.DataFrame(index=range(raw_ads_energy_data.shape[0]),columns=['MOF','defect_conc','defect_index','n_CO2','n_H2O','configuration_index','ads_energy_ev'])#,'LCD','PLD','metal','OMS']) +raw_ads_energy_data = pd.read_csv("adsorption_energy.txt", header=None, sep=" ") +complete_data = pd.DataFrame( + index=range(raw_ads_energy_data.shape[0]), + columns=[ + "MOF", + "defect_conc", + "defect_index", + "n_CO2", + "n_H2O", + "configuration_index", + "ads_energy_ev", + ], +) # ,'LCD','PLD','metal','OMS']) for i in range(raw_ads_energy_data.shape[0]): - temp_split_string=raw_ads_energy_data.iloc[i,0].split('_w_') - temp_0_parts = temp_split_string[0].rsplit('_', 2) - #non-defective - if len(temp_0_parts)<3: - complete_data.iloc[i,0]=temp_split_string[0] - complete_data.iloc[i,1]=None - complete_data.iloc[i,2]=None - elif len(temp_0_parts)==3: - if temp_0_parts[-1] in ['0', '1' , '2' , '3'] and float(temp_0_parts[1])<0.21: - #defective - complete_data.iloc[i,0]=temp_0_parts[0] - complete_data.iloc[i,1]=temp_0_parts[1] - complete_data.iloc[i,2]=temp_0_parts[2] + temp_split_string = raw_ads_energy_data.iloc[i, 0].split("_w_") + temp_0_parts = temp_split_string[0].rsplit("_", 2) + # non-defective + if len(temp_0_parts) < 3: + complete_data.iloc[i, 0] = temp_split_string[0] + complete_data.iloc[i, 1] = None + complete_data.iloc[i, 2] = None + elif len(temp_0_parts) == 3: + if temp_0_parts[-1] in ["0", "1", "2", "3"] and float(temp_0_parts[1]) < 0.21: + # defective + complete_data.iloc[i, 0] = temp_0_parts[0] + complete_data.iloc[i, 1] = temp_0_parts[1] + complete_data.iloc[i, 2] = temp_0_parts[2] else: - complete_data.iloc[i,0]=temp_split_string[0] - complete_data.iloc[i,1]=None - complete_data.iloc[i,2]=None - - - - if temp_split_string[1].split('_')[-2]=='random': - complete_data.iloc[i,-2]='random_'+temp_split_string[1].split('_')[-1] -# elif temp_split_string[1].split('_')[-1]=='new': -# complete_data.iloc[i,-2]='new_'+temp_split_string[1].split('_')[-1] + complete_data.iloc[i, 0] = temp_split_string[0] + complete_data.iloc[i, 1] = None + complete_data.iloc[i, 2] = None + + if temp_split_string[1].split("_")[-2] == "random": + complete_data.iloc[i, -2] = "random_" + temp_split_string[1].split("_")[-1] + # elif temp_split_string[1].split('_')[-1]=='new': + # complete_data.iloc[i,-2]='new_'+temp_split_string[1].split('_')[-1] else: - complete_data.iloc[i,-2]=temp_split_string[1].split('_')[-1] - -# if len(temp_split_string[1].split('_'))==2: -# if temp_split_string[1].split('_')[0]=='CO2': -# complete_data.iloc[i,2]=0 -# elif temp_split_string[1].split('_')[0]=='H2O': -# complete_data.iloc[i,1]=0 - - if temp_split_string[1].rsplit('_',1)[0]=='CO2' or temp_split_string[1].rsplit('_',1)[0]=='CO2_random': - complete_data.iloc[i,3]=1 - complete_data.iloc[i,4]=0 - elif temp_split_string[1].rsplit('_',1)[0]=='H2O' or temp_split_string[1].rsplit('_',1)[0]=='H2O_random' or temp_split_string[1].rsplit('_',1)[0]=='H2O_new': - complete_data.iloc[i,3]=0 - complete_data.iloc[i,4]=1 - elif temp_split_string[1].rsplit('_',1)[0]=='CO2_H2O' or temp_split_string[1].rsplit('_',1)[0]=='CO2_H2O_random': - complete_data.iloc[i,3]=1 - complete_data.iloc[i,4]=1 - #co2+2h2o - elif temp_split_string[1].rsplit('_',1)[0]=='CO2_2H2O' or temp_split_string[1].rsplit('_',1)[0]=='CO2_2H2O_random': - complete_data.iloc[i,3]=1 - complete_data.iloc[i,4]=2 + complete_data.iloc[i, -2] = temp_split_string[1].split("_")[-1] + + # if len(temp_split_string[1].split('_'))==2: + # if temp_split_string[1].split('_')[0]=='CO2': + # complete_data.iloc[i,2]=0 + # elif temp_split_string[1].split('_')[0]=='H2O': + # complete_data.iloc[i,1]=0 + + if ( + temp_split_string[1].rsplit("_", 1)[0] == "CO2" + or temp_split_string[1].rsplit("_", 1)[0] == "CO2_random" + ): + complete_data.iloc[i, 3] = 1 + complete_data.iloc[i, 4] = 0 + elif ( + temp_split_string[1].rsplit("_", 1)[0] == "H2O" + or temp_split_string[1].rsplit("_", 1)[0] == "H2O_random" + or temp_split_string[1].rsplit("_", 1)[0] == "H2O_new" + ): + complete_data.iloc[i, 3] = 0 + complete_data.iloc[i, 4] = 1 + elif ( + temp_split_string[1].rsplit("_", 1)[0] == "CO2_H2O" + or temp_split_string[1].rsplit("_", 1)[0] == "CO2_H2O_random" + ): + complete_data.iloc[i, 3] = 1 + complete_data.iloc[i, 4] = 1 + # co2+2h2o + elif ( + temp_split_string[1].rsplit("_", 1)[0] == "CO2_2H2O" + or temp_split_string[1].rsplit("_", 1)[0] == "CO2_2H2O_random" + ): + complete_data.iloc[i, 3] = 1 + complete_data.iloc[i, 4] = 2 else: print(temp_split_string) - - complete_data.iloc[i,-1]=raw_ads_energy_data.iloc[i,1] - -complete_data['Name']=raw_ads_energy_data.iloc[:,0] + complete_data.iloc[i, -1] = raw_ads_energy_data.iloc[i, 1] -###set 2eV bound -complete_data=complete_data[(complete_data['ads_energy_ev']/(complete_data['n_CO2']+complete_data['n_H2O']))<2].copy() +complete_data["Name"] = raw_ads_energy_data.iloc[:, 0] -#split into pristine and defective here -complete_data_merged_pristine=complete_data[complete_data['defect_conc'].isnull()].copy() -complete_data_merged_pristine=complete_data.reset_index(drop=True) -complete_data_merged_defective=complete_data[complete_data['defect_conc'].notnull()].copy() -complete_data_merged_defective=complete_data.reset_index(drop=True) -complete_data_merged_defective['defective_MOF_name']=complete_data_merged_defective['MOF']+"_"+complete_data_merged_defective['defect_conc']+"_"+complete_data_merged_defective['defect_index'] +###set 2eV bound +complete_data = complete_data[ + (complete_data["ads_energy_ev"] / (complete_data["n_CO2"] + complete_data["n_H2O"])) + < 2 +].copy() + + +# split into pristine and defective here +complete_data_merged_pristine = complete_data[ + complete_data["defect_conc"].isnull() +].copy() +complete_data_merged_pristine = complete_data.reset_index(drop=True) +complete_data_merged_defective = complete_data[ + complete_data["defect_conc"].notnull() +].copy() +complete_data_merged_defective = complete_data.reset_index(drop=True) +complete_data_merged_defective["defective_MOF_name"] = ( + complete_data_merged_defective["MOF"] + + "_" + + complete_data_merged_defective["defect_conc"] + + "_" + + complete_data_merged_defective["defect_index"] +) ####get the lowest energy -complete_data_merged_pristine_co2=complete_data_merged_pristine[(complete_data_merged_pristine['n_CO2']==1) & (complete_data_merged_pristine['n_H2O']==0)].copy() -complete_data_merged_pristine_h2o=complete_data_merged_pristine[(complete_data_merged_pristine['n_CO2']==0) & (complete_data_merged_pristine['n_H2O']==1)].copy() -complete_data_merged_pristine_co_ads=complete_data_merged_pristine[(complete_data_merged_pristine['n_CO2']==1) & (complete_data_merged_pristine['n_H2O']==1)].copy() -complete_data_merged_pristine_co_ads_2=complete_data_merged_pristine[(complete_data_merged_pristine['n_CO2']==1) & (complete_data_merged_pristine['n_H2O']==2)].copy() -complete_data_merged_defective_co2=complete_data_merged_defective[(complete_data_merged_defective['n_CO2']==1) & (complete_data_merged_defective['n_H2O']==0)].copy() -complete_data_merged_defective_h2o=complete_data_merged_defective[(complete_data_merged_defective['n_CO2']==0) & (complete_data_merged_defective['n_H2O']==1)].copy() -complete_data_merged_defective_co_ads=complete_data_merged_defective[(complete_data_merged_defective['n_CO2']==1) & (complete_data_merged_defective['n_H2O']==1)].copy() -complete_data_merged_defective_co_ads_2=complete_data_merged_defective[(complete_data_merged_defective['n_CO2']==1) & (complete_data_merged_defective['n_H2O']==2)].copy() - -lowest_energy_data_co2=pd.DataFrame(columns=complete_data_merged_pristine_co2.columns) +complete_data_merged_pristine_co2 = complete_data_merged_pristine[ + (complete_data_merged_pristine["n_CO2"] == 1) + & (complete_data_merged_pristine["n_H2O"] == 0) +].copy() +complete_data_merged_pristine_h2o = complete_data_merged_pristine[ + (complete_data_merged_pristine["n_CO2"] == 0) + & (complete_data_merged_pristine["n_H2O"] == 1) +].copy() +complete_data_merged_pristine_co_ads = complete_data_merged_pristine[ + (complete_data_merged_pristine["n_CO2"] == 1) + & (complete_data_merged_pristine["n_H2O"] == 1) +].copy() +complete_data_merged_pristine_co_ads_2 = complete_data_merged_pristine[ + (complete_data_merged_pristine["n_CO2"] == 1) + & (complete_data_merged_pristine["n_H2O"] == 2) +].copy() +complete_data_merged_defective_co2 = complete_data_merged_defective[ + (complete_data_merged_defective["n_CO2"] == 1) + & (complete_data_merged_defective["n_H2O"] == 0) +].copy() +complete_data_merged_defective_h2o = complete_data_merged_defective[ + (complete_data_merged_defective["n_CO2"] == 0) + & (complete_data_merged_defective["n_H2O"] == 1) +].copy() +complete_data_merged_defective_co_ads = complete_data_merged_defective[ + (complete_data_merged_defective["n_CO2"] == 1) + & (complete_data_merged_defective["n_H2O"] == 1) +].copy() +complete_data_merged_defective_co_ads_2 = complete_data_merged_defective[ + (complete_data_merged_defective["n_CO2"] == 1) + & (complete_data_merged_defective["n_H2O"] == 2) +].copy() + +lowest_energy_data_co2 = pd.DataFrame(columns=complete_data_merged_pristine_co2.columns) for i in range(complete_data_merged_pristine_co2.shape[0]): - - - current_entry=complete_data_merged_pristine_co2.iloc[i,:] - current_MOF=complete_data_merged_pristine_co2.iloc[i,0] - current_n_CO2=complete_data_merged_pristine_co2.iloc[i,3] - current_n_H2O=complete_data_merged_pristine_co2.iloc[i,4] - current_configuration_index=complete_data_merged_pristine_co2.iloc[i,5] - current_lowest_energy=complete_data_merged_pristine_co2.iloc[i,6] - current_name=complete_data_merged_pristine_co2.iloc[i,7] - #if this case is not included - - if lowest_energy_data_co2[(lowest_energy_data_co2['MOF']==current_MOF) \ - & (lowest_energy_data_co2['defect_conc'].isnull()) & (lowest_energy_data_co2['defect_index'].isnull())].empty: - lowest_energy_data_co2=lowest_energy_data_co2.append(current_entry) - #if this case is already included + current_entry = complete_data_merged_pristine_co2.iloc[i, :] + current_MOF = complete_data_merged_pristine_co2.iloc[i, 0] + current_n_CO2 = complete_data_merged_pristine_co2.iloc[i, 3] + current_n_H2O = complete_data_merged_pristine_co2.iloc[i, 4] + current_configuration_index = complete_data_merged_pristine_co2.iloc[i, 5] + current_lowest_energy = complete_data_merged_pristine_co2.iloc[i, 6] + current_name = complete_data_merged_pristine_co2.iloc[i, 7] + # if this case is not included + + if lowest_energy_data_co2[ + (lowest_energy_data_co2["MOF"] == current_MOF) + & (lowest_energy_data_co2["defect_conc"].isnull()) + & (lowest_energy_data_co2["defect_index"].isnull()) + ].empty: + lowest_energy_data_co2 = lowest_energy_data_co2.append(current_entry) + # if this case is already included else: - #find the index of the this case's entry - index_this_case=lowest_energy_data_co2[(lowest_energy_data_co2['MOF']==current_MOF) \ - & (lowest_energy_data_co2['defect_conc'].isnull()) & (lowest_energy_data_co2['defect_index'].isnull())].index[0] - if current_lowest_energy None: """ Args: @@ -76,7 +78,7 @@ def __init__( to call the API should be made. """ super().__init__(method=method, url=url, cause="Exceeded rate limit") - self.retry_after: Optional[timedelta] = retry_after + self.retry_after: timedelta | None = retry_after class Client: @@ -167,7 +169,7 @@ async def get_adsorbates(self) -> Adsorbates: ) return Adsorbates.from_json(response) - async def get_slabs(self, bulk: Union[str, Bulk]) -> Slabs: + async def get_slabs(self, bulk: str | Bulk) -> Slabs: """ Get a unique list of slabs for the input bulk structure. @@ -235,7 +237,7 @@ async def get_adsorbate_slab_configs( async def submit_adsorbate_slab_relaxations( self, adsorbate: str, - adsorbate_configs: List[Atoms], + adsorbate_configs: list[Atoms], bulk: Bulk, slab: Slab, model: str, @@ -318,8 +320,8 @@ async def get_adsorbate_slab_relaxations_request( async def get_adsorbate_slab_relaxations_results( self, system_id: str, - config_ids: Optional[List[int]] = None, - fields: Optional[List[str]] = None, + config_ids: list[int] | None = None, + fields: list[str] | None = None, ) -> AdsorbateSlabRelaxationsResults: """ Fetches relaxation results for the input system. @@ -342,7 +344,7 @@ async def get_adsorbate_slab_relaxations_results( Returns: The relaxation results for each configuration in the system. """ - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if fields: params["field"] = fields if config_ids: @@ -415,7 +417,7 @@ async def _run_request(self, path: str, method: str, **kwargs) -> str: if response.status_code >= 300: # Exceeded server side rate limit if response.status_code == 429: - retry_after: Optional[str] = response.headers.get("Retry-After", None) + retry_after: str | None = response.headers.get("Retry-After", None) raise RateLimitExceededException( method=method, url=url, diff --git a/src/fairchem/demo/ocpapi/client/models.py b/src/fairchem/demo/ocpapi/client/models.py index 3e8a96438..37666c473 100644 --- a/src/fairchem/demo/ocpapi/client/models.py +++ b/src/fairchem/demo/ocpapi/client/models.py @@ -1,3 +1,4 @@ +# dataclasses_json breaks if using __future__.annotations, so keep 3.9 typing annotations for compatibility from dataclasses import dataclass, field from enum import Enum from typing import List, Optional, Tuple diff --git a/src/fairchem/demo/ocpapi/client/ui.py b/src/fairchem/demo/ocpapi/client/ui.py index 4d0f63d68..51310bed3 100644 --- a/src/fairchem/demo/ocpapi/client/ui.py +++ b/src/fairchem/demo/ocpapi/client/ui.py @@ -1,12 +1,12 @@ -from typing import Dict, Optional +from __future__ import annotations # Map of known API hosts to UI hosts -_API_TO_UI_HOSTS: Dict[str, str] = { +_API_TO_UI_HOSTS: dict[str, str] = { "open-catalyst-api.metademolab.com": "open-catalyst.metademolab.com", } -def get_results_ui_url(api_host: str, system_id: str) -> Optional[str]: +def get_results_ui_url(api_host: str, system_id: str) -> str | None: """ Generates the URL at which results for the input system can be visualized. @@ -19,6 +19,6 @@ def get_results_ui_url(api_host: str, system_id: str) -> Optional[str]: The URL at which the input system can be visualized. None if the API host is not recognized. """ - if ui_host := _API_TO_UI_HOSTS.get(api_host, None): + if ui_host := _API_TO_UI_HOSTS.get(api_host): return f"https://{ui_host}/results/{system_id}" return None diff --git a/src/fairchem/demo/ocpapi/version.py b/src/fairchem/demo/ocpapi/version.py index 3277f64c2..dd97582b3 100644 --- a/src/fairchem/demo/ocpapi/version.py +++ b/src/fairchem/demo/ocpapi/version.py @@ -1 +1,3 @@ +from __future__ import annotations + VERSION = "1.0.0" diff --git a/src/fairchem/demo/ocpapi/workflows/adsorbates.py b/src/fairchem/demo/ocpapi/workflows/adsorbates.py index a6a54ec22..51f43acbc 100644 --- a/src/fairchem/demo/ocpapi/workflows/adsorbates.py +++ b/src/fairchem/demo/ocpapi/workflows/adsorbates.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import logging from contextlib import AsyncExitStack, asynccontextmanager, suppress @@ -9,16 +11,9 @@ AsyncGenerator, Awaitable, Callable, - Dict, - List, - Optional, - Tuple, ) from dataclasses_json import Undefined, dataclass_json -from tqdm import tqdm -from tqdm.contrib.logging import logging_redirect_tqdm - from fairchem.demo.ocpapi.client import ( Adsorbates, AdsorbateSlabConfigs, @@ -35,6 +30,8 @@ Status, get_results_ui_url, ) +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm from .context import set_context_var from .filter import prompt_for_slabs_to_keep @@ -43,7 +40,7 @@ # Context instance that stores information about the adsorbate and bulk # material as a tuple in that order -_CTX_AD_BULK: ContextVar[Tuple[str, str]] = ContextVar(f"{__name__}:ad_bulk") +_CTX_AD_BULK: ContextVar[tuple[str, str]] = ContextVar(f"{__name__}:ad_bulk") # Context intance that stores information about a slab _CTX_SLAB: ContextVar[Slab] = ContextVar(f"{__name__}:slab") @@ -58,7 +55,7 @@ def _setup_log_record_factory() -> None: def new_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: # Save information about the bulk and absorbate if set - parts: List[str] = [] + parts: list[str] = [] if (ad_bulk := _CTX_AD_BULK.get(None)) is not None: parts.append(f"[{ad_bulk[0]}/{ad_bulk[1]}]") @@ -89,15 +86,13 @@ class AdsorbatesException(Exception): Base exception for all others in this module. """ - pass - class UnsupportedModelException(AdsorbatesException): """ Exception raised when a model is not supported in the API. """ - def __init__(self, model: str, allowed_models: List[str]) -> None: + def __init__(self, model: str, allowed_models: list[str]) -> None: """ Args: model: The model that was requested. @@ -141,19 +136,19 @@ class Lifetime(Enum): SAVE = auto() """ - The relaxation will be available on API servers indefinitely. It will not + The relaxation will be available on API servers indefinitely. It will not be possible to delete the relaxation in the future. """ MARK_EPHEMERAL = auto() """ - The relaxation will be saved on API servers, but can be deleted at any time + The relaxation will be saved on API servers, but can be deleted at any time in the future. """ DELETE = auto() """ - The relaxation will be deleted from API servers as soon as the results have + The relaxation will be deleted from API servers as soon as the results have been fetched. """ @@ -170,9 +165,9 @@ class AdsorbateSlabRelaxations: The slab on which the adsorbate was placed. """ - configs: List[AdsorbateSlabRelaxationResult] + configs: list[AdsorbateSlabRelaxationResult] """ - Details of the relaxation of each adsorbate placement, including the + Details of the relaxation of each adsorbate placement, including the final position. """ @@ -186,7 +181,7 @@ class AdsorbateSlabRelaxations: The API host on which the relaxations were run. """ - ui_url: Optional[str] + ui_url: str | None """ The URL at which results can be visualized. """ @@ -215,7 +210,7 @@ class AdsorbateBindingSites: The type of the model that was run. """ - slabs: List[AdsorbateSlabRelaxations] + slabs: list[AdsorbateSlabRelaxations] """ The list of slabs that were generated from the bulk structure. Each contains its own list of adsorbate placements. @@ -235,7 +230,7 @@ async def _ensure_model_supported(client: Client, model: str) -> None: UnsupportedModelException: If the model is not supported. """ models: Models = await client.get_models() - allowed_models: List[str] = [m.id for m in models.models] + allowed_models: list[str] = [m.id for m in models.models] if model not in allowed_models: raise UnsupportedModelException( model=model, @@ -286,7 +281,7 @@ async def _ensure_adsorbate_supported(client: Client, adsorbate: str) -> None: async def _get_slabs( client: Client, bulk: Bulk, -) -> List[Slab]: +) -> list[Slab]: """ Enumerates surfaces for the input bulk material. @@ -352,8 +347,8 @@ async def _get_absorbate_configs_on_slab_with_logging( async def _get_adsorbate_configs_on_slabs( client: Client, adsorbate: str, - slabs: List[Slab], -) -> List[AdsorbateSlabConfigs]: + slabs: list[Slab], +) -> list[AdsorbateSlabConfigs]: """ Finds candidate adsorbate binding sites on each of the input slabs. @@ -366,7 +361,7 @@ async def _get_adsorbate_configs_on_slabs( List of slabs and, for each, the positions of the adsorbate atoms in the potential binding site. """ - tasks: List[asyncio.Task] = [ + tasks: list[asyncio.Task] = [ asyncio.create_task( _get_absorbate_configs_on_slab_with_logging( client=client, @@ -396,7 +391,7 @@ async def _get_adsorbate_configs_on_slabs( async def _submit_relaxations( client: Client, adsorbate: str, - adsorbate_configs: List[Atoms], + adsorbate_configs: list[Atoms], bulk: Bulk, slab: Slab, model: str, @@ -436,7 +431,7 @@ async def _submit_relaxations( async def _submit_relaxations_with_progress_logging( client: Client, adsorbate: str, - adsorbate_configs: List[Atoms], + adsorbate_configs: list[Atoms], bulk: Bulk, slab: Slab, model: str, @@ -487,10 +482,10 @@ async def log_waiting() -> None: @retry_api_calls(max_attempts=3) async def get_adsorbate_slab_relaxation_results( system_id: str, - config_ids: Optional[List[int]] = None, - fields: Optional[List[str]] = None, + config_ids: list[int] | None = None, + fields: list[str] | None = None, client: Client = DEFAULT_CLIENT, -) -> List[AdsorbateSlabRelaxationResult]: +) -> list[AdsorbateSlabRelaxationResult]: """ Wrapper around Client.get_adsorbate_slab_relaxations_results() that handles retries, including re-fetching individual configurations that @@ -517,7 +512,7 @@ async def get_adsorbate_slab_relaxation_results( ) # Save a copy of all results that were fetched - fetched: List[AdsorbateSlabRelaxationResult] = list(results.configs) + fetched: list[AdsorbateSlabRelaxationResult] = list(results.configs) # If any results were omitted, fetch them before returning if results.omitted_config_ids: @@ -538,9 +533,9 @@ async def wait_for_adsorbate_slab_relaxations( check_immediately: bool = False, slow_interval_sec: float = 30, fast_interval_sec: float = 10, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, client: Client = DEFAULT_CLIENT, -) -> Dict[int, Status]: +) -> dict[int, Status]: """ Blocks until all relaxations in the input system have finished, whether successfully or not. @@ -579,7 +574,7 @@ async def wait_for_adsorbate_slab_relaxations( while True: # Get the current results. Only fetch the energy; this hits an index # that will return results more quickly. - results: List[ + results: list[ AdsorbateSlabRelaxationResult ] = await get_adsorbate_slab_relaxation_results( client=client, @@ -642,7 +637,7 @@ async def _ensure_system_deleted( async def _run_relaxations_on_slab( client: Client, adsorbate: str, - adsorbate_configs: List[Atoms], + adsorbate_configs: list[Atoms], bulk: Bulk, slab: Slab, model: str, @@ -703,7 +698,7 @@ async def _run_relaxations_on_slab( ) # Fetch the final results - results: List[ + results: list[ AdsorbateSlabRelaxationResult ] = await get_adsorbate_slab_relaxation_results( client=client, @@ -739,7 +734,7 @@ async def _relax_binding_sites_on_slabs( client: Client, adsorbate: str, bulk: Bulk, - adslabs: List[AdsorbateSlabConfigs], + adslabs: list[AdsorbateSlabConfigs], model: str, lifetime: Lifetime, ) -> AdsorbateBindingSites: @@ -780,7 +775,7 @@ async def _relax_binding_sites_on_slabs( ) # Run relaxations for all configurations on all slabs - tasks: List[asyncio.Task] = [ + tasks: list[asyncio.Task] = [ asyncio.create_task( _run_relaxations_on_slab( client=client, @@ -813,7 +808,7 @@ async def _relax_binding_sites_on_slabs( _DEFAULT_ADSLAB_FILTER: Callable[ - [List[AdsorbateSlabConfigs]], Awaitable[List[AdsorbateSlabConfigs]] + [list[AdsorbateSlabConfigs]], Awaitable[list[AdsorbateSlabConfigs]] ] = prompt_for_slabs_to_keep() @@ -822,7 +817,7 @@ async def find_adsorbate_binding_sites( bulk: str, model: str = "equiformer_v2_31M_s2ef_all_md", adslab_filter: Callable[ - [List[AdsorbateSlabConfigs]], Awaitable[List[AdsorbateSlabConfigs]] + [list[AdsorbateSlabConfigs]], Awaitable[list[AdsorbateSlabConfigs]] ] = _DEFAULT_ADSLAB_FILTER, client: Client = DEFAULT_CLIENT, lifetime: Lifetime = Lifetime.SAVE, @@ -889,13 +884,13 @@ async def find_adsorbate_binding_sites( # Fetch all slabs for the bulk log.info("Generating surfaces") - slabs: List[Slab] = await _get_slabs( + slabs: list[Slab] = await _get_slabs( client=client, bulk=bulk_obj, ) # Finding candidate binding site on each slab - adslabs: List[AdsorbateSlabConfigs] = await _get_adsorbate_configs_on_slabs( + adslabs: list[AdsorbateSlabConfigs] = await _get_adsorbate_configs_on_slabs( client=client, adsorbate=adsorbate, slabs=slabs, diff --git a/src/fairchem/demo/ocpapi/workflows/context.py b/src/fairchem/demo/ocpapi/workflows/context.py index bc7b3df09..8352a70ff 100644 --- a/src/fairchem/demo/ocpapi/workflows/context.py +++ b/src/fairchem/demo/ocpapi/workflows/context.py @@ -1,6 +1,10 @@ +from __future__ import annotations + from contextlib import contextmanager -from contextvars import ContextVar -from typing import Any, Generator +from typing import TYPE_CHECKING, Any, Generator + +if TYPE_CHECKING: + from contextvars import ContextVar @contextmanager diff --git a/src/fairchem/demo/ocpapi/workflows/filter.py b/src/fairchem/demo/ocpapi/workflows/filter.py index f7ea73a63..661e872d4 100644 --- a/src/fairchem/demo/ocpapi/workflows/filter.py +++ b/src/fairchem/demo/ocpapi/workflows/filter.py @@ -1,6 +1,9 @@ -from typing import Iterable, List, Set, Tuple +from __future__ import annotations -from fairchem.demo.ocpapi.client import AdsorbateSlabConfigs, SlabMetadata +from typing import TYPE_CHECKING, Iterable + +if TYPE_CHECKING: + from fairchem.demo.ocpapi.client import AdsorbateSlabConfigs, SlabMetadata class keep_all_slabs: @@ -10,8 +13,8 @@ class keep_all_slabs: async def __call__( self, - adslabs: List[AdsorbateSlabConfigs], - ) -> List[AdsorbateSlabConfigs]: + adslabs: list[AdsorbateSlabConfigs], + ) -> list[AdsorbateSlabConfigs]: return adslabs @@ -21,19 +24,19 @@ class keep_slabs_with_miller_indices: Slabs with other miller indices will be ignored. """ - def __init__(self, miller_indices: Iterable[Tuple[int, int, int]]) -> None: + def __init__(self, miller_indices: Iterable[tuple[int, int, int]]) -> None: """ Args: miller_indices: The list of miller indices that will be allowed. Slabs with any other miller indices will be dropped by this filter. """ - self._unique_millers: Set[Tuple[int, int, int]] = set(miller_indices) + self._unique_millers: set[tuple[int, int, int]] = set(miller_indices) async def __call__( self, - adslabs: List[AdsorbateSlabConfigs], - ) -> List[AdsorbateSlabConfigs]: + adslabs: list[AdsorbateSlabConfigs], + ) -> list[AdsorbateSlabConfigs]: return [ adslab for adslab in adslabs @@ -50,7 +53,7 @@ class prompt_for_slabs_to_keep: @staticmethod def _sort_key( adslab: AdsorbateSlabConfigs, - ) -> Tuple[Tuple[int, int, int], float, str]: + ) -> tuple[tuple[int, int, int], float, str]: """ Generates a sort key from the input adslab. Returns the miller indices, shift, and top/bottom label so that they will be sorted by those values @@ -61,8 +64,8 @@ def _sort_key( async def __call__( self, - adslabs: List[AdsorbateSlabConfigs], - ) -> List[AdsorbateSlabConfigs]: + adslabs: list[AdsorbateSlabConfigs], + ) -> list[AdsorbateSlabConfigs]: from inquirer import Checkbox, prompt # Break early if no adslabs were provided @@ -76,7 +79,7 @@ async def __call__( # will be presented to the user in the prompt. The second item in each # tuple (indices from the input list of adslabs) will be returned from # the prompt. - choices: List[Tuple[str, int]] = [ + choices: list[tuple[str, int]] = [ ( ( f"{adslab.slab.metadata.millers} " @@ -98,7 +101,7 @@ async def __call__( ), choices=choices, ) - selected_indices: List[int] = prompt([checkbox])["adslabs"] + selected_indices: list[int] = prompt([checkbox])["adslabs"] # Return the adslabs that were chosen return [adslabs[i] for i in selected_indices] diff --git a/src/fairchem/demo/ocpapi/workflows/log.py b/src/fairchem/demo/ocpapi/workflows/log.py index 049be042f..4f175b548 100644 --- a/src/fairchem/demo/ocpapi/workflows/log.py +++ b/src/fairchem/demo/ocpapi/workflows/log.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging log = logging.getLogger("ocpapi") diff --git a/src/fairchem/demo/ocpapi/workflows/retry.py b/src/fairchem/demo/ocpapi/workflows/retry.py index 98a8a8df3..4ef1e129f 100644 --- a/src/fairchem/demo/ocpapi/workflows/retry.py +++ b/src/fairchem/demo/ocpapi/workflows/retry.py @@ -1,10 +1,15 @@ -import logging +from __future__ import annotations + from dataclasses import dataclass -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal -from tenacity import RetryCallState -from tenacity import retry as tenacity_retry +from fairchem.demo.ocpapi.client import ( + NonRetryableRequestException, + RateLimitExceededException, + RequestException, +) from tenacity import ( + RetryCallState, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt, @@ -12,13 +17,11 @@ wait_fixed, wait_random, ) +from tenacity import retry as tenacity_retry from tenacity.wait import wait_base -from fairchem.demo.ocpapi.client import ( - NonRetryableRequestException, - RateLimitExceededException, - RequestException, -) +if TYPE_CHECKING: + import logging @dataclass @@ -48,7 +51,7 @@ class _wait_check_retry_after(wait_base): def __init__( self, default_wait: wait_base, - rate_limit_logging: Optional[RateLimitLogging] = None, + rate_limit_logging: RateLimitLogging | None = None, ) -> None: """ Args: @@ -84,8 +87,8 @@ def __call__(self, retry_state: RetryCallState) -> float: def retry_api_calls( - max_attempts: Union[int, NoLimitType] = 3, - rate_limit_logging: Optional[RateLimitLogging] = None, + max_attempts: int | NoLimitType = 3, + rate_limit_logging: RateLimitLogging | None = None, fixed_wait_sec: float = 2, max_jitter_sec: float = 1, ) -> Any: diff --git a/tests/demo/ocpapi/tests/unit/client/test_models.py b/tests/demo/ocpapi/tests/unit/client/test_models.py index 74c6c5df3..e3cb4089a 100644 --- a/tests/demo/ocpapi/tests/unit/client/test_models.py +++ b/tests/demo/ocpapi/tests/unit/client/test_models.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from dataclasses import dataclass from typing import ( @@ -17,8 +19,8 @@ from ase.atoms import Atoms as ASEAtoms from ase.calculators.singlepoint import SinglePointCalculator from ase.constraints import FixAtoms - -from fairchem.demo.ocpapi.client import ( +from fairchem.demo.ocpapi.client import Status +from fairchem.demo.ocpapi.client.models import ( Adsorbates, AdsorbateSlabConfigs, AdsorbateSlabRelaxationResult, @@ -33,9 +35,8 @@ Slab, SlabMetadata, Slabs, - Status, + _DataModel, ) -from fairchem.demo.ocpapi.client.models import _DataModel T = TypeVar("T", bound=_DataModel)