From 265db39da05c5032b40fd9bfa6ab3663ce572191 Mon Sep 17 00:00:00 2001 From: Tyson Smith Date: Thu, 18 Apr 2024 16:45:17 -0700 Subject: [PATCH] Add type hinting to status.py and status_reporter.py These is still some work to be done on the reduction related code. --- grizzly/common/status.py | 1060 +++++++++++++----------- grizzly/common/status_reporter.py | 487 ++++++----- grizzly/common/test_status.py | 24 +- grizzly/common/test_status_reporter.py | 11 +- 4 files changed, 840 insertions(+), 742 deletions(-) diff --git a/grizzly/common/status.py b/grizzly/common/status.py index 20cefa26..b25d2f5c 100644 --- a/grizzly/common/status.py +++ b/grizzly/common/status.py @@ -2,16 +2,32 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. """Manage Grizzly status reports.""" -from collections import defaultdict, namedtuple +from abc import ABC +from collections import defaultdict from contextlib import closing, contextmanager from copy import deepcopy +from dataclasses import dataclass from json import dumps, loads from logging import getLogger from os import getpid -from sqlite3 import OperationalError, connect +from pathlib import Path +from sqlite3 import Connection, OperationalError, connect from time import perf_counter, time +from typing import ( + Callable, + Dict, + Generator, + List, + NamedTuple, + Optional, + Set, + Tuple, + Union, + cast, +) -from ..common.utils import grz_tmp +from .reporter import FuzzManagerReporter +from .utils import grz_tmp __all__ = ("ReadOnlyStatus", "ReductionStatus", "Status", "SimpleStatus") __author__ = "Tyson Smith" @@ -33,20 +49,32 @@ LOG = getLogger(__name__) -ProfileEntry = namedtuple("ProfileEntry", "count max min name total") -ResultEntry = namedtuple("ResultEntry", "rid count desc") +@dataclass(eq=False, frozen=True) +class ProfileEntry: + count: int + max: float + min: float + name: str + total: float + +@dataclass(frozen=True) +class ResultEntry: + rid: str + count: int + desc: Optional[str] -def _db_version_check(con, expected=DB_VERSION): + +def _db_version_check(con: Connection, expected: int = DB_VERSION) -> bool: """Perform version check and remove obsolete tables if required. Args: - con (sqlite3.Connection): An open database connection. - expected (int): The latest database version. + con: An open database connection. + expected: The latest database version. Returns: - bool: True if database was reset otherwise False. + True if database was reset otherwise False. """ assert expected > 0 cur = con.cursor() @@ -57,7 +85,7 @@ def _db_version_check(con, expected=DB_VERSION): cur.execute("BEGIN EXCLUSIVE;") # check db version again while locked to avoid race cur.execute("PRAGMA user_version;") - version = cur.fetchone()[0] + version = cast(int, cur.fetchone()[0]) if version < expected: LOG.debug("db version %d < %d", version, expected) # remove ALL tables from the database @@ -73,18 +101,308 @@ def _db_version_check(con, expected=DB_VERSION): return False -class BaseStatus: +class SimpleResultCounter: + __slots__ = ("_count", "_desc", "pid") + + def __init__(self, pid: int) -> None: + assert pid >= 0 + self._count: Dict[str, int] = defaultdict(int) + self._desc: Dict[str, str] = {} + self.pid = pid + + def __iter__(self) -> Generator[ResultEntry, None, None]: + """Yield all result data. + + Args: + None + + Yields: + Contains ID, count and description for each result entry. + """ + for result_id, count in self._count.items(): + if count > 0: + yield ResultEntry(result_id, count, self._desc.get(result_id, None)) + + def blockers( + self, iterations: int, iters_per_result: int = 100 + ) -> Generator[ResultEntry, None, None]: + """Any result with an iterations-per-result ratio of less than or equal the + given limit are considered 'blockers'. Results with a count <= 1 are not + included. + + Args: + iterations: Total iterations. + iters_per_result: Iterations-per-result threshold. + + Yields: + ResultEntry: ID, count and description of blocking result. + """ + assert iters_per_result > 0 + if iterations > 0: + for entry in self: + if entry.count > 1 and iterations / entry.count <= iters_per_result: + yield entry + + def count(self, result_id: str, desc: str) -> Tuple[int, bool]: + """ + + Args: + result_id: Result ID. + desc: User friendly description. + + Returns: + Current count for given result_id. + """ + assert isinstance(result_id, str) + self._count[result_id] += 1 + initial = False + if result_id not in self._desc: + self._desc[result_id] = desc + initial = True + return self._count[result_id], initial + + def get(self, result_id: str) -> ResultEntry: + """Get count and description for given result id. + + Args: + result_id: Result ID. + + Returns: + ResultEntry: Count and description. + """ + assert isinstance(result_id, str) + return ResultEntry( + result_id, self._count.get(result_id, 0), self._desc.get(result_id, None) + ) + + @property + def total(self) -> int: + """Get total count of all results. + + Args: + None + + Returns: + Total result count. + """ + return sum(self._count.values()) + + +class ReadOnlyResultCounter(SimpleResultCounter): + def count(self, result_id: str, desc: str) -> Tuple[int, bool]: + raise NotImplementedError("Read only!") # pragma: no cover + + @classmethod + def load( + cls, db_file: Path, time_limit: float = 0 + ) -> List["ReadOnlyResultCounter"]: + """Load existing entries for database and populate a ReadOnlyResultCounter. + + Args: + db_file: Database file. + time_limit: Used to filter older entries. + + Returns: + Loaded ReadOnlyResultCounter objects. + """ + assert time_limit >= 0 + with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: + cur = con.cursor() + try: + # collect entries + if time_limit: + cur.execute( + """SELECT pid, + result_id, + description, + count + FROM results + WHERE timestamp > ?;""", + (time() - time_limit,), + ) + else: + cur.execute( + """SELECT pid, result_id, description, count FROM results""" + ) + entries = cur.fetchall() + except OperationalError as exc: + if not str(exc).startswith("no such table:"): + raise # pragma: no cover + entries = [] + + loaded = {} + for pid, result_id, desc, count in entries: + if pid not in loaded: + loaded[pid] = cls(pid) + loaded[pid]._desc[result_id] = desc # pylint: disable=protected-access + loaded[pid]._count[result_id] = count # pylint: disable=protected-access + + return list(loaded.values()) + + +class ResultCounter(SimpleResultCounter): + __slots__ = ("_db_file", "_frequent", "_limit", "last_found") + + def __init__( + self, + pid: int, + db_file: Path, + life_time: int = RESULTS_EXPIRE, + report_limit: int = 0, + ) -> None: + super().__init__(pid) + assert db_file + assert report_limit >= 0 + self._db_file = db_file + self._frequent: Set[str] = set() + # use zero to disable report limit + self._limit = report_limit + self.last_found = 0.0 + self._init_db(db_file, pid, life_time) + + @staticmethod + def _init_db(db_file: Path, pid: int, life_time: float) -> None: + # prepare database + LOG.debug("resultcounter using db %s", db_file) + with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: + _db_version_check(con) + cur = con.cursor() + with con: + # create table if needed + cur.execute( + """CREATE TABLE IF NOT EXISTS results ( + count INTEGER NOT NULL, + description TEXT NOT NULL, + pid INTEGER NOT NULL, + result_id TEXT NOT NULL, + timestamp INTEGER NOT NULL, + PRIMARY KEY(pid, result_id));""" + ) + # remove expired entries + if life_time > 0: + cur.execute( + """DELETE FROM results WHERE timestamp <= ?;""", + (time() - life_time,), + ) + # avoid (unlikely) pid reuse collision + cur.execute("""DELETE FROM results WHERE pid = ?;""", (pid,)) + # remove results for jobs that have been removed + try: + cur.execute( + """DELETE FROM results + WHERE pid NOT IN (SELECT pid FROM status);""" + ) + except OperationalError as exc: + if not str(exc).startswith("no such table:"): + raise # pragma: no cover + + def count(self, result_id: str, desc: str) -> Tuple[int, bool]: + """Count results and write results to the database. + + Args: + result_id: Result ID. + desc: User friendly description. + + Returns: + Local count and initial report flag (includes parallel instances) + for given result_id. + """ + super().count(result_id, desc) + timestamp = time() + initial = False + with closing(connect(self._db_file, timeout=DB_TIMEOUT)) as con: + cur = con.cursor() + with con: + cur.execute( + """UPDATE results + SET timestamp = ?, + count = ? + WHERE pid = ? + AND result_id = ?;""", + (timestamp, self._count[result_id], self.pid, result_id), + ) + if cur.rowcount < 1: + cur.execute( + """SELECT pid FROM results WHERE result_id = ?;""", + (result_id,), + ) + initial = cur.fetchone() is None + cur.execute( + """INSERT INTO results( + pid, + result_id, + description, + timestamp, + count) + VALUES (?, ?, ?, ?, ?);""", + (self.pid, result_id, desc, timestamp, self._count[result_id]), + ) + self.last_found = timestamp + return self._count[result_id], initial + + def is_frequent(self, result_id: str) -> bool: + """Scan all results including results from other running instances + to determine if the limit has been exceeded. Local count must be >1 before + limit is checked. + + Args: + result_id: Result ID. + + Returns: + True if limit has been exceeded otherwise False. + """ + assert isinstance(result_id, str) + if self._limit < 1: + return False + if result_id in self._frequent: + return True + # get local total + total = self._count.get(result_id, 0) + # only check the db for parallel results if + # - result has been found locally more than once + # - limit has not been exceeded locally + if self._limit >= total > 1: + with closing(connect(self._db_file, timeout=DB_TIMEOUT)) as con: + cur = con.cursor() + # look up total count from all processes + cur.execute( + """SELECT COALESCE(SUM(count), 0) + FROM results WHERE result_id = ?;""", + (result_id,), + ) + global_total = cur.fetchone()[0] + assert global_total >= total + total = global_total + if total > self._limit: + self._frequent.add(result_id) + return True + return False + + def mark_frequent(self, result_id: str) -> None: + """Mark given results ID as frequent locally. + + Args: + result_id: Result ID. + + Returns: + None + """ + assert isinstance(result_id, str) + if result_id not in self._frequent: + self._frequent.add(result_id) + + +class BaseStatus(ABC): """Record and manage status information. Attributes: - _profiles (dict): Profiling data. - ignored (int): Ignored result count. - iteration (int): Iteration count. - log_size (int): Log size in bytes. - pid (int): Python process ID. - results (None): Placeholder for result data. - start_time (float): Start time of session. - test_name (str): Current test name. + _profiles: Profiling data. + ignored: Ignored result count. + iteration: Iteration count. + log_size: Log size in bytes. + pid: Python process ID. + start_time: Start time of session. + test_name: Current test name. """ __slots__ = ( @@ -93,28 +411,32 @@ class BaseStatus: "iteration", "log_size", "pid", - "results", "start_time", "test_name", ) - def __init__(self, pid, start_time, ignored=0, iteration=0, log_size=0): + def __init__( + self, + pid: int, + start_time: float, + ignored: int = 0, + iteration: int = 0, + log_size: int = 0, + ) -> None: assert pid >= 0 assert ignored >= 0 assert iteration >= 0 assert log_size >= 0 - assert isinstance(start_time, float) assert start_time >= 0 - self._profiles = {} + self._profiles: Dict[str, Dict[str, Union[float, int]]] = {} self.ignored = ignored self.iteration = iteration self.log_size = log_size self.pid = pid - self.results = None self.start_time = start_time self.test_name = None - def profile_entries(self): + def profile_entries(self) -> Generator[ProfileEntry, None, None]: """Used to retrieve profiling data. Args: @@ -125,30 +447,34 @@ def profile_entries(self): """ for name, entry in self._profiles.items(): yield ProfileEntry( - entry["count"], entry["max"], entry["min"], name, entry["total"] + cast(int, entry["count"]), + entry["max"], + entry["min"], + name, + entry["total"], ) @property - def rate(self): + def rate(self) -> float: """Calculate the average iteration rate in seconds. Args: None Returns: - float: Number of iterations performed per second. + Number of iterations performed per second. """ return self.iteration / self.runtime if self.runtime else 0 @property - def runtime(self): + def runtime(self) -> float: """Calculate the number of seconds since start() was called. Args: None Returns: - int: Total runtime in seconds. + Total runtime in seconds. """ return max(time() - self.start_time, 0) @@ -157,34 +483,49 @@ class ReadOnlyStatus(BaseStatus): """Store status information. Attributes: - _profiles (dict): Profiling data. - ignored (int): Ignored result count. - iteration (int): Iteration count. - log_size (int): Log size in bytes. - pid (int): Python process ID. - results (None): Placeholder for result data. - start_time (float): Start time of session. - test_name (str): Test name. - timestamp (float): Last time data was saved to database. + _profiles: Profiling data. + ignored: Ignored result count. + iteration: Iteration count. + log_size: Log size in bytes. + pid: Python process ID. + results: Result data. + start_time: Start time of session. + test_name: Test name. + timestamp: Last time data was saved to database. """ - __slots__ = ("timestamp",) + __slots__ = ("results", "timestamp") - def __init__(self, pid, start_time, timestamp, ignored=0, iteration=0, log_size=0): + def __init__( + self, + pid: int, + start_time: float, + timestamp: float, + ignored: int = 0, + iteration: int = 0, + log_size: int = 0, + results: Optional[ReadOnlyResultCounter] = None, + ) -> None: super().__init__( - pid, start_time, ignored=ignored, iteration=iteration, log_size=log_size + pid, + start_time, + ignored=ignored, + iteration=iteration, + log_size=log_size, ) - assert isinstance(timestamp, float) assert timestamp >= start_time + self.results = results or ReadOnlyResultCounter(pid) self.timestamp = timestamp @classmethod - def load_all(cls, db_file, time_limit=300): + def load_all( + cls, db_file: Path, time_limit: float = 300 + ) -> Generator["ReadOnlyStatus", None, None]: """Load all status reports found in `db_file`. Args: - db_file (Path): Database containing status data. - time_limit (int): Filter entries by age. Use zero for no limit. + db_file: Database containing status data. + time_limit: Filter entries by age. Use zero for no limit. Yields: ReadOnlyStatus: Successfully loaded objects. @@ -210,11 +551,18 @@ def load_all(cls, db_file, time_limit=300): except OperationalError as exc: if not str(exc).startswith("no such table:"): raise # pragma: no cover - entries = () + entries = [] # Load all results results = ReadOnlyResultCounter.load(db_file, time_limit=0) for entry in entries: + # look up counter + current_counter = None + for counter in results: + if counter.pid == cast(int, entry[0]): + current_counter = counter + break + status = cls( entry[0], entry[5], @@ -222,26 +570,20 @@ def load_all(cls, db_file, time_limit=300): ignored=entry[2], iteration=entry[3], log_size=entry[4], + results=current_counter, ) status._profiles = loads(entry[1]) - for counter in results: - if counter.pid == status.pid: - status.results = counter - break - else: - # no existing ReadOnlyResultCounter with matching pid found - status.results = ReadOnlyResultCounter(status.pid) yield status @property - def runtime(self): + def runtime(self) -> float: """Calculate total runtime in seconds relative to 'timestamp'. Args: None Returns: - int: Total runtime in seconds. + Total runtime in seconds. """ return self.timestamp - self.start_time @@ -250,29 +592,31 @@ class SimpleStatus(BaseStatus): """Record and manage status information. Attributes: - _profiles (dict): Profiling data. - ignored (int): Ignored result count. - iteration (int): Iteration count. - log_size (int): Log size in bytes. - pid (int): Python process ID. - results (None): Placeholder for result data. - start_time (float): Start time of session. - test_name (str): Current test name. + _profiles: Profiling data. + ignored: Ignored result count. + iteration: Iteration count. + log_size: Log size in bytes. + pid: Python process ID. + results: + start_time: Start time of session. + test_name: Current test name. """ - def __init__(self, pid, start_time): + __slots__ = ("results",) + + def __init__(self, pid: int, start_time: float) -> None: super().__init__(pid, start_time) self.results = SimpleResultCounter(pid) @classmethod - def start(cls): + def start(cls) -> "SimpleStatus": """Create a unique SimpleStatus object. Args: None Returns: - SimpleStatus: Active status report. + Active status report. """ return cls(getpid(), time()) @@ -281,30 +625,30 @@ class Status(BaseStatus): """Status records status information and stores it in a database. Attributes: - _db_file (Path): Database file containing data. - _enable_profiling (bool): Profiling support status. - _profiles (dict): Profiling data. - ignored (int): Ignored result count. - iteration (int): Iteration count. - log_size (int): Log size in bytes. - pid (int): Python process ID. - results (ResultCounter): Results data. Used to count occurrences of results. - start_time (float): Start time of session. - test_name (str): Current test name. - timestamp (float): Last time data was saved to database. + _db_file: Database file containing data. + _enable_profiling: Profiling support status. + _profiles: Profiling data. + ignored: Ignored result count. + iteration: Iteration count. + log_size: Log size in bytes. + pid: Python process ID. + results: Results data. Used to count occurrences of results. + start_time: Start time of session. + test_name: Current test name. + timestamp: Last time data was saved to database. """ - __slots__ = ("_db_file", "_enable_profiling", "timestamp") + __slots__ = ("_db_file", "_enable_profiling", "results", "timestamp") def __init__( self, - pid, - start_time, - db_file, - enable_profiling=False, - life_time=REPORTS_EXPIRE, - report_limit=0, - ): + pid: int, + start_time: float, + db_file: Path, + enable_profiling: bool = False, + life_time: float = REPORTS_EXPIRE, + report_limit: int = 0, + ) -> None: super().__init__(pid, start_time) assert life_time >= 0 assert report_limit >= 0 @@ -315,7 +659,7 @@ def __init__( self.timestamp = start_time @staticmethod - def _init_db(db_file, pid, life_time): + def _init_db(db_file: Path, pid: int, life_time: float) -> None: # prepare database LOG.debug("status using db %s", db_file) with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: @@ -343,11 +687,11 @@ def _init_db(db_file, pid, life_time): cur.execute("""DELETE FROM status WHERE pid = ?;""", (pid,)) @contextmanager - def measure(self, name): + def measure(self, name: str) -> Generator[None, None, None]: """Used to simplify collecting profiling data. Args: - name (str): Used to group the entries. + name: Used to group the entries. Yields: None @@ -359,13 +703,13 @@ def measure(self, name): else: yield - def record(self, name, duration): + def record(self, name: str, duration: float) -> None: """Used to add profiling data. This is intended to be used to make rough calculations to identify major configuration issues. Args: - name (str): Used to group the entries. - duration (int, float): Stored to be later used for measurements. + name: Used to group the entries. + duration: Stored to be later used for measurements. Returns: None @@ -388,22 +732,23 @@ def record(self, name, duration): "total": duration, } - def report(self, force=False, report_rate=REPORT_RATE): + def report(self, force: bool = False, report_rate: int = REPORT_RATE) -> bool: """Write status report to database. Reports are only written periodically. It is limited by `report_rate`. The specified number of seconds must elapse before another write will be performed unless `force` is True. Args: - force (bool): Ignore report frequency limiting. - report_rate (int): Minimum number of seconds between writes to database. + force: Ignore report frequency limiting. + report_rate: Minimum number of seconds between writes to database. Returns: - bool: True if the report was successful otherwise False. + True if the report was successful otherwise False. """ now = time() if self.results.last_found > self.timestamp: LOG.debug("results have been found since last report, force update") force = True + assert report_rate >= 0 if not force and now < (self.timestamp + report_rate): return False assert self.start_time <= now @@ -456,13 +801,15 @@ def report(self, force=False, report_rate=REPORT_RATE): return True @classmethod - def start(cls, db_file, enable_profiling=False, report_limit=0): + def start( + cls, db_file: Path, enable_profiling: bool = False, report_limit: int = 0 + ) -> "Status": """Create a unique Status object. Args: - db_file (Path): Database containing status data. - enable_profiling (bool): Record profiling data. - report_limit (int): Number of times a unique result will be reported. + db_file: Database containing status data. + enable_profiling: Record profiling data. + report_limit: Number of times a unique result will be reported. Returns: Status: Active status report. @@ -478,288 +825,21 @@ def start(cls, db_file, enable_profiling=False, report_limit=0): return status -class SimpleResultCounter: - __slots__ = ("_count", "_desc", "pid") - - def __init__(self, pid): - assert pid >= 0 - self._count = defaultdict(int) - self._desc = {} - self.pid = pid - - def __iter__(self): - """Yield all result data. - - Args: - None - - Yields: - ResultEntry: Contains ID, count and description for each result entry. - """ - for result_id, count in self._count.items(): - if count > 0: - yield ResultEntry(result_id, count, self._desc.get(result_id, None)) - - def blockers(self, iterations, iters_per_result=100): - """Any result with an iterations-per-result ratio of less than or equal the - given limit are considered 'blockers'. Results with a count <= 1 are not - included. - - Args: - iterations (int): Total iterations. - iters_per_result (int): Iterations-per-result threshold. - - Yields: - ResultEntry: ID, count and description of blocking result. - """ - assert iters_per_result > 0 - if iterations > 0: - for entry in self: - if entry.count > 1 and iterations / entry.count <= iters_per_result: - yield entry - - def count(self, result_id, desc): - """ - - Args: - result_id (str): Result ID. - desc (str): User friendly description. - - Returns: - int: Current count for given result_id. - """ - assert isinstance(result_id, str) - self._count[result_id] += 1 - if result_id not in self._desc: - self._desc[result_id] = desc - return self._count[result_id] - - def get(self, result_id): - """Get count and description for given result id. - - Args: - result_id (str): Result ID. - - Returns: - ResultEntry: Count and description. - """ - assert isinstance(result_id, str) - return ResultEntry( - result_id, self._count.get(result_id, 0), self._desc.get(result_id, None) - ) - - @property - def total(self): - """Get total count of all results. - - Args: - None - - Returns: - int: Total result count. - """ - return sum(self._count.values()) - - -class ReadOnlyResultCounter(SimpleResultCounter): - def count(self, result_id, desc): - raise NotImplementedError("Read only!") # pragma: no cover - - @classmethod - def load(cls, db_file, time_limit=0): - """Load existing entries for database and populate a ReadOnlyResultCounter. - - Args: - db_file (Path): Database file. - time_limit (int): Used to filter older entries. - - Returns: - list: Loaded ReadOnlyResultCounter objects. - """ - assert time_limit >= 0 - with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: - cur = con.cursor() - try: - # collect entries - if time_limit: - cur.execute( - """SELECT pid, - result_id, - description, - count - FROM results - WHERE timestamp > ?;""", - (time() - time_limit,), - ) - else: - cur.execute( - """SELECT pid, result_id, description, count FROM results""" - ) - entries = cur.fetchall() - except OperationalError as exc: - if not str(exc).startswith("no such table:"): - raise # pragma: no cover - entries = () - - loaded = {} - for pid, result_id, desc, count in entries: - if pid not in loaded: - loaded[pid] = cls(pid) - loaded[pid]._desc[result_id] = desc # pylint: disable=protected-access - loaded[pid]._count[result_id] = count # pylint: disable=protected-access - - return list(loaded.values()) - - -class ResultCounter(SimpleResultCounter): - __slots__ = ("_db_file", "_frequent", "_limit", "last_found") - - def __init__(self, pid, db_file, life_time=RESULTS_EXPIRE, report_limit=0): - super().__init__(pid) - assert db_file - assert report_limit >= 0 - self._db_file = db_file - self._frequent = set() - # use zero to disable report limit - self._limit = report_limit - self.last_found = 0 - self._init_db(db_file, pid, life_time) - - @staticmethod - def _init_db(db_file, pid, life_time): - # prepare database - LOG.debug("resultcounter using db %s", db_file) - with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: - _db_version_check(con) - cur = con.cursor() - with con: - # create table if needed - cur.execute( - """CREATE TABLE IF NOT EXISTS results ( - count INTEGER NOT NULL, - description TEXT NOT NULL, - pid INTEGER NOT NULL, - result_id TEXT NOT NULL, - timestamp INTEGER NOT NULL, - PRIMARY KEY(pid, result_id));""" - ) - # remove expired entries - if life_time > 0: - cur.execute( - """DELETE FROM results WHERE timestamp <= ?;""", - (time() - life_time,), - ) - # avoid (unlikely) pid reuse collision - cur.execute("""DELETE FROM results WHERE pid = ?;""", (pid,)) - # remove results for jobs that have been removed - try: - cur.execute( - """DELETE FROM results - WHERE pid NOT IN (SELECT pid FROM status);""" - ) - except OperationalError as exc: - if not str(exc).startswith("no such table:"): - raise # pragma: no cover - - def count(self, result_id, desc): - """Count results and write results to the database. +class ReductionStep(NamedTuple): + name: str + duration: Optional[float] + successes: Optional[int] + attempts: Optional[int] + size: Optional[int] + iterations: Optional[int] - Args: - result_id (str): Result ID. - desc (str): User friendly description. - Returns: - tuple (int, bool): Local count and initial report (includes - parallel instances) for given result_id. - """ - super().count(result_id, desc) - timestamp = time() - initial = False - with closing(connect(self._db_file, timeout=DB_TIMEOUT)) as con: - cur = con.cursor() - with con: - cur.execute( - """UPDATE results - SET timestamp = ?, - count = ? - WHERE pid = ? - AND result_id = ?;""", - (timestamp, self._count[result_id], self.pid, result_id), - ) - if cur.rowcount < 1: - cur.execute( - """SELECT pid FROM results WHERE result_id = ?;""", - (result_id,), - ) - initial = cur.fetchone() is None - cur.execute( - """INSERT INTO results( - pid, - result_id, - description, - timestamp, - count) - VALUES (?, ?, ?, ?, ?);""", - (self.pid, result_id, desc, timestamp, self._count[result_id]), - ) - self.last_found = timestamp - return self._count[result_id], initial - - def is_frequent(self, result_id): - """Scan all results including results from other running instances - to determine if the limit has been exceeded. Local count must be >1 before - limit is checked. - - Args: - result_id (str): Result ID. - - Returns: - bool: True if limit has been exceeded otherwise False. - """ - assert isinstance(result_id, str) - if self._limit < 1: - return False - if result_id in self._frequent: - return True - # get local total - total = self._count.get(result_id, 0) - # only check the db for parallel results if - # - result has been found locally more than once - # - limit has not been exceeded locally - if self._limit >= total > 1: - with closing(connect(self._db_file, timeout=DB_TIMEOUT)) as con: - cur = con.cursor() - # look up total count from all processes - cur.execute( - """SELECT COALESCE(SUM(count), 0) - FROM results WHERE result_id = ?;""", - (result_id,), - ) - global_total = cur.fetchone()[0] - assert global_total >= total - total = global_total - if total > self._limit: - self._frequent.add(result_id) - return True - return False - - def mark_frequent(self, result_id): - """Mark given results ID as frequent locally. - - Args: - result_id (str): Result ID. - - Returns: - None - """ - assert isinstance(result_id, str) - if result_id not in self._frequent: - self._frequent.add(result_id) - - -ReductionStep = namedtuple( - "ReductionStep", "name, duration, successes, attempts, size, iterations" -) +class _MilestoneTimer(NamedTuple): + name: str + start: float + attempts: int + iterations: int + successes: int class ReductionStatus: @@ -767,40 +847,44 @@ class ReductionStatus: def __init__( self, - strategies=None, - testcase_size_cb=None, - crash_id=None, - db_file=None, - pid=None, - tool=None, - life_time=REPORTS_EXPIRE, - ): + strategies: Optional[List[str]] = None, + testcase_size_cb: Optional[Callable[[], int]] = None, + crash_id: Optional[int] = None, + db_file: Optional[Path] = None, + pid: Optional[int] = None, + tool: Optional[str] = None, + life_time: float = REPORTS_EXPIRE, + ) -> None: """Initialize a ReductionStatus instance. Arguments: - strategies (list(str)): List of strategies to be run. - testcase_size_cb (callable): Callback to get testcase size - crash_id (int): CrashManager ID of original testcase - db_file (Path): Database file containing data. - tool (str): The tool name used for reporting to FuzzManager. + strategies: List of strategies to be run. + testcase_size_cb: Callback to get testcase size. + crash_id: CrashManager ID of original testcase. + db_file: Database file containing data. + tool: The tool name used for reporting to FuzzManager. + life_time: """ - self.analysis = {} + self.analysis: Dict[str, float] = {} self.attempts = 0 self.iterations = 0 - self.run_params = {} - self.signature_info = {} + # TODO: make RunParams dataclass? + self.run_params: Dict[str, Union[bool, int]] = {} + # TODO: make SigInfo dataclass? + self.signature_info: Dict[str, Union[bool, str]] = {} self.successes = 0 self.current_strategy_idx = None self._testcase_size_cb = testcase_size_cb self.crash_id = crash_id - self.finished_steps = [] - self._in_progress_steps = [] + self.finished_steps: List[ReductionStep] = [] + self._in_progress_steps: List[_MilestoneTimer] = [] self.strategies = strategies self._db_file = db_file self.pid = pid self.timestamp = time() self.tool = tool - self._current_size = None + self._current_size: Optional[int] = None + # this holds results from Reporter.submit() self.last_reports = [] # prepare database @@ -842,23 +926,23 @@ def __init__( @classmethod def start( cls, - db_file, - strategies=None, - testcase_size_cb=None, - crash_id=None, - tool=None, - ): + db_file: Path, + strategies: Optional[List[str]] = None, + testcase_size_cb: Optional[Callable[[], int]] = None, + crash_id: Optional[int] = None, + tool: Optional[str] = None, + ) -> "ReductionStatus": """Create a unique ReductionStatus object. Args: - db_file (Path): Database containing status data. - strategies (list(str)): List of strategies to be run. - testcase_size_cb (callable): Callback to get testcase size - crash_id (int): CrashManager ID of original testcase - tool (str): The tool name used for reporting to FuzzManager. + db_file: Database containing status data. + strategies: List of strategies to be run. + testcase_size_cb: Callback to get testcase size. + crash_id: CrashManager ID of original testcase. + tool: The tool name used for reporting to FuzzManager. Returns: - ReductionStatus: Active status report. + Active status report. """ status = cls( crash_id=crash_id, @@ -871,17 +955,17 @@ def start( status.report(force=True) return status - def report(self, force=False, report_rate=REPORT_RATE): + def report(self, force: bool = False, report_rate: float = REPORT_RATE) -> bool: """Write status report to database. Reports are only written periodically. It is limited by `report_rate`. The specified number of seconds must elapse before another write will be performed unless `force` is True. Args: - force (bool): Ignore report frequently limiting. - report_rate (int): Minimum number of seconds between writes. + force: Ignore report frequently limiting. + report_rate: Minimum number of seconds between writes. Returns: - bool: Returns true if the report was successful otherwise false. + True if the report was successful otherwise false. """ now = time() if not force and now < (self.timestamp + report_rate): @@ -896,7 +980,7 @@ def report(self, force=False, report_rate=REPORT_RATE): run_params = dumps(self.run_params) sig_info = dumps(self.signature_info) finished = dumps(self.finished_steps) - in_prog = dumps([step.serialize() for step in self._in_progress_steps]) + in_prog = dumps(self._in_progress_steps) strategies = dumps(self.strategies) last_reports = dumps(self.last_reports) @@ -980,16 +1064,18 @@ def report(self, force=False, report_rate=REPORT_RATE): return True @classmethod - def load_all(cls, db_file, time_limit=300): + def load_all( + cls, db_file: Path, time_limit: float = 300 + ) -> Generator["ReductionStatus", None, None]: """Load all reduction status reports found in `db_file`. Args: - db_file (Path): Database containing status data. - time_limit (int): Only include entries with a timestamp that is within the - given number of seconds. Use zero for no limit. + db_file: Database containing status data. + time_limit: Only include entries with a timestamp that is within the + given number of seconds. Use zero for no limit. Yields: - Status: Successfully loaded read-only status objects. + Successfully loaded read-only status objects. """ assert time_limit >= 0 with closing(connect(db_file, timeout=DB_TIMEOUT)) as con: @@ -1022,7 +1108,7 @@ def load_all(cls, db_file, time_limit=300): except OperationalError as exc: if not str(exc).startswith("no such table:"): raise # pragma: no cover - entries = () + entries = [] for entry in entries: pid = entry[0] @@ -1043,7 +1129,7 @@ def load_all(cls, db_file, time_limit=300): ReductionStep._make(step) for step in loads(entry[8]) ] status._in_progress_steps = [ - status._construct_milestone(*step) for step in loads(entry[9]) + _MilestoneTimer(*step) for step in loads(entry[9]) ] status._current_size = entry[11] status.current_strategy_idx = entry[12] @@ -1051,9 +1137,10 @@ def load_all(cls, db_file, time_limit=300): status.last_reports = loads(entry[15]) yield status - def _testcase_size(self): + def _testcase_size(self) -> Optional[int]: if self._db_file is None: return self._current_size + assert self._testcase_size_cb is not None return self._testcase_size_cb() def __deepcopy__(self, memo): @@ -1077,7 +1164,8 @@ def __deepcopy__(self, memo): result.finished_steps = deepcopy(self.finished_steps, memo) result.last_reports = deepcopy(self.last_reports, memo) # finish open timers - for step in reversed(self._in_progress_steps): + for tmr in reversed(self._in_progress_steps): + step = self._tmr_to_step(tmr) result.record( step.name, attempts=step.attempts, @@ -1088,37 +1176,51 @@ def __deepcopy__(self, memo): ) return result + def _tmr_to_step(self, tmr: _MilestoneTimer) -> ReductionStep: + if self._db_file is None: + duration = self.timestamp - tmr.start + else: + duration = time() - tmr.start + return ReductionStep( + tmr.name, + duration=duration, + successes=self.successes - tmr.successes, + attempts=self.attempts - tmr.attempts, + size=None, + iterations=self.iterations - tmr.iterations, + ) + @property - def current_strategy(self): + def current_strategy(self) -> Optional[ReductionStep]: if self._in_progress_steps: - return self._in_progress_steps[-1] + return self._tmr_to_step(self._in_progress_steps[-1]) if self.finished_steps: return self.finished_steps[-1] return None @property - def total(self): + def total(self) -> Optional[ReductionStep]: if self._in_progress_steps: - return self._in_progress_steps[0] + return self._tmr_to_step(self._in_progress_steps[0]) if self.finished_steps: return self.finished_steps[-1] return None @property - def original(self): + def original(self) -> Optional[ReductionStep]: if self.finished_steps: return self.finished_steps[0] return None def record( self, - name, - duration=None, - iterations=None, - attempts=None, - successes=None, - report=True, - ): + name: str, + duration: Optional[float] = None, + iterations: Optional[int] = None, + attempts: Optional[int] = None, + successes: Optional[int] = None, + report: bool = True, + ) -> None: """Record reduction status for a given point in time: - name of the milestone (eg. init, strategy name completed) @@ -1128,12 +1230,12 @@ def record( - # of successful attempts Arguments: - name (str): name of milestone - duration (float or None): seconds elapsed for period recorded - iterations (int or None): # of iterations performed - attempts (int or None): # of attempts performed - successes (int or None): # of attempts successful - report (bool): Automatically force a report. + name: name of milestone + duration: seconds elapsed for period recorded + iterations: # of iterations performed + attempts: # of attempts performed + successes: # of attempts successful + report: Automatically force a report. Returns: None @@ -1151,94 +1253,54 @@ def record( if report: self.report(force=True) - def _construct_milestone(self, name, start, attempts, iterations, successes): - # pylint: disable=no-self-argument - class _MilestoneTimer: - def __init__(sub): - sub.name = name - sub._start_time = start - sub._start_attempts = attempts - sub._start_iterations = iterations - sub._start_successes = successes - - @property - def size(sub): - return self._testcase_size() # pylint: disable=protected-access - - @property - def attempts(sub): - return self.attempts - sub._start_attempts - - @property - def iterations(sub): - return self.iterations - sub._start_iterations - - @property - def successes(sub): - return self.successes - sub._start_successes - - @property - def duration(sub): - if self._db_file is None: # pylint: disable=protected-access - return self.timestamp - sub._start_time - return time() - sub._start_time - - def serialize(sub): - return ( - sub.name, - sub._start_time, - sub._start_attempts, - sub._start_iterations, - sub._start_successes, - ) - - return _MilestoneTimer() - @contextmanager - def measure(self, name, report=True): + def measure(self, name: str, report: bool = True) -> Generator[None, None, None]: """Time and record the period leading up to a reduction milestone. eg. a strategy being run. Arguments: - name (str): name of milestone - report (bool): Automatically force a report. + name: name of milestone + report: Automatically force a report. Yields: None """ - tmr = self._construct_milestone( + tmr = _MilestoneTimer( name, time(), self.attempts, self.iterations, self.successes ) self._in_progress_steps.append(tmr) yield assert self._in_progress_steps.pop() is tmr + step = self._tmr_to_step(tmr) self.record( name, - attempts=tmr.attempts, - duration=tmr.duration, - iterations=tmr.iterations, - successes=tmr.successes, + attempts=step.attempts, + duration=step.duration, + iterations=step.iterations, + successes=step.successes, report=report, ) - def copy(self): + def copy(self) -> "ReductionStatus": """Create a deep copy of this instance. Arguments: None Returns: - ReductionStatus: Clone of self + Clone of self """ return deepcopy(self) - def add_to_reporter(self, reporter, expected=True): + def add_to_reporter( + self, reporter: FuzzManagerReporter, expected: bool = True + ) -> None: """Add the reducer status to reported metadata for the given reporter. Arguments: - reporter (FuzzManagerReporter): Reporter to update. - expected (bool): Add detailed stats. + reporter: Reporter to update. + expected: Add detailed stats. Returns: None diff --git a/grizzly/common/status_reporter.py b/grizzly/common/status_reporter.py index 0fe20264..4d546cf8 100644 --- a/grizzly/common/status_reporter.py +++ b/grizzly/common/status_reporter.py @@ -12,14 +12,18 @@ try: from os import getloadavg + + GETLOADAVG_AVAILABLE = True except ImportError: # pragma: no cover # os.getloadavg() is not available on all platforms - getloadavg = None + GETLOADAVG_AVAILABLE = False + from os import SEEK_CUR, getenv from pathlib import Path from re import match from re import sub as re_sub from time import gmtime, localtime, strftime +from typing import Callable, Dict, Generator, List, Optional, Set, Tuple, Union from psutil import cpu_count, cpu_percent, disk_usage, virtual_memory @@ -39,6 +43,116 @@ LOG = getLogger(__name__) +class TracebackReport: + """Read Python tracebacks from log files and store it in a manner that is helpful + when generating reports. + """ + + MAX_LINES = 16 # should be no less than 6 + READ_LIMIT = 0x20000 # 128KB + + def __init__( + self, + log_file: Path, + lines: List[str], + is_kbi: bool = False, + prev_lines: Optional[List[str]] = None, + ) -> None: + self.is_kbi = is_kbi + self.lines = lines + self.log_file = log_file + self.prev_lines = prev_lines or [] + + @classmethod + def from_file( + cls, log_file: Path, max_preceding: int = 5, ignore_kbi: bool = False + ) -> Optional["TracebackReport"]: + """Create TracebackReport from a text file containing a Python traceback. + Only the first traceback in the file will be parsed. + + Args: + log_file: File to parse. + max_preceding: Number of lines to collect leading up to the traceback. + ignore_kbi: Skip/ignore KeyboardInterrupt. + + Returns: + TracebackReport containing data from givin log file. + """ + token = b"Traceback (most recent call last):" + assert len(token) < cls.READ_LIMIT + try: + with log_file.open("rb") as in_fp: + for chunk in iter(partial(in_fp.read, cls.READ_LIMIT), b""): + idx = chunk.find(token) + if idx > -1: + # calculate offset of data in the file + pos = in_fp.tell() - len(chunk) + idx + break + if len(chunk) == cls.READ_LIMIT: + # seek back to avoid missing beginning of token + in_fp.seek(len(token) * -1, SEEK_CUR) + else: + # no traceback here, move along + return None + # seek back 2KB to collect preceding lines + in_fp.seek(max(pos - 2048, 0)) + data = in_fp.read(cls.READ_LIMIT) + except OSError: # pragma: no cover + # in case the file goes away + return None + + data_lines = data.decode("ascii", errors="ignore").splitlines() + token_str = token.decode() + is_kbi = False + tb_start = None + tb_end = None + line_count = len(data_lines) + for line_num, log_line in enumerate(data_lines): + if tb_start is None and token_str in log_line: + tb_start = line_num + continue + if tb_start is not None: + log_line = log_line.strip() + if not log_line: + # stop at first empty line + tb_end = min(line_num, line_count) + break + if match(r"^\w+(\.\w+)*\:\s|^\w+(Interrupt|Error)$", log_line): + is_kbi = log_line.startswith("KeyboardInterrupt") + if is_kbi and ignore_kbi: + # ignore this exception since it is a KeyboardInterrupt + return None + # stop after error message + tb_end = min(line_num + 1, line_count) + break + assert tb_start is not None + if max_preceding > 0: + prev_start = max(tb_start - max_preceding, 0) + prev_lines = data_lines[prev_start:tb_start] + else: + prev_lines = None + if tb_end is None: + # limit if the end is not identified (failsafe) + tb_end = max(line_count, cls.MAX_LINES) + if tb_end - tb_start > cls.MAX_LINES: + # add first entry + lines = data_lines[tb_start : tb_start + 3] + lines += ["<--- TRACEBACK TRIMMED--->"] + # add end entries + lines += data_lines[tb_end - (cls.MAX_LINES - 3) : tb_end] + else: + lines = data_lines[tb_start:tb_end] + return cls(log_file, lines, is_kbi=is_kbi, prev_lines=prev_lines) + + def __len__(self) -> int: + return len(str(self)) + + def __str__(self) -> str: + return "\n".join( + [f"Log: '{self.log_file.name}'"] + self.prev_lines + self.lines + ) + + class StatusReporter: """Read and merge Grizzly status reports, including tracebacks if found. Output is a single textual report, e.g. for submission to EC2SpotManager. @@ -50,26 +164,35 @@ class StatusReporter: SUMMARY_LIMIT = 4095 # summary output must be no more than 4KB TIME_LIMIT = 120 # ignore older reports - def __init__(self, reports, tracebacks=None): + def __init__( + self, + reports: List[ReadOnlyStatus], + tracebacks: Optional[List[TracebackReport]] = None, + ) -> None: self.reports = reports self.tracebacks = tracebacks @property - def has_results(self): - return any(x.results.total for x in self.reports) + def has_results(self) -> bool: + return any(x.results.total for x in self.reports if x.results) @classmethod - def load(cls, db_file, tb_path=None, time_limit=TIME_LIMIT): + def load( + cls, + db_file: Path, + tb_path: Optional[Path] = None, + time_limit: float = TIME_LIMIT, + ) -> "StatusReporter": """Read Grizzly status reports and create a StatusReporter object. Args: - db_file (str): Status data file to load. - tb_path (Path): Directory to scan for files containing Python tracebacks. - time_limit (int): Only include entries with a timestamp that is within the - given number of seconds. Use zero for no limit. + db_file: Status data file to load. + tb_path: Directory to scan for files containing Python tracebacks. + time_limit: Only include entries with a timestamp that is within the + given number of seconds. Use zero for no limit. Returns: - StatusReporter: Contains available status reports and traceback reports. + Available status reports and traceback reports. """ return cls( list(ReadOnlyStatus.load_all(db_file, time_limit=time_limit)), @@ -77,7 +200,7 @@ def load(cls, db_file, tb_path=None, time_limit=TIME_LIMIT): ) @staticmethod - def format_entries(entries): + def format_entries(entries: List[Tuple[str, Optional[str]]]) -> str: """Generate formatted output from (label, body) pairs. Each entry must have a label and an optional body. @@ -95,33 +218,33 @@ def format_entries(entries): third : 3.0 Args: - entries list(2-tuple(str, str)): Data to merge. + entries: Data to merge. Returns: - str: Formatted output. + Formatted output. """ label_lengths = tuple(len(x[0]) for x in entries if x[1]) max_len = max(label_lengths) if label_lengths else 0 out = [] for label, body in entries: - if body: - out.append(f"{label}".rjust(max_len) + f" : {body}") - else: + if body is None: out.append(label) + else: + out.append(f"{label}".rjust(max_len) + f" : {body}") return "\n".join(out) - def results(self, max_len=85): + def results(self, max_len: int = 85) -> str: """Merged and generate formatted output from results. Args: - max_len (int): Maximum length of result description. + max_len: Maximum length of result description. Returns: - str: A formatted report. + A formatted report. """ - blockers = set() - counts = defaultdict(int) - descs = {} + blockers: Set[str] = set() + counts: Dict[str, int] = defaultdict(int) + descs: Dict[str, Optional[str]] = {} # calculate totals for report in self.reports: for result in report.results: @@ -129,11 +252,12 @@ def results(self, max_len=85): counts[result.rid] += result.count blockers.update(x.rid for x in report.results.blockers(report.iteration)) # generate output - entries = [] + entries: List[Tuple[str, Optional[str]]] = [] for rid, count in sorted(counts.items(), key=lambda x: x[1], reverse=True): desc = descs[rid] + assert desc is not None # trim long descriptions - if len(descs[rid]) > max_len: + if len(desc) > max_len: desc = f"{desc[: max_len - 3]}..." label = f"*{count}" if rid in blockers else str(count) entries.append((label, desc)) @@ -144,19 +268,19 @@ def results(self, max_len=85): entries.append(("", None)) return self.format_entries(entries) - def specific(self, iters_per_result=100): + def specific(self, iters_per_result: int = 100) -> str: """Merged and generate formatted output from status reports. Args: - iters_per_result (int): Threshold for warning of potential blockers. + iters_per_result: Threshold for warning of potential blockers. Returns: - str: A formatted report. + A formatted report. """ if not self.reports: return "No status reports available" self.reports.sort(key=lambda x: x.start_time) - entries = [] + entries: List[Tuple[str, Optional[str]]] = [] for report in self.reports: label = ( f"PID {report.pid} started at " @@ -213,25 +337,25 @@ def specific(self, iters_per_result=100): def summary( self, - rate=True, - runtime=True, - sysinfo=False, - timestamp=False, - iters_per_result=100, - ): + rate: bool = True, + runtime: bool = True, + sysinfo: bool = False, + timestamp: bool = False, + iters_per_result: int = 100, + ) -> str: """Merge and generate a summary from status reports. Args: - rate (bool): Include iteration rate. - runtime (bool): Include total runtime in output. - sysinfo (bool): Include system info (CPU, disk, RAM... etc) in output. - timestamp (bool): Include time stamp in output. - iters_per_result (int): Threshold for warning of potential blockers. + rate: Include iteration rate. + runtime: Include total runtime in output. + sysinfo: Include system info (CPU, disk, RAM... etc) in output. + timestamp: Include time stamp in output. + iters_per_result: Threshold for warning of potential blockers. Returns: - str: A summary of merged reports. + A summary of merged reports. """ - entries = [] + entries: List[Tuple[str, Optional[str]]] = [] # Job specific status if self.reports: # calculate totals @@ -262,7 +386,7 @@ def summary( if total_iters: total_results = sum(results) result_pct = total_results / total_iters * 100 - buckets = set() + buckets: Set[str] = set() for report in self.reports: buckets.update(x.rid for x in report.results) disp = [f"{total_results} ({len(buckets)})"] @@ -320,15 +444,15 @@ def summary( return msg @staticmethod - def _merge_tracebacks(tracebacks, size_limit): + def _merge_tracebacks(tracebacks: List[TracebackReport], size_limit: int) -> str: """Merge traceback without exceeding size_limit. Args: - tracebacks (iterable): TracebackReport to merge. - size_limit (int): Maximum size in bytes of output. + tracebacks: TracebackReports to merge. + size_limit: Maximum size in bytes of output. Returns: - str: merged tracebacks. + Merged tracebacks. """ txt = [] txt.append(f"\n\nWARNING Tracebacks ({len(tracebacks)}) detected!") @@ -341,16 +465,16 @@ def _merge_tracebacks(tracebacks, size_limit): return "\n".join(txt) @staticmethod - def _sys_info(): + def _sys_info() -> List[Tuple[str, str]]: """Collect system information. Args: None Returns: - list(tuple): System information in tuples (label, display data). + System information. """ - entries = [] + entries: List[Tuple[str, str]] = [] # CPU and load disp = [] @@ -358,7 +482,7 @@ def _sys_info(): f"{cpu_count(logical=True)} ({cpu_count(logical=False)}) @ " f"{cpu_percent(interval=StatusReporter.CPU_POLL_INTERVAL):0.0f}%" ) - if getloadavg is not None: + if GETLOADAVG_AVAILABLE: disp.append(" (") # round the results of getloadavg(), precision varies across platforms disp.append(", ".join(f"{x:0.1f}" for x in getloadavg())) @@ -388,17 +512,18 @@ def _sys_info(): return entries @staticmethod - def _tracebacks(path, ignore_kbi=True, max_preceding=5): + def _tracebacks( + path: Path, ignore_kbi: bool = True, max_preceding: int = 5 + ) -> List[TracebackReport]: """Search screen logs for tracebacks. Args: - path (Path): Directory containing log files. - ignore_kbi (bool): Do not include KeyboardInterrupts in results - max_preceding (int): Maximum number of lines preceding traceback to - include. + path: Directory containing log files. + ignore_kbi: Do not include KeyboardInterrupts in results + max_preceding: Maximum number of lines preceding traceback to include. Returns: - list: A list of TracebackReports. + TracebackReports. """ tracebacks = [] for screen_log in (x for x in path.glob("screenlog.*") if x.is_file()): @@ -410,123 +535,24 @@ def _tracebacks(path, ignore_kbi=True, max_preceding=5): return tracebacks -class TracebackReport: - """Read Python tracebacks from log files and store it in a manner that is helpful - when generating reports. - """ - - MAX_LINES = 16 # should be no less than 6 - READ_LIMIT = 0x20000 # 128KB - - def __init__(self, log_file, lines, is_kbi=False, prev_lines=None): - assert isinstance(lines, list) - assert isinstance(log_file, Path) - assert isinstance(prev_lines, list) or prev_lines is None - self.is_kbi = is_kbi - self.lines = lines - self.log_file = log_file - self.prev_lines = prev_lines or [] - - @classmethod - def from_file(cls, log_file, max_preceding=5, ignore_kbi=False): - """Create TracebackReport from a text file containing a Python traceback. - Only the first traceback in the file will be parsed. - - Args: - log_file (Path): File to parse. - max_preceding (int): Number of lines to collect leading up to the traceback. - ignore_kbi (bool): Skip/ignore KeyboardInterrupt. - - Returns: - TracebackReport: Contains data from log_file. - """ - token = b"Traceback (most recent call last):" - assert len(token) < cls.READ_LIMIT - try: - with log_file.open("rb") as in_fp: - for chunk in iter(partial(in_fp.read, cls.READ_LIMIT), b""): - idx = chunk.find(token) - if idx > -1: - # calculate offset of data in the file - pos = in_fp.tell() - len(chunk) + idx - break - if len(chunk) == cls.READ_LIMIT: - # seek back to avoid missing beginning of token - in_fp.seek(len(token) * -1, SEEK_CUR) - else: - # no traceback here, move along - return None - # seek back 2KB to collect preceding lines - in_fp.seek(max(pos - 2048, 0)) - data = in_fp.read(cls.READ_LIMIT) - except OSError: # pragma: no cover - # in case the file goes away - return None - - data = data.decode("ascii", errors="ignore").splitlines() - token = token.decode() - is_kbi = False - tb_start = None - tb_end = None - line_count = len(data) - for line_num, log_line in enumerate(data): - if tb_start is None and token in log_line: - tb_start = line_num - continue - if tb_start is not None: - log_line = log_line.strip() - if not log_line: - # stop at first empty line - tb_end = min(line_num, line_count) - break - if match(r"^\w+(\.\w+)*\:\s|^\w+(Interrupt|Error)$", log_line): - is_kbi = log_line.startswith("KeyboardInterrupt") - if is_kbi and ignore_kbi: - # ignore this exception since it is a KeyboardInterrupt - return None - # stop after error message - tb_end = min(line_num + 1, line_count) - break - assert tb_start is not None - if max_preceding > 0: - prev_start = max(tb_start - max_preceding, 0) - prev_lines = data[prev_start:tb_start] - else: - prev_lines = None - if tb_end is None: - # limit if the end is not identified (failsafe) - tb_end = max(line_count, cls.MAX_LINES) - if tb_end - tb_start > cls.MAX_LINES: - # add first entry - lines = data[tb_start : tb_start + 3] - lines += ["<--- TRACEBACK TRIMMED--->"] - # add end entries - lines += data[tb_end - (cls.MAX_LINES - 3) : tb_end] - else: - lines = data[tb_start:tb_end] - return cls(log_file, lines, is_kbi=is_kbi, prev_lines=prev_lines) - - def __len__(self): - return len(str(self)) - - def __str__(self): - return "\n".join( - [f"Log: '{self.log_file.name}'"] + self.prev_lines + self.lines - ) - - class _TableFormatter: """Format data in a table.""" - def __init__(self, columns, formatters, vsep=" | ", hsep="-"): + def __init__( + self, + columns: Tuple[str, ...], + formatters: Tuple[Optional[Callable[..., str]]], + vsep: str = " | ", + hsep: str = "-", + ) -> None: """Initialize a TableFormatter instance. Arguments: - columns (iterable(str)): List of column names for the table header. - formatters (iterable(callable)): List of format functions for each column. - None will result in hiding that column. - vsep (str): Vertical separation between columns. - hsep (str): Horizontal separation between header and data. + columns: List of column names for the table header. + formatters: List of format functions for each column. + None will result in hiding that column. + vsep: Vertical separation between columns. + hsep: Horizontal separation between header and data. """ assert len(columns) == len(formatters) self._columns = tuple( @@ -536,18 +562,18 @@ def __init__(self, columns, formatters, vsep=" | ", hsep="-"): self._vsep = vsep self._hsep = hsep - def format_rows(self, rows): + def format_rows(self, rows: List[ReductionStep]) -> Generator[str, None, None]: """Format rows as a table and return a line generator. Arguments: - rows (list(list(str))): Tabular data. Each row must be the same length as - `columns` passed to `__init__`. + rows: Tabular data. Each row must be the same length as + `columns` passed to `__init__`. Yields: - str: Each line of formatted tabular data. + Each line of formatted tabular data. """ max_width = [len(col) for col in self._columns] - formatted = [] + formatted: List[List[str]] = [] for row in rows: assert len(row) == len(self._formatters) formatted.append([]) @@ -568,15 +594,15 @@ def format_rows(self, rows): ) yield format_str % self._columns yield self._hsep * (len(self._vsep) * (len(self._columns) - 1) + sum(max_width)) - for row in formatted: - yield format_str % tuple(row) + for fmt_row in formatted: + yield format_str % tuple(fmt_row) -def _format_seconds(duration): +def _format_seconds(duration: float) -> str: # format H:M:S, and then remove all leading zeros with regex minutes, seconds = divmod(int(duration), 60) hours, minutes = divmod(minutes, 60) - result = re_sub("^[0:]*", "", f"{hours}:{minutes:0>2d}:{seconds:0>2d}") + result = re_sub("^[0:]*", "", f"{hours}:{minutes:02d}:{seconds:02d}") # if the result is all zeroes, ensure one zero is output if not result: result = "0" @@ -586,7 +612,7 @@ def _format_seconds(duration): return result -def _format_duration(duration, total=0): +def _format_duration(duration: Optional[int], total: float = 0) -> str: result = "" if duration is not None: if total == 0: @@ -594,11 +620,11 @@ def _format_duration(duration, total=0): else: percent = int(100 * duration / total) result = _format_seconds(duration) - result += f" ({percent:>3d}%)" + result += f" ({percent:3d}%)" return result -def _format_number(number, total=0): +def _format_number(number: Optional[int], total: float = 0) -> str: result = "" if number is not None: if total == 0: @@ -617,28 +643,37 @@ class ReductionStatusReporter(StatusReporter): TIME_LIMIT = 120 # ignore older reports # pylint: disable=super-init-not-called - def __init__(self, reports, tracebacks=None): - self.reports = reports + def __init__( + self, + reports: List[ReductionStatus], + tracebacks: Optional[List[TracebackReport]] = None, + ) -> None: + self.reports: List[ReductionStatus] = reports self.tracebacks = tracebacks @property - def has_results(self): + def has_results(self) -> bool: return False # TODO @classmethod - def load(cls, db_file, tb_path=None, time_limit=TIME_LIMIT): + def load( + cls, + db_file: Path, + tb_path: Optional[Path] = None, + time_limit: float = TIME_LIMIT, + ) -> "ReductionStatusReporter": """Read Grizzly reduction status reports and create a ReductionStatusReporter object. Args: - path (str): Path to scan for status data files. - tb_path (str): Directory to scan for files containing Python tracebacks. - time_limit (int): Only include entries with a timestamp that is within the - given number of seconds. Use zero for no limit. + path: Path to scan for status data files. + tb_path: Directory to scan for files containing Python tracebacks. + time_limit: Only include entries with a timestamp that is within the + given number of seconds. Use zero for no limit. Returns: - ReductionStatusReporter: Contains available status reports and traceback - reports. + ReductionStatusReporter containing available status reports and traceback + reports. """ tracebacks = None if tb_path is None else cls._tracebacks(tb_path) return cls( @@ -647,7 +682,7 @@ def load(cls, db_file, tb_path=None, time_limit=TIME_LIMIT): ) @staticmethod - def _analysis_entry(report): + def _analysis_entry(report: ReductionStatus) -> Tuple[str, str]: return ( "Analysis", ", ".join( @@ -657,18 +692,18 @@ def _analysis_entry(report): ) @staticmethod - def _crash_id_entry(report): + def _crash_id_entry(report: ReductionStatus) -> Tuple[str, str]: crash_str = str(report.crash_id) if report.tool: crash_str += f" ({report.tool})" return ("Crash ID", crash_str) @staticmethod - def _last_reports_entry(report): + def _last_reports_entry(report: ReductionStatus) -> Tuple[str, str]: return ("Latest Reports", ", ".join(str(r) for r in report.last_reports)) @staticmethod - def _run_params_entry(report): + def _run_params_entry(report: ReductionStatus) -> Tuple[str, str]: return ( "Run Parameters", ", ".join( @@ -677,7 +712,7 @@ def _run_params_entry(report): ) @staticmethod - def _signature_info_entry(report): + def _signature_info_entry(report: ReductionStatus) -> Tuple[str, str]: return ( "Signature", ", ".join( @@ -685,25 +720,25 @@ def _signature_info_entry(report): ), ) - def specific( + def specific( # pylint: disable=arguments-renamed self, - sysinfo=False, - timestamp=False, - ): # pylint: disable=arguments-renamed + sysinfo: bool = False, + timestamp: bool = False, + ) -> str: """Generate formatted output from status report. Args: None Returns: - str: A formatted report. + A formatted report. """ if not self.reports: return "No status reports available" - reports = [] + reports: List[str] = [] for report in self.reports: - entries = [] + entries: List[Tuple[str, Optional[str]]] = [] if report.crash_id: entries.append(self._crash_id_entry(report)) if report.analysis: @@ -718,7 +753,7 @@ def specific( "Current Strategy", f"{report.current_strategy.name} " f"({report.current_strategy_idx!r} of " - f"{len(report.strategies)})", + f"{len(report.strategies) if report.strategies else 0})", ) ) if report.current_strategy and report.original: @@ -741,11 +776,13 @@ def specific( ) ) if report.total and report.current_strategy: + strategy_duration = report.current_strategy.duration or 0 + total_duration = report.total.duration or 0 entries.append( ( "Time Elapsed", - f"{_format_seconds(report.current_strategy.duration)} in " - f"strategy, {_format_seconds(report.total.duration)} total", + f"{_format_seconds(strategy_duration)} in " + f"strategy, {_format_seconds(total_duration)} total", ) ) @@ -760,31 +797,31 @@ def specific( reports.append(self.format_entries(entries)) return "\n\n".join(reports) - def summary( + def summary( # pylint: disable=arguments-differ self, - rate=False, - runtime=False, - sysinfo=False, - timestamp=False, - ): # pylint: disable=arguments-differ + rate: bool = False, + runtime: bool = False, + sysinfo: bool = False, + timestamp: bool = False, + ) -> str: """Merge and generate a summary from status reports. Args: - rate (bool): Ignored (compatibility). - runtime (bool): Ignored (compatibility). - sysinfo (bool): Include system info (CPU, disk, RAM... etc) in output. - timestamp (bool): Include time stamp in output. + rate: Ignored (compatibility). + runtime: Ignored (compatibility). + sysinfo: Include system info (CPU, disk, RAM... etc) in output. + timestamp: Include time stamp in output. Returns: - str: A summary of merged reports. + A summary of merged reports. """ if not self.reports: return "No status reports available" - reports = [] + reports: List[str] = [] for report in self.reports: - entries = [] - lines = [] + entries: List[Tuple[str, Optional[str]]] = [] + lines: List[str] = [] if report.crash_id: entries.append(self._crash_id_entry(report)) if report.analysis: @@ -838,14 +875,14 @@ def summary( return msg -def main(args=None): +def main(argv: Optional[List[str]] = None) -> int: """Merge Grizzly status files into a single report (main entrypoint). Args: - args (list/None): Argument list to parse instead of sys.argv (for testing). + argv: Argument list to parse instead of sys.argv (for testing). Returns: - None + int """ if bool(getenv("DEBUG")): # pragma: no cover log_level = DEBUG @@ -855,7 +892,9 @@ def main(args=None): log_fmt = "%(message)s" basicConfig(format=log_fmt, datefmt="%Y-%m-%d %H:%M:%S", level=log_level) - modes = { + modes: Dict[ + str, Tuple[Union[type[StatusReporter], type[ReductionStatusReporter]], Path] + ] = { "fuzzing": (StatusReporter, STATUS_DB_FUZZ), "reducing": (ReductionStatusReporter, STATUS_DB_REDUCE), } @@ -903,12 +942,12 @@ def main(args=None): type=Path, help="Scan path for Python tracebacks found in screenlog.# files", ) - args = parser.parse_args(args) + args = parser.parse_args(argv) if args.tracebacks and not args.tracebacks.is_dir(): parser.error("--tracebacks must be a directory") time_limit = report_types[args.type] if args.time_limit is None else args.time_limit - reporter_cls, status_db = modes.get(args.scan_mode) + reporter_cls, status_db = modes[args.scan_mode] reporter = reporter_cls.load( status_db, tb_path=args.tracebacks, diff --git a/grizzly/common/test_status.py b/grizzly/common/test_status.py index 6ef65c62..5edfdb2a 100644 --- a/grizzly/common/test_status.py +++ b/grizzly/common/test_status.py @@ -22,6 +22,7 @@ ReductionStatus, ReductionStep, ResultCounter, + ResultEntry, SimpleResultCounter, SimpleStatus, Status, @@ -38,7 +39,6 @@ def test_basic_status_01(): assert status.ignored == 0 assert status.iteration == 0 assert status.log_size == 0 - assert status.results is None assert not status._profiles assert status.runtime > 0 assert status.rate == 0 @@ -141,7 +141,7 @@ def test_status_03(tmp_path): assert status.iteration == loaded.iteration assert status.log_size == loaded.log_size assert status.pid == loaded.pid - assert loaded.results.get("uid1") == ("uid1", 1, "sig1") + assert loaded.results.get("uid1") == ResultEntry("uid1", 1, "sig1") assert "test" in loaded._profiles @@ -183,7 +183,7 @@ def test_status_05(mocker, tmp_path): assert status.iteration == loaded.iteration assert status.log_size == loaded.log_size assert status.pid == loaded.pid - assert loaded.results.get("uid1") == ("uid1", 1, "sig1") + assert loaded.results.get("uid1") == ResultEntry("uid1", 1, "sig1") # NOTE: this function must be at the top level to work on Windows @@ -572,16 +572,16 @@ def test_report_counter_01(tmp_path, keys, counts, limit): db_path = tmp_path / "storage.db" counter = ResultCounter(1, db_path, report_limit=limit) for report_id, counted in zip(keys, counts): - assert counter.get(report_id) == (report_id, 0, None) + assert counter.get(report_id) == ResultEntry(report_id, 0, None) assert not counter.is_frequent(report_id) # call count() with report_id 'counted' times for current in range(1, counted + 1): assert counter.count(report_id, "desc") == (current, (current == 1)) # test get() if sum(counts) > 0: - assert counter.get(report_id) == (report_id, counted, "desc") + assert counter.get(report_id) == ResultEntry(report_id, counted, "desc") else: - assert counter.get(report_id) == (report_id, counted, None) + assert counter.get(report_id) == ResultEntry(report_id, counted, None) # test is_frequent() if counted > limit > 0: assert counter.is_frequent(report_id) @@ -592,8 +592,8 @@ def test_report_counter_01(tmp_path, keys, counts, limit): assert counter.is_frequent(report_id) else: assert limit == 0 - for _report_id, counted, _desc in counter: - assert counted > 0 + for result in counter: + assert result.count > 0 assert counter.total == sum(counts) @@ -670,17 +670,17 @@ def test_report_counter_03(mocker, tmp_path): # last 2 seconds loaded = ReadOnlyResultCounter.load(db_path, 2)[0] assert loaded.total == 1 - assert loaded.get("b") == ("b", 1, "desc_b") + assert loaded.get("b") == ResultEntry("b", 1, "desc_b") # last 3 seconds loaded = ReadOnlyResultCounter.load(db_path, 3)[0] - assert loaded.get("a") == ("a", 2, "desc_a") + assert loaded.get("a") == ResultEntry("a", 2, "desc_a") assert loaded.total == 3 # increase time limit fake_time.return_value = 4 loaded = ReadOnlyResultCounter.load(db_path, 10)[0] assert loaded.total == counter.total == 3 - assert loaded.get("a") == ("a", 2, "desc_a") - assert loaded.get("b") == ("b", 1, "desc_b") + assert loaded.get("a") == ResultEntry("a", 2, "desc_a") + assert loaded.get("b") == ResultEntry("b", 1, "desc_b") def test_report_counter_04(mocker, tmp_path): diff --git a/grizzly/common/test_status_reporter.py b/grizzly/common/test_status_reporter.py index e8ada0ad..e490d37d 100644 --- a/grizzly/common/test_status_reporter.py +++ b/grizzly/common/test_status_reporter.py @@ -253,13 +253,10 @@ def test_status_reporter_03(mocker, disk, memory, getloadavg): autospec=True, return_value=memory, ) - if getloadavg is None: - # simulate platform that does not have os.getloadavg() - mocker.patch("grizzly.common.status_reporter.getloadavg", None) - else: - mocker.patch( - "grizzly.common.status_reporter.getloadavg", side_effect=getloadavg - ) + mocker.patch( + "grizzly.common.status_reporter.GETLOADAVG_AVAILABLE", getloadavg is not None + ) + mocker.patch("grizzly.common.status_reporter.getloadavg", side_effect=getloadavg) sysinfo = StatusReporter._sys_info() assert len(sysinfo) == 3 assert sysinfo[0][0] == "CPU & Load"