-
Notifications
You must be signed in to change notification settings - Fork 455
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
193 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from multiprocessing.connection import Connection | ||
from multiprocessing import Pipe, Process | ||
from contextlib import contextmanager | ||
import os | ||
import subprocess | ||
|
||
# Adapted from optimum-benchmark, I don't trust pytorch peak memory memory info when external libs are used. | ||
class MemoryTracker: | ||
def __init__(self): | ||
self.peak_memory: int = 0 | ||
self.device_index = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) | ||
|
||
@contextmanager | ||
def track(self, interval: float = 0.1): | ||
print(f"Tracking memory for device {self.device_index}") | ||
yield from self._track_peak_memory(interval) | ||
|
||
def _track_peak_memory(self, interval: float): | ||
child_connection, parent_connection = Pipe() | ||
# instantiate process | ||
mem_process: Process = PeakMemoryMeasureProcess(self.device_index, child_connection, interval) | ||
mem_process.start() | ||
# wait until we get memory | ||
parent_connection.recv() | ||
yield | ||
# start parent connection | ||
parent_connection.send(0) | ||
# receive peak memory | ||
self.peak_memory = parent_connection.recv() | ||
|
||
|
||
class PeakMemoryMeasureProcess(Process): | ||
def __init__(self, device_index: int, child_connection: Connection, interval: float): | ||
super().__init__() | ||
self.device_index = device_index | ||
self.interval = interval | ||
self.connection = child_connection | ||
self.mem_usage = 0 | ||
|
||
def run(self): | ||
self.connection.send(0) | ||
stop = False | ||
|
||
command = f"nvidia-smi --query-gpu=memory.used --format=csv --id={self.device_index}" | ||
|
||
while True: | ||
# py3nvml is broken since it outputs only the reserved memory, and nvidia-smi has only the MiB precision. | ||
gpu_mem_mb = subprocess.check_output(command.split()).decode("ascii").split("\n")[1].split()[0] | ||
gpu_mem_mb = int(gpu_mem_mb) * 1.048576 | ||
self.mem_usage = max(self.mem_usage, gpu_mem_mb) | ||
|
||
if stop: | ||
break | ||
stop = self.connection.poll(self.interval) | ||
|
||
# send results to parent pipe | ||
self.connection.send(self.mem_usage) | ||
self.connection.close() |