Skip to content

Commit

Permalink
tests for frame fragmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
genesiscrew committed May 21, 2024
1 parent f681856 commit 16ff8cc
Showing 1 changed file with 128 additions and 1 deletion.
129 changes: 128 additions & 1 deletion tests/test_frame.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest
import zigpy_zboss.types as t
from zigpy_zboss.frames import Frame, InvalidFrame, CRC8, HLPacket
from zigpy_zboss.frames import (
Frame, InvalidFrame, CRC8,
HLPacket, ZBNCP_LL_BODY_SIZE_MAX, LLHeader
)


def test_frame_deserialization():
Expand Down Expand Up @@ -112,3 +115,127 @@ def test_first_frag_flag_deserialization():
assert frame.hl_packet.header.control_type == t.ControlType.RSP
assert frame.hl_packet.header.id == 0x1234
assert frame.hl_packet.data == b"test_data"


def test_handle_tx_fragmentation():
"""Test the handle_tx_fragmentation method for proper fragmentation."""
# Create an HLCommonHeader with specific fields
hl_header = t.HLCommonHeader(
version=0x01, type=t.ControlType.RSP, id=0x1234
)
large_data = b"a" * (ZBNCP_LL_BODY_SIZE_MAX * 2 + 50)
hl_data = t.Bytes(large_data)

# Create an HLPacket with the large data
hl_packet = HLPacket(header=hl_header, data=hl_data)
frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet)

fragments = frame.handle_tx_fragmentation()

total_fragments = frame.count_fragments()
assert len(fragments) == total_fragments

# Calculate the expected size of the first fragment
# Exclude the CRC16 for size calculation
serialized_hl_packet = hl_packet.serialize()[2:]
first_frag_size = (
len(serialized_hl_packet) % ZBNCP_LL_BODY_SIZE_MAX
or ZBNCP_LL_BODY_SIZE_MAX
)

# Check the first fragment
first_fragment = fragments[0]
assert first_fragment.ll_header.flags == t.LLFlags.FirstFrag
assert first_fragment.ll_header.size == first_frag_size + 7
assert len(first_fragment.hl_packet.data) == first_frag_size - 4

# Check the middle fragments
for middle_fragment in fragments[1:-1]:
assert middle_fragment.ll_header.flags == 0
assert middle_fragment.ll_header.size == ZBNCP_LL_BODY_SIZE_MAX + 7
assert len(middle_fragment.hl_packet.data) == ZBNCP_LL_BODY_SIZE_MAX

# Check the last fragment
last_fragment = fragments[-1]
last_frag_size = (
len(serialized_hl_packet) -
(first_frag_size + (total_fragments - 2) * ZBNCP_LL_BODY_SIZE_MAX)
)
assert last_fragment.ll_header.flags == t.LLFlags.LastFrag
assert last_fragment.ll_header.size == last_frag_size + 7
assert len(last_fragment.hl_packet.data) == last_frag_size


def test_handle_tx_fragmentation_edge_cases():
"""Test the handle_tx_fragmentation method for various edge cases."""
# Data size exactly equal to ZBNCP_LL_BODY_SIZE_MAX
exact_size_data = b"a" * (ZBNCP_LL_BODY_SIZE_MAX - 2 - 2)
hl_header = t.HLCommonHeader(version=0x01, type=t.ControlType.RSP,
id=0x1234)
hl_packet = HLPacket(header=hl_header, data=t.Bytes(exact_size_data))
frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet)

# Perform fragmentation
fragments = frame.handle_tx_fragmentation()
assert len(fragments) == 1 # Should not fragment

# Test with data size just above ZBNCP_LL_BODY_SIZE_MAX
just_above_size_data = b"a" * (ZBNCP_LL_BODY_SIZE_MAX + 1 - 2 - 2)
hl_packet = HLPacket(header=hl_header, data=t.Bytes(just_above_size_data))
frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet)
fragments = frame.handle_tx_fragmentation()
assert len(fragments) == 2 # Should fragment into two

# Test with data size much larger than ZBNCP_LL_BODY_SIZE_MAX
large_data = b"a" * ((ZBNCP_LL_BODY_SIZE_MAX * 5) + 50 - 2 - 2)
hl_packet = HLPacket(header=hl_header, data=t.Bytes(large_data))
frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet)
fragments = frame.handle_tx_fragmentation()
assert len(fragments) == 6 # 5 full fragments and 1 partial fragment

# Test with very small data
small_data = b"a" * 10
hl_packet = HLPacket(header=hl_header, data=t.Bytes(small_data))
frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet)
fragments = frame.handle_tx_fragmentation()
assert len(fragments) == 1 # Should not fragment


def test_handle_rx_fragmentation():
"""Test the handle_rx_fragmentation method for.
proper reassembly of fragments.
"""
# Create an HLCommonHeader with specific fields
hl_header = t.HLCommonHeader(
version=0x01, type=t.ControlType.RSP, id=0x1234
)
large_data = b"a" * (ZBNCP_LL_BODY_SIZE_MAX * 2 + 50)
hl_data = t.Bytes(large_data)

# Create an HLPacket with the large data
hl_packet = HLPacket(header=hl_header, data=hl_data)
frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet)

# Perform fragmentation
fragments = frame.handle_tx_fragmentation()

# Verify that the correct number of fragments was created
total_fragments = frame.count_fragments()
assert len(fragments) == total_fragments

# Reassemble the fragments using handle_rx_fragmentation
reassembled_frame = Frame.handle_rx_fragmentation(fragments)

# Verify the reassembled frame
assert (
reassembled_frame.ll_header.frame_type == t.TYPE_ZBOSS_NCP_API_HL
)
assert (
reassembled_frame.ll_header.flags ==
(t.LLFlags.FirstFrag | t.LLFlags.LastFrag)
)

# Verify the reassembled data matches the original data
reassembled_data = reassembled_frame.hl_packet.data
assert reassembled_data == large_data

0 comments on commit 16ff8cc

Please sign in to comment.