Skip to content

Commit

Permalink
Migrate ktor (tcp and websocket) and nodejs tcp transport to new API (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
whyoleg authored Oct 21, 2024
1 parent c279c51 commit e22410f
Show file tree
Hide file tree
Showing 30 changed files with 1,339 additions and 70 deletions.
2 changes: 2 additions & 0 deletions rsocket-internal-io/api/rsocket-internal-io.api
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ public final class io/rsocket/kotlin/internal/io/ChannelsKt {
public final class io/rsocket/kotlin/internal/io/ContextKt {
public static final fun childContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
public static final fun ensureActive (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;)V
public static final fun launchCoroutine (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun launchCoroutine$default (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public static final fun onCompletion (Lkotlinx/coroutines/Job;Lkotlin/jvm/functions/Function1;)Lkotlinx/coroutines/Job;
public static final fun supervisorContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,13 @@ public inline fun CoroutineContext.ensureActive(onInactive: () -> Unit) {
onInactive() // should not throw
ensureActive() // will throw
}

@Suppress("SuspendFunctionOnCoroutineScope")
public suspend inline fun <T> CoroutineScope.launchCoroutine(
context: CoroutineContext = EmptyCoroutineContext,
crossinline block: suspend (CancellableContinuation<T>) -> Unit,
): T = suspendCancellableCoroutine { cont ->
val job = launch(context) { block(cont) }
job.invokeOnCompletion { if (it != null && cont.isActive) cont.resumeWithException(it) }
cont.invokeOnCancellation { job.cancel("launchCoroutine was cancelled", it) }
}
40 changes: 40 additions & 0 deletions rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api
Original file line number Diff line number Diff line change
@@ -1,3 +1,43 @@
public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport : io/rsocket/kotlin/transport/RSocketTransport {
public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory;
public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketClientTarget;
public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketClientTarget;
}

public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory {
}

public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder {
public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V
public fun inheritDispatcher ()V
public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V
public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V
public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V
}

public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance {
public abstract fun getLocalAddress ()Lio/ktor/network/sockets/SocketAddress;
}

public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport : io/rsocket/kotlin/transport/RSocketTransport {
public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory;
public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketServerTarget;
public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketServerTarget;
public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Lio/ktor/network/sockets/SocketAddress;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget;
public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget;
}

public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory {
}

public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder {
public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V
public fun inheritDispatcher ()V
public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V
public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V
public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V
}

public final class io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransportKt {
public static final fun TcpClientTransport (Lio/ktor/network/sockets/InetSocketAddress;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport;
public static final fun TcpClientTransport (Ljava/lang/String;ILkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.ktor.tcp

import io.ktor.network.selector.*
import io.ktor.network.sockets.*
import io.rsocket.kotlin.internal.io.*
import io.rsocket.kotlin.transport.*
import kotlinx.coroutines.*
import kotlin.coroutines.*

@OptIn(RSocketTransportApi::class)
public sealed interface KtorTcpClientTransport : RSocketTransport {
public fun target(remoteAddress: SocketAddress): RSocketClientTarget
public fun target(host: String, port: Int): RSocketClientTarget

public companion object Factory :
RSocketTransportFactory<KtorTcpClientTransport, KtorTcpClientTransportBuilder>(::KtorTcpClientTransportBuilderImpl)
}

@OptIn(RSocketTransportApi::class)
public sealed interface KtorTcpClientTransportBuilder : RSocketTransportBuilder<KtorTcpClientTransport> {
public fun dispatcher(context: CoroutineContext)
public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext)

public fun selectorManagerDispatcher(context: CoroutineContext)
public fun selectorManager(manager: SelectorManager, manage: Boolean)

public fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit)

//TODO: TLS support
}

private class KtorTcpClientTransportBuilderImpl : KtorTcpClientTransportBuilder {
private var dispatcher: CoroutineContext = Dispatchers.Default
private var selector: KtorTcpSelector = KtorTcpSelector.FromContext(Dispatchers.IO)
private var socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit = {}

override fun dispatcher(context: CoroutineContext) {
check(context[Job] == null) { "Dispatcher shouldn't contain job" }
this.dispatcher = context
}

override fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit) {
this.socketOptions = block
}

override fun selectorManagerDispatcher(context: CoroutineContext) {
check(context[Job] == null) { "Dispatcher shouldn't contain job" }
this.selector = KtorTcpSelector.FromContext(context)
}

override fun selectorManager(manager: SelectorManager, manage: Boolean) {
this.selector = KtorTcpSelector.FromInstance(manager, manage)
}

@RSocketTransportApi
override fun buildTransport(context: CoroutineContext): KtorTcpClientTransport {
val transportContext = context.supervisorContext() + dispatcher
return KtorTcpClientTransportImpl(
coroutineContext = transportContext,
socketOptions = socketOptions,
selectorManager = selector.createFor(transportContext)
)
}
}

private class KtorTcpClientTransportImpl(
override val coroutineContext: CoroutineContext,
private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit,
private val selectorManager: SelectorManager,
) : KtorTcpClientTransport {
override fun target(remoteAddress: SocketAddress): RSocketClientTarget = KtorTcpClientTargetImpl(
coroutineContext = coroutineContext.supervisorContext(),
socketOptions = socketOptions,
selectorManager = selectorManager,
remoteAddress = remoteAddress
)

override fun target(host: String, port: Int): RSocketClientTarget = target(InetSocketAddress(host, port))
}

@OptIn(RSocketTransportApi::class)
private class KtorTcpClientTargetImpl(
override val coroutineContext: CoroutineContext,
private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit,
private val selectorManager: SelectorManager,
private val remoteAddress: SocketAddress,
) : RSocketClientTarget {

@RSocketTransportApi
override fun connectClient(handler: RSocketConnectionHandler): Job = launch {
val socket = aSocket(selectorManager).tcp().connect(remoteAddress, socketOptions)
handler.handleKtorTcpConnection(socket)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.ktor.tcp

import io.ktor.network.sockets.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import io.rsocket.kotlin.internal.io.*
import io.rsocket.kotlin.transport.*
import io.rsocket.kotlin.transport.internal.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*

@RSocketTransportApi
internal suspend fun RSocketConnectionHandler.handleKtorTcpConnection(socket: Socket): Unit = coroutineScope {
val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED)
val inbound = channelForCloseable<ByteReadPacket>(Channel.BUFFERED)

val readerJob = launch {
val input = socket.openReadChannel()
try {
while (true) inbound.send(input.readFrame() ?: break)
input.cancel(null)
} catch (cause: Throwable) {
input.cancel(cause)
throw cause
}
}.onCompletion { inbound.cancel() }

val writerJob = launch {
val output = socket.openWriteChannel()
try {
while (true) {
// we write all available frames here, and only after it flush
// in this case, if there are several buffered frames we can send them in one go
// avoiding unnecessary flushes
output.writeFrame(outboundQueue.dequeueFrame() ?: break)
while (true) output.writeFrame(outboundQueue.tryDequeueFrame() ?: break)
output.flush()
}
output.close(null)
} catch (cause: Throwable) {
output.close(cause)
throw cause
}
}.onCompletion { outboundQueue.cancel() }

try {
handleConnection(KtorTcpConnection(outboundQueue, inbound))
} finally {
readerJob.cancel()
outboundQueue.close() // will cause `writerJob` completion
// even if it was cancelled, we still need to close socket and await it closure
withContext(NonCancellable) {
// await completion of read/write and then close socket
readerJob.join()
writerJob.join()
// close socket
socket.close()
socket.socketContext.join()
}
}
}

@RSocketTransportApi
private class KtorTcpConnection(
private val outboundQueue: PrioritizationFrameQueue,
private val inbound: ReceiveChannel<ByteReadPacket>,
) : RSocketSequentialConnection {
override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend
override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) {
return outboundQueue.enqueueFrame(streamId, frame)
}

override suspend fun receiveFrame(): ByteReadPacket? {
return inbound.receiveCatching().getOrNull()
}
}

private suspend fun ByteWriteChannel.writeFrame(frame: ByteReadPacket) {
val packet = buildPacket {
writeInt24(frame.remaining.toInt())
writePacket(frame)
}
try {
writePacket(packet)
} catch (cause: Throwable) {
packet.close()
throw cause
}
}

private suspend fun ByteReadChannel.readFrame(): ByteReadPacket? {
val lengthPacket = readRemaining(3)
if (lengthPacket.remaining == 0L) return null
return readPacket(lengthPacket.readInt24())
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.ktor.tcp

import io.ktor.network.selector.*
import kotlinx.coroutines.*
import kotlin.coroutines.*

internal sealed class KtorTcpSelector {
class FromContext(val context: CoroutineContext) : KtorTcpSelector()
class FromInstance(val selectorManager: SelectorManager, val manage: Boolean) : KtorTcpSelector()
}

internal fun KtorTcpSelector.createFor(parentContext: CoroutineContext): SelectorManager {
val selectorManager: SelectorManager
val manage: Boolean
when (this) {
is KtorTcpSelector.FromContext -> {
selectorManager = SelectorManager(parentContext + context)
manage = true
}

is KtorTcpSelector.FromInstance -> {
selectorManager = this.selectorManager
manage = this.manage
}
}
if (manage) Job(parentContext.job).invokeOnCompletion { selectorManager.close() }
return selectorManager
}
Loading

0 comments on commit e22410f

Please sign in to comment.