Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove orig_params check #2981

Merged
merged 67 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 65 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
b276016
remove orig_params check
milocress Feb 8, 2024
c9fea5d
expect cpu-cpu test to fail
milocress Feb 14, 2024
13ccb2b
expect cpu-cpu test to fail
milocress Feb 14, 2024
3139b5c
fix formatting
milocress Feb 14, 2024
03cbf4a
Merge branch 'dev' into milo/remove-orig-params-check
milocress Feb 14, 2024
412a672
Merge branch 'dev' into milo/remove-orig-params-check
milocress Feb 29, 2024
aea6ebd
Update test_checkpoint.py
milocress Mar 19, 2024
8d65cfb
Merge branch 'dev' into milo/remove-orig-params-check
milocress Mar 20, 2024
0fe6c1f
Merge branch 'dev' into milo/remove-orig-params-check
milocress Mar 20, 2024
faf8c71
expect True,True to succeed?
milocress Mar 20, 2024
0939597
remove commented code
milocress Mar 20, 2024
80f79dc
rerun tests
mvpatel2000 Mar 20, 2024
f483dfb
how about now
milocress Mar 20, 2024
7ccac42
updated tests
milocress Mar 20, 2024
a01c954
fix tests
milocress Mar 20, 2024
ca3929f
Update test_checkpoint.py
milocress Mar 21, 2024
dd7583c
WIP debugging
milocress Apr 3, 2024
3da6a5f
merged
milocress Apr 3, 2024
469f31a
wip debug
milocress Apr 9, 2024
34ba1bd
Merge branch 'dev' into milo/remove-orig-params-check
milocress Apr 12, 2024
bd71a66
Merge branch 'dev' into milo/remove-orig-params-check
milocress May 2, 2024
b6eea9c
Update test_checkpoint.py
milocress May 2, 2024
6a8ff28
Update test_checkpoint.py
milocress May 2, 2024
43aa2d4
Update test_checkpoint.py
milocress May 2, 2024
7de12a2
Update test_checkpoint.py
milocress May 2, 2024
a38348f
Update test_checkpoint.py
milocress May 2, 2024
b567fa9
Update test_checkpoint.py
milocress May 2, 2024
072844f
WIP
milocress May 7, 2024
3ba9215
Merge branch 'dev' into milo/remove-orig-params-check
mvpatel2000 May 7, 2024
7db4478
counterexample is hanging guh
milocress May 13, 2024
6e73f60
merged
milocress May 13, 2024
2038088
have repro
milocress May 13, 2024
9945aaf
less composer repro
milocress May 13, 2024
46c87b8
found bug
milocress May 13, 2024
d7c9879
make counterexample match composer wrapper model situation
milocress May 14, 2024
3c28db9
still doesn't work
milocress May 14, 2024
8e1457e
update counterexample with printed models to show they're wrappped th…
milocress May 14, 2024
d80d9da
trim unnecessary stuff
milocress May 14, 2024
5624b47
Merge branch 'dev' into milo/remove-orig-params-check
milocress May 14, 2024
a37962d
counterexample for the record
milocress May 14, 2024
fdb681b
simplified counterexample
milocress May 14, 2024
e983170
fix tests
milocress May 14, 2024
9789245
fix quality and delete counterexample
milocress May 14, 2024
b213b92
fix
milocress May 14, 2024
79fbed0
asdict
milocress May 14, 2024
cea4402
remove self (lol if only)
milocress May 14, 2024
01d18fa
fix test
milocress May 14, 2024
9eb1a35
fix fix
milocress May 14, 2024
4b430a8
fix
milocress May 15, 2024
2d6aea9
reset save folder changes
milocress May 15, 2024
37031b7
remove meta for now
milocress May 15, 2024
049158f
replace load_monolith_ with load_fsdp_monolith_
milocress May 15, 2024
b477bc3
Merge branch 'dev' into milo/remove-orig-params-check
milocress May 15, 2024
1c7d3c2
change load_fsdp_monolith_rank0_only to load_monolith_rank0_only
milocress May 15, 2024
6e0eac0
merged fsdp change
milocress May 15, 2024
4507db1
resolve conflicts?
milocress May 15, 2024
33e007f
expect failure when no sync module states
milocress May 15, 2024
a70e852
fix constructor()
milocress May 15, 2024
aba4387
WIP
milocress May 15, 2024
44e6f7f
Merge branch 'dev' into milo/remove-orig-params-check
milocress May 15, 2024
eb186f2
WIP
milocress May 16, 2024
73e97b7
merged
milocress May 16, 2024
40dc38b
Update test_fsdp_checkpoint.py
mvpatel2000 May 16, 2024
115156a
make PR ready for review
milocress May 16, 2024
c2d8b66
Merge branch 'dev' into milo/remove-orig-params-check
milocress May 16, 2024
cddb4c1
remove device
milocress May 16, 2024
e1b7439
merged
milocress May 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,6 @@ def __init__(
if self.load_monolith_rank0_only:
assert fsdp_config is not None
error_message = ''
if fsdp_config['use_orig_params'] == True:
error_message += textwrap.dedent(
"load_monolith_rank0_only requires fsdp_config['use_orig_params'] to be False. "
"Either set fsdp_config['use_orig_params'] = False or set load_monolith_rank0_only = False. ",
)
if fsdp_config['sync_module_states'] == False:
error_message += textwrap.dedent(
"load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. "
Expand Down
102 changes: 1 addition & 101 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,106 +1511,6 @@ def test_set_dataloaders_to_cur_epoch(
# Epoch count starts at O
assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1

@pytest.mark.parametrize(
'world_size',
[
pytest.param(2, marks=pytest.mark.world_size(2)),
],
)
@pytest.mark.parametrize(
'device',
[
pytest.param('gpu', marks=pytest.mark.gpu),
],
)
@pytest.mark.parametrize(
'use_orig_params,sync_module_states,model_1_init_device,model_2_init_device',
[
pytest.param(False, True, 'cpu', 'cpu'), # success
pytest.param(False, True, 'cpu', 'meta'), # success
pytest.param(True, True, 'cpu', 'cpu'), # fail
pytest.param(False, False, 'cpu', 'cpu'), # fail
pytest.param(False, True, 'meta', 'cpu'), # fail
],
)
@pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*')
@pytest.mark.filterwarnings(
'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*',
)
def test_fsdp_monolith_resumption(
self,
device: str,
world_size: int,
use_orig_params: bool,
sync_module_states: bool,
model_1_init_device: str,
model_2_init_device: str,
tmp_path: pathlib.Path,
):
save_interval = '1ba'
save_filename = 'ba{batch}-rank{rank}.pt'
resume_file = 'ba1-rank{rank}.pt'
final_checkpoint = 'latest-rank{rank}.pt'
fsdp_config = {
'use_orig_params': use_orig_params,
'sync_module_states': sync_module_states,
'state_dict_type': 'full',
}

# All ranks use rank 0 folder
tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = pathlib.Path(tmp_paths[0])

trainer_1 = self.get_trainer(
save_folder=os.path.join(save_folder, 'first'),
save_filename=save_filename,
save_interval=save_interval,
eval_interval=save_interval,
fsdp_config=fsdp_config,
device=device,
precision='amp_fp16',
max_duration='1ep',
train_subset_num_batches=2,
)

trainer_1.fit()
trainer_1.close()

self._assert_expected_num_checkpoints(
save_folder=os.path.join(save_folder, 'first'),
save_interval=save_interval,
num_epochs=1, # set in get_trainer()
num_batches_per_epoch=2, # set in get_trainer()
is_deepspeed=False,
)

resume_file = os.path.join(save_folder, 'first', resume_file)
model_init_device = [model_1_init_device, model_2_init_device][dist.get_global_rank()]
fsdp_config['load_monolith_rank0_only'] = True

success = use_orig_params == False and sync_module_states == True and model_1_init_device == 'cpu'
with contextlib.nullcontext() if success else pytest.raises(ValueError):
trainer_2 = self.get_trainer(
model_init_device=model_init_device,
save_folder=os.path.join(save_folder, 'second'),
save_filename=save_filename,
save_interval=save_interval,
eval_interval=save_interval,
fsdp_config=fsdp_config,
device=device,
precision='amp_fp16',
max_duration='1ep',
train_subset_num_batches=2,
load_path=resume_file, # <-- resume training from file
)
trainer_2.fit()
trainer_2.close()

_assert_checkpoints_equivalent(
save_folder / 'first' / final_checkpoint,
save_folder / 'second' / final_checkpoint,
)

@pytest.mark.parametrize('spin_dataloaders', [False, True])
def test_spin_dataloaders(
self,
Expand Down Expand Up @@ -1674,8 +1574,8 @@ def test_format_load_path(self, tmp_path: pathlib.Path):
os.path.join(save_folder, 'second', 'latest-rank{rank}.pt'),
)

@staticmethod
def _assert_expected_num_checkpoints(
self,
save_folder: str,
save_interval: str,
num_epochs: int,
Expand Down
85 changes: 83 additions & 2 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tests.common import RandomClassificationDataset, deep_compare
from tests.common.compare import deep_compare
from tests.common.markers import world_size
from tests.trainer.test_checkpoint import TestCheckpointResumption, _assert_checkpoints_equivalent


# This model is to be used explicitly for this unit test because some old reference checkpoints
Expand Down Expand Up @@ -120,7 +121,7 @@ def get_trainer(
train_metrics=train_metrics,
val_metrics=val_metrics,
)
model.to(model_init_device)
model.module.to(model_init_device)
dataset = RandomClassificationDataset(shape=(num_features,), size=128)
dataloader = DataLoader(
dataset,
Expand Down Expand Up @@ -325,7 +326,6 @@ def test_fsdp_full_state_dict_load(
autoresume=autoresume,
optimizer=optimizer,
fsdp_config=fsdp_config,
save_weights_only=save_weights_only,
)
trainer1.fit()
state_dict_from_trainer1 = trainer1.state.state_dict()
Expand Down Expand Up @@ -1127,3 +1127,84 @@ def set_up_planner(

trainer2.fit()
trainer2.close()


@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('use_orig_params', [True, False])
@pytest.mark.parametrize('sync_module_states', [True, False])
@pytest.mark.parametrize('model_1_init_device', ['cpu', 'meta'])
@pytest.mark.parametrize('model_2_init_device', ['cpu', 'meta'])
@pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*')
@pytest.mark.filterwarnings(
'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*',
)
def test_fsdp_monolith_resumption(
device: str,
world_size: int,
use_orig_params: bool,
sync_module_states: bool,
tmp_path: pathlib.Path,
model_1_init_device: str,
model_2_init_device: str,
):
save_interval = '1ba'
save_filename = 'ba{batch}-rank{rank}.pt'
resume_file = 'ba1-rank{rank}.pt'
final_checkpoint = 'latest-rank{rank}.pt'
fsdp_config = FSDPConfig(
use_orig_params=use_orig_params,
sync_module_states=sync_module_states,
state_dict_type='full',
)

# All ranks use rank 0 folder
tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = pathlib.Path(tmp_paths[0])

trainer_1 = get_trainer(
save_folder=os.path.join(save_folder, 'first'),
save_filename=save_filename,
save_interval=save_interval,
fsdp_config=fsdp_config,
precision='amp_fp16',
max_duration='1ep',
)

trainer_1.fit()
trainer_1.close()

TestCheckpointResumption._assert_expected_num_checkpoints(
save_folder=os.path.join(save_folder, 'first'),
save_interval=save_interval,
num_epochs=1, # set in get_trainer()
num_batches_per_epoch=8, # set in get_trainer()
is_deepspeed=False,
)

resume_file = os.path.join(save_folder, 'first', resume_file)
model_init_device = [model_1_init_device, model_2_init_device][dist.get_global_rank()]
fsdp_config_dict = dataclasses.asdict(fsdp_config)
fsdp_config_dict['load_monolith_rank0_only'] = True
fsdp_config = FSDPConfig(**fsdp_config_dict)

success = (sync_module_states == True and model_1_init_device == 'cpu')

with (does_not_raise() if success else pytest.raises(ValueError)):
trainer_2 = get_trainer(
model_init_device=model_init_device,
save_folder=os.path.join(save_folder, 'second'),
save_filename=save_filename,
save_interval=save_interval,
fsdp_config=fsdp_config,
precision='amp_fp16',
max_duration='1ep',
load_path=resume_file, # <-- resume training from file
)
trainer_2.fit()
trainer_2.close()

_assert_checkpoints_equivalent(
save_folder / 'first' / final_checkpoint,
save_folder / 'second' / final_checkpoint,
)
Loading