Skip to content

Commit

Permalink
Merge pull request #3 from sanketpurandare/run_est
Browse files Browse the repository at this point in the history
Merge Andrew's updates
  • Loading branch information
sanketpurandare authored Oct 28, 2024
2 parents 94e4a1f + adf2b31 commit e7f5335
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 58 deletions.
8 changes: 8 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@
.github/scripts/gql_mocks.json linguist-generated=true
third_party/LICENSES_BUNDLED.txt linguist-generated=true
tools/build/bazel/requirements.txt linguist-generated=true
torch/distributed/_tools/a100_models/bmm.joblib filter=lfs diff=lfs merge=lfs -text
torch/distributed/_tools/a100_models/mm.joblib filter=lfs diff=lfs merge=lfs -text
torch/distributed/_tools/a100_models/sdpa.joblib filter=lfs diff=lfs merge=lfs -text
torch/distributed/_tools/a100_models/sdpa_backward.joblib filter=lfs diff=lfs merge=lfs -text
torch/distributed/_tools/h100_models/sdpa.joblib filter=lfs diff=lfs merge=lfs -text
torch/distributed/_tools/h100_models/sdpa_backward.joblib filter=lfs diff=lfs merge=lfs -text
torch/distributed/_tools/h100_models/bmm.joblib filter=lfs diff=lfs merge=lfs -text
torch/distributed/_tools/h100_models/mm.joblib filter=lfs diff=lfs merge=lfs -text
20 changes: 17 additions & 3 deletions test/distributed/_tools/test_runtime_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _init_model_and_args(
def test_transformer_runtime(
self,
):
print("Transformer Test")
"""Runs a basic GPT-2 model"""
vocab_size = 8192
bsz, seq_len = 8, 1024
Expand All @@ -155,22 +156,30 @@ def test_transformer_runtime(
roofline_estimate = self._runtime_estimate(
"operator-level-cost-model", self._train_step, fake_args
)
learned_estimate = self._runtime_estimate(
"operator-level-learned-model", self._train_step, fake_args
)
benchmark_accuracy = actual_runtime / benchmark_estimate
roofline_accuracy = actual_runtime / roofline_estimate
learned_accuracy = actual_runtime / learned_estimate
print(
f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}"
f"\n Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}"
f"\nActual: {actual_runtime} Roofline Estimate: {roofline_estimate} Accuracy: {roofline_accuracy}"
f"\nActual: {actual_runtime} Learned Estimate: {learned_estimate} Accuracy: {learned_accuracy}"
)

# No accuracy check for benchmark in CI as it is highly variable
# self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2)
# self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3)


@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_conv_model_runtime(
self,
):
"""Runs a simple CNN model"""
print("CNN Test")
num_classes = 100
bsz, img_sz = 256, 128
model_args = ConvArgs(img_sz, num_classes)
Expand All @@ -184,11 +193,16 @@ def test_conv_model_runtime(
roofline_estimate = self._runtime_estimate(
"operator-level-cost-model", self._train_step, fake_args
)
learned_estimate = self._runtime_estimate(
"operator-level-learned-model", self._train_step, fake_args
)
benchmark_accuracy = actual_runtime / benchmark_estimate
roofline_accuracy = actual_runtime / roofline_estimate
learned_accuracy = actual_runtime / learned_estimate
print(
f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}\n"
f"Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}"
f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}"
f"\nActual: {actual_runtime} Roofline Estimate: {roofline_estimate} Accuracy: {roofline_accuracy}"
f"\nActual: {actual_runtime} Learned Estimate: {learned_estimate} Accuracy: {learned_accuracy}"
)
# No accuracy check for benchmark in CI as it is highly variable
# self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2)
Expand Down
3 changes: 3 additions & 0 deletions torch/distributed/_tools/a100_models/bmm.joblib
Git LFS file not shown
3 changes: 3 additions & 0 deletions torch/distributed/_tools/a100_models/mm.joblib
Git LFS file not shown
3 changes: 3 additions & 0 deletions torch/distributed/_tools/a100_models/sdpa.joblib
Git LFS file not shown
3 changes: 3 additions & 0 deletions torch/distributed/_tools/a100_models/sdpa_backward.joblib
Git LFS file not shown
3 changes: 3 additions & 0 deletions torch/distributed/_tools/h100_models/bmm.joblib
Git LFS file not shown
3 changes: 3 additions & 0 deletions torch/distributed/_tools/h100_models/mm.joblib
Git LFS file not shown
3 changes: 3 additions & 0 deletions torch/distributed/_tools/h100_models/sdpa.joblib
Git LFS file not shown
3 changes: 3 additions & 0 deletions torch/distributed/_tools/h100_models/sdpa_backward.joblib
Git LFS file not shown
Loading

0 comments on commit e7f5335

Please sign in to comment.