Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/file reporter #78

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/nnbench/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
import types

from .base import BenchmarkReporter
from .file import FileReporter

# internal, mutable
_reporter_registry: dict[str, type[BenchmarkReporter]] = {}
_reporter_registry: dict[str, type[BenchmarkReporter]] = {
"file": FileReporter,
}

# external, immutable
reporter_registry: types.MappingProxyType[str, type[BenchmarkReporter]] = types.MappingProxyType(
Expand Down
205 changes: 205 additions & 0 deletions src/nnbench/reporter/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from __future__ import annotations

import os
from typing import Any, List

from nnbench.reporter.base import BenchmarkReporter
from nnbench.types import BenchmarkRecord


class Parser:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for clarification: We're not actually parsing things when we are reading files, that is the job of the respective modules (json|yaml|toml).

"""The base interface for parsing records form file.

Usage:
------
```
class MyCustomParser(Parser):
def parse_file(self, records):
# Implement your custom parsing logic here
...
def write_records(self, records):
# Implement your custom file writing logic here
...
# Register your custom parser with a distinct file type
MyCustomParser.register("my_custom_format")
# Usage:
new_record = ... # Load records in your custom format
append_record_to_records(records, new_record, "my_custom_format")
```
"""

def parse_file(self, records: str) -> Any:
"""Parses records and returns a list of parsed data.

Args:
-----
`records:` A list or iterator of record strings.

Returns:
--------
A list of parsed records.
"""
raise NotImplementedError

def write_records(self, records: Any[BenchmarkRecord], record: BenchmarkRecord) -> str:
"""Appends a record to the existing records based on the file type.

Args:
-----
`records:` A list of parsed records.
`record:` The record string to append.
`file_type:` The file type (string).

Returns:
--------
A string form of the content to be written in a file.
"""
raise NotImplementedError

@classmethod
def register(cls, file_type: str) -> None:
"""Registers a parser for a specific file type.

Args:
`file_type:` The file type (string)
"""
parsers[file_type] = cls

@staticmethod
def get_parser(file_type: str):
"""Gets the registered parser for a file type.

Args:
`file_type:` The file type (string)

Returns:
--------
The registered RecordParser, or None if not found.
"""
return parsers.get(file_type)


class JsonParser(Parser):
def parse_file(self, records: str) -> List[dict]:
import json

try:
return json.loads(records if records else "[]")
except json.JSONDecodeError:
raise ValueError("Unexpected records passed")

def write_records(
self, parsed_records: Any[BenchmarkRecord] | None, record: BenchmarkRecord
) -> str:
import json

parsed_records.append(record)
return json.dumps(parsed_records)


class YamlParser(Parser):
def parse_file(self, records: str) -> List[dict]:
import yaml

return yaml.safe_load(records) if records else []

def write_records(
self, parsed_records: Any[BenchmarkRecord] | None, record: BenchmarkRecord
) -> str:
import yaml

parsed_records.append(record)
for element in record["benchmarks"]:
element["value"] = float(element["value"])
return yaml.dump(parsed_records)


# Register custom parsers here
parsers = {"json": JsonParser, "yaml": YamlParser}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of making it a module-wide default, but we should take care that

a) the variable is private (i.e. has a leading underscore) to prevent accidental export, and
b) we should make the value structure of the map as easy as possible.

To address b), I would start by making it a tuple[ser, de] where ser is a Callable[[IO, dict[str, Any]], None], i.e., a function taking a file descriptor in write mode and a record and writing it to a file, and de being a Callable[[IO], dict[str, Any]], a function taking a file descriptor in read mode and returning the loaded record.

You can then register the SerDe factories based on whether the necessary packages are installed (json is available out of the box, yaml and toml are not).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I use a class-based approach to register the SerDe factories (like I used in Parser class) ?
One other option is to use a simple register method which takes 3 arguments (i.e., Ser and De functions and a file_type) as arguments.
Also, I think the class-based approach is much more concise to define the ser and de methods on the user side.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with either, though the optional import part (i.e., erroring if toml/yaml are not installed) will be a bit easier in the class case.

For now, I think the quickest way is a functional approach, though. Maybe like this:

_file_loaders: dict[str, tuple[Any, Any]] = {}

def yaml_load(fp: IO, options=None):
    try:
        import yaml
    except ImportError:
        raise ModuleNotFoundError("`pyyaml` is not installed")
    
    # takes no options, but the slot is useful for passing options to file loaders.
    obj = yaml.safe_load(fp)
    return BenchmarkRecord(context=obj["context"], benchmarks=obj["benchmarks"])

def yaml_save(record: BenchmarkRecord, fp: IO, options=None) -> None:
    try:
        import yaml
    except ImportError:
        raise ModuleNotFoundError("`pyyaml` is not installed")

    yaml.safe_dump(record, fp, **(options or {})


_file_loaders["yaml"] = (yaml_save, yaml_load)

With an option of defining e.g. a register_file_io(ser, de) later to do the dict insertion if we want.



def parse_records(records: str, file_type: str) -> Any:
"""Parses records based on the specified file type.

This function retrieves and calls the registered parser for
the given file type.

Args:
`records:` A list or iterator of record strings.
`file_type:` The file type (string).

Returns:
A list of parsed records.
"""

parser = Parser.get_parser(file_type)
if parser is None:
raise ValueError(f"Unsupported file type: {file_type}")

return parser().parse_file(records)


def append_record_to_records(parsed_records: Any, record: BenchmarkRecord, file_type: str) -> str:
"""Appends a record to the list based on the file type.

This function first parses the record using the appropriate parser
and then appends it to the `parsed_records`.

Args:
`records:` A list of parsed records.
`record:` The record to append.
`file_type:` The file type (string).
"""

parser = Parser.get_parser(file_type)
if parser is None:
raise ValueError(f"Unsupported file type: {file_type}")

return parser().write_records(parsed_records, record)


class FileReporter(BenchmarkReporter):
def __init__(self, dir: str):
self.dir = dir
Hrsj123 marked this conversation as resolved.
Show resolved Hide resolved
if not os.path.exists(dir):
self.initialize()

def initialize(self) -> None:
try:
os.makedirs(self.dir, exist_ok=True)
except OSError as e:
self.finalize()
raise ValueError(f"Could not create directory: {self.dir}") from e
Hrsj123 marked this conversation as resolved.
Show resolved Hide resolved

def read(self, file_name: str) -> BenchmarkRecord:
if not self.dir:
raise BaseException("Directory is not initialized")
file_path = os.path.join(self.dir, file_name)
file_type = file_name.split(".")[1]
try:
with open(file_path) as file:
data = file.read()
parsed_data = parse_records(data, file_type)
return parsed_data
except FileNotFoundError:
raise ValueError(f"Could not read the file: {file_path}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs the restructured file loading dict I talked about earlier, but in essence, all that should happen here is the file being loaded with open (like you did), and then calling the deserializer on the opened file.

In particular, no error handling for open() should be necessary, since those are informative enough for the user on their own.


def write(self, record: BenchmarkRecord, file_name: str) -> None:
if not self.dir:
raise BaseException("Directory is not initialized")

file_path = os.path.join(self.dir, file_name)
if not os.path.exists(file_path): # Create the file
with open(file_path, "w") as file:
file.write("")
try:
parsed_records = self.read(file_name)
file_type = file_name.split(".")[1]
new_records = append_record_to_records(parsed_records, record, file_type)
with open(file_path, "w") as file:
file.write(new_records)
except FileNotFoundError:
raise ValueError(f"Could not read the file: {file_path}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, just load the serializer from the dict and call it on the opened file.


def finalize(self) -> None:
del self.dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not do what you think it does - it just stages the self.dir variable for garbage collection.
To safely remove the directory (which you might not want to do anyway, we could add a flag in the constructor for that?), you should call shutil.rmtree(self.dir, ignore_errors=True). (Though you might want to check existence first and set ignore_errors to False.