Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add assertions to support MetricWriter #129

Open
wants to merge 2 commits into
base: ms_metric-writer-feature-branch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions fgpyo/util/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
from typing import Generic
from typing import Iterator
from typing import List
from typing import Type
from typing import TypeVar

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -406,6 +407,16 @@ def _read_header(
return MetricFileHeader(preamble=preamble, fieldnames=fieldnames)


def _is_metric_class(cls: Any) -> TypeGuard[Metric]:
msto marked this conversation as resolved.
Show resolved Hide resolved
"""True if the given class is a Metric."""

return (
isclass(cls)
and issubclass(cls, Metric)
and (dataclasses.is_dataclass(cls) or attr.has(cls))
)


def _is_dataclass_instance(metric: Metric) -> TypeGuard[DataclassInstance]:
msto marked this conversation as resolved.
Show resolved Hide resolved
"""
Test if the given metric is a dataclass instance.
Expand Down Expand Up @@ -466,3 +477,98 @@ def asdict(metric: Metric) -> Dict[str, Any]:
"The provided metric is not an instance of a `dataclass` or `attr.s`-decorated Metric "
f"class: {metric.__class__}"
)


def _get_fieldnames(metric_class: Type[Metric]) -> List[str]:
"""
Get the fieldnames of the specified metric class.

Args:
metric_class: A Metric class.

Returns:
A list of fieldnames.
"""
_assert_is_metric_class(metric_class)

if dataclasses.is_dataclass(metric_class):
return [f.name for f in dataclasses.fields(metric_class)]
elif attr.has(metric_class):
return [f.name for f in attr.fields(metric_class)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use metric.header() method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short answer: this is an artifact from the initial proof-of-concept for this class being implemented in https://github.com/msto/dataclass_io/ (i.e. only for dataclasses, not considering Metric)

Long answer: metric.header() and inspect.get_fields() have a lot of type: ignores , and I'd rather use an implementation that's type-safe. (I have a mental TODO to clean up the volume of type: ignores in the repo, and I can leave a TODO that this function should be deprecated/merged into inspect once that happens.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use it and fix it?

else:
assert False, "Unreachable"


def _assert_file_header_matches_metric(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just saying out loud this will not work if the path is a stream, like standard input, since it consumes the header.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add an assertion that the path is not /dev/stdin. Are there other streams that are likely to be represented as a Path?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just streams I think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump

path: Path,
metric_class: Type[MetricType],
delimiter: str,
) -> None:
"""
Check that the specified file has a header and its fields match those of the provided Metric.

Args:
path: A path to a `Metric` file.
metric_class: The `Metric` class to validate against.
delimiter: The delimiter to use when reading the header.

Raises:
ValueError: If the provided file does not include a header.
ValueError: If the header of the provided file does not match the provided Metric.
"""
# NB: _get_fieldnames() will validate that `metric_class` is a valid Metric class.
fieldnames: List[str] = _get_fieldnames(metric_class)

header: MetricFileHeader
with path.open("r") as fin:
try:
header = metric_class._read_header(fin, delimiter=delimiter)
except ValueError:
raise ValueError(f"Could not find a header in the provided file: {path}")

if header.fieldnames != fieldnames:
raise ValueError(
"The provided file does not have the same field names as the provided Metric:\n"
f"\tMetric: {metric_class.__name__}\n"
f"\tFile: {path}\n"
f"\tExpected fields: {', '.join(fieldnames)}\n"
f"\tActual fields: {', '.join(header.fieldnames)}\n"
)


def _assert_fieldnames_are_metric_attributes(
msto marked this conversation as resolved.
Show resolved Hide resolved
specified_fieldnames: List[str],
metric_class: Type[MetricType],
) -> None:
"""
Check that all of the specified fields are attributes on the given Metric.

Raises:
ValueError: if any of the specified fieldnames are not an attribute on the given Metric.
"""
_assert_is_metric_class(metric_class)

invalid_fieldnames = {
f for f in specified_fieldnames if f not in _get_fieldnames(metric_class)
msto marked this conversation as resolved.
Show resolved Hide resolved
}

if len(invalid_fieldnames) > 0:
raise ValueError(
"One or more of the specified fields are not attributes on the Metric "
+ f"{metric_class.__name__}: "
+ ", ".join(invalid_fieldnames)
)


def _assert_is_metric_class(cls: Type[Metric]) -> None:
"""
Assert that the given class is a Metric.

Args:
cls: A class object.

Raises:
TypeError: If the given class is not a Metric.
"""
if not _is_metric_class(cls):
raise TypeError(f"Not a dataclass or attr decorated Metric: {cls}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we feel about _assert functions raising errors other than AssertionError? I would probably be surprised by this.

Copy link
Contributor Author

@msto msto Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think assert should be used outside of unit tests, and I prefer to raise more specific exceptions.

The Google style guide recommends/requires this as well:

Do not use assert statements in place of conditionals or validating preconditions. They must not be critical to the application logic. A litmus test would be that the assert could be removed without breaking the code. assert conditionals are not guaranteed to be evaluated. For pytest based tests, assert is okay and expected to verify expectations.

The primary reason to avoid the use of assert statements is that they are intended for debugging and can be disabled.

Additionally, using assert makes it easy to fail to cover all branches of a program:

# Covered by any unit test that calls the parent function
assert is_ok(foo), "Foo was not ok"

if not is_ok(foo):
    # Not covered unless the condition where `is_ok(foo)` is False is tested explicitly
    raise ValueError("Foo was not ok")  

I'm happy to rename the functions to something other than assert_* (I've used validate_* in the past), I was just trying to remain consistent with naming convention in the project.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a good case for having a custom exception:

class ValueAssertion(AssertionError, ValueError):
    pass

Then you would catch the exception with either AssertionError or ValueError.

I do like having the explicit branch with if since code coverage analysis will identify it as a covered/uncovered branch depending on actual test coverage. With single-line statements (e.g. assert), you don't get any branch information so it is easy to forget to test the exception paths.

156 changes: 156 additions & 0 deletions fgpyo/util/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import enum
import gzip
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
Expand All @@ -29,6 +30,10 @@
from fgpyo.util.inspect import is_attr_class
from fgpyo.util.inspect import is_dataclasses_class
from fgpyo.util.metric import Metric
from fgpyo.util.metric import _assert_fieldnames_are_metric_attributes
from fgpyo.util.metric import _assert_file_header_matches_metric
from fgpyo.util.metric import _assert_is_metric_class
from fgpyo.util.metric import _get_fieldnames
from fgpyo.util.metric import _is_attrs_instance
from fgpyo.util.metric import _is_dataclass_instance
from fgpyo.util.metric import asdict
Expand Down Expand Up @@ -590,3 +595,154 @@ def test_read_header_can_read_picard(tmp_path: Path) -> None:
header = Metric._read_header(metrics_file, comment_prefix="#")

assert header.fieldnames == ["SAMPLE", "FOO", "BAR"]


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
def test_get_fieldnames(data_and_classes: DataBuilder) -> None:
"""Test we can get the fieldnames of a metric."""

assert _get_fieldnames(data_and_classes.Person) == ["name", "age"]


def test_fieldnames_raises_if_not_a_metric() -> None:
"""Test we raise if we get a non-metric."""

@dataclass
msto marked this conversation as resolved.
Show resolved Hide resolved
class BadMetric:
foo: str
bar: int

with pytest.raises(TypeError, match="Not a dataclass or attr decorated Metric"):
_get_fieldnames(BadMetric) # type: ignore[arg-type]


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
def test_assert_is_metric_class(data_and_classes: DataBuilder) -> None:
"""
Test that we can validate if a class is a Metric.
"""
try:
_assert_is_metric_class(data_and_classes.DummyMetric)
except TypeError:
raise AssertionError("Failed to validate a valid Metric") from None


def test_assert_is_metric_class_raises_if_not_decorated() -> None:
"""
Test that we raise an error if the provided type is a Metric subclass but not decorated as a
dataclass or attr.
"""

class BadMetric(Metric["BadMetric"]):
foo: str
bar: int

with pytest.raises(TypeError, match="Not a dataclass or attr decorated Metric"):
_assert_is_metric_class(BadMetric)


def test_assert_is_metric_class_raises_if_not_a_metric() -> None:
"""
Test that we raise an error if the provided type is decorated as a
dataclass or attr but does not subclass Metric.
"""

@dataclass
class BadMetric:
foo: str
bar: int

with pytest.raises(TypeError, match="Not a dataclass or attr decorated Metric"):
_assert_is_metric_class(BadMetric)

@attr.s
class BadMetric:
foo: str
bar: int

with pytest.raises(TypeError, match="Not a dataclass or attr decorated Metric"):
_assert_is_metric_class(BadMetric)


# fmt: off
@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
@pytest.mark.parametrize(
"fieldnames",
[
["name", "age"], # The fieldnames are all the attributes of the provided metric
["age", "name"], # The fieldnames are out of order
["name"], # The fieldnames are a subset of the attributes of the provided metric
],
)
# fmt: on
def test_assert_fieldnames_are_metric_attributes(
data_and_classes: DataBuilder,
fieldnames: List[str],
) -> None:
"""
Should not raise an error if the provided fieldnames are all attributes of the provided metric.
"""
try:
_assert_fieldnames_are_metric_attributes(fieldnames, data_and_classes.Person)
except Exception:
raise AssertionError("Fieldnames should be valid") from None


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
@pytest.mark.parametrize(
"fieldnames",
[
["name", "age", "foo"],
["name", "foo"],
["foo", "name", "age"],
["foo"],
],
)
def test_assert_fieldnames_are_metric_attributes_raises(
data_and_classes: DataBuilder,
fieldnames: List[str],
) -> None:
"""
Should raise an error if any of the provided fieldnames are not an attribute on the metric.
"""
with pytest.raises(ValueError, match="One or more of the specified fields are not "):
_assert_fieldnames_are_metric_attributes(fieldnames, data_and_classes.Person)


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
def test_assert_file_header_matches_metric(tmp_path: Path, data_and_classes: DataBuilder) -> None:
"""
Should not raise an error if the provided file header matches the provided metric.
"""
metric_path = tmp_path / "metrics.tsv"
with metric_path.open("w") as metrics_file:
metrics_file.write("name\tage\n")

try:
_assert_file_header_matches_metric(metric_path, data_and_classes.Person, delimiter="\t")
except Exception:
raise AssertionError("File header should be valid") from None


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
@pytest.mark.parametrize(
"header",
[
["name"],
["age"],
["name", "age", "foo"],
["foo", "name", "age"],
],
)
def test_assert_file_header_matches_metric_raises(
tmp_path: Path, data_and_classes: DataBuilder, header: List[str]
) -> None:
"""
Should raise an error if the provided file header does not match the provided metric.
"""
metric_path = tmp_path / "metrics.tsv"
with metric_path.open("w") as metrics_file:
metrics_file.write("\t".join(header) + "\n")

with pytest.raises(ValueError, match="The provided file does not have the same field names"):
_assert_file_header_matches_metric(metric_path, data_and_classes.Person, delimiter="\t")
Loading