Skip to content

Commit

Permalink
Use dataclass instead of namedtuple for reduction status.
Browse files Browse the repository at this point in the history
This fixes a few mypy errors.
  • Loading branch information
jschwartzentruber committed May 30, 2024
1 parent f95118e commit 3d08359
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 44 deletions.
17 changes: 8 additions & 9 deletions grizzly/common/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from contextlib import closing, contextmanager
from copy import deepcopy
from dataclasses import dataclass
from dataclasses import astuple, dataclass
from json import dumps, loads
from logging import getLogger
from os import getpid
Expand All @@ -19,7 +19,6 @@
Dict,
Generator,
List,
NamedTuple,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -826,7 +825,8 @@ def start(
return status


class ReductionStep(NamedTuple):
@dataclass(frozen=True)
class ReductionStep:
name: str
duration: Optional[float]
successes: Optional[int]
Expand All @@ -835,7 +835,8 @@ class ReductionStep(NamedTuple):
iterations: Optional[int]


class _MilestoneTimer(NamedTuple):
@dataclass(frozen=True)
class _MilestoneTimer:
name: str
start: float
attempts: int
Expand Down Expand Up @@ -980,8 +981,8 @@ def report(self, force: bool = False, report_rate: float = REPORT_RATE) -> bool:
analysis = dumps(self.analysis)
run_params = dumps(self.run_params)
sig_info = dumps(self.signature_info)
finished = dumps(self.finished_steps)
in_prog = dumps(self._in_progress_steps)
finished = dumps([astuple(step) for step in self.finished_steps])
in_prog = dumps([astuple(step) for step in self._in_progress_steps])
strategies = dumps(self.strategies)
last_reports = dumps(self.last_reports)

Expand Down Expand Up @@ -1126,9 +1127,7 @@ def load_all(
status.run_params = loads(entry[4])
status.signature_info = loads(entry[5])
status.successes = entry[6]
status.finished_steps = [
ReductionStep._make(step) for step in loads(entry[8])
]
status.finished_steps = [ReductionStep(*step) for step in loads(entry[8])]
status._in_progress_steps = [
_MilestoneTimer(*step) for step in loads(entry[9])
]
Expand Down
64 changes: 32 additions & 32 deletions grizzly/common/status_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Manage Grizzly status reports."""
from argparse import ArgumentParser
from collections import defaultdict
from dataclasses import astuple, fields
from datetime import timedelta
from functools import partial
from itertools import zip_longest
Expand All @@ -14,9 +15,8 @@
from pathlib import Path
from platform import system
from re import match
from re import sub as re_sub
from time import gmtime, localtime, strftime
from typing import Callable, Dict, Generator, List, Optional, Set, Tuple, Type
from typing import Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, Type

from psutil import cpu_count, cpu_percent, disk_usage, getloadavg, virtual_memory

Expand Down Expand Up @@ -529,51 +529,52 @@ class _TableFormatter:

def __init__(
self,
columns: Tuple[str, ...],
formatters: Tuple[Optional[Callable[..., str]]],
column_names: Tuple[str, ...],
formatters: Tuple[Optional[Callable[..., str]], ...],
vsep: str = " | ",
hsep: str = "-",
) -> None:
"""Initialize a TableFormatter instance.
Arguments:
columns: List of column names for the table header.
column_names: List of column names for the table header.
formatters: List of format functions for each column.
None will result in hiding that column.
vsep: Vertical separation between columns.
hsep: Horizontal separation between header and data.
"""
assert len(columns) == len(formatters)
assert len(column_names) == len(formatters)
self._columns = tuple(
column for (column, fmt) in zip(columns, formatters) if fmt is not None
column for (column, fmt) in zip(column_names, formatters) if fmt is not None
)
self._formatters = formatters
self._formatters = tuple(formatters)
self._vsep = vsep
self._hsep = hsep

def format_rows(self, rows: List[ReductionStep]) -> Generator[str, None, None]:
def format_rows(self, rows: Iterable[ReductionStep]) -> Generator[str, None, None]:
"""Format rows as a table and return a line generator.
Arguments:
rows: Tabular data. Each row must be the same length as
`columns` passed to `__init__`.
`column_names` passed to `__init__`.
Yields:
Each line of formatted tabular data.
"""
max_width = [len(col) for col in self._columns]
formatted: List[List[str]] = []
for row in rows:
assert len(row) == len(self._formatters)
data = astuple(row)
assert len(data) == len(self._formatters)
formatted.append([])
offset = 0
for idx, (data, formatter) in enumerate(zip(row, self._formatters)):
for idx, (datum, formatter) in enumerate(zip(data, self._formatters)):
if formatter is None:
offset += 1
continue
data = formatter(data)
max_width[idx - offset] = max(max_width[idx - offset], len(data))
formatted[-1].append(data)
datum_str = formatter(datum)
max_width[idx - offset] = max(max_width[idx - offset], len(datum_str))
formatted[-1].append(datum_str)

# build a format_str to space out the columns with separators using `max_width`
# the first column is left-aligned, and other fields are right-aligned.
Expand All @@ -588,17 +589,15 @@ def format_rows(self, rows: List[ReductionStep]) -> Generator[str, None, None]:


def _format_seconds(duration: float) -> str:
# format H:M:S, and then remove all leading zeros with regex
# format H:M:S, without leading zeros
minutes, seconds = divmod(int(duration), 60)
hours, minutes = divmod(minutes, 60)
result = re_sub("^[0:]*", "", f"{hours}:{minutes:02d}:{seconds:02d}")
# if the result is all zeroes, ensure one zero is output
if not result:
result = "0"
if hours:
return f"{hours}:{minutes:02d}:{seconds:02d}"

Check warning on line 596 in grizzly/common/status_reporter.py

View check run for this annotation

Codecov / codecov/patch

grizzly/common/status_reporter.py#L596

Added line #L596 was not covered by tests
if minutes:
return f"{minutes}:{seconds:02d}"
# a bare number is ambiguous. output 's' for seconds
if ":" not in result:
result += "s"
return result
return f"{seconds}s"


def _format_duration(duration: Optional[int], total: float = 0) -> str:
Expand Down Expand Up @@ -823,15 +822,16 @@ def summary( # pylint: disable=arguments-differ
entries.append(self._last_reports_entry(report))
if report.total and report.original:
tabulator = _TableFormatter(
ReductionStep._fields,
ReductionStep(
name=str,
# duration and attempts are % of total/last, size % of init/1st
duration=partial(_format_duration, total=report.total.duration),
attempts=partial(_format_number, total=report.total.attempts),
successes=partial(_format_number, total=report.total.successes),
iterations=None, # hide
size=partial(_format_number, total=report.original.size),
tuple(f.name for f in fields(ReductionStep)),
# this tuple must match the order of fields
# defined on ReductionStep!
(
str, # name
partial(_format_duration, total=report.total.duration),
partial(_format_number, total=report.total.successes),
partial(_format_number, total=report.total.attempts),
partial(_format_number, total=report.original.size),
None, # iterations (hidden)
),
)
lines.extend(tabulator.format_rows(report.finished_steps))
Expand Down
9 changes: 6 additions & 3 deletions grizzly/common/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# pylint: disable=protected-access

from contextlib import closing
from dataclasses import fields
from itertools import count
from multiprocessing import Event, Process
from sqlite3 import connect
Expand Down Expand Up @@ -523,10 +524,12 @@ def test_reduce_status_06(mocker, tmp_path):
assert len(loaded_status.finished_steps) == 2
assert len(loaded_status._in_progress_steps) == 0
assert loaded_status.original == status.original
for field in ReductionStep._fields:
if field == "size":
for field in fields(ReductionStep):
if field.name == "size":
continue
assert getattr(loaded_status.total, field) == getattr(status.total, field)
assert getattr(loaded_status.total, field.name) == getattr(
status.total, field.name
)
assert loaded_status.total.size is None


Expand Down

0 comments on commit 3d08359

Please sign in to comment.