Skip to content

Commit

Permalink
Infrastructure for memory usage tests (#299)
Browse files Browse the repository at this point in the history
* added MemoryThreshold context manager

* added unit test for memory usage of median computation
  • Loading branch information
emolter authored Oct 14, 2024
1 parent 28b1245 commit dfe1d6d
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 0 deletions.
1 change: 1 addition & 0 deletions changes/299.general.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add infrastructure for testing memory usage
51 changes: 51 additions & 0 deletions src/stcal/testing_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tracemalloc

MEMORY_UNIT_CONVERSION = {"B": 1, "KB": 1024, "MB": 1024 ** 2, "GB": 1024 ** 3, "TB": 1024 ** 4}

class MemoryThresholdExceeded(Exception):
pass


class MemoryThreshold:
"""
Context manager to check peak memory usage against an expected threshold.
example usage:
with MemoryThreshold(expected_usage):
# code that should not exceed expected
If the code in the with statement uses more than the expected_usage
memory a ``MemoryThresholdExceeded`` exception
will be raised.
Note that this class does not prevent allocations beyond the threshold
and only checks the actual peak allocations to the threshold at the
end of the with statement.
"""

def __init__(self, expected_usage):
"""
Parameters
----------
expected_usage : str
Expected peak memory usage expressed as a whitespace-separated string
with a number and a memory unit (e.g. "100 KB").
Supported units are "B", "KB", "MB", "GB", "TB".
"""
expected, self.units = expected_usage.upper().split()
self.expected_usage_bytes = float(expected) * MEMORY_UNIT_CONVERSION[self.units]

def __enter__(self):
tracemalloc.start()
return self

def __exit__(self, exc_type, exc_value, traceback):
_, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

if peak > self.expected_usage_bytes:
scaling = MEMORY_UNIT_CONVERSION[self.units]
msg = ("Peak memory usage exceeded expected usage: "
f"{peak / scaling:.2f} {self.units} > "
f"{self.expected_usage_bytes / scaling:.2f} {self.units} ")
raise MemoryThresholdExceeded(msg)
38 changes: 38 additions & 0 deletions tests/outlier_detection/test_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_OnDiskMedian,
nanmedian3D,
)
from stcal.testing_helpers import MemoryThreshold


def test_disk_appendable_array(tmp_path):
Expand Down Expand Up @@ -194,3 +195,40 @@ def test_nanmedian3D():

assert med.dtype == np.float32
assert np.allclose(med, np.nanmedian(cube, axis=0), equal_nan=True)


@pytest.mark.parametrize("in_memory", [True, False])
def test_memory_computer(in_memory, tmp_path):
"""
Analytically calculate how much memory the median computation
is supposed to take, then ensure that the implementation
stays near that.
in_memory=True case allocates the following memory:
- one cube size
- median array == one frame size
in_memory=False case allocates the following memory:
- one buffer size, which by default is the frame size
- median array == one frame size
add a half-frame-size buffer to the expected memory usage in both cases
"""
shp = (20, 500, 500)
cube_size = np.dtype("float32").itemsize * shp[0] * shp[1] * shp[2] #bytes
frame_size = cube_size / shp[0]

# calculate expected memory usage
if in_memory:
expected_mem = cube_size + frame_size*1.5
else:
expected_mem = frame_size * 2.5

# compute the median while tracking memory usage
with MemoryThreshold(str(expected_mem) + " B"):
computer = MedianComputer(shp, in_memory=in_memory, tempdir=tmp_path)
for i in range(shp[0]):
frame = np.full(shp[1:], i, dtype=np.float32)
computer.append(frame, i)
del frame
computer.evaluate()
16 changes: 16 additions & 0 deletions tests/test_infrastructure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Tests of custom testing infrastructure"""

import pytest
import numpy as np
from stcal.testing_helpers import MemoryThreshold, MemoryThresholdExceeded


def test_memory_threshold():
with MemoryThreshold("10 KB"):
buff = np.ones(1000, dtype=np.uint8)


def test_memory_threshold_exceeded():
with pytest.raises(MemoryThresholdExceeded):
with MemoryThreshold("500. B"):
buff = np.ones(10000, dtype=np.uint8)

0 comments on commit dfe1d6d

Please sign in to comment.