Skip to content

Commit

Permalink
Add FSDP custom wrap with torch 2.1 (mosaicml#2460)
Browse files Browse the repository at this point in the history
* add torch2

* add code

* tag more changes

* Update composer/trainer/mosaic_fsdp.py

Co-authored-by: Vitaliy Chiley <[email protected]>

* monkeypatch init

* raise pins

* add print

* more logs

* change if statements

* remove imports

* remove imports

* fix init

* fix versioning

* add hybrid shard

* checkdown

* revert hsdp

* add peak memory stats

* lint

* imports

* Update composer/trainer/mosaic_fsdp.py

Co-authored-by: Daniel King <[email protected]>

* fix wrap

* fix gate

* lint

* test

* change thresh

* import typing

* fix checks

* nuke pyright

* typo

* Update composer/trainer/mosaic_fsdp.py

Co-authored-by: Brian <[email protected]>

* Update composer/trainer/mosaic_fsdp.py

Co-authored-by: Brian <[email protected]>

* Update composer/trainer/mosaic_fsdp_utils.py

Co-authored-by: Brian <[email protected]>

* resolve comments

* add comments

* add comments

---------

Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Brian <[email protected]>
  • Loading branch information
4 people committed Sep 8, 2023
1 parent d6aa754 commit 92c84aa
Show file tree
Hide file tree
Showing 3 changed files with 323 additions and 10 deletions.
20 changes: 12 additions & 8 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,35 @@ def patch_pytorch():
raise NotImplementedError(f'Not supported for torch < 1.13.1')

elif version.parse(torch.__version__) < version.parse('2.0.0'):
# FullyShardedDataParallel monkey path for torch < 2.0 ie torch == 1.13.1
# Monkey patch for torch < 2.0 ie torch == 1.13.1

# monkey patch _auto_wrap with _custom_auto_wrap fn
# Monkey patch _auto_wrap with _custom_auto_wrap fn
FullyShardedDataParallel._auto_wrap = custom_auto_wrap_t1p13p1 # type: ignore

elif version.parse(torch.__version__) < version.parse('2.0.1'):
raise NotImplementedError(f'Not supported for torch == 2.0.0')

elif version.parse(torch.__version__) == version.parse('2.0.1'):
elif version.parse(torch.__version__) < version.parse('2.0.2'):
# Monkey patch for torch == 2.0.1

# Monkey patch __init__ where __init__ calls the custom _auto_wrap fn
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p0p1
FullyShardedDataParallel.__init__ = init_fn_t2p0p1
FullyShardedDataParallel.__init__ = init_fn_t2p0p1 # type: ignore

# Monkey patch sharding method
ChunkShardingSpec.build_metadata = build_metadata
ChunkShardingSpec.shard = shard

elif version.parse(torch.__version__) < version.parse('2.2.0'):
# Monkey path for torch < 2.2.0 ie torch == 2.1.0
elif version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey path for torch < 2.1.1 ie torch == 2.1.0

# Monkey patch __init__ where __init__ calls the custom _auto_wrap fn
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p1p0
FullyShardedDataParallel.__init__ = init_fn_t2p1p0 # type: ignore

# Monkey patch sharding method
ChunkShardingSpec.build_metadata = build_metadata
ChunkShardingSpec.shard = shard

elif version.parse(torch.__version__) >= version.parse('2.2.0'):
raise NotImplementedError(f'Not supported for torch >= 2.2.0')
elif version.parse(torch.__version__) >= version.parse('2.1.1'):
raise NotImplementedError(f'FullyShardedDataParallel is not supported for torch >= 2.2.0')
Loading

0 comments on commit 92c84aa

Please sign in to comment.