From 236fe909715a7707fe6d6bc2d77cd95f9b3b08d6 Mon Sep 17 00:00:00 2001 From: Jason Ernst Date: Thu, 12 Sep 2024 11:27:51 +0200 Subject: [PATCH] added a packet encapsulation test that uses ip,nextheader and payload together --- .../main/kotlin/com/jasonernst/knet/Packet.kt | 7 ++--- .../com/jasonernst/knet/EncapsulationTests.kt | 27 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/knet/src/main/kotlin/com/jasonernst/knet/Packet.kt b/knet/src/main/kotlin/com/jasonernst/knet/Packet.kt index ecf668f..b1ada49 100644 --- a/knet/src/main/kotlin/com/jasonernst/knet/Packet.kt +++ b/knet/src/main/kotlin/com/jasonernst/knet/Packet.kt @@ -23,10 +23,11 @@ data class Packet( val ipHeader = IPHeader.fromStream(stream) val nextHeader = NextHeader.fromStream(stream, ipHeader.protocol) - if (stream.remaining() < ipHeader.getPayloadLength().toInt()) { - throw PacketTooShortException("Packet too short to obtain entire payload") + val expectedRemaining = (ipHeader.getPayloadLength() - nextHeader.getHeaderLength()).toInt() + if (stream.remaining() < expectedRemaining) { + throw PacketTooShortException("Packet too short to obtain entire payload, have ${stream.remaining()}, expecting $expectedRemaining") } - val payload = ByteArray(ipHeader.getPayloadLength().toInt()) + val payload = ByteArray(expectedRemaining) stream.get(payload) return Packet(ipHeader, nextHeader, payload) } diff --git a/knet/src/test/kotlin/com/jasonernst/knet/EncapsulationTests.kt b/knet/src/test/kotlin/com/jasonernst/knet/EncapsulationTests.kt index dc2fdb7..8c4af8d 100644 --- a/knet/src/test/kotlin/com/jasonernst/knet/EncapsulationTests.kt +++ b/knet/src/test/kotlin/com/jasonernst/knet/EncapsulationTests.kt @@ -148,4 +148,31 @@ class EncapsulationTests { ) encapsulationTest(sourceAddress, destinationAddress, tcpHeader, payload) } + + @Test + fun packetEncapsulationTest() { + val payload = "test".toByteArray() + val sourcePort = Random.Default.nextInt(2 * Short.MAX_VALUE - 1) + val sourceAddress = InetSocketAddress(InetAddress.getByName("::1"), sourcePort) + val destPort = Random.Default.nextInt(2 * Short.MAX_VALUE - 1) + val destinationAddress = InetSocketAddress(InetAddress.getByName("::2"), destPort) + val tcpHeader = + TCPHeader( + sourcePort = sourcePort.toUShort(), + destinationPort = destPort.toUShort(), + sequenceNumber = 100u, + acknowledgementNumber = 500u, + windowSize = 35000.toUShort(), + ) + val ipHeader = IPHeader.createIPHeader( + sourceAddress.address, + destinationAddress.address, + IPType.TCP, + tcpHeader.getHeaderLength().toInt() + payload.size + ) + val packet = Packet(ipHeader, tcpHeader, payload) + val stream = ByteBuffer.wrap(packet.toByteArray()) + val parsedPacket = Packet.fromStream(stream) + assertEquals(packet, parsedPacket) + } }