Skip to content

Commit

Permalink
increase test coverage, fixes for fragmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
compscidr committed Sep 27, 2024
1 parent ef037e4 commit 29621a3
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 34 deletions.
7 changes: 4 additions & 3 deletions knet/src/main/kotlin/com/jasonernst/knet/ip/IpHeader.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ interface IpHeader {
/**
* Helper function so that we can ensure the payload length is a multiple of 8
*/
fun closestDivisibleBy(initialValue: UInt, divisor: UInt): UInt {
return (initialValue + divisor - 1u) / divisor * divisor
}
fun closestDivisibleBy(
initialValue: UInt,
divisor: UInt,
): UInt = (initialValue + divisor - 1u) / divisor * divisor

fun fromStream(stream: ByteBuffer): IpHeader {
val start = stream.position()
Expand Down
31 changes: 21 additions & 10 deletions knet/src/main/kotlin/com/jasonernst/knet/ip/v4/Ipv4Header.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.experimental.or
import kotlin.math.ceil
import kotlin.math.min

/**
* Internet Protocol Version 4 Header Implementation.
Expand All @@ -29,8 +30,8 @@ data class Ipv4Header(
val dscp: UByte = 0u,
// 2-bits, explicit congestion notification.
val ecn: UByte = 0u,
// 16-bits, IP packet, including the header
private val totalLength: UShort = 0u,
// 16-bits, IP packet, including the header: default to a just the header with no payload
private val totalLength: UShort = IP4_MIN_HEADER_LENGTH.toUShort(),
// 16-bits, groups fragments of a single IPv4 datagram.
val id: UShort = 0u,
// if the packet is marked as don't fragment and we can't fit it in a single packet, drop it.
Expand Down Expand Up @@ -313,9 +314,12 @@ data class Ipv4Header(
*
*/
fun fragment(
maxSize: UInt,
maxSize: UInt, // max size includes the header size
payload: ByteArray,
): List<Pair<Ipv4Header, ByteArray>> {
if (maxSize.toInt() % 8 != 0) {
throw IllegalArgumentException("Fragment max size must be divisible by 8")
}
if (dontFragment) {
throw IllegalArgumentException("Cannot fragment packets marked as don't fragment")
}
Expand All @@ -328,11 +332,17 @@ data class Ipv4Header(
"The smallest fragment size is ${IP4_MIN_FRAGMENT_PAYLOAD.toInt()} bytes because it must align on a 64-bit boundary",
)
}
var payloadPerPacket = closestDivisibleBy(maxSize - getHeaderLength(), 8u)
var lastFragment = false
var payloadPosition = 0u
var payloadPerPacket = min(payload.size - payloadPosition.toInt(), closestDivisibleBy(maxSize - getHeaderLength(), 8u).toInt())
logger.debug("PAYLOAD PER PACKET: $payloadPerPacket HEADERSIZE: ${getHeaderLength()}")
if (payloadPosition.toInt() + payloadPerPacket == payload.size) {
lastFragment = true
}

var isFirstFragment = true
while (payloadPosition < payload.size.toUInt()) {
logger.debug("$payloadPosition:${payloadPosition + payloadPerPacket}")
logger.debug("$payloadPosition:${payloadPosition + payloadPerPacket.toUInt()}")
val offsetIn64BitOctets = payloadPosition / 8u
var newOptions = options
var newIhl = ihl
Expand All @@ -356,10 +366,10 @@ data class Ipv4Header(
ihl = newIhl,
dscp = dscp,
ecn = ecn,
totalLength = (getHeaderLength() + payloadPerPacket).toUShort(),
totalLength = (getHeaderLength() + payloadPerPacket.toUInt()).toUShort(),
id = id,
dontFragment = false,
lastFragment = payloadPosition >= getHeaderLength(),
lastFragment = lastFragment,
fragmentOffset = offsetIn64BitOctets.toUShort(),
ttl = ttl,
protocol = protocol,
Expand All @@ -370,9 +380,10 @@ data class Ipv4Header(
logger.debug("payload len:${newHeader.getPayloadLength()}")
val newPayload = payload.copyOfRange(payloadPosition.toInt(), payloadPosition.toInt() + payloadPerPacket.toInt())
packetList.add(Pair(newHeader, newPayload))
payloadPosition += payloadPerPacket
if (payloadPosition + payloadPerPacket > payload.size.toUInt()) {
payloadPerPacket = payload.size.toUInt() - payloadPosition
payloadPosition += payloadPerPacket.toUInt()
if (payloadPosition + payloadPerPacket.toUInt() > payload.size.toUInt()) {
payloadPerPacket = (payload.size.toUInt() - payloadPosition).toInt()
lastFragment = true
}
}
return packetList
Expand Down
3 changes: 2 additions & 1 deletion knet/src/main/kotlin/com/jasonernst/knet/ip/v6/Ipv6Header.kt
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ data class Ipv6Header(
destinationAddress,
firstHeaderExtensions,
)
val firstPayloadBytes = closestDivisibleBy(maxSize - IP6_HEADER_SIZE - perFragmentHeaderBytes.toUInt() - extAndUpperBytes.toUInt(), 8u)
val firstPayloadBytes =
closestDivisibleBy(maxSize - IP6_HEADER_SIZE - perFragmentHeaderBytes.toUInt() - extAndUpperBytes.toUInt(), 8u)
val firstPair = Triple(firstFragment, nextHeader, payload.sliceArray(0 until firstPayloadBytes.toInt()))
fragments.add(firstPair)
var payloadPosition = firstPayloadBytes.toInt()
Expand Down
6 changes: 6 additions & 0 deletions knet/src/test/kotlin/com/jasonernst/knet/ip/IpHeaderTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import java.net.Inet6Address
import java.nio.ByteBuffer

class IpHeaderTest {
companion object {
fun byteArrayOfInts(vararg ints: Int) = ByteArray(ints.size) { pos -> ints[pos].toByte() }

fun byteBufferOfInts(vararg ints: Int) = ByteBuffer.wrap(byteArrayOfInts(*ints))
}

@Test fun tooShortBuffer() {
val stream = ByteBuffer.allocate(0)
assertThrows<PacketTooShortException> {
Expand Down
121 changes: 121 additions & 0 deletions knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4FragmentTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package com.jasonernst.knet.ip.v4

import com.jasonernst.knet.ip.IpHeader.Companion.closestDivisibleBy
import com.jasonernst.knet.ip.IpHeaderTest.Companion.byteArrayOfInts
import com.jasonernst.knet.ip.IpType
import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_MIN_HEADER_LENGTH
import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_WORD_LENGTH
import com.jasonernst.knet.ip.v4.options.Ipv4OptionEndOfOptionList
import com.jasonernst.knet.ip.v4.options.Ipv4OptionNoOperation
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.net.InetAddress
import kotlin.math.ceil

class Ipv4FragmentTest {
@Test
fun fragmentationAndReassembly() {
val payload = byteArrayOfInts(0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09)
val ipv4Header = Ipv4Header(totalLength = (IP4_MIN_HEADER_LENGTH + payload.size.toUShort()).toUShort(), dontFragment = false)
val fragmentSize = closestDivisibleBy(IP4_MIN_HEADER_LENGTH + (payload.size / 2).toUInt(), 8u)
val fragments = ipv4Header.fragment(fragmentSize, payload)
assertEquals(2, fragments.size)

val reassembly = Ipv4Header.reassemble(fragments)
assertEquals(ipv4Header, reassembly.first)
assertArrayEquals(payload, reassembly.second)
}

@Test
fun emptyFragments() {
assertThrows<IllegalArgumentException> {
Ipv4Header.reassemble(emptyList())
}
}

@Test
fun singleFragmentWithoutLastFragmentSetToTrue() {
val ipv4Header = Ipv4Header(lastFragment = false)
assertThrows<IllegalArgumentException> {
Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0))))
}
}

@Test
fun singleFragmentLastFragment() {
val payload = byteArrayOfInts(0x01, 0x02, 0x03)
val ipv4Header = Ipv4Header(totalLength = (IP4_MIN_HEADER_LENGTH + payload.size.toUInt()).toUShort(), dontFragment = false)
val fragmentSize = closestDivisibleBy(IP4_MIN_HEADER_LENGTH + 8u, 8u)
val fragmented = ipv4Header.fragment(fragmentSize, payload)
val reassembled = Ipv4Header.reassemble(fragmented)
assertEquals(ipv4Header, reassembled.first)
}

@Test
fun notDivisbleBy8() {
val ipv4Header = Ipv4Header()
assertThrows<IllegalArgumentException> {
ipv4Header.fragment(1u, ByteArray(0))
}
}

@Test
fun nonMatchingFragmentFields() {
val ipv4Header = Ipv4Header(id = 1u)
val ipv4Header2 = Ipv4Header(id = 2u)
assertThrows<IllegalArgumentException> {
Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0)), Pair(ipv4Header2, ByteArray(0))))
}

val ipv4Header3 = Ipv4Header(id = 1u, protocol = IpType.TCP.value)
assertThrows<IllegalArgumentException> {
Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0)), Pair(ipv4Header3, ByteArray(0))))
}

val ipv4Header4 = Ipv4Header(id = 1u, sourceAddress = InetAddress.getLoopbackAddress())
assertThrows<IllegalArgumentException> {
Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0)), Pair(ipv4Header4, ByteArray(0))))
}

val ipv4Header5 = Ipv4Header(id = 1u, destinationAddress = InetAddress.getLoopbackAddress())
assertThrows<IllegalArgumentException> {
Ipv4Header.reassemble(listOf(Pair(ipv4Header, ByteArray(0)), Pair(ipv4Header5, ByteArray(0))))
}
}

@Test
fun fragmentADontFragment() {
val ipv4Header = Ipv4Header(dontFragment = true)
assertThrows<IllegalArgumentException> {
ipv4Header.fragment(8u, byteArrayOfInts(0x01, 0x02, 0x03))
}
}

@Test
fun fragmentTooSmall() {
val ipv4Header = Ipv4Header(dontFragment = false)
assertThrows<IllegalArgumentException> {
ipv4Header.fragment(0u, byteArrayOfInts(0x01, 0x02, 0x03))
}
}

@Test
fun fragmentWithOptions() {
val options = listOf(Ipv4OptionNoOperation(isCopied = true), Ipv4OptionEndOfOptionList(isCopied = false))
val optionsLength = options.sumOf { it.size.toInt() }
val payload = ByteArray(16)
val totalHeaderLength = (IP4_MIN_HEADER_LENGTH + optionsLength.toUInt())
val ihl = ceil(totalHeaderLength.toDouble() / IP4_WORD_LENGTH.toDouble()).toUInt().toUByte()
val totalLength = (((ihl * IP4_WORD_LENGTH) + payload.size.toUInt()).toUShort()) // account for zero padding the header
val ipv4Header = Ipv4Header(ihl = ihl, totalLength = totalLength, dontFragment = false, options = options)
val maxSize = closestDivisibleBy(totalHeaderLength + 8u, 8u)
val fragments = ipv4Header.fragment(maxSize, payload)
assertEquals(2, fragments.size)

val reassembly = Ipv4Header.reassemble(fragments)
assertEquals(ipv4Header, reassembly.first)
assertArrayEquals(payload, reassembly.second)
}
}
21 changes: 1 addition & 20 deletions knet/src/test/kotlin/com/jasonernst/knet/ip/v4/Ipv4HeaderTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ import com.jasonernst.icmp_common.Checksum
import com.jasonernst.knet.PacketTooShortException
import com.jasonernst.knet.ip.IpHeader
import com.jasonernst.knet.ip.IpHeader.Companion.IP4_VERSION
import com.jasonernst.knet.ip.IpHeaderTest.Companion.byteBufferOfInts
import com.jasonernst.knet.ip.IpType
import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_MIN_FRAGMENT_PAYLOAD
import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_MIN_HEADER_LENGTH
import com.jasonernst.knet.ip.v4.Ipv4Header.Companion.IP4_WORD_LENGTH
import com.jasonernst.knet.ip.v4.options.Ipv4OptionNoOperation
import com.jasonernst.knet.transport.tcp.TcpHeader
import com.jasonernst.knet.transport.tcp.options.TcpOptionEndOfOptionList
import com.jasonernst.packetdumper.stringdumper.StringPacketDumper
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
Expand All @@ -27,13 +26,6 @@ class Ipv4HeaderTest {
private val logger = LoggerFactory.getLogger(javaClass)
private val stringPacketDumper = StringPacketDumper()

private fun byteArrayOfInts(vararg ints: Int) =
ByteArray(ints.size) { pos ->
ints[pos].toByte()
}

private fun byteBufferOfInts(vararg ints: Int) = ByteBuffer.wrap(byteArrayOfInts(*ints))

@Test
fun ipv4checksumTest2() {
val buffer =
Expand Down Expand Up @@ -319,15 +311,4 @@ class Ipv4HeaderTest {
Ipv4Header.fromStream(stream)
}
}

@Test fun fragmentationAndReassembly() {
val payload = byteArrayOfInts(0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09)
val ipv4Header = Ipv4Header(totalLength = (IP4_MIN_HEADER_LENGTH + payload.size.toUShort()).toUShort(), dontFragment = false)
val fragments = ipv4Header.fragment(IP4_MIN_HEADER_LENGTH + IP4_MIN_FRAGMENT_PAYLOAD.toUInt(), payload)
assertEquals(2, fragments.size)

val reassembly = Ipv4Header.reassemble(fragments)
assertEquals(ipv4Header, reassembly.first)
assertArrayEquals(payload, reassembly.second)
}
}

0 comments on commit 29621a3

Please sign in to comment.