From 72fc2789aeb9030b442a1f30134ae8544494f703 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Wed, 18 Sep 2024 14:30:20 +0300 Subject: [PATCH] Account for SOCKADDR in control packets In multipeer UDP mode, we expect userspace to prepend CC packets with SOCKADDR to know where to send the control packet. Likewise, when we receive the control packet, we prepend it with remote SOCKADDR before pushing to userspace. https://github.com/OpenVPN/ovpn-dco-win/issues/84 Co-authored-by: Leon Dang Signed-off-by: Leon Dang Signed-off-by: Lev Stipakov --- Driver.cpp | 62 +++++++++++++++++++++++++++++++++++++---------- socket.cpp | 69 +++++++++++++++++++++++++++++++++++++++-------------- socket.h | 2 +- timer.cpp | 2 +- txqueue.cpp | 4 ++-- 5 files changed, 104 insertions(+), 35 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index 80d6368..f793fb6 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -187,7 +187,7 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length) // acquire spinlock, since we access device->TransportSocket KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); - OVPN_TX_BUFFER* buffer = NULL; + OVPN_TX_BUFFER* txBuf = NULL; if (device->Socket.Socket == NULL) { status = STATUS_INVALID_DEVICE_STATE; @@ -195,31 +195,67 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length) goto error; } - // fetch tx buffer - GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnTxBufferPoolGet(device->TxBufferPool, &buffer)); - // get request buffer - PVOID requestBuffer; - size_t requestBufferLength; - GOTO_IF_NOT_NT_SUCCESS(error, status, WdfRequestRetrieveInputBuffer(request, 0, &requestBuffer, &requestBufferLength)); + PVOID buf; + size_t bufLen; + GOTO_IF_NOT_NT_SUCCESS(error, status, WdfRequestRetrieveInputBuffer(request, 0, &buf, &bufLen)); + + PSOCKADDR sa = NULL; + + if (device->Mode == OVPN_MODE_MP) { + // buffer is prepended with SOCKADDR + + sa = (PSOCKADDR)buf; + switch (sa->sa_family) { + case AF_INET: + if (bufLen <= sizeof(SOCKADDR_IN)) { + status = STATUS_INVALID_MESSAGE; + LOG_ERROR("Message too short", TraceLoggingValue(bufLen, "msgLen"), TraceLoggingValue(sizeof(SOCKADDR_IN), "minLen")); + goto error; + } + + buf = (char*)buf + sizeof(SOCKADDR_IN); + bufLen -= sizeof(SOCKADDR_IN); + break; + + case AF_INET6: + if (bufLen <= sizeof(SOCKADDR_IN6)) { + status = STATUS_INVALID_MESSAGE; + LOG_ERROR("Message too short", TraceLoggingValue(bufLen, "msgLen"), TraceLoggingValue(sizeof(SOCKADDR_IN6), "minLen")); + goto error; + } + + buf = (char*)buf + sizeof(SOCKADDR_IN6); + bufLen -= sizeof(SOCKADDR_IN6); + break; + + default: + LOG_ERROR("Invalid address family", TraceLoggingValue(sa->sa_family, "AF")); + status = STATUS_INVALID_ADDRESS; + goto error; + } + } + + // fetch tx buffer + GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnTxBufferPoolGet(device->TxBufferPool, &txBuf)); // copy data from request to tx buffer - PUCHAR buf = OvpnTxBufferPut(buffer, requestBufferLength); - RtlCopyMemory(buf, requestBuffer, requestBufferLength); + PUCHAR data = OvpnTxBufferPut(txBuf, bufLen); + RtlCopyMemory(data, buf, bufLen); - buffer->IoQueue = device->PendingWritesQueue; + txBuf->IoQueue = device->PendingWritesQueue; // move request to manual queue GOTO_IF_NOT_NT_SUCCESS(error, status, WdfRequestForwardToIoQueue(request, device->PendingWritesQueue)); // send - LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer)); + LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, txBuf, sa)); goto done_not_complete; error: - if (buffer != NULL) { - OvpnTxBufferPoolPut(buffer); + if (txBuf != NULL) { + OvpnTxBufferPoolPut(txBuf); } ULONG_PTR bytesCopied = 0; diff --git a/socket.cpp b/socket.cpp index 70987b9..bebf26e 100644 --- a/socket.cpp +++ b/socket.cpp @@ -97,8 +97,29 @@ OvpnSocketSyncOp(_In_z_ CHAR* opName, OP op, SUCCESS success) static _Requires_shared_lock_held_(device->SpinLock) VOID -OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR buf, SIZE_T len) +OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR buf, SIZE_T len, _In_opt_ PSOCKADDR remote) { + SIZE_T hdrLen = 0, totalLen = len; + + // in UDP and MP mode we prepend CC packet with remote sockaddr before pushing it to userspace + if (device->Mode == OVPN_MODE_MP && remote != NULL) { + switch (remote->sa_family) { + case AF_INET: + hdrLen = sizeof(SOCKADDR_IN); + break; + + case AF_INET6: + hdrLen = sizeof(SOCKADDR_IN6); + break; + + default: + LOG_ERROR("Invalid remote address family", TraceLoggingValue(remote->sa_family, "AF")); + InterlockedIncrementNoFence(&device->Stats.LostInControlPackets); + return; + } + totalLen += hdrLen; + } + WDFREQUEST request; NTSTATUS status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request); if (!NT_SUCCESS(status)) { @@ -113,17 +134,22 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR return; } - if (sizeof(buffer->Data) >= len) { - // copy control packet to buffer - RtlCopyMemory(buffer->Data, buf, len); - buffer->Len = len; + if (sizeof(buffer->Data) >= totalLen) { + if (hdrLen > 0) { + // prepend with sockaddr + RtlCopyMemory(buffer->Data, remote, hdrLen); + } + + // copy control packet payload + RtlCopyMemory(buffer->Data + hdrLen, buf, totalLen - hdrLen); + buffer->Len = totalLen; // enqueue buffer, it will be dequeued when read request arrives OvpnBufferQueueEnqueue(device->ControlRxBufferQueue, &buffer->QueueListEntry); } else { LOG_ERROR("Buffer too small, packet len , buf len ", - TraceLoggingValue(len, "pktlen"), TraceLoggingValue(sizeof(buffer->Data), "buflen")); + TraceLoggingValue(totalLen, "pktlen"), TraceLoggingValue(sizeof(buffer->Data), "buflen")); OvpnRxBufferPoolPut(buffer); } @@ -133,19 +159,26 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR PVOID readBuffer; size_t readBufferLength; - ULONG_PTR bytesSent = len; + ULONG_PTR bytesSent = totalLen; - LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, len, &readBuffer, &readBufferLength)); + LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, totalLen, &readBuffer, &readBufferLength)); if (NT_SUCCESS(status)) { - // copy control packet to read request buffer - RtlCopyMemory(readBuffer, buf, len); + + if (hdrLen > 0) { + // prepend with sockaddr + RtlCopyMemory(readBuffer, remote, hdrLen); + } + + // copy control packet payload + RtlCopyMemory((PCHAR)readBuffer + hdrLen, buf, totalLen - hdrLen); + InterlockedIncrementNoFence(&device->Stats.ReceivedControlPackets); } else { InterlockedIncrementNoFence(&device->Stats.LostInControlPackets); if (status == STATUS_BUFFER_TOO_SMALL) { LOG_ERROR("Buffer too small, packet len , buf len ", - TraceLoggingValue(len, "pktlen"), TraceLoggingValue(readBufferLength, "buflen")); + TraceLoggingValue(totalLen, "pktlen"), TraceLoggingValue(readBufferLength, "buflen")); } bytesSent = 0; @@ -238,7 +271,7 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ } VOID -OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLength) PUCHAR buf, SIZE_T packetLength, BOOLEAN irqlDispatch) +OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLength) PUCHAR buf, SIZE_T packetLength, BOOLEAN irqlDispatch, _In_opt_ PSOCKADDR remoteAddr) { // If we're at dispatch level, we can use a small optimization and use function // which is not calling KeRaiseIRQL to raise the IRQL to DISPATCH_LEVEL before attempting to acquire the lock @@ -255,7 +288,7 @@ OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLengt OvpnSocketDataPacketReceived(device, op, buf, packetLength); } else { - OvpnSocketControlPacketReceived(device, buf, packetLength); + OvpnSocketControlPacketReceived(device, buf, packetLength, remoteAddr); } // don't forget to release spinlock @@ -330,7 +363,7 @@ OvpnSocketUdpReceiveFromEvent(_In_ PVOID socketContext, ULONG flags, _In_opt_ PW buf = packetBuf; } - OvpnSocketProcessIncomingPacket(device, buf, dataIndication->Buffer.Length, flags & WSK_FLAG_AT_DISPATCH_LEVEL); + OvpnSocketProcessIncomingPacket(device, buf, dataIndication->Buffer.Length, flags & WSK_FLAG_AT_DISPATCH_LEVEL, dataIndication->RemoteAddress); dataIndication = dataIndication->Next; } @@ -412,7 +445,7 @@ OvpnSocketTcpReceiveEvent(_In_opt_ PVOID socketContext, _In_ ULONG flags, _In_op buf = tcpState->PacketBuf; } - OvpnSocketProcessIncomingPacket(device, buf, tcpState->PacketLength, flags & WSK_FLAG_AT_DISPATCH_LEVEL); + OvpnSocketProcessIncomingPacket(device, buf, tcpState->PacketLength, flags & WSK_FLAG_AT_DISPATCH_LEVEL, NULL); mdlDataLen -= bytesRemained; dataIndicationLen -= bytesRemained; @@ -704,7 +737,7 @@ OvpnSocketSendComplete(_In_ PDEVICE_OBJECT deviceObj, _In_ PIRP irp, _In_ PVOID NTSTATUS _Use_decl_annotations_ -OvpnSocketSend(OvpnSocket* ovpnSocket, OVPN_TX_BUFFER* buffer) { +OvpnSocketSend(OvpnSocket* ovpnSocket, OVPN_TX_BUFFER* buffer, SOCKADDR* sa) { OVPN_DEVICE* device = (OVPN_DEVICE*)OvpnTxBufferPoolGetContext(buffer->Pool); PWSK_SOCKET socket = ovpnSocket->Socket; @@ -742,11 +775,11 @@ OvpnSocketSend(OvpnSocket* ovpnSocket, OVPN_TX_BUFFER* buffer) { } else if (buffer->WskBufList.Buffer.Length != 0) { PWSK_PROVIDER_DATAGRAM_DISPATCH datagramDispatch = (PWSK_PROVIDER_DATAGRAM_DISPATCH)socket->Dispatch; - LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendMessages(socket, &buffer->WskBufList, 0, NULL, 0, NULL, irp)); + LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendMessages(socket, &buffer->WskBufList, 0, sa, 0, NULL, irp)); } else { WSK_BUF wskBuf{ buffer->Mdl, FIELD_OFFSET(OVPN_TX_BUFFER, Head) + (ULONG)(buffer->Data - buffer->Head), buffer->Len }; PWSK_PROVIDER_DATAGRAM_DISPATCH datagramDispatch = (PWSK_PROVIDER_DATAGRAM_DISPATCH)socket->Dispatch; - LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendTo(socket, &wskBuf, 0, NULL, 0, NULL, irp)); + LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendTo(socket, &wskBuf, 0, sa, 0, NULL, irp)); } return status; diff --git a/socket.h b/socket.h index f1df149..0a6fc34 100644 --- a/socket.h +++ b/socket.h @@ -70,7 +70,7 @@ OvpnSocketClose(_In_opt_ PWSK_SOCKET socket); _Must_inspect_result_ NTSTATUS -OvpnSocketSend(_In_ OvpnSocket* ovpnSocket, _In_ OVPN_TX_BUFFER* buffer); +OvpnSocketSend(_In_ OvpnSocket* ovpnSocket, _In_ OVPN_TX_BUFFER* buffer, _In_opt_ SOCKADDR* sa); _Must_inspect_result_ NTSTATUS diff --git a/timer.cpp b/timer.cpp index 95eeaa1..5c15ab9 100644 --- a/timer.cpp +++ b/timer.cpp @@ -91,7 +91,7 @@ static VOID OvpnTimerXmit(WDFTIMER timer) if (NT_SUCCESS(status)) { // start async send, completion handler will return ciphertext buffer to the pool - LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer)); + LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer, NULL)); if (NT_SUCCESS(status)) { LOG_INFO("Ping sent"); } diff --git a/txqueue.cpp b/txqueue.cpp index feec4c5..2ae629f 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -117,7 +117,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET if (NT_SUCCESS(status)) { // start async send, this will return ciphertext buffer to the pool if (device->Socket.Tcp) { - status = OvpnSocketSend(&device->Socket, buffer); + status = OvpnSocketSend(&device->Socket, buffer, NULL); } else { // for UDP we use SendMessages to send multiple datagrams at once @@ -195,7 +195,7 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) if (!device->Socket.Tcp) { // this will use WskSendMessages to send buffers list which we constructed before - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead)); + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, NULL)); } } }