Skip to content

Commit

Permalink
Add plugin interface and add GPU mem snapshot as an example (#777)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #777

WTTS.
* Add basic `Plugin` interface to extend bulk gen functionality
```
class Plugin(ABC):
    def step(self) -> None:
       ...
    def shutdown(self) -> None
       ...
```

- Add GPU mem snapshot as an example.
- Add `is_started` status to TNT `MemorySnapshotProfiler` and expose `log_memory_snapshot`

Reviewed By: skcoirz

Differential Revision: D55724465

fbshipit-source-id: 87226f9be2da801119d575a9b43a48109aa86a82
  • Loading branch information
jaconey authored and facebook-github-bot committed Apr 5, 2024
1 parent 4c90a5f commit 23191d4
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions torchtnt/utils/memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
)

self.step_num: int = 0
self.is_started: bool = False

if not is_torch_version_geq_2_0():
raise RuntimeError("CUDA memory snapshot requires torch>=2.0")
Expand All @@ -146,20 +147,26 @@ def __init__(
)

def start(self) -> None:
if self.is_started:
return
if not torch.cuda.is_available():
logger.warn("CUDA unavailable. Not recording memory history.")
return

logger.info("Starting to record memory history.")
torch.cuda.memory._record_memory_history(max_entries=self.params.max_entries)
self.is_started = True

def stop(self) -> None:
if not self.is_started:
return
if not torch.cuda.is_available():
logger.warn("CUDA unavailable. Not recording memory history.")
return

logger.info("Stopping recording memory history.")
torch.cuda.memory._record_memory_history(enabled=None)
self.is_started = False

def step(self) -> None:
self.step_num += 1
Expand All @@ -169,7 +176,10 @@ def step(self) -> None:
):
self.start()
if self.params.stop_step is not None and self.step_num == self.params.stop_step:
log_memory_snapshot(
output_dir=self.output_dir, file_prefix=f"step_{self.step_num}"
)
self.log_memory_snapshot()
self.stop()

def log_memory_snapshot(self) -> None:
log_memory_snapshot(
output_dir=self.output_dir, file_prefix=f"step_{self.step_num}"
)

0 comments on commit 23191d4

Please sign in to comment.