Skip to content

Commit

Permalink
Change atomic_types and gradients from sets to unique lists (#296)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: frostedoyster <[email protected]>
  • Loading branch information
PicoCentauri and frostedoyster authored Jul 16, 2024
1 parent 0024253 commit 576b2b0
Show file tree
Hide file tree
Showing 28 changed files with 273 additions and 134 deletions.
4 changes: 2 additions & 2 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,10 @@ def eval_model(
# TODO: allow the user to specify which outputs to evaluate
eval_targets = {}
eval_info_dict = TargetInfoDict()
gradients = {"positions"}
gradients = ["positions"]
if all(not torch.all(system.cell == 0) for system in eval_systems):
# only add strain if all structures have cells
gradients.add("strain")
gradients.append("strain")
for key in model.capabilities().outputs.keys():
eval_info_dict[key] = TargetInfo(
quantity=model.capabilities().outputs[key].quantity,
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/alchemical_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
super().__init__()
self.hypers = model_hypers
self.dataset_info = dataset_info
self.atomic_types = sorted(dataset_info.atomic_types)
self.atomic_types = dataset_info.atomic_types

if len(dataset_info.targets) != 1:
raise ValueError("The AlchemicalModel only supports a single target")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_to(device, dtype):

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_prediction_subset_elements():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_rotational_invariance():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_regression_init():
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=targets
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)

Expand Down Expand Up @@ -97,7 +97,7 @@ def test_regression_train():
hypers = DEFAULT_HYPERS.copy()

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_alchemical_model_inference():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=set(unique_numbers),
atomic_types=unique_numbers,
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_torchscript():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)

Expand All @@ -24,7 +24,7 @@ def test_torchscript_save_load():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/gap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
for key, value in dataset_info.targets.items()
}

self.atomic_types = sorted(dataset_info.atomic_types)
self.atomic_types = dataset_info.atomic_types
self.hypers = model_hypers

# creates a composition weight tensor that can be directly indexed by species,
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/gap/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_ethanol_regression_train_and_invariance():
)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
)

gap = GAP(hypers["model"], dataset_info)
Expand Down
6 changes: 3 additions & 3 deletions src/metatrain/experimental/gap/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_regression_init():
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=targets
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
)
GAP(DEFAULT_HYPERS["model"], dataset_info)

Expand Down Expand Up @@ -61,7 +61,7 @@ def test_regression_train_and_invariance():
target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
)

gap = GAP(DEFAULT_HYPERS["model"], dataset_info)
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_ethanol_regression_train_and_invariance():
)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
)

gap = GAP(hypers["model"], dataset_info)
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/gap/tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_torchscript():
target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
)
conf = {
"mtt::U0": {
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_torchscript_save():
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=targets
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
)
gap = GAP(DEFAULT_HYPERS["model"], dataset_info)
torch.jit.save(
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
model_hypers["TARGET_AGGREGATION"] = "sum"
self.hypers = model_hypers
self.cutoff = self.hypers["R_CUT"]
self.atomic_types: List[int] = sorted(dataset_info.atomic_types)
self.atomic_types: List[int] = dataset_info.atomic_types
self.dataset_info = dataset_info
self.pet = None
self.checkpoint_path: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/pet/tests/test_exported.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_to(device):
dtype = torch.float32 # for now
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info)
Expand Down
6 changes: 3 additions & 3 deletions src/metatrain/experimental/pet/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_prediction():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info)
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_per_atom_predictions_functionality():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info)
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_selected_atoms_functionality():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,15 @@ def test_predictions_compatibility(cutoff):
are consistent with the predictions of the original PET implementation."""

structure = ase.io.read(DATASET_PATH)
atomic_types = set(structure.numbers)

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=atomic_types,
atomic_types=structure.numbers,
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
capabilities = ModelCapabilities(
length_unit="Angstrom",
atomic_types=sorted(atomic_types),
atomic_types=dataset_info.atomic_types,
outputs={
"energy": ModelOutput(
quantity="energy",
Expand All @@ -116,7 +115,7 @@ def test_predictions_compatibility(cutoff):
hypers["R_CUT"] = cutoff
model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info)
ARCHITECTURAL_HYPERS = Hypers(model.hypers)
raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types))
raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(dataset_info.atomic_types))
model.set_trained_model(raw_pet)

system = systems_to_torch(structure)
Expand All @@ -142,7 +141,7 @@ def test_predictions_compatibility(cutoff):
ARCHITECTURAL_HYPERS = Hypers(DEFAULT_HYPERS["model"])
batch = get_pyg_graphs(
[structure],
sorted(atomic_types),
dataset_info.atomic_types,
cutoff,
ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
ARCHITECTURAL_HYPERS.USE_LONG_RANGE,
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/pet/tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_torchscript():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info)
Expand All @@ -30,7 +30,7 @@ def test_torchscript_save_load():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info)
Expand Down
8 changes: 5 additions & 3 deletions src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
self.hypers = model_hypers
self.dataset_info = dataset_info
self.new_outputs = list(dataset_info.targets.keys())
self.atomic_types = sorted(dataset_info.atomic_types)
self.atomic_types = dataset_info.atomic_types

self.soap_calculator = rascaline.torch.SoapPowerSpectrum(
radial_basis={"Gto": {}}, **self.hypers["soap"]
Expand Down Expand Up @@ -198,7 +198,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn":
# merge old and new dataset info
merged_info = self.dataset_info.union(dataset_info)
new_atomic_types = merged_info.atomic_types - self.dataset_info.atomic_types
new_atomic_types = [
at for at in merged_info.atomic_types if at not in self.atomic_types
]
new_targets = merged_info.targets - self.dataset_info.targets

if len(new_atomic_types) > 0:
Expand All @@ -212,7 +214,7 @@ def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn":
self.add_output(output_name)

self.dataset_info = merged_info
self.atomic_types = sorted(self.dataset_info.atomic_types)
self.atomic_types = sorted(self.atomic_types)

for target_name, target in new_targets.items():
self.outputs[target_name] = ModelOutput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_continue(monkeypatch, tmp_path):
target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
)
model = SoapBpnn(MODEL_HYPERS, dataset_info)
output_before = model(systems[:5], {"mtt::U0": model.outputs["mtt::U0"]})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_to(device, dtype):

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = SoapBpnn(MODEL_HYPERS, dataset_info).to(dtype=dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_prediction_subset_elements():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)

Expand All @@ -37,7 +37,7 @@ def test_prediction_subset_atoms():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)

Expand Down Expand Up @@ -103,7 +103,7 @@ def test_output_last_layer_features():
"""Tests that the model can output its last layer features."""
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)

Expand Down Expand Up @@ -174,7 +174,7 @@ def test_output_per_atom():
"""Tests that the model can output per-atom quantities."""
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_rotational_invariance():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = SoapBpnn(MODEL_HYPERS, dataset_info)
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_regression_init():
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=targets
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
)
model = SoapBpnn(MODEL_HYPERS, dataset_info)

Expand Down Expand Up @@ -76,7 +76,7 @@ def test_regression_train():
hypers["training"]["num_epochs"] = 2

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
)
model = SoapBpnn(MODEL_HYPERS, dataset_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_torchscript():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = SoapBpnn(MODEL_HYPERS, dataset_info)
Expand All @@ -38,7 +38,7 @@ def test_torchscript_with_identity():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
hypers = copy.deepcopy(MODEL_HYPERS)
Expand All @@ -64,7 +64,7 @@ def test_torchscript_save_load():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types={1, 6, 7, 8},
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
)
model = SoapBpnn(MODEL_HYPERS, dataset_info)
Expand Down
Loading

0 comments on commit 576b2b0

Please sign in to comment.