Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallelism #1521

Merged
merged 76 commits into from
Sep 27, 2024
Merged

Tensor Parallelism #1521

merged 76 commits into from
Sep 27, 2024

Conversation

eitanturok
Copy link
Contributor

@eitanturok eitanturok commented Sep 12, 2024

Implement Tensor Parallelism (TP) in foundry.

To do:

  • Make tp_strategy registry
  • Check TP: if world_size == 1 or training an MoE, don't apply TP
  • Get the same loss with fsdp and fsdp-tp
  • Add Tests

Updates:
I compared training 125m param models for 100 steps on c4 with tp-fsdp VS fsdp.

  • brown = tp-fsdp
  • gray = fsdp

loss_train_total:
loss_train_total

throughput_batches_per_sec:
throughput_batches_per_sec

memory_peak_reserved_mem:
memory_peak_reserved_mem

It is okay that we don't see performance improvements here yet -- we'll get those later, in follow up PRs.

@eitanturok
Copy link
Contributor Author

eitanturok commented Sep 12, 2024

Currently, the ffn strategy gives different results when we train fsdp vs fsdp-tp.

See mcli runs:

  • mpt-125m-tp-fsdp-IHKG5s
  • mpt-125m-fsdp-8OgGjt

and here are their losses which are visibly different.

image

Currently investigating, though I think this may have to do more with my specific layer plan/strategy then with anything else.

@snarayan21
Copy link
Contributor

Will review once loss discrepancy has been addressed. good to see it's at least mechanically working though

@eitanturok
Copy link
Contributor Author

Should we include a tp-mpt-125m.yaml in the repo?

@dakinggg
Copy link
Collaborator

@eitanturok does checkpointing work now?

Copy link
Collaborator

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

llmfoundry/command_utils/train.py Outdated Show resolved Hide resolved
@mvpatel2000
Copy link
Collaborator

@eitanturok does checkpointing work now?

No and it won't with FSDPv1.

@dakinggg
Copy link
Collaborator

@mvpatel2000 @eitanturok ok lets leave the yaml out then

@dakinggg
Copy link
Collaborator

Also can we log a warning when using TP that checkpointing is known to not work?

@eitanturok
Copy link
Contributor Author

@dakinggg I just added a warning that checkpointing does not work + give a link to the exact pytorch issue.

One of the tests verifies that the trainer works but it takes too long cause it downloads a dataset. So I will fix this and I think we will be good to go.

@eitanturok eitanturok merged commit ee45600 into mosaicml:main Sep 27, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants