-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add memory snapshot callback #2788
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this Cheng! Overall, this PR looks good to me barring some style suggestions.
Additionally, here are two higher level use cases to consider for this PR:
- First, should we add this callback in the profiler instead so we support having N memory traces instead of just 1 memory trace? I can see pros and cons for both so just wanted to make sure we are aligned on this decision cc: @mvpatel2000
- Second, is it possible to add this memory trace when a run oom's?
Here's the original ticket that this PR addresses that has some ideas on the second point:
https://databricks.atlassian.net/browse/GRT-2231
Co-authored-by: Charles Tang <[email protected]>
My thoughts on the two questions
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Can you please provide a manual test and a screenshot just to show it works on a real run?
- I agree this should be separate from profiler. It's convenient to have modularity, especially because profiler slows things down a lot
- I think merging OOM capture into this callback is reasonable given it's the same torch system. I suggest follow-on PR since this is basically done
Agreed with 1 and 2 from Mihir. Up to you whether or not you want to create follow up PR or add the OOM callback to this PR 👍 |
* add memory snapshot callback * fix check * fix check * Update composer/callbacks/memory_snapshot.py Co-authored-by: Charles Tang <[email protected]> * address comments * fix upload filename print * fix cpu check * fix cpu check * add pt version check * add pt version check * fix remote upload * fix test * fix cpu test * fix gpu test * fix gpu test * fix gpu test * fix gpu test * fix gpu test * do plotting before saving * fix test * fix test * fix test --------- Co-authored-by: Charles Tang <[email protected]> Co-authored-by: Mihir Patel <[email protected]>
* add memory snapshot callback * fix check * fix check * Update composer/callbacks/memory_snapshot.py Co-authored-by: Charles Tang <[email protected]> * address comments * fix upload filename print * fix cpu check * fix cpu check * add pt version check * add pt version check * fix remote upload * fix test * fix cpu test * fix gpu test * fix gpu test * fix gpu test * fix gpu test * fix gpu test * do plotting before saving * fix test * fix test * fix test --------- Co-authored-by: Charles Tang <[email protected]> Co-authored-by: Mihir Patel <[email protected]>
This PR adds a callback to capture memory snapshot. A html file with interactive visualization would be generated if enabled. https://github.com/pytorch/pytorch.github.io/blob/site/assets/images/understanding-gpu-memory-1/snapshot.html shows an example visualization.
To enable the callback, in the yaml config file, also need this change in the foundry mosaicml/llm-foundry#810.
Below is an example memory snapshot over three batches of MPT-7B with micro batch size 4, FSDP full shard, on 8XH100
For more details on the memory snapshot, refer to