-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added option for continue training from checkpoint #49
Changes from 5 commits
5f7b77e
699d726
bb54b84
ea3254e
8c2dc5e
99e9abb
689bcae
e03c5bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -305,7 +305,9 @@ def forward( | |
if output_name in outputs: | ||
atomic_energies[output_name] = apply_composition_contribution( | ||
output_layer(hidden_features), | ||
self.composition_weights[self.output_to_index[output_name]], | ||
self.composition_weights[ # type: ignore | ||
self.output_to_index[output_name] | ||
], | ||
) | ||
|
||
# Sum the atomic energies coming from the BPNN to get the total energy | ||
|
@@ -331,6 +333,29 @@ def set_composition_weights( | |
) -> None: | ||
"""Set the composition weights for a given output.""" | ||
# all species that are not present retain their weight of zero | ||
self.composition_weights[self.output_to_index[output_name]][ | ||
self.composition_weights[self.output_to_index[output_name]][ # type: ignore | ||
self.all_species | ||
] = input_composition_weights | ||
|
||
def add_output(self, output_name: str) -> None: | ||
"""Add a new output to the model.""" | ||
# add a new row to the composition weights tensor | ||
Comment on lines
+341
to
+342
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't be both in the docstring? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both? |
||
self.composition_weights = torch.cat( | ||
[ | ||
self.composition_weights, # type: ignore | ||
torch.zeros( | ||
1, | ||
self.composition_weights.shape[1], # type: ignore | ||
dtype=self.composition_weights.dtype, # type: ignore | ||
device=self.composition_weights.device, # type: ignore | ||
), | ||
] | ||
) # type: ignore | ||
self.output_to_index[output_name] = len(self.output_to_index) | ||
# add a new linear layer to the last layers | ||
hypers_bpnn = self.hypers["bpnn"] | ||
if hypers_bpnn["num_hidden_layers"] == 0: | ||
n_inputs_last_layer = hypers_bpnn["input_size"] | ||
else: | ||
n_inputs_last_layer = hypers_bpnn["num_neurons_per_layer"] | ||
self.last_layers[output_name] = LinearMap(self.all_species, n_inputs_last_layer) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import Tuple | ||
|
||
from metatensor.torch.atomistic import ModelCapabilities | ||
|
||
|
||
def merge_capabilities( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be a test for this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, my bad |
||
old_capabilities: ModelCapabilities, requested_capabilities: ModelCapabilities | ||
) -> Tuple[ModelCapabilities, ModelCapabilities]: | ||
""" | ||
Merge the capabilities of a model with the requested capabilities. | ||
|
||
:param old_capabilities: The old capabilities of the model. | ||
:param requested_capabilities: The requested capabilities. | ||
|
||
:return: The merged capabilities and the new capabilities that | ||
were not present in the old capabilities. The order will | ||
be preserved. | ||
""" | ||
# Check that the length units are the same: | ||
if old_capabilities.length_unit != requested_capabilities.length_unit: | ||
raise ValueError( | ||
"The length units of the old and new capabilities are not the same." | ||
) | ||
|
||
# Check that there are no new species: | ||
for species in requested_capabilities.species: | ||
if species not in old_capabilities.species: | ||
raise ValueError( | ||
f"The species {species} is not within " | ||
"the capabilities of the loaded model." | ||
) | ||
|
||
# Merge the outputs: | ||
outputs = {} | ||
for key, value in old_capabilities.outputs.items(): | ||
outputs[key] = value | ||
for key, value in requested_capabilities.outputs.items(): | ||
if key not in outputs: | ||
outputs[key] = value | ||
else: | ||
assert ( | ||
outputs[key].unit == value.unit | ||
), f"Output {key} has different units in the old and new capabilities." | ||
|
||
# Find the new outputs: | ||
new_outputs = {} | ||
for key, value in requested_capabilities.outputs.items(): | ||
if key not in old_capabilities.outputs: | ||
new_outputs[key] = value | ||
|
||
merged_capabilities = ModelCapabilities( | ||
length_unit=requested_capabilities.length_unit, | ||
species=old_capabilities.species, | ||
outputs=outputs, | ||
) | ||
|
||
new_capabilities = ModelCapabilities( | ||
length_unit=requested_capabilities.length_unit, | ||
species=old_capabilities.species, | ||
outputs=new_outputs, | ||
) | ||
|
||
return merged_capabilities, new_capabilities |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe adding a default=None, here changes it from a string.