diff --git a/Driver.h b/Driver.h index 232ad70..62476f5 100644 --- a/Driver.h +++ b/Driver.h @@ -84,13 +84,9 @@ struct OVPN_DEVICE { _Guarded_by_(SpinLock) LONG KeepaliveTimeout; - // timer used to send periodic ping messages to the server if no data has been sent within the past KeepaliveInterval seconds + // 1-sec timer which handles ping intervals and keepalive timeouts _Guarded_by_(SpinLock) - WDFTIMER KeepaliveXmitTimer; - - // timer used to report keepalive timeout error to userspace when no data has been received for KeepaliveTimeout seconds - _Guarded_by_(SpinLock) - WDFTIMER KeepaliveRecvTimer; + WDFTIMER Timer; // set from the userspace, defines TCP Maximum Segment Size _Guarded_by_(SpinLock) diff --git a/PropertySheet.props b/PropertySheet.props index 5382bfd..0eabefd 100644 --- a/PropertySheet.props +++ b/PropertySheet.props @@ -3,8 +3,8 @@ 1 - 0 - 1 + 1 + 0 diff --git a/peer.cpp b/peer.cpp index ad14df1..4a548bb 100644 --- a/peer.cpp +++ b/peer.cpp @@ -95,6 +95,8 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) status = OvpnSocketTcpConnect(socket, device, (PSOCKADDR)&peer->Remote); } + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, &device->Timer)); + done: LOG_EXIT(); @@ -116,8 +118,7 @@ OvpnPeerDel(POVPN_DEVICE device) KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - OvpnTimerDestroy(&device->KeepaliveXmitTimer); - OvpnTimerDestroy(&device->KeepaliveRecvTimer); + OvpnTimerDestroy(&device->Timer); aesAlgHandle = device->CryptoContext.AesAlgHandle; chachaAlgHandle = device->CryptoContext.ChachaAlgHandle; @@ -183,29 +184,15 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) if (peer->KeepaliveInterval != -1) { device->KeepaliveInterval = peer->KeepaliveInterval; - if (device->KeepaliveInterval > 0) { - // keepalive xmit timer, sends ping packets - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerXmitCreate(device->WdfDevice, peer->KeepaliveInterval, &device->KeepaliveXmitTimer)); - OvpnTimerReset(device->KeepaliveXmitTimer, peer->KeepaliveInterval); - } - else { - LOG_INFO("Destroy xmit timer"); - OvpnTimerDestroy(&device->KeepaliveXmitTimer); - } + // keepalive xmit timer, sends ping packets + OvpnTimerSetXmitInterval(device->Timer, peer->KeepaliveInterval); } if (peer->KeepaliveTimeout != -1) { device->KeepaliveTimeout = peer->KeepaliveTimeout; - if (device->KeepaliveTimeout > 0) { - // keepalive recv timer, detects keepalive timeout - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerRecvCreate(device->WdfDevice, &device->KeepaliveRecvTimer)); - OvpnTimerReset(device->KeepaliveRecvTimer, peer->KeepaliveTimeout); - } - else { - LOG_INFO("Destroy recv timer"); - OvpnTimerDestroy(&device->KeepaliveRecvTimer); - } + // keepalive recv timer, detects keepalive timeout + OvpnTimerSetRecvTimeout(device->Timer, peer->KeepaliveTimeout); } done: diff --git a/socket.cpp b/socket.cpp index 3b342f0..765152f 100644 --- a/socket.cpp +++ b/socket.cpp @@ -194,7 +194,7 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ return; } - OvpnTimerReset(device->KeepaliveRecvTimer, device->KeepaliveTimeout); + OvpnTimerResetRecv(device->Timer); // points to the beginning of plaintext UCHAR* buf = buffer->Data + device->CryptoContext.CryptoOverhead; diff --git a/timer.cpp b/timer.cpp index d9563db..a58dd95 100644 --- a/timer.cpp +++ b/timer.cpp @@ -30,13 +30,23 @@ static const UCHAR OvpnKeepaliveMessage[] = { 0x07, 0xed, 0x2d, 0x0a, 0x98, 0x1f, 0xc7, 0x48 }; +typedef struct _OVPN_TIMER_CONTEXT { + LARGE_INTEGER lastXmit; + LARGE_INTEGER lastRecv; + + // 0 means "not set" + LONG recvTimeout; + LONG xmitInterval; +} OVPN_TIMER_CONTEXT, * POVPN_TIMER_CONTEXT; + +WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_TIMER_CONTEXT, OvpnGetTimerContext); + _Use_decl_annotations_ BOOLEAN OvpnTimerIsKeepaliveMessage(const PUCHAR buf, SIZE_T len) { return RtlCompareMemory(buf, OvpnKeepaliveMessage, len) == sizeof(OvpnKeepaliveMessage); } -_Function_class_(EVT_WDF_TIMER) static VOID OvpnTimerXmit(WDFTIMER timer) { POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer)); @@ -78,21 +88,21 @@ static VOID OvpnTimerXmit(WDFTIMER timer) ExReleaseSpinLockShared(&device->SpinLock, kiqrl); } -_Function_class_(EVT_WDF_TIMER) -static VOID OvpnTimerRecv(WDFTIMER timer) +static BOOLEAN OvpnTimerRecv(WDFTIMER timer) { - LOG_WARN("Keepalive timeout"); - POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer)); WDFREQUEST request; NTSTATUS status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request); if (!NT_SUCCESS(status)) { LOG_WARN("No pending request for keepalive timeout notification"); + return FALSE; } else { + LOG_INFO("Notify userspace about keepalive timeout"); ULONG_PTR bytesSent = 0; WdfRequestCompleteWithInformation(request, STATUS_CONNECTION_DISCONNECTED, bytesSent); + return TRUE; } } @@ -107,7 +117,31 @@ VOID OvpnTimerDestroy(WDFTIMER* timer) } } -static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, ULONG period, PFN_WDF_TIMER func, _Inout_ WDFTIMER* timer) +_Function_class_(EVT_WDF_TIMER) +static VOID OvpnTimerTick(WDFTIMER timer) +{ + LARGE_INTEGER now; + KeQuerySystemTime(&now); + + POVPN_TIMER_CONTEXT timerCtx = OvpnGetTimerContext(timer); + if ((timerCtx->xmitInterval > 0) && (((now.QuadPart - timerCtx->lastXmit.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->xmitInterval)) + { + OvpnTimerXmit(timer); + timerCtx->lastXmit = now; + } + + if ((timerCtx->recvTimeout > 0) && (((now.QuadPart - timerCtx->lastRecv.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->recvTimeout)) + { + // have we have completed pending read request? + if (OvpnTimerRecv(timer)) + { + timerCtx->recvTimeout = 0; // one-off timer + } + } +} + +_Use_decl_annotations_ +NTSTATUS OvpnTimerCreate(WDFOBJECT parent, WDFTIMER* timer) { if (*timer != WDF_NO_HANDLE) { WdfTimerStop(*timer, FALSE); @@ -117,40 +151,47 @@ static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, ULONG period, PFN_WDF_TIMER fu } WDF_TIMER_CONFIG timerConfig; - WDF_TIMER_CONFIG_INIT(&timerConfig, func); - timerConfig.Period = period * 1000; + WDF_TIMER_CONFIG_INIT(&timerConfig, OvpnTimerTick); + timerConfig.TolerableDelay = TolerableDelayUnlimited; + timerConfig.Period = 1000; WDF_OBJECT_ATTRIBUTES timerAttributes; WDF_OBJECT_ATTRIBUTES_INIT(&timerAttributes); + WDF_OBJECT_ATTRIBUTES_SET_CONTEXT_TYPE(&timerAttributes, OVPN_TIMER_CONTEXT); timerAttributes.ParentObject = parent; - return WdfTimerCreate(&timerConfig, &timerAttributes, timer); -} - -_Use_decl_annotations_ -NTSTATUS OvpnTimerXmitCreate(WDFOBJECT parent, ULONG period, WDFTIMER* timer) -{ + *timer = WDF_NO_HANDLE; NTSTATUS status; - LOG_INFO("Create xmit timer", TraceLoggingValue(period, "period")); - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, period, OvpnTimerXmit, timer)); + LOG_IF_NOT_NT_SUCCESS(status = WdfTimerCreate(&timerConfig, &timerAttributes, timer)); + if (NT_SUCCESS(status)) { + WdfTimerStart(*timer, WDF_REL_TIMEOUT_IN_SEC(1)); + } return status; } -_Use_decl_annotations_ -NTSTATUS OvpnTimerRecvCreate(WDFOBJECT parent, WDFTIMER* timer) +VOID OvpnTimerSetXmitInterval(WDFTIMER timer, LONG xmitInterval) { - NTSTATUS status; - LOG_INFO("Create recv timer"); - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, 0, OvpnTimerRecv, timer)); + POVPN_TIMER_CONTEXT timerCtx = OvpnGetTimerContext(timer); + timerCtx->xmitInterval = xmitInterval; + KeQuerySystemTime(&timerCtx->lastXmit); +} - return status; +VOID OvpnTimerSetRecvTimeout(WDFTIMER timer, LONG recvTimeout) +{ + POVPN_TIMER_CONTEXT timerCtx = OvpnGetTimerContext(timer); + timerCtx->recvTimeout = recvTimeout; + KeQuerySystemTime(&timerCtx->lastRecv); } -VOID OvpnTimerReset(WDFTIMER timer, ULONG dueTime) +VOID OvpnTimerResetXmit(WDFTIMER timer) { - if (timer != WDF_NO_HANDLE) { - // if timer has already been created this will reset "due time" value to the new one - WdfTimerStart(timer, WDF_REL_TIMEOUT_IN_SEC(dueTime)); - } + POVPN_TIMER_CONTEXT timerCtx = OvpnGetTimerContext(timer); + KeQuerySystemTime(&timerCtx->lastXmit); } + +VOID OvpnTimerResetRecv(WDFTIMER timer) +{ + POVPN_TIMER_CONTEXT timerCtx = OvpnGetTimerContext(timer); + KeQuerySystemTime(&timerCtx->lastRecv); +} \ No newline at end of file diff --git a/timer.h b/timer.h index 41f52a5..8325919 100644 --- a/timer.h +++ b/timer.h @@ -25,15 +25,20 @@ #include VOID -OvpnTimerReset(WDFTIMER timer, ULONG dueTime); +OvpnTimerResetXmit(WDFTIMER timer); -_Must_inspect_result_ -NTSTATUS -OvpnTimerXmitCreate(WDFOBJECT parent, ULONG period, _Inout_ WDFTIMER* timer); +VOID +OvpnTimerResetRecv(WDFTIMER timer); _Must_inspect_result_ NTSTATUS -OvpnTimerRecvCreate(WDFOBJECT parent, _Inout_ WDFTIMER* timer); +OvpnTimerCreate(WDFOBJECT parent, _Inout_ WDFTIMER* timer); + +VOID +OvpnTimerSetXmitInterval(WDFTIMER timer, LONG xmitInterval); + +VOID +OvpnTimerSetRecvTimeout(WDFTIMER timer, LONG recvTimeout); VOID OvpnTimerDestroy(_Inout_ WDFTIMER* timer); diff --git a/txqueue.cpp b/txqueue.cpp index 711b555..d3887dd 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -169,7 +169,7 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) // reset keepalive timer if (packetSent) { - OvpnTimerReset(device->KeepaliveXmitTimer, device->KeepaliveInterval); + OvpnTimerResetXmit(device->Timer); if (!device->Socket.Tcp) { // this will use WskSendMessages to send buffers list which we constructed before