Skip to content

Commit

Permalink
make AveragedModel OSS compatible
Browse files Browse the repository at this point in the history
Summary: Certain imports from pytorch AveragedModel are not in torch version 2.0.0, so this diff guards against potential import errors at runtime

Reviewed By: galrotem

Differential Revision: D56534614

fbshipit-source-id: fe6a5b56eabf51c101fa0ebc0f1eb2870df10c97
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 25, 2024
1 parent 1a47149 commit 7737e13
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
18 changes: 13 additions & 5 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import DDPStrategy
from torchtnt.utils.progress import Progress
from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL
from torchtnt.utils.test_utils import skip_if_not_distributed
from torchtnt.utils.timer import Timer

Expand Down Expand Up @@ -149,6 +150,9 @@ def test_predict_step(self) -> None:
predict(auto_unit, pred_dataloader, max_steps_per_epoch=1)
mock_predict_step_end.assert_called_once()

@unittest.skipUnless(
_AVERAGED_MODEL_AVAIL, "AveragedModel needed in version of Pytorch"
)
def test_stochastic_weight_averaging_basic(self) -> None:
"""
Basic stochastic weight averaging tests
Expand Down Expand Up @@ -182,6 +186,9 @@ def test_stochastic_weight_averaging_basic(self) -> None:
self.assertIn("swa_scheduler", auto_unit2.app_state())
self.assertIn("swa_scheduler", auto_unit2.tracked_lr_schedulers())

@unittest.skipUnless(
_AVERAGED_MODEL_AVAIL, "AveragedModel needed in version of Pytorch"
)
def test_stochastic_weight_averaging_update_freq(self) -> None:
"""
e2e stochastic weight averaging test to ensure that the SWA model is updated at the correct frequency
Expand Down Expand Up @@ -295,11 +302,12 @@ def test_auto_unit_ddp(self) -> None:
Launch tests of AutoUnit with DDP strategy
"""

spawn_multi_process(
2,
"gloo",
self._test_stochastic_weight_averaging_with_ddp,
)
if _AVERAGED_MODEL_AVAIL:
spawn_multi_process(
2,
"gloo",
self._test_stochastic_weight_averaging_with_ddp,
)
spawn_multi_process(
2,
"gloo",
Expand Down
5 changes: 4 additions & 1 deletion tests/utils/test_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

import torch

from torchtnt.utils.swa import AveragedModel
from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL, AveragedModel

if not _AVERAGED_MODEL_AVAIL:
raise unittest.SkipTest("Latest Pytorch is required to run SWA tests")


class TestSWA(unittest.TestCase):
Expand Down
22 changes: 17 additions & 5 deletions torchtnt/utils/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@

import torch

from torch.optim.swa_utils import (
AveragedModel as PyTorchAveragedModel,
get_ema_multi_avg_fn,
get_swa_multi_avg_fn,
)
_AVERAGED_MODEL_AVAIL: bool = True

try:
from torch.optim.swa_utils import (
AveragedModel as PyTorchAveragedModel,
get_ema_multi_avg_fn,
get_swa_multi_avg_fn,
)
except ImportError:
_AVERAGED_MODEL_AVAIL = False


TSWA_avg_fn = Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
TSWA_multi_avg_fn = Callable[[List[torch.Tensor], List[torch.Tensor], int], None]
Expand Down Expand Up @@ -49,6 +55,12 @@ def __init__(
number of updates. The EMA decay will start small and will approach the
specified ema_decay as more updates occur.
"""
if not _AVERAGED_MODEL_AVAIL:
raise ImportError(
"AveragedModel is not available in this version of PyTorch. \
Please install the latest version of PyTorch."
)

# setup averaging method
if averaging_method == "ema":
if ema_decay < 0.0 or ema_decay > 1.0:
Expand Down

0 comments on commit 7737e13

Please sign in to comment.