From 53c6f911f2d127ceff57012d54e29eb6a985d664 Mon Sep 17 00:00:00 2001 From: Saurabh Mishra Date: Fri, 12 Jul 2024 09:19:53 -0700 Subject: [PATCH] Fix UTs which were broken due to additional observability warnings (#862) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/862 Fix UTs which were broken due to additional observability warnings. DCP test should only validate the first warning in the list of warnings. Reviewed By: galrotem Differential Revision: D59684986 fbshipit-source-id: fa0eccc2f0a85e84c518cd15b4dfa09e34bb78ad --- tests/framework/callbacks/test_dcp_saver.py | 6 ++---- tests/framework/callbacks/test_torchsnapshot_saver.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index 09bde932bb..7f66e301ab 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -134,10 +134,8 @@ def test_save_restore_dataloader_state(self) -> None: # load_state_dict is not called again on dataloader because there is no dataloader in manifest self.assertEqual(stateful_dataloader.load_state_dict_call_count, 1) self.assertEqual( - log.output, - [ - "WARNING:torchtnt.utils.rank_zero_log:train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot" - ], + log.output[0], + "WARNING:torchtnt.utils.rank_zero_log:train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot", ) def test_restore_from_latest(self) -> None: diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py index 9fc1d27e95..ddf617bb89 100644 --- a/tests/framework/callbacks/test_torchsnapshot_saver.py +++ b/tests/framework/callbacks/test_torchsnapshot_saver.py @@ -126,10 +126,8 @@ def test_save_restore_dataloader_state(self) -> None: # load_state_dict is not called again on dataloader because there is no dataloader in manifest self.assertEqual(stateful_dataloader.load_state_dict_call_count, 1) self.assertEqual( - log.output, - [ - "WARNING:torchtnt.utils.rank_zero_log:train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot" - ], + log.output[0], + "WARNING:torchtnt.utils.rank_zero_log:train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot", ) def test_restore_from_latest(self) -> None: