Skip to content

Commit

Permalink
enable use_buffer for SWA in AutoUnit (pytorch#844)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#844

Enables batch normalization statistics to be updated via `use_buffers` flag. See https://fb.workplace.com/groups/1323951304836028/posts/1668131377084684/?comment_id=1668215887076233 for user req

This functionality existed in the base AveragedModel but was never public in TNT

Reviewed By: diego-urgell

Differential Revision: D58368554

fbshipit-source-id: fff608da39c471ee6e61af945a2f782e76ece5b9
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jun 11, 2024
1 parent 9d99ea6 commit f9f566b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class SWAParams:
Args:
warmup_steps_or_epochs: number of steps or epochs before starting SWA
step_or_epoch_update_freq: number of steps or epochs between each SWA update
use_buffers: if ``True``, it will compute running averages for
both the parameters and the buffers of the model. (default: ``True``)
This will update activation statistics for Batch Normalization. This is an
alternative to calling `torch.optim.swa_utils.update_bn` post-training.
averaging_method: whether to use SWA or EMA to average model weights
ema_decay: the exponential decay applied to the averaged parameters. This param
is only needed for EMA, and is ignored otherwise (for SWA).
Expand All @@ -101,6 +105,7 @@ class SWAParams:

warmup_steps_or_epochs: int
step_or_epoch_update_freq: int
use_buffers: bool = True
averaging_method: Literal["ema", "swa"] = "ema"
ema_decay: float = 0.999
use_lit: bool = False
Expand Down Expand Up @@ -487,7 +492,7 @@ def __init__(
self.swa_model = AveragedModel(
module_for_swa,
device=device,
use_buffers=True,
use_buffers=swa_params.use_buffers,
averaging_method=swa_params.averaging_method,
ema_decay=swa_params.ema_decay,
skip_deepcopy=skip_deepcopy,
Expand Down
3 changes: 3 additions & 0 deletions torchtnt/utils/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(
to see what the model, device, and use_buffer arguments entail.
Args:
use_buffers: if ``True``, it will compute running averages for
both the parameters and the buffers of the model. (default: ``False``)
This will update activation statistics for Batch Normalization.
averaging_method: Whether to use EMA or SWA.
ema_decay: The exponential decay applied to the averaged parameters. This param
is only needed for EMA, and is ignored otherwise (for SWA).
Expand Down

0 comments on commit f9f566b

Please sign in to comment.