Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bytes*.fill() helper #761

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/merkle_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
MIDDLE = bytes([2])
TRUNCATED = bytes([3])

BLANK = bytes32([0] * 32)
BLANK = bytes32.zeros

prehashed: Dict[bytes, _Hash] = {}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_merkle_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def check_tree(leafs: List[bytes32]) -> None:
)

for i in range(256):
item = bytes32([i] + [2] * 31)
item = bytes32.fill(i.to_bytes(), fill=b"\x02", align="<")
py_included, py_proof = py_tree.is_included_already_hashed(item)
assert not py_included
ru_included, ru_proof = ru_tree.is_included_already_hashed(item)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_sized_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from chia_rs.sized_bytes import bytes8


def test_fill_empty() -> None:
assert bytes8.fill(b"", b"\x01") == bytes8([1, 1, 1, 1, 1, 1, 1, 1])


def test_fill_non_empty_with_single() -> None:
assert bytes8.fill(b"\x02", b"\x01") == bytes8([1, 1, 1, 1, 1, 1, 1, 2])


def test_fill_non_empty_with_double() -> None:
assert bytes8.fill(b"\x02\x02", b"\x01\x01") == bytes8([1, 1, 1, 1, 1, 1, 2, 2])


def test_fill_needed_with_0_length_fill_raises() -> None:
with pytest.raises(ValueError):
bytes8.fill(b"\x00", fill=b"")


def test_fill_not_needed_with_0_length_fill_works() -> None:
blob = b"\x00" * 8
assert bytes8.fill(blob, fill=b"") == bytes8(blob)


def test_fill_not_multiple_raises() -> None:
with pytest.raises(ValueError):
bytes8.fill(b"\x00", fill=b"\x01\x01")

def test_align_left() -> None:
assert bytes8.fill(b"\x01", fill=b"\x02", align="<") == bytes8([1, 2, 2, 2, 2, 2, 2, 2])

def test_invalid_alignment() -> None:
with pytest.raises(ValueError):
bytes8.fill(b"", fill=b"\x00", align="|")
22 changes: 22 additions & 0 deletions wheel/python/chia_rs/sized_byte_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (
BinaryIO,
Iterable,
Literal,
Optional,
SupportsBytes,
SupportsIndex,
Expand Down Expand Up @@ -78,6 +79,27 @@ def random(
def secret(cls: Type[_T_SizedBytes]) -> _T_SizedBytes:
return cls.random(r=system_random)

@classmethod
def fill(cls: Type[_T_SizedBytes], blob: bytes, fill: bytes, align: Literal["<", ">"] = ">") -> _T_SizedBytes:
if len(blob) == cls._size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if len(blob) > cls._size, do we want to truncate and construct cls then too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am tending towards no. It seems like with these classes that truncation would be surprising. I think most use cases for this are tests where you are trying to create unique values. If your blob is too long and is truncated you hazard unexpectedly having multiple of the same value generated.

This isn't meant to be a 'format' level tool, but I will note that format just creates a larger output in this case.

return cls(blob)

fill_length = len(fill)
if fill_length == 0:
raise ValueError("fill required but length is zero")

div, mod = divmod(cls._size - len(blob), fill_length)
if mod != 0:
raise ValueError("invalid fill value, range to be filled must be multiple of fil size")

all_fill = fill * div
if align == "<":
return cls(blob + all_fill)
elif align == ">":
return cls(all_fill + blob)

raise ValueError(f"invalid alignment: {align!r}")

def __str__(self) -> str:
return self.hex()

Expand Down
Loading