diff --git a/Driver.cpp b/Driver.cpp index 51f2df6..a9e9b5f 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -316,6 +316,7 @@ OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) // those IOCTLs are for MP mode case OVPN_IOCTL_MP_START_VPN: case OVPN_IOCTL_MP_NEW_PEER: + case OVPN_IOCTL_MP_SET_PEER: return FALSE; } } @@ -503,6 +504,12 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe status = OvpnMPPeerNew(device, request); break; + case OVPN_IOCTL_MP_SET_PEER: + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + status = OvpnMPPeerSet(device, request); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + break; + default: LOG_WARN("Unknown ", TraceLoggingValue(ioControlCode, "ioControlCode")); status = STATUS_INVALID_DEVICE_REQUEST; diff --git a/Driver.h b/Driver.h index 745854c..0f60608 100644 --- a/Driver.h +++ b/Driver.h @@ -81,10 +81,6 @@ struct OVPN_DEVICE { BCRYPT_ALG_HANDLE AesAlgHandle; BCRYPT_ALG_HANDLE ChachaAlgHandle; - // set from the userspace, defines TCP Maximum Segment Size - _Guarded_by_(SpinLock) - UINT16 MSS; - _Guarded_by_(SpinLock) OvpnSocket Socket; diff --git a/peer.cpp b/peer.cpp index 6a140b4..6ba9a8c 100644 --- a/peer.cpp +++ b/peer.cpp @@ -282,6 +282,13 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) peerCtx->PeerId = peer->PeerId; + // create peer-specific timer + LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); + if (status != STATUS_SUCCESS) { + OvpnPeerCtxFree(peerCtx); + goto done; + } + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeer(device, peerCtx)); if (status == STATUS_SUCCESS) { @@ -304,6 +311,27 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) return status; } +VOID OvpnPeerSetDoWork(OvpnPeerContext *peer, LONG keepaliveInterval, LONG keepaliveTimeout, LONG mss) +{ + if (mss != -1) { + peer->MSS = (UINT16)mss; + } + + if (keepaliveInterval != -1) { + peer->KeepaliveInterval = keepaliveInterval; + + // keepalive xmit timer, sends ping packets + OvpnTimerSetXmitInterval(peer->Timer, peer->KeepaliveInterval); + } + + if (keepaliveTimeout != -1) { + peer->KeepaliveTimeout = keepaliveTimeout; + + // keepalive recv timer, detects keepalive timeout + OvpnTimerSetRecvTimeout(peer->Timer, peer->KeepaliveTimeout); + } +} + _Use_decl_annotations_ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) { @@ -326,24 +354,37 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) TraceLoggingValue(set_peer->KeepaliveTimeout, "timeout"), TraceLoggingValue(set_peer->MSS, "MSS")); - if (set_peer->MSS != -1) { - device->MSS = (UINT16)set_peer->MSS; - } + OvpnPeerSetDoWork(peer, set_peer->KeepaliveInterval, set_peer->KeepaliveTimeout, set_peer->MSS); - if (set_peer->KeepaliveInterval != -1) { - peer->KeepaliveInterval = set_peer->KeepaliveInterval; +done: + LOG_EXIT(); + return status; +} - // keepalive xmit timer, sends ping packets - OvpnTimerSetXmitInterval(peer->Timer, peer->KeepaliveInterval); - } +_Use_decl_annotations_ +NTSTATUS OvpnMPPeerSet(POVPN_DEVICE device, WDFREQUEST request) +{ + LOG_ENTER(); + + NTSTATUS status = STATUS_SUCCESS; - if (peer->KeepaliveTimeout != -1) { - peer->KeepaliveTimeout = set_peer->KeepaliveTimeout; + POVPN_MP_SET_PEER set_peer = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_SET_PEER), (PVOID*)&set_peer, nullptr)); - // keepalive recv timer, detects keepalive timeout - OvpnTimerSetRecvTimeout(peer->Timer, peer->KeepaliveTimeout); + LOG_INFO("MP Set peer", TraceLoggingValue(set_peer->PeerId, "peer-id"), + TraceLoggingValue(set_peer->KeepaliveInterval, "interval"), + TraceLoggingValue(set_peer->KeepaliveTimeout, "timeout"), + TraceLoggingValue(set_peer->MSS, "MSS")); + + OvpnPeerContext* peer = OvpnFindPeer(device, set_peer->PeerId); + if (peer == NULL) { + LOG_ERROR("Peer not found", TraceLoggingValue(set_peer->PeerId, "peer-id")); + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; } + OvpnPeerSetDoWork(peer, set_peer->KeepaliveInterval, set_peer->KeepaliveTimeout, set_peer->MSS); + done: LOG_EXIT(); return status; diff --git a/peer.h b/peer.h index 907d44b..00a1585 100644 --- a/peer.h +++ b/peer.h @@ -42,6 +42,8 @@ struct OvpnPeerContext // 1-sec timer which handles ping intervals and keepalive timeouts WDFTIMER Timer; + UINT16 MSS; + struct { IN_ADDR IPv4; IN6_ADDR IPv6; @@ -91,6 +93,11 @@ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS OvpnPeerSet(_In_ POVPN_DEVICE device, WDFREQUEST request); +_Must_inspect_result_ +_Requires_exclusive_lock_held_(device->SpinLock) +NTSTATUS +OvpnMPPeerSet(_In_ POVPN_DEVICE device, WDFREQUEST request); + _Must_inspect_result_ NTSTATUS _Requires_shared_lock_held_(device->SpinLock) diff --git a/socket.cpp b/socket.cpp index 2b8653c..5281207 100644 --- a/socket.cpp +++ b/socket.cpp @@ -254,16 +254,16 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee // ping packet? if (OvpnTimerIsKeepaliveMessage(buffer->Data, buffer->Len)) { - LOG_INFO("Ping received"); + LOG_INFO("Ping received", TraceLoggingValue(peer->PeerId, "peer-id")); // no need to inject ping packet into OS, return buffer to the pool OvpnRxBufferPoolPut(buffer); } else { if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { - OvpnMssDoIPv4(buffer->Data, buffer->Len, device->MSS); + OvpnMssDoIPv4(buffer->Data, buffer->Len, peer->MSS); } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { - OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); + OvpnMssDoIPv6(buffer->Data, buffer->Len, peer->MSS); } // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath diff --git a/timer.cpp b/timer.cpp index dc409c9..d67e158 100644 --- a/timer.cpp +++ b/timer.cpp @@ -91,9 +91,10 @@ static VOID OvpnTimerXmit(WDFTIMER timer) if (NT_SUCCESS(status)) { // start async send, completion handler will return ciphertext buffer to the pool - LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer, NULL)); + SOCKADDR* sa = (SOCKADDR*)&(peer->TransportAddrs.Remote); + LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer, sa)); if (NT_SUCCESS(status)) { - LOG_INFO("Ping sent"); + LOG_INFO("Ping sent", TraceLoggingValue(peer->PeerId, "peer-id")); } } else { diff --git a/txqueue.cpp b/txqueue.cpp index f0cf628..dd1b063 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -78,18 +78,24 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET } OvpnPeerContext* peer = NULL; + if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { - OvpnMssDoIPv4(buffer->Data, buffer->Len, device->MSS); - peer = OvpnFindPeerVPN4(device, ((IPV4_HEADER*)buffer->Data)->DestinationAddress); + auto addr = ((IPV4_HEADER*)buffer->Data)->DestinationAddress; + peer = OvpnFindPeerVPN4(device, addr); + if (peer != NULL) { + OvpnMssDoIPv4(buffer->Data, buffer->Len, peer->MSS); + } } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { - OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); - peer = OvpnFindPeerVPN6(device, ((IPV6_HEADER*)buffer->Data)->DestinationAddress); + auto addr = ((IPV6_HEADER*)buffer->Data)->DestinationAddress; + peer = OvpnFindPeerVPN6(device, addr); + if (peer != NULL) { + OvpnMssDoIPv6(buffer->Data, buffer->Len, peer->MSS); + } } if (peer == NULL) { status = STATUS_ADDRESS_NOT_ASSOCIATED; OvpnTxBufferPoolPut(buffer); - LOG_WARN("No peer"); goto out; } diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index 60e4a53..be4770b 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -119,10 +119,17 @@ typedef struct _OVPN_CRYPTO_DATA_V2 { UINT32 CryptoOptions; } OVPN_CRYPTO_DATA_V2, * POVPN_CRYPTO_DATA_V2; +typedef struct _OVPN_MP_SET_PEER { + int PeerId; + LONG KeepaliveInterval; + LONG KeepaliveTimeout; + LONG MSS; +} OVPN_MP_SET_PEER, * POVPN_MP_SET_PEER; + typedef struct _OVPN_SET_PEER { - LONG KeepaliveInterval; - LONG KeepaliveTimeout; - LONG MSS; + LONG KeepaliveInterval; + LONG KeepaliveTimeout; + LONG MSS; } OVPN_SET_PEER, * POVPN_SET_PEER; typedef struct _OVPN_VERSION { @@ -160,3 +167,4 @@ typedef struct _OVPN_MP_START_VPN { #define OVPN_IOCTL_MP_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 11, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_MP_NEW_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 12, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define OVPN_IOCTL_MP_SET_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 13, METHOD_BUFFERED, FILE_ANY_ACCESS)