Skip to content

Commit

Permalink
Fix FSDP checkpoint bug (missing keys)
Browse files Browse the repository at this point in the history
  • Loading branch information
chiayewken committed Mar 28, 2023
1 parent 85e2d96 commit f3e8815
Showing 1 changed file with 56 additions and 3 deletions.
59 changes: 56 additions & 3 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import FSDPStrategy
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp import (
MixedPrecision,
FullyShardedDataParallel,
StateDictType,
FullStateDictConfig,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Adafactor
Expand All @@ -18,6 +23,51 @@
os.environ["TOKENIZERS_PARALLELISM"] = "true"


class MyFSDPStrategy(FSDPStrategy):
@staticmethod
def clean_up_state_names(state: dict, prefix="_forward_module.") -> dict:
"""
To restore original transformer state dict, remove FSDP name prefix from keys
"""
new = {}
for k in state.keys():
assert k.startswith(prefix)
new[k[slice(len(prefix), len(k))]] = state[k]
return new

def lightning_module_state_dict(self):
"""
Returns model state for checkpointing.
Original FSDPStrategy returns state of unwrapped lightning module which is incomplete
But we need the FSDP-wrapped module to get the full state dict to load checkpoints properly
See: https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html?highlight=transformer
"""
# model = self.lightning_module
model = self.model
assert model is not None

with FullyShardedDataParallel.state_dict_type(
module=model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
state = model.state_dict()
state = self.clean_up_state_names(state)
print(dict(my_fsdp=type(model), state=len(state), io=self.checkpoint_io))
return state

def save_checkpoint(self, checkpoint: dict, filepath: str, **kwargs) -> None:
"""
Save model/training states as a checkpoint file through state-dump and file-write.
Default TorchCheckpointIO saves dict to bytes and bytes to file, which may take up more cpu memory
So we bypass it and save direct from dict to file
"""
if self.is_global_zero:
print(dict(save_checkpoint_unused_kwargs=kwargs))
os.makedirs(os.path.dirname(filepath), exist_ok=True)
torch.save(checkpoint, filepath)


def init_args(raw_args):
# Training args should follow FlanT5 (Scaling Instruction-Finetuned Language Models)
parser = argparse.ArgumentParser()
Expand All @@ -35,6 +85,7 @@ def init_args(raw_args):
parser.add_argument("--use_compile", action="store_true")
parser.add_argument("--use_gradient_checkpointing", action="store_true")
parser.add_argument("--use_fsdp", action="store_true")
parser.add_argument("--debug", action="store_true")

args = parser.parse_args(raw_args)
return args
Expand All @@ -48,6 +99,7 @@ def __init__(self, hparams):
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.hparams.model_name_or_path
)
print(dict(orig_state_dict=len(self.model.state_dict())))
if self.hparams.use_compile:
self.model = torch.compile(self.model)
if self.hparams.use_gradient_checkpointing:
Expand Down Expand Up @@ -143,7 +195,7 @@ def main(raw_args=None):
if args.use_fsdp:
# https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/
# https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html
strategy = FSDPStrategy(
strategy = MyFSDPStrategy(
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={T5Block},
Expand All @@ -161,12 +213,13 @@ def main(raw_args=None):
precision="bf16-mixed",
accelerator="gpu",
strategy=strategy,
accumulate_grad_batches=args.gradient_accumulation_steps,
accumulate_grad_batches=1 if args.debug else args.gradient_accumulation_steps,
default_root_dir=args.output_dir,
gradient_clip_val=None if args.use_fsdp else 1.0,
max_epochs=args.train_epochs,
callbacks=[saver],
logger=False,
overfit_batches=10 if args.debug else 0,
)

trainer.fit(model)
Expand Down

0 comments on commit f3e8815

Please sign in to comment.