Skip to content

Commit

Permalink
Update default value for TorchCompileConfig to be None
Browse files Browse the repository at this point in the history
Summary:
If we set dynamics to False, then automatic_dynamic_shapes will not take effect. This is not ideal because for sparse arch compilation, we expect that there are a number of dynamic shapes, if we do not apply automatic_dynamic_shapes, compiler will quit compilation after 8 trials.

Thus, we change it to None in this diff which is also the default setting for `torch.compile`:

https://www.internalfb.com/code/fbsource/[08c8192803b22f831329c8f736a3f7c6093ea4a8]/fbcode/caffe2/torch/__init__.py?lines=2321

Differential Revision: D63476875
  • Loading branch information
Microve authored and facebook-github-bot committed Sep 26, 2024
1 parent 001ccd5 commit 85e0361
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ class TorchCompileConfig:
Configs for torch.compile
fullgraph: bool = False, whether to compile the whole graph or not
dynamic: bool = False, whether to use dynamic shapes or not
dynamic: Optional[bool] = None, whether to use dynamic shapes or not, if None, automatic_dynamic_shapes will apply
backend: str = "inductor", which compiler to use (either inductor or aot)
compile_on_iter: int = 3, compile the model on which iteration
this is useful when we want to profile the first few iterations of training
and then start using compiled model from iteration #3 onwards
"""

fullgraph: bool = False
dynamic: bool = False
dynamic: Optional[bool] = None
backend: str = "inductor"
compile_on_iter: int = 3

Expand Down

0 comments on commit 85e0361

Please sign in to comment.