Skip to content

Commit

Permalink
[ZeRO-3] Ensured passing neox deepspeed_config when using partitioned…
Browse files Browse the repository at this point in the history
… init (EleutherAI#1191)

* added ds zero.Init() to get_model

* Clean up conditional with block

* pre-commit

* ensured deepspeed configs are passed to init

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
R0n12 and Quentin-Anthony authored Apr 1, 2024
1 parent 51a7de9 commit 01657aa
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ def get_model(neox_args, use_cache=False):
old_use_mup = neox_args.use_mup
neox_args.use_mup = False

with deepspeed.zero.Init() if neox_args.zero_stage == 3 else nullcontext() as gs:
with deepspeed.zero.Init(
config_dict_or_path=neox_args.deespeed_config
) if neox_args.zero_stage == 3 else nullcontext() as gs:
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
Expand Down

0 comments on commit 01657aa

Please sign in to comment.