From 29621a386c03faa4ce961b00d91f3d0aeb12c1d1 Mon Sep 17 00:00:00 2001 From: Jason Ernst Date: Fri, 27 Sep 2024 12:17:09 +0000 Subject: [PATCH] increase test coverage, fixes for fragmentation --- .../kotlin/com/jasonernst/knet/ip/IpHeader.kt | 7 +- .../com/jasonernst/knet/ip/v4/Ipv4Header.kt | 31 +++-- .../com/jasonernst/knet/ip/v6/Ipv6Header.kt | 3 +- .../com/jasonernst/knet/ip/IpHeaderTest.kt | 6 + .../jasonernst/knet/ip/v4/Ipv4FragmentTest.kt | 121 ++++++++++++++++++ .../jasonernst/knet/ip/v4/Ipv4HeaderTest.kt | 21 +-- 6 files changed, 155 insertions(+), 34 deletions(-) create mode 100644 knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4FragmentTest.kt diff --git a/knet/src/main/kotlin/com/jasonernst/knet/ip/IpHeader.kt b/knet/src/main/kotlin/com/jasonernst/knet/ip/IpHeader.kt index 2a5400c..149a09f 100644 --- a/knet/src/main/kotlin/com/jasonernst/knet/ip/IpHeader.kt +++ b/knet/src/main/kotlin/com/jasonernst/knet/ip/IpHeader.kt @@ -23,9 +23,10 @@ interface IpHeader { /** * Helper function so that we can ensure the payload length is a multiple of 8 */ - fun closestDivisibleBy(initialValue: UInt, divisor: UInt): UInt { - return (initialValue + divisor - 1u) / divisor * divisor - } + fun closestDivisibleBy( + initialValue: UInt, + divisor: UInt, + ): UInt = (initialValue + divisor - 1u) / divisor * divisor fun fromStream(stream: ByteBuffer): IpHeader { val start = stream.position() diff --git a/knet/src/main/kotlin/com/jasonernst/knet/ip/v4/Ipv4Header.kt b/knet/src/main/kotlin/com/jasonernst/knet/ip/v4/Ipv4Header.kt index 806f5de..d237667 100644 --- a/knet/src/main/kotlin/com/jasonernst/knet/ip/v4/Ipv4Header.kt +++ b/knet/src/main/kotlin/com/jasonernst/knet/ip/v4/Ipv4Header.kt @@ -15,6 +15,7 @@ import java.nio.ByteBuffer import java.nio.ByteOrder import kotlin.experimental.or import kotlin.math.ceil +import kotlin.math.min /** * Internet Protocol Version 4 Header Implementation. @@ -29,8 +30,8 @@ data class Ipv4Header( val dscp: UByte = 0u, // 2-bits, explicit congestion notification. val ecn: UByte = 0u, - // 16-bits, IP packet, including the header - private val totalLength: UShort = 0u, + // 16-bits, IP packet, including the header: default to a just the header with no payload + private val totalLength: UShort = IP4_MIN_HEADER_LENGTH.toUShort(), // 16-bits, groups fragments of a single IPv4 datagram. val id: UShort = 0u, // if the packet is marked as don't fragment and we can't fit it in a single packet, drop it. @@ -313,9 +314,12 @@ data class Ipv4Header( * */ fun fragment( - maxSize: UInt, + maxSize: UInt, // max size includes the header size payload: ByteArray, ): List> { + if (maxSize.toInt() % 8 != 0) { + throw IllegalArgumentException("Fragment max size must be divisible by 8") + } if (dontFragment) { throw IllegalArgumentException("Cannot fragment packets marked as don't fragment") } @@ -328,11 +332,17 @@ data class Ipv4Header( "The smallest fragment size is ${IP4_MIN_FRAGMENT_PAYLOAD.toInt()} bytes because it must align on a 64-bit boundary", ) } - var payloadPerPacket = closestDivisibleBy(maxSize - getHeaderLength(), 8u) + var lastFragment = false var payloadPosition = 0u + var payloadPerPacket = min(payload.size - payloadPosition.toInt(), closestDivisibleBy(maxSize - getHeaderLength(), 8u).toInt()) + logger.debug("PAYLOAD PER PACKET: $payloadPerPacket HEADERSIZE: ${getHeaderLength()}") + if (payloadPosition.toInt() + payloadPerPacket == payload.size) { + lastFragment = true + } + var isFirstFragment = true while (payloadPosition < payload.size.toUInt()) { - logger.debug("$payloadPosition:${payloadPosition + payloadPerPacket}") + logger.debug("$payloadPosition:${payloadPosition + payloadPerPacket.toUInt()}") val offsetIn64BitOctets = payloadPosition / 8u var newOptions = options var newIhl = ihl @@ -356,10 +366,10 @@ data class Ipv4Header( ihl = newIhl, dscp = dscp, ecn = ecn, - totalLength = (getHeaderLength() + payloadPerPacket).toUShort(), + totalLength = (getHeaderLength() + payloadPerPacket.toUInt()).toUShort(), id = id, dontFragment = false, - lastFragment = payloadPosition >= getHeaderLength(), + lastFragment = lastFragment, fragmentOffset = offsetIn64BitOctets.toUShort(), ttl = ttl, protocol = protocol, @@ -370,9 +380,10 @@ data class Ipv4Header( logger.debug("payload len:${newHeader.getPayloadLength()}") val newPayload = payload.copyOfRange(payloadPosition.toInt(), payloadPosition.toInt() + payloadPerPacket.toInt()) packetList.add(Pair(newHeader, newPayload)) - payloadPosition += payloadPerPacket - if (payloadPosition + payloadPerPacket > payload.size.toUInt()) { - payloadPerPacket = payload.size.toUInt() - payloadPosition + payloadPosition += payloadPerPacket.toUInt() + if (payloadPosition + payloadPerPacket.toUInt() > payload.size.toUInt()) { + payloadPerPacket = (payload.size.toUInt() - payloadPosition).toInt() + lastFragment = true } } return packetList diff --git a/knet/src/main/kotlin/com/jasonernst/knet/ip/v6/Ipv6Header.kt b/knet/src/main/kotlin/com/jasonernst/knet/ip/v6/Ipv6Header.kt index 79f95a5..2295d2b 100644 --- a/knet/src/main/kotlin/com/jasonernst/knet/ip/v6/Ipv6Header.kt +++ b/knet/src/main/kotlin/com/jasonernst/knet/ip/v6/Ipv6Header.kt @@ -255,7 +255,8 @@ data class Ipv6Header( destinationAddress, firstHeaderExtensions, ) - val firstPayloadBytes = closestDivisibleBy(maxSize - IP6_HEADER_SIZE - perFragmentHeaderBytes.toUInt() - extAndUpperBytes.toUInt(), 8u) + val firstPayloadBytes = + closestDivisibleBy(maxSize - IP6_HEADER_SIZE - perFragmentHeaderBytes.toUInt() - extAndUpperBytes.toUInt(), 8u) val firstPair = Triple(firstFragment, nextHeader, payload.sliceArray(0 until firstPayloadBytes.toInt())) fragments.add(firstPair) var payloadPosition = firstPayloadBytes.toInt() diff --git a/knet/src/test/kotlin/com/jasonernst/knet/ip/IpHeaderTest.kt b/knet/src/test/kotlin/com/jasonernst/knet/ip/IpHeaderTest.kt index 1db4700..eafc57b 100644 --- a/knet/src/test/kotlin/com/jasonernst/knet/ip/IpHeaderTest.kt +++ b/knet/src/test/kotlin/com/jasonernst/knet/ip/IpHeaderTest.kt @@ -8,6 +8,12 @@ import java.net.Inet6Address import java.nio.ByteBuffer class IpHeaderTest { + companion object { + fun byteArrayOfInts(vararg ints: Int) = ByteArray(ints.size) { pos -> ints[pos].toByte() } + + fun byteBufferOfInts(vararg ints: Int) = ByteBuffer.wrap(byteArrayOfInts(*ints)) + } + @Test fun tooShortBuffer() { val stream = ByteBuffer.allocate(0) assertThrows { diff --git a/knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4FragmentTest.kt b/knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4FragmentTest.kt new file mode 100644 index 0000000..07b882b --- /dev/null +++ b/knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4FragmentTest.kt @@ -0,0 +1,121 @@ +package com.jasonernst.knet.ip.v4 + +import com.jasonernst.knet.ip.IpHeader.Companion.closestDivisibleBy +import com.jasonernst.knet.ip.IpHeaderTest.Companion.byteArrayOfInts +import com.jasonernst.knet.ip.IpType +import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_MIN_HEADER_LENGTH +import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_WORD_LENGTH +import com.jasonernst.knet.ip.v4.options.Ipv4OptionEndOfOptionList +import com.jasonernst.knet.ip.v4.options.Ipv4OptionNoOperation +import org.junit.jupiter.api.Assertions.assertArrayEquals +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import java.net.InetAddress +import kotlin.math.ceil + +class Ipv4FragmentTest { + @Test + fun fragmentationAndReassembly() { + val payload = byteArrayOfInts(0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09) + val ipv4Header = Ipv4Header(totalLength = (IP4_MIN_HEADER_LENGTH + payload.size.toUShort()).toUShort(), dontFragment = false) + val fragmentSize = closestDivisibleBy(IP4_MIN_HEADER_LENGTH + (payload.size / 2).toUInt(), 8u) + val fragments = ipv4Header.fragment(fragmentSize, payload) + assertEquals(2, fragments.size) + + val reassembly = Ipv4Header.reassemble(fragments) + assertEquals(ipv4Header, reassembly.first) + assertArrayEquals(payload, reassembly.second) + } + + @Test + fun emptyFragments() { + assertThrows { + Ipv4Header.reassemble(emptyList()) + } + } + + @Test + fun singleFragmentWithoutLastFragmentSetToTrue() { + val ipv4Header = Ipv4Header(lastFragment = false) + assertThrows { + Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0)))) + } + } + + @Test + fun singleFragmentLastFragment() { + val payload = byteArrayOfInts(0x01, 0x02, 0x03) + val ipv4Header = Ipv4Header(totalLength = (IP4_MIN_HEADER_LENGTH + payload.size.toUInt()).toUShort(), dontFragment = false) + val fragmentSize = closestDivisibleBy(IP4_MIN_HEADER_LENGTH + 8u, 8u) + val fragmented = ipv4Header.fragment(fragmentSize, payload) + val reassembled = Ipv4Header.reassemble(fragmented) + assertEquals(ipv4Header, reassembled.first) + } + + @Test + fun notDivisbleBy8() { + val ipv4Header = Ipv4Header() + assertThrows { + ipv4Header.fragment(1u, ByteArray(0)) + } + } + + @Test + fun nonMatchingFragmentFields() { + val ipv4Header = Ipv4Header(id = 1u) + val ipv4Header2 = Ipv4Header(id = 2u) + assertThrows { + Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0)), Pair(ipv4Header2, ByteArray(0)))) + } + + val ipv4Header3 = Ipv4Header(id = 1u, protocol = IpType.TCP.value) + assertThrows { + Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0)), Pair(ipv4Header3, ByteArray(0)))) + } + + val ipv4Header4 = Ipv4Header(id = 1u, sourceAddress = InetAddress.getLoopbackAddress()) + assertThrows { + Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0)), Pair(ipv4Header4, ByteArray(0)))) + } + + val ipv4Header5 = Ipv4Header(id = 1u, destinationAddress = InetAddress.getLoopbackAddress()) + assertThrows { + Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0)), Pair(ipv4Header5, ByteArray(0)))) + } + } + + @Test + fun fragmentADontFragment() { + val ipv4Header = Ipv4Header(dontFragment = true) + assertThrows { + ipv4Header.fragment(8u, byteArrayOfInts(0x01, 0x02, 0x03)) + } + } + + @Test + fun fragmentTooSmall() { + val ipv4Header = Ipv4Header(dontFragment = false) + assertThrows { + ipv4Header.fragment(0u, byteArrayOfInts(0x01, 0x02, 0x03)) + } + } + + @Test + fun fragmentWithOptions() { + val options = listOf(Ipv4OptionNoOperation(isCopied = true), Ipv4OptionEndOfOptionList(isCopied = false)) + val optionsLength = options.sumOf { it.size.toInt() } + val payload = ByteArray(16) + val totalHeaderLength = (IP4_MIN_HEADER_LENGTH + optionsLength.toUInt()) + val ihl = ceil(totalHeaderLength.toDouble() / IP4_WORD_LENGTH.toDouble()).toUInt().toUByte() + val totalLength = (((ihl * IP4_WORD_LENGTH) + payload.size.toUInt()).toUShort()) // account for zero padding the header + val ipv4Header = Ipv4Header(ihl = ihl, totalLength = totalLength, dontFragment = false, options = options) + val maxSize = closestDivisibleBy(totalHeaderLength + 8u, 8u) + val fragments = ipv4Header.fragment(maxSize, payload) + assertEquals(2, fragments.size) + + val reassembly = Ipv4Header.reassemble(fragments) + assertEquals(ipv4Header, reassembly.first) + assertArrayEquals(payload, reassembly.second) + } +} diff --git a/knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4HeaderTest.kt b/knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4HeaderTest.kt index 37265cf..b18920a 100644 --- a/knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4HeaderTest.kt +++ b/knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4HeaderTest.kt @@ -4,15 +4,14 @@ import com.jasonernst.icmp_common.Checksum import com.jasonernst.knet.PacketTooShortException import com.jasonernst.knet.ip.IpHeader import com.jasonernst.knet.ip.IpHeader.Companion.IP4_VERSION +import com.jasonernst.knet.ip.IpHeaderTest.Companion.byteBufferOfInts import com.jasonernst.knet.ip.IpType -import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_MIN_FRAGMENT_PAYLOAD import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_MIN_HEADER_LENGTH import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_WORD_LENGTH import com.jasonernst.knet.ip.v4.options.Ipv4OptionNoOperation import com.jasonernst.knet.transport.tcp.TcpHeader import com.jasonernst.knet.transport.tcp.options.TcpOptionEndOfOptionList import com.jasonernst.packetdumper.stringdumper.StringPacketDumper -import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows @@ -27,13 +26,6 @@ class Ipv4HeaderTest { private val logger = LoggerFactory.getLogger(javaClass) private val stringPacketDumper = StringPacketDumper() - private fun byteArrayOfInts(vararg ints: Int) = - ByteArray(ints.size) { pos -> - ints[pos].toByte() - } - - private fun byteBufferOfInts(vararg ints: Int) = ByteBuffer.wrap(byteArrayOfInts(*ints)) - @Test fun ipv4checksumTest2() { val buffer = @@ -319,15 +311,4 @@ class Ipv4HeaderTest { Ipv4Header.fromStream(stream) } } - - @Test fun fragmentationAndReassembly() { - val payload = byteArrayOfInts(0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09) - val ipv4Header = Ipv4Header(totalLength = (IP4_MIN_HEADER_LENGTH + payload.size.toUShort()).toUShort(), dontFragment = false) - val fragments = ipv4Header.fragment(IP4_MIN_HEADER_LENGTH + IP4_MIN_FRAGMENT_PAYLOAD.toUInt(), payload) - assertEquals(2, fragments.size) - - val reassembly = Ipv4Header.reassemble(fragments) - assertEquals(ipv4Header, reassembly.first) - assertArrayEquals(payload, reassembly.second) - } }