Skip to content

Commit

Permalink
Added ipv6 fragmentation and reassembly
Browse files Browse the repository at this point in the history
  • Loading branch information
compscidr committed Sep 27, 2024
1 parent b688823 commit 618a781
Show file tree
Hide file tree
Showing 9 changed files with 502 additions and 17 deletions.
196 changes: 196 additions & 0 deletions knet/src/main/kotlin/com/jasonernst/knet/ip/v6/Ipv6Header.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import com.jasonernst.knet.ip.IpHeader
import com.jasonernst.knet.ip.IpHeader.Companion.IP6_VERSION
import com.jasonernst.knet.ip.IpType
import com.jasonernst.knet.ip.v6.extenions.Ipv6ExtensionHeader
import com.jasonernst.knet.ip.v6.extenions.Ipv6Fragment
import com.jasonernst.knet.nextheader.NextHeader
import org.slf4j.LoggerFactory
import java.net.Inet6Address
import java.net.InetAddress
Expand Down Expand Up @@ -37,6 +39,17 @@ data class Ipv6Header(
private val logger = LoggerFactory.getLogger(Ipv6Header::class.java)
private const val IP6_HEADER_SIZE: UShort = 40u // ipv6 header is not variable like ipv4

// The Per-Fragment headers must consist of the IPv6 header plus any
// extension headers that must be processed by nodes en route to the
// destination, that is, all headers up to and including the Routing
// header if present
private val onRouteHeaders =
listOf(
IpType.HOPOPT,
IpType.IPV6_OPTS,
IpType.IPV6_ROUTE,
)

fun fromStream(stream: ByteBuffer): Ipv6Header {
val start = stream.position()

Expand Down Expand Up @@ -93,6 +106,189 @@ data class Ipv6Header(
extensionHeaders,
)
}

fun reassemble(fragments: List<Triple<Ipv6Header, NextHeader?, ByteArray>>): Triple<Ipv6Header, NextHeader, ByteArray> {
if (fragments.isEmpty()) {
throw IllegalArgumentException("No fragments to reassemble")
}

val extensionHeaders = mutableListOf<Ipv6ExtensionHeader>()
for (extensionHeader in fragments[0].first.extensionHeaders) {
if (extensionHeader.type == IpType.IPV6_FRAG) {
continue
}
extensionHeaders.add(extensionHeader)
}
val payloadLength =
fragments.sumOf {
it.third.size
} +
extensionHeaders.sumOf {
it.getExtensionLengthInBytes()
} + fragments[0].second!!.getHeaderLength().toInt()

val ipv6Header =
Ipv6Header(
version = fragments[0].first.version,
trafficClass = fragments[0].first.trafficClass,
flowLabel = fragments[0].first.flowLabel,
payloadLength = payloadLength.toUShort(),
protocol = fragments[0].first.protocol,
hopLimit = fragments[0].first.hopLimit,
sourceAddress = fragments[0].first.sourceAddress,
destinationAddress = fragments[0].first.destinationAddress,
extensionHeaders = extensionHeaders,
)

val payload =
fragments
.flatMap {
it.third.toList()
}.toByteArray()

val nextHeader = fragments.first().second

return Triple(ipv6Header, nextHeader!!, payload)
}
}

/**
* Fragments this ipv6 header into smaller fragments. The fragments have:
* 1) an Ipv6 header with different sets of extension headers, depending on if its the first
* fragment or not
* 2) a next header that is either the next header in the original packet or null if it is a
* fragment
* 3) a payload that is a subset of the original payload
*/
fun fragment(
maxSize: UInt,
nextHeader: NextHeader,
payload: ByteArray,
): List<Triple<Ipv6Header, NextHeader?, ByteArray>> {
if (maxSize.toInt() % 8 != 0) {
throw IllegalArgumentException("Max size must be a multiple of 8")
}

val fragments = mutableListOf<Triple<Ipv6Header, NextHeader?, ByteArray>>()

val perFragmentHeaders = mutableListOf<Ipv6ExtensionHeader>()

// need to figure out type because this could be a mix of extension and upper layer headers (tcp)
val nonPerFragmentExtensionHeaders = mutableListOf<Ipv6ExtensionHeader>()

// up to an including the routing header if they exist
for (extensionHeader in extensionHeaders) {
if (onRouteHeaders.contains(extensionHeader.type)) {
perFragmentHeaders.add(extensionHeader)
} else {
nonPerFragmentExtensionHeaders.add(extensionHeader)
}
}

if (perFragmentHeaders.isNotEmpty()) {
perFragmentHeaders.last().nextHeader = IpType.IPV6_FRAG.value
}

val perFragmentHeaderBytes =
perFragmentHeaders.sumOf {
it.getExtensionLengthInBytes()
}
val extAndUpperBytes =
nonPerFragmentExtensionHeaders.sumOf {
it.getExtensionLengthInBytes()
} + nextHeader.getHeaderLength().toInt()

val fragmentHeaderNextHeader =
if (nonPerFragmentExtensionHeaders.isEmpty()) {
nextHeader.protocol
} else {
nonPerFragmentExtensionHeaders.first().type.value
}

val firstFragmentHeader =
Ipv6Fragment(
nextHeader = fragmentHeaderNextHeader,
fragmentOffset = 0u,
moreFlag = true,
identification = Ipv6Fragment.globalIdentificationCounter++,
)

val firstHeaderExtensions = mutableListOf<Ipv6ExtensionHeader>()
firstHeaderExtensions.addAll(perFragmentHeaders)
firstHeaderExtensions.add(firstFragmentHeader)
firstHeaderExtensions.addAll(nonPerFragmentExtensionHeaders)

val firstFragment =
Ipv6Header(
version,
trafficClass,
flowLabel,
(maxSize - IP6_HEADER_SIZE).toUShort(),
if (nonPerFragmentExtensionHeaders.isEmpty()) {
protocol
} else {
IpType.IPV6_FRAG.value
},
hopLimit,
sourceAddress,
destinationAddress,
firstHeaderExtensions,
)
val firstPayloadBytes = maxSize - IP6_HEADER_SIZE - perFragmentHeaderBytes.toUInt() - extAndUpperBytes.toUInt()
val firstPair = Triple(firstFragment, nextHeader, payload.sliceArray(0 until firstPayloadBytes.toInt()))
fragments.add(firstPair)
var payloadPosition = firstPayloadBytes.toInt()

while (payloadPosition < payload.size) {
val nextPayloadBytes =
minOf(
(payload.size - payloadPosition).toUInt(),
maxSize - IP6_HEADER_SIZE - perFragmentHeaderBytes.toUInt(),
)
val moreFlag = nextPayloadBytes >= maxSize - IP6_HEADER_SIZE - perFragmentHeaderBytes.toUInt()

val nextFragment =
Ipv6Fragment(
nextHeader = fragmentHeaderNextHeader,
fragmentOffset = (payloadPosition / 8).toUShort(),
moreFlag = moreFlag,
identification = firstFragmentHeader.identification,
)

val nextHeaderExtensions = mutableListOf<Ipv6ExtensionHeader>()
nextHeaderExtensions.addAll(perFragmentHeaders)
nextHeaderExtensions.add(nextFragment)

val nextFragmentHeader =
Ipv6Header(
version,
trafficClass,
flowLabel,
(maxSize - IP6_HEADER_SIZE).toUShort(),
if (nonPerFragmentExtensionHeaders.isEmpty()) {
protocol
} else {
IpType.IPV6_FRAG.value
},
hopLimit,
sourceAddress,
destinationAddress,
nextHeaderExtensions,
)

val nextPair =
Triple(
nextFragmentHeader,
null,
payload.sliceArray(
payloadPosition until (payloadPosition + nextPayloadBytes.toInt()),
),
)
fragments.add(nextPair)
payloadPosition += nextPayloadBytes.toInt()
}

return fragments
}

override fun toByteArray(order: ByteOrder): ByteArray {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package com.jasonernst.knet.ip.v6.extenions

import com.jasonernst.knet.ip.IpType
import java.nio.ByteBuffer

/**
* https://datatracker.ietf.org/doc/html/rfc4302
*/
class Ipv6Authentication(
override val nextHeader: UByte,
override var nextHeader: UByte,
override val length: UByte,
) : Ipv6ExtensionHeader(nextHeader = nextHeader, length = length) {
) : Ipv6ExtensionHeader(IpType.AH, nextHeader = nextHeader, length = length) {
companion object {
// nextheader, length, reserved, SPI, sequence number, ICV
const val MIN_LENGTH = 20
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import java.nio.ByteBuffer
import java.nio.ByteOrder

data class Ipv6DestinationOptions(
override val nextHeader: UByte = IpType.TCP.value,
override var nextHeader: UByte = IpType.TCP.value,
override val length: UByte = 0u,
val optionData: List<Ipv6Tlv> = emptyList(),
) : Ipv6ExtensionHeader(nextHeader, length) {
) : Ipv6ExtensionHeader(IpType.IPV6_OPTS, nextHeader, length) {
companion object {
const val MIN_LENGTH = 2 // next header and length with no actual option data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import java.nio.ByteOrder
* https://datatracker.ietf.org/doc/html/rfc4303
*/
class Ipv6EncapsulatingSecurityPayload(
override val nextHeader: UByte = IpType.TCP.value,
override var nextHeader: UByte = IpType.TCP.value,
override val length: UByte = MIN_LENGTH,
) : Ipv6ExtensionHeader(nextHeader = nextHeader, length = length) {
) : Ipv6ExtensionHeader(IpType.ESP, nextHeader = nextHeader, length = length) {
companion object {
const val MIN_LENGTH: UByte = 2u // next header and length with no actual option data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ import java.nio.ByteOrder
* that recommendation.
*/
open class Ipv6ExtensionHeader(
open val nextHeader: UByte,
val type: IpType,
open var nextHeader: UByte,
open val length: UByte, // measured in 64-bit / 8-octet units
) {
/**
Expand Down
Loading

0 comments on commit 618a781

Please sign in to comment.