Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving communication overlap for the case of multi kernel queue usage #1308

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

youngeunkwon0405
Copy link

Description

The current TP-overlap relay is on a single kernel queue to configure launch ordering to control compute-communication overlap, which fails to overlap when multi kernel queue is used.

This PR enforces launch ordering using the LaunchCompletionEvent feature between the communication kernel and the compute kernel to ensure the overlap.

This feature is specific to Hopper and applies only to bulk overlap cases.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@youngeunkwon0405
Copy link
Author

@erhoo82 Hi Sangkug, this is a PR for launch ordering work. Could you please assign a reviewer?

Comment on lines +1905 to +1911
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for duplicated kernel launch code.

Suggested change
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
}
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @denera, the suggested coding style causes a compile error, which is why I had to do a duplicated kernel launch...
Since both SETUP_LAUNCH_CONFIG and callranks_** are define functions, there is a variable scope issue. The compute kernel call should be in the same or lower scope than the SETUP kernel. This issue applies the same to the other comments. If you have a better solution for this, please let me know.

@denera
Copy link
Collaborator

denera commented Nov 4, 2024

@youngeunkwon0405 The TP overlap unit tests explicitly set CUDA_DEVICE_MAX_CONNECTIONS=1 in tests/pytorch/distributed/test_comm_gemm_overlap.py:43. Could you update this to not set the environment variable for Hopper so the changes in this PR are tested in our CI?

Also please launch the L1 tests with /te-ci pytorch L1 when you update the unit tests. Thanks!

@youngeunkwon0405
Copy link
Author

youngeunkwon0405 commented Nov 7, 2024

@youngeunkwon0405 The TP overlap unit tests explicitly set CUDA_DEVICE_MAX_CONNECTIONS=1 in tests/pytorch/distributed/test_comm_gemm_overlap.py:43. Could you update this to not set the environment variable for Hopper so the changes in this PR are tested in our CI?

Also please launch the L1 tests with /te-ci pytorch L1 when you update the unit tests. Thanks!

Hi @denera, I have updated the test_comm_gemm_overlap.py file in the latest commit. Will it meet your expectations?

Also, could you please elaborate on more details about the following? I am new to writing a test and also new to the ci process.

please launch the L1 tests with /te-ci pytorch L1 when you update the unit tests.

I have tested the modified test case only and the following was a new result.
============================= test session starts ==============================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0 -- /usr/bin/python
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/workspace/.hypothesis/examples')
rootdir: /lustre/fsw/coreai_dlalgo_llm/youngeunk/nemo/nemo.dev/mount/TransformerEngine-youngeunk
plugins: xdoctest-1.0.2, typeguard-4.3.0, xdist-3.6.1, shard-0.1.2, rerunfailures-14.0, mock-3.14.0, flakefinder-1.1.0, hypothesis-5.35.1, hydra-core-1.3.2, anyio-4.4.0
collecting ... collected 6 items
Running 6 items in this shard: tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[ALL-GATHER - BF16 - 1 connections], tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[REDUCE-SCATTER - BF16 - 1 connections], tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[REDUCE-SCATTER - FP8 - 1 connections], tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[ALL-GATHER - BF16 - 8 connections], tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[REDUCE-SCATTER - BF16 - 8 connections], tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[REDUCE-SCATTER - FP8 - 8 connections]

../lustre/fsw/coreai_dlalgo_llm/youngeunk/nemo/nemo.dev/mount/TransformerEngine-youngeunk/tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[ALL-GATHER - BF16 - 1 connections] PASSED
../lustre/fsw/coreai_dlalgo_llm/youngeunk/nemo/nemo.dev/mount/TransformerEngine-youngeunk/tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[REDUCE-SCATTER - BF16 - 1 connections] PASSED
../lustre/fsw/coreai_dlalgo_llm/youngeunk/nemo/nemo.dev/mount/TransformerEngine-youngeunk/tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[REDUCE-SCATTER - FP8 - 1 connections] PASSED
../lustre/fsw/coreai_dlalgo_llm/youngeunk/nemo/nemo.dev/mount/TransformerEngine-youngeunk/tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[ALL-GATHER - BF16 - 8 connections] PASSED
../lustre/fsw/coreai_dlalgo_llm/youngeunk/nemo/nemo.dev/mount/TransformerEngine-youngeunk/tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[REDUCE-SCATTER - BF16 - 8 connections] PASSED
../lustre/fsw/coreai_dlalgo_llm/youngeunk/nemo/nemo.dev/mount/TransformerEngine-youngeunk/tests/pytorch/distributed/test_comm_gemm_overlap.py::test_bulk_overlaps[REDUCE-SCATTER - FP8 - 8 connections] PASSED

========================= 6 passed in 93.35s (0:01:33) =========================

@denera
Copy link
Collaborator

denera commented Nov 14, 2024

/te-ci pytorch L1

Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending rebase on latest TE/main and clean CI results.

@youngeunkwon0405
Copy link
Author

@denera Rebased with the main. Could you please let me know what the next step would be?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants