Skip to content

Commit

Permalink
Pass in strict flag to allow non-strict state_dict loading (#691)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #691

This allows users to ignore keys in the state_dict that aren't part of
the given module, or part of the module that aren't in the state_dict.

The default value for the flag (true) keeps the status quo and what the pytorch
interface uses.

Reviewed By: anshulverma

Differential Revision: D53066198

fbshipit-source-id: 8a849f46d09d6e7d9185d589b966f7a7a089d9fc
  • Loading branch information
schwarzmx authored and facebook-github-bot committed Jan 26, 2024
1 parent 466a0cd commit c25e140
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def state_dict(self) -> Dict[str, Any]:
self.state_dict_call_count += 1
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True) -> None:
self.load_state_dict_call_count += 1
return None

Expand Down

0 comments on commit c25e140

Please sign in to comment.