diff --git a/rsocket-internal-io/api/rsocket-internal-io.api b/rsocket-internal-io/api/rsocket-internal-io.api index 31ae74c9..e98f43e4 100644 --- a/rsocket-internal-io/api/rsocket-internal-io.api +++ b/rsocket-internal-io/api/rsocket-internal-io.api @@ -6,6 +6,8 @@ public final class io/rsocket/kotlin/internal/io/ChannelsKt { public final class io/rsocket/kotlin/internal/io/ContextKt { public static final fun childContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; public static final fun ensureActive (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;)V + public static final fun launchCoroutine (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun launchCoroutine$default (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public static final fun onCompletion (Lkotlinx/coroutines/Job;Lkotlin/jvm/functions/Function1;)Lkotlinx/coroutines/Job; public static final fun supervisorContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; } diff --git a/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt index dc9b9780..1ed305c7 100644 --- a/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt +++ b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt @@ -32,3 +32,13 @@ public inline fun CoroutineContext.ensureActive(onInactive: () -> Unit) { onInactive() // should not throw ensureActive() // will throw } + +@Suppress("SuspendFunctionOnCoroutineScope") +public suspend inline fun CoroutineScope.launchCoroutine( + context: CoroutineContext = EmptyCoroutineContext, + crossinline block: suspend (CancellableContinuation) -> Unit, +): T = suspendCancellableCoroutine { cont -> + val job = launch(context) { block(cont) } + job.invokeOnCompletion { if (it != null && cont.isActive) cont.resumeWithException(it) } + cont.invokeOnCancellation { job.cancel("launchCoroutine was cancelled", it) } +} diff --git a/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api b/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api index 5833bcf1..d6153045 100644 --- a/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api +++ b/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api @@ -1,3 +1,43 @@ +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory; + public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketClientTarget; +} + +public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V + public fun inheritDispatcher ()V + public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V + public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V + public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getLocalAddress ()Lio/ktor/network/sockets/SocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory; + public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Lio/ktor/network/sockets/SocketAddress;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; +} + +public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V + public fun inheritDispatcher ()V + public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V + public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V + public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V +} + public final class io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransportKt { public static final fun TcpClientTransport (Lio/ktor/network/sockets/InetSocketAddress;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport; public static final fun TcpClientTransport (Ljava/lang/String;ILkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport; diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt new file mode 100644 index 00000000..ae5c815f --- /dev/null +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt @@ -0,0 +1,110 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.tcp + +import io.ktor.network.selector.* +import io.ktor.network.sockets.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorTcpClientTransport : RSocketTransport { + public fun target(remoteAddress: SocketAddress): RSocketClientTarget + public fun target(host: String, port: Int): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::KtorTcpClientTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorTcpClientTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) + + public fun selectorManagerDispatcher(context: CoroutineContext) + public fun selectorManager(manager: SelectorManager, manage: Boolean) + + public fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit) + + //TODO: TLS support +} + +private class KtorTcpClientTransportBuilderImpl : KtorTcpClientTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.Default + private var selector: KtorTcpSelector = KtorTcpSelector.FromContext(Dispatchers.IO) + private var socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit = {} + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + override fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit) { + this.socketOptions = block + } + + override fun selectorManagerDispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.selector = KtorTcpSelector.FromContext(context) + } + + override fun selectorManager(manager: SelectorManager, manage: Boolean) { + this.selector = KtorTcpSelector.FromInstance(manager, manage) + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): KtorTcpClientTransport { + val transportContext = context.supervisorContext() + dispatcher + return KtorTcpClientTransportImpl( + coroutineContext = transportContext, + socketOptions = socketOptions, + selectorManager = selector.createFor(transportContext) + ) + } +} + +private class KtorTcpClientTransportImpl( + override val coroutineContext: CoroutineContext, + private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit, + private val selectorManager: SelectorManager, +) : KtorTcpClientTransport { + override fun target(remoteAddress: SocketAddress): RSocketClientTarget = KtorTcpClientTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + socketOptions = socketOptions, + selectorManager = selectorManager, + remoteAddress = remoteAddress + ) + + override fun target(host: String, port: Int): RSocketClientTarget = target(InetSocketAddress(host, port)) +} + +@OptIn(RSocketTransportApi::class) +private class KtorTcpClientTargetImpl( + override val coroutineContext: CoroutineContext, + private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit, + private val selectorManager: SelectorManager, + private val remoteAddress: SocketAddress, +) : RSocketClientTarget { + + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + val socket = aSocket(selectorManager).tcp().connect(remoteAddress, socketOptions) + handler.handleKtorTcpConnection(socket) + } +} diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt new file mode 100644 index 00000000..b65c1077 --- /dev/null +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt @@ -0,0 +1,111 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.tcp + +import io.ktor.network.sockets.* +import io.ktor.utils.io.* +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* + +@RSocketTransportApi +internal suspend fun RSocketConnectionHandler.handleKtorTcpConnection(socket: Socket): Unit = coroutineScope { + val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) + val inbound = channelForCloseable(Channel.BUFFERED) + + val readerJob = launch { + val input = socket.openReadChannel() + try { + while (true) inbound.send(input.readFrame() ?: break) + input.cancel(null) + } catch (cause: Throwable) { + input.cancel(cause) + throw cause + } + }.onCompletion { inbound.cancel() } + + val writerJob = launch { + val output = socket.openWriteChannel() + try { + while (true) { + // we write all available frames here, and only after it flush + // in this case, if there are several buffered frames we can send them in one go + // avoiding unnecessary flushes + output.writeFrame(outboundQueue.dequeueFrame() ?: break) + while (true) output.writeFrame(outboundQueue.tryDequeueFrame() ?: break) + output.flush() + } + output.close(null) + } catch (cause: Throwable) { + output.close(cause) + throw cause + } + }.onCompletion { outboundQueue.cancel() } + + try { + handleConnection(KtorTcpConnection(outboundQueue, inbound)) + } finally { + readerJob.cancel() + outboundQueue.close() // will cause `writerJob` completion + // even if it was cancelled, we still need to close socket and await it closure + withContext(NonCancellable) { + // await completion of read/write and then close socket + readerJob.join() + writerJob.join() + // close socket + socket.close() + socket.socketContext.join() + } + } +} + +@RSocketTransportApi +private class KtorTcpConnection( + private val outboundQueue: PrioritizationFrameQueue, + private val inbound: ReceiveChannel, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return inbound.receiveCatching().getOrNull() + } +} + +private suspend fun ByteWriteChannel.writeFrame(frame: ByteReadPacket) { + val packet = buildPacket { + writeInt24(frame.remaining.toInt()) + writePacket(frame) + } + try { + writePacket(packet) + } catch (cause: Throwable) { + packet.close() + throw cause + } +} + +private suspend fun ByteReadChannel.readFrame(): ByteReadPacket? { + val lengthPacket = readRemaining(3) + if (lengthPacket.remaining == 0L) return null + return readPacket(lengthPacket.readInt24()) +} diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt new file mode 100644 index 00000000..161752e0 --- /dev/null +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt @@ -0,0 +1,44 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.tcp + +import io.ktor.network.selector.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +internal sealed class KtorTcpSelector { + class FromContext(val context: CoroutineContext) : KtorTcpSelector() + class FromInstance(val selectorManager: SelectorManager, val manage: Boolean) : KtorTcpSelector() +} + +internal fun KtorTcpSelector.createFor(parentContext: CoroutineContext): SelectorManager { + val selectorManager: SelectorManager + val manage: Boolean + when (this) { + is KtorTcpSelector.FromContext -> { + selectorManager = SelectorManager(parentContext + context) + manage = true + } + + is KtorTcpSelector.FromInstance -> { + selectorManager = this.selectorManager + manage = this.manage + } + } + if (manage) Job(parentContext.job).invokeOnCompletion { selectorManager.close() } + return selectorManager +} diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt new file mode 100644 index 00000000..be170e4d --- /dev/null +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt @@ -0,0 +1,155 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.tcp + +import io.ktor.network.selector.* +import io.ktor.network.sockets.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorTcpServerInstance : RSocketServerInstance { + public val localAddress: SocketAddress +} + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorTcpServerTransport : RSocketTransport { + public fun target(localAddress: SocketAddress? = null): RSocketServerTarget + public fun target(host: String = "0.0.0.0", port: Int = 0): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory(::KtorTcpServerTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorTcpServerTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) + + public fun selectorManagerDispatcher(context: CoroutineContext) + public fun selectorManager(manager: SelectorManager, manage: Boolean) + + public fun socketOptions(block: SocketOptions.AcceptorOptions.() -> Unit) +} + +private class KtorTcpServerTransportBuilderImpl : KtorTcpServerTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.Default + private var selector: KtorTcpSelector = KtorTcpSelector.FromContext(Dispatchers.IO) + private var socketOptions: SocketOptions.AcceptorOptions.() -> Unit = {} + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + override fun socketOptions(block: SocketOptions.AcceptorOptions.() -> Unit) { + this.socketOptions = block + } + + override fun selectorManagerDispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.selector = KtorTcpSelector.FromContext(context) + } + + override fun selectorManager(manager: SelectorManager, manage: Boolean) { + this.selector = KtorTcpSelector.FromInstance(manager, manage) + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): KtorTcpServerTransport { + val transportContext = context.supervisorContext() + dispatcher + return KtorTcpServerTransportImpl( + coroutineContext = transportContext, + socketOptions = socketOptions, + selectorManager = selector.createFor(transportContext) + ) + } +} + +private class KtorTcpServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val socketOptions: SocketOptions.AcceptorOptions.() -> Unit, + private val selectorManager: SelectorManager, +) : KtorTcpServerTransport { + override fun target(localAddress: SocketAddress?): RSocketServerTarget = KtorTcpServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + socketOptions = socketOptions, + selectorManager = selectorManager, + localAddress = localAddress + ) + + override fun target(host: String, port: Int): RSocketServerTarget = target(InetSocketAddress(host, port)) +} + +@OptIn(RSocketTransportApi::class) +private class KtorTcpServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val socketOptions: SocketOptions.AcceptorOptions.() -> Unit, + private val selectorManager: SelectorManager, + private val localAddress: SocketAddress?, +) : RSocketServerTarget { + + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): KtorTcpServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + return startKtorTcpServer(this, bindSocket(), handler) + } + + @OptIn(ExperimentalCoroutinesApi::class) + private suspend fun bindSocket(): ServerSocket = launchCoroutine { cont -> + val socket = aSocket(selectorManager).tcp().bind(localAddress, socketOptions) + cont.resume(socket) { socket.close() } + } +} + +@RSocketTransportApi +private fun startKtorTcpServer( + scope: CoroutineScope, + serverSocket: ServerSocket, + handler: RSocketConnectionHandler, +): KtorTcpServerInstance { + val serverJob = scope.launch { + try { + // the failure of one connection should not stop all other connections + supervisorScope { + while (true) { + val socket = serverSocket.accept() + launch { handler.handleKtorTcpConnection(socket) } + } + } + } finally { + // even if it was cancelled, we still need to close socket and await it closure + withContext(NonCancellable) { + serverSocket.close() + serverSocket.socketContext.join() + } + } + } + return KtorTcpServerInstanceImpl( + coroutineContext = scope.coroutineContext + serverJob, + localAddress = serverSocket.localAddress + ) +} + +@RSocketTransportApi +private class KtorTcpServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: SocketAddress, +) : KtorTcpServerInstance diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransport.kt index cd3a208a..8194bd8d 100644 --- a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransport.kt +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransport.kt @@ -26,11 +26,6 @@ import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlin.coroutines.* -//TODO user should close ClientTransport manually if there is no job provided in context - -//this dispatcher will be used, if no dispatcher were provided by user in client and server -internal expect val defaultDispatcher: CoroutineDispatcher - public fun TcpClientTransport( hostname: String, port: Int, context: CoroutineContext = EmptyCoroutineContext, @@ -42,10 +37,10 @@ public fun TcpClientTransport( remoteAddress: InetSocketAddress, context: CoroutineContext = EmptyCoroutineContext, intercept: (Socket) -> Socket = { it }, //f.e. for tls, which is currently supported by ktor only on JVM - configure: SocketOptions.TCPClientSocketOptions.() -> Unit = {} + configure: SocketOptions.TCPClientSocketOptions.() -> Unit = {}, ): ClientTransport { val transportJob = SupervisorJob(context[Job]) - val transportContext = defaultDispatcher + context + transportJob + CoroutineName("rSocket-tcp-client") + val transportContext = Dispatchers.IO + context + transportJob + CoroutineName("rSocket-tcp-client") val selector = SelectorManager(transportContext) Job(transportJob).invokeOnCompletion { selector.close() } return ClientTransport(transportContext) { diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTransport.kt index 7282068f..f9be46d8 100644 --- a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTransport.kt +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTransport.kt @@ -41,7 +41,7 @@ public fun TcpServerTransport( configure: SocketOptions.AcceptorOptions.() -> Unit = {}, ): ServerTransport = ServerTransport { accept -> val serverSocketDeferred = CompletableDeferred() - val handlerJob = launch(defaultDispatcher + coroutineContext) { + val handlerJob = launch(Dispatchers.IO + coroutineContext) { SelectorManager(coroutineContext).use { selector -> aSocket(selector).tcp().bind(localAddress, configure).use { serverSocket -> serverSocketDeferred.complete(serverSocket) diff --git a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt index 95667dfb..df99dc02 100644 --- a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt +++ b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt @@ -16,7 +16,6 @@ package io.rsocket.kotlin.transport.ktor.tcp -import io.ktor.network.sockets.* import io.rsocket.kotlin.* import io.rsocket.kotlin.test.* import kotlinx.coroutines.* @@ -25,11 +24,9 @@ import kotlin.test.* class TcpServerTest : SuspendTest, TestWithLeakCheck { private val testJob = Job() private val testContext = testJob + TestExceptionHandler - private val serverTransport = TcpServerTransport() - private suspend fun clientTransport(server: TcpServer) = TcpClientTransport( - server.serverSocket.await().localAddress as InetSocketAddress, - testContext - ) + private val serverTransport = KtorTcpServerTransport(testContext).target() + private fun KtorTcpServerInstance.clientTransport() = + KtorTcpClientTransport(testContext).target(localAddress) override suspend fun after() { testJob.cancelAndJoin() @@ -37,13 +34,13 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { @Test fun testFailedConnection() = test { - val server = TestServer().bindIn(CoroutineScope(testContext), serverTransport) { + val server = TestServer().startServer(serverTransport) { if (config.setupPayload.data.readText() == "ok") { RSocketRequestHandler { requestResponse { it } } } else error("FAILED") - }.also { it.serverSocket.await() } + } suspend fun newClient(text: String) = TestConnector { connectionConfig { @@ -51,7 +48,7 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { payload(text) } } - }.connect(clientTransport(server)) + }.connect(server.clientTransport()) val client1 = newClient("ok") client1.requestResponse(payload("ok")).close() @@ -70,8 +67,7 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { assertFalse(client2.isActive) assertTrue(client3.isActive) - assertTrue(server.serverSocket.await().socketContext.isActive) - assertTrue(server.handlerJob.isActive) + assertTrue(server.isActive) client1.coroutineContext.job.cancelAndJoin() client2.coroutineContext.job.cancelAndJoin() @@ -81,13 +77,13 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { @Test fun testFailedHandler() = test { val handlers = mutableListOf() - val server = TestServer().bindIn(CoroutineScope(testContext), serverTransport) { + val server = TestServer().startServer(serverTransport) { RSocketRequestHandler { requestResponse { it } }.also { handlers += it } - }.also { it.serverSocket.await() } + } - suspend fun newClient() = TestConnector().connect(clientTransport(server)) + suspend fun newClient() = TestConnector().connect(server.clientTransport()) val client1 = newClient() @@ -118,8 +114,7 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { assertFalse(client2.isActive) assertTrue(client3.isActive) - assertTrue(server.serverSocket.await().socketContext.isActive) - assertTrue(server.handlerJob.isActive) + assertTrue(server.isActive) client1.coroutineContext.job.cancelAndJoin() client2.coroutineContext.job.cancelAndJoin() diff --git a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt index a55440e5..b859c9e1 100644 --- a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt +++ b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt @@ -16,8 +16,10 @@ package io.rsocket.kotlin.transport.ktor.tcp +import io.ktor.network.selector.* import io.ktor.network.sockets.* import io.rsocket.kotlin.transport.tests.* +import kotlinx.coroutines.* class TcpTransportTest : TransportTest() { override suspend fun before() { @@ -25,3 +27,25 @@ class TcpTransportTest : TransportTest() { client = connectClient(TcpClientTransport(serverSocket.localAddress as InetSocketAddress, testContext)) } } + +class KtorTcpTransportTest : TransportTest() { + // a single SelectorManager for both client and server works much better in K/N + // in user code in most of the cases, only one SelectorManager will be created + private val selector = SelectorManager(Dispatchers.IO) + override suspend fun before() { + val server = startServer(KtorTcpServerTransport(testContext) { + selectorManager(selector, false) + }.target()) + client = connectClient(KtorTcpClientTransport(testContext) { + selectorManager(selector, false) + }.target(server.localAddress)) + } + + override suspend fun after() { + try { + super.after() + } finally { + selector.close() + } + } +} diff --git a/rsocket-transports/ktor-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt b/rsocket-transports/ktor-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt deleted file mode 100644 index aec80603..00000000 --- a/rsocket-transports/ktor-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.ktor.tcp - -import kotlinx.coroutines.* - -internal actual val defaultDispatcher: CoroutineDispatcher get() = Dispatchers.IO diff --git a/rsocket-transports/ktor-tcp/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt b/rsocket-transports/ktor-tcp/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt deleted file mode 100644 index 5fa3a8e2..00000000 --- a/rsocket-transports/ktor-tcp/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.ktor.tcp - -import kotlinx.coroutines.* - -internal actual val defaultDispatcher: CoroutineDispatcher get() = Dispatchers.Unconfined diff --git a/rsocket-transports/ktor-websocket-client/api/rsocket-transport-ktor-websocket-client.api b/rsocket-transports/ktor-websocket-client/api/rsocket-transport-ktor-websocket-client.api index 9aaa4d11..1761b00d 100644 --- a/rsocket-transports/ktor-websocket-client/api/rsocket-transport-ktor-websocket-client.api +++ b/rsocket-transports/ktor-websocket-client/api/rsocket-transport-ktor-websocket-client.api @@ -1,3 +1,24 @@ +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport$Factory; + public abstract fun target (Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/Integer;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/Integer;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketClientTarget; +} + +public final class io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun httpEngine (Lio/ktor/client/engine/HttpClientEngine;Lkotlin/jvm/functions/Function1;)V + public abstract fun httpEngine (Lio/ktor/client/engine/HttpClientEngineFactory;Lkotlin/jvm/functions/Function1;)V + public abstract fun httpEngine (Lkotlin/jvm/functions/Function1;)V + public static synthetic fun httpEngine$default (Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransportBuilder;Lio/ktor/client/engine/HttpClientEngine;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun httpEngine$default (Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransportBuilder;Lio/ktor/client/engine/HttpClientEngineFactory;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public abstract fun webSocketsConfig (Lkotlin/jvm/functions/Function1;)V +} + public final class io/rsocket/kotlin/transport/ktor/websocket/client/WebSocketClientTransportKt { public static final fun WebSocketClientTransport (Lio/ktor/client/engine/HttpClientEngineFactory;Ljava/lang/String;Ljava/lang/Integer;Ljava/lang/String;ZLkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport; public static final fun WebSocketClientTransport (Lio/ktor/client/engine/HttpClientEngineFactory;Ljava/lang/String;ZLkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport; diff --git a/rsocket-transports/ktor-websocket-client/build.gradle.kts b/rsocket-transports/ktor-websocket-client/build.gradle.kts index 86ee3265..f9bcbf6a 100644 --- a/rsocket-transports/ktor-websocket-client/build.gradle.kts +++ b/rsocket-transports/ktor-websocket-client/build.gradle.kts @@ -30,6 +30,7 @@ kotlin { sourceSets { commonMain.dependencies { implementation(projects.rsocketTransportKtorWebsocketInternal) + implementation(projects.rsocketInternalIo) api(projects.rsocketCore) api(libs.ktor.client.core) api(libs.ktor.client.websockets) diff --git a/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt b/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt new file mode 100644 index 00000000..cc91d10d --- /dev/null +++ b/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt @@ -0,0 +1,183 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.websocket.client + +import io.ktor.client.* +import io.ktor.client.engine.* +import io.ktor.client.plugins.websocket.* +import io.ktor.client.request.* +import io.ktor.http.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.ktor.websocket.internal.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorWebSocketClientTransport : RSocketTransport { + public fun target(request: HttpRequestBuilder.() -> Unit): RSocketClientTarget + public fun target(urlString: String, request: HttpRequestBuilder.() -> Unit = {}): RSocketClientTarget + + public fun target( + method: HttpMethod = HttpMethod.Get, + host: String? = null, + port: Int? = null, + path: String? = null, + request: HttpRequestBuilder.() -> Unit = {}, + ): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::KtorWebSocketClientTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorWebSocketClientTransportBuilder : RSocketTransportBuilder { + public fun httpEngine(configure: HttpClientEngineConfig.() -> Unit) + public fun httpEngine(engine: HttpClientEngine, configure: HttpClientEngineConfig.() -> Unit = {}) + public fun httpEngine(factory: HttpClientEngineFactory, configure: T.() -> Unit = {}) + + public fun webSocketsConfig(block: WebSockets.Config.() -> Unit) +} + +private class KtorWebSocketClientTransportBuilderImpl : KtorWebSocketClientTransportBuilder { + private var httpClientFactory: HttpClientFactory = HttpClientFactory.Default + private var webSocketsConfig: WebSockets.Config.() -> Unit = {} + + override fun httpEngine(configure: HttpClientEngineConfig.() -> Unit) { + this.httpClientFactory = HttpClientFactory.FromConfiguration(configure) + } + + override fun httpEngine(engine: HttpClientEngine, configure: HttpClientEngineConfig.() -> Unit) { + this.httpClientFactory = HttpClientFactory.FromEngine(engine, configure) + } + + override fun httpEngine(factory: HttpClientEngineFactory, configure: T.() -> Unit) { + this.httpClientFactory = HttpClientFactory.FromFactory(factory, configure) + } + + override fun webSocketsConfig(block: WebSockets.Config.() -> Unit) { + this.webSocketsConfig = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): KtorWebSocketClientTransport { + val httpClient = httpClientFactory.createHttpClient { + install(WebSockets, webSocketsConfig) + } + // only dispatcher of a client is used - it looks like it's Dispatchers.IO now + val newContext = context.supervisorContext() + (httpClient.coroutineContext[ContinuationInterceptor] ?: EmptyCoroutineContext) + val newJob = newContext.job + val httpClientJob = httpClient.coroutineContext.job + + httpClientJob.invokeOnCompletion { newJob.cancel("HttpClient closed", it) } + newJob.invokeOnCompletion { httpClientJob.cancel("KtorWebSocketClientTransport closed", it) } + + return KtorWebSocketClientTransportImpl( + coroutineContext = newContext, + httpClient = httpClient, + ) + } +} + +private class KtorWebSocketClientTransportImpl( + override val coroutineContext: CoroutineContext, + private val httpClient: HttpClient, +) : KtorWebSocketClientTransport { + override fun target(request: HttpRequestBuilder.() -> Unit): RSocketClientTarget = KtorWebSocketClientTargetImpl( + coroutineContext = coroutineContext, + httpClient = httpClient, + request = request + ) + + override fun target( + urlString: String, + request: HttpRequestBuilder.() -> Unit, + ): RSocketClientTarget = target( + method = HttpMethod.Get, host = null, port = null, path = null, + request = { + url.protocol = URLProtocol.WS + url.port = port + + url.takeFrom(urlString) + request() + }, + ) + + override fun target( + method: HttpMethod, + host: String?, + port: Int?, + path: String?, + request: HttpRequestBuilder.() -> Unit, + ): RSocketClientTarget = target { + this.method = method + url("ws", host, port, path) + request() + } +} + +@OptIn(RSocketTransportApi::class) +private class KtorWebSocketClientTargetImpl( + override val coroutineContext: CoroutineContext, + private val httpClient: HttpClient, + private val request: HttpRequestBuilder.() -> Unit, +) : RSocketClientTarget { + + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + httpClient.webSocket(request) { + handler.handleKtorWebSocketConnection(this) + } + } +} + +private sealed class HttpClientFactory { + abstract fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient + + object Default : HttpClientFactory() { + override fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient = HttpClient(block) + } + + class FromConfiguration( + private val configure: HttpClientEngineConfig.() -> Unit, + ) : HttpClientFactory() { + override fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient = HttpClient { + engine(configure) + block() + } + } + + class FromEngine( + private val engine: HttpClientEngine, + private val configure: HttpClientEngineConfig.() -> Unit, + ) : HttpClientFactory() { + override fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient = HttpClient(engine) { + engine(configure) + block() + } + } + + class FromFactory( + private val factory: HttpClientEngineFactory, + private val configure: T.() -> Unit, + ) : HttpClientFactory() { + override fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient = HttpClient(factory) { + engine(configure) + block() + } + } +} diff --git a/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api b/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api index 9ca3776f..8607860c 100644 --- a/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api +++ b/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api @@ -1,3 +1,7 @@ +public final class io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnectionKt { + public static final fun handleKtorWebSocketConnection (Lio/rsocket/kotlin/transport/RSocketConnectionHandler;Lio/ktor/websocket/WebSocketSession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/rsocket/kotlin/transport/ktor/websocket/internal/WebSocketConnection : io/rsocket/kotlin/Connection, kotlinx/coroutines/CoroutineScope { public fun (Lio/ktor/websocket/WebSocketSession;)V public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext; diff --git a/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt b/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt new file mode 100644 index 00000000..05e351ae --- /dev/null +++ b/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt @@ -0,0 +1,63 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.websocket.internal + +import io.ktor.utils.io.core.* +import io.ktor.websocket.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* + +@RSocketTransportApi +public suspend fun RSocketConnectionHandler.handleKtorWebSocketConnection(webSocketSession: WebSocketSession): Unit = coroutineScope { + val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) + + val senderJob = launch { + while (true) webSocketSession.send(outboundQueue.dequeueFrame()?.readBytes() ?: break) + }.onCompletion { outboundQueue.cancel() } + + try { + handleConnection(KtorWebSocketConnection(outboundQueue, webSocketSession.incoming)) + } finally { + webSocketSession.incoming.cancel() + outboundQueue.close() + withContext(NonCancellable) { + senderJob.join() // await all frames sent + webSocketSession.close() + webSocketSession.coroutineContext.job.join() + } + } +} + +@RSocketTransportApi +private class KtorWebSocketConnection( + private val outboundQueue: PrioritizationFrameQueue, + private val inbound: ReceiveChannel, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend + + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + val frame = inbound.receiveCatching().getOrNull() ?: return null + return ByteReadPacket(frame.data) + } +} diff --git a/rsocket-transports/ktor-websocket-server/api/rsocket-transport-ktor-websocket-server.api b/rsocket-transports/ktor-websocket-server/api/rsocket-transport-ktor-websocket-server.api index 4c9b6cde..7dece198 100644 --- a/rsocket-transports/ktor-websocket-server/api/rsocket-transport-ktor-websocket-server.api +++ b/rsocket-transports/ktor-websocket-server/api/rsocket-transport-ktor-websocket-server.api @@ -1,3 +1,30 @@ +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getConnectors ()Ljava/util/List; + public abstract fun getPath ()Ljava/lang/String; + public abstract fun getProtocol ()Ljava/lang/String; +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport$Factory; + public abstract fun target (Lio/ktor/server/engine/EngineConnectorConfig;Ljava/lang/String;Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/lang/String;ILjava/lang/String;Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/util/List;Ljava/lang/String;Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport;Lio/ktor/server/engine/EngineConnectorConfig;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport;Ljava/lang/String;ILjava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; +} + +public final class io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun httpEngine (Lio/ktor/server/engine/ApplicationEngineFactory;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun httpEngine$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransportBuilder;Lio/ktor/server/engine/ApplicationEngineFactory;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public abstract fun webSocketsConfig (Lkotlin/jvm/functions/Function1;)V +} + public final class io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransportKt { public static final fun WebSocketServerTransport (Lio/ktor/server/engine/ApplicationEngineFactory;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ServerTransport; public static final fun WebSocketServerTransport (Lio/ktor/server/engine/ApplicationEngineFactory;[Lio/ktor/server/engine/EngineConnectorConfig;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ServerTransport; diff --git a/rsocket-transports/ktor-websocket-server/build.gradle.kts b/rsocket-transports/ktor-websocket-server/build.gradle.kts index a5868053..157ab04f 100644 --- a/rsocket-transports/ktor-websocket-server/build.gradle.kts +++ b/rsocket-transports/ktor-websocket-server/build.gradle.kts @@ -29,6 +29,7 @@ kotlin { sourceSets { commonMain.dependencies { implementation(projects.rsocketTransportKtorWebsocketInternal) + implementation(projects.rsocketInternalIo) api(projects.rsocketCore) api(libs.ktor.server.host.common) api(libs.ktor.server.websockets) diff --git a/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt new file mode 100644 index 00000000..c78bc569 --- /dev/null +++ b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt @@ -0,0 +1,214 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.websocket.server + +import io.ktor.server.application.* +import io.ktor.server.engine.* +import io.ktor.server.routing.* +import io.ktor.server.websocket.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.ktor.websocket.internal.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorWebSocketServerInstance : RSocketServerInstance { + public val connectors: List + public val path: String + public val protocol: String? +} + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorWebSocketServerTransport : RSocketTransport { + + public fun target( + host: String = "0.0.0.0", + port: Int = 80, + path: String = "", + protocol: String? = null, + ): RSocketServerTarget + + public fun target( + path: String = "", + protocol: String? = null, + connectorBuilder: EngineConnectorBuilder.() -> Unit, + ): RSocketServerTarget + + public fun target( + connector: EngineConnectorConfig, + path: String = "", + protocol: String? = null, + ): RSocketServerTarget + + public fun target( + connectors: List, + path: String = "", + protocol: String? = null, + ): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory(::KtorWebSocketServerTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface KtorWebSocketServerTransportBuilder : RSocketTransportBuilder { + public fun httpEngine( + factory: ApplicationEngineFactory, + configure: T.() -> Unit = {}, + ) + + public fun webSocketsConfig(block: WebSockets.WebSocketOptions.() -> Unit) +} + +private class KtorWebSocketServerTransportBuilderImpl : KtorWebSocketServerTransportBuilder { + private var httpServerFactory: HttpServerFactory<*, *>? = null + private var webSocketsConfig: WebSockets.WebSocketOptions.() -> Unit = {} + + override fun httpEngine( + factory: ApplicationEngineFactory, + configure: T.() -> Unit, + ) { + this.httpServerFactory = HttpServerFactory(factory, configure) + } + + override fun webSocketsConfig(block: WebSockets.WebSocketOptions.() -> Unit) { + this.webSocketsConfig = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): KtorWebSocketServerTransport = KtorWebSocketServerTransportImpl( + // we always add IO - as it's the best choice here, server will use it's own dispatcher anyway + coroutineContext = context.supervisorContext() + Dispatchers.IO, + factory = requireNotNull(httpServerFactory) { "httpEngine is required" }, + webSocketsConfig = webSocketsConfig, + ) +} + +private class KtorWebSocketServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val factory: HttpServerFactory<*, *>, + private val webSocketsConfig: WebSockets.WebSocketOptions.() -> Unit, +) : KtorWebSocketServerTransport { + override fun target( + connectors: List, + path: String, + protocol: String?, + ): RSocketServerTarget = KtorWebSocketServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + factory = factory, + webSocketsConfig = webSocketsConfig, + connectors = connectors, + path = path, + protocol = protocol + ) + + override fun target( + host: String, + port: Int, + path: String, + protocol: String?, + ): RSocketServerTarget = target(path, protocol) { + this.host = host + this.port = port + } + + override fun target( + path: String, + protocol: String?, + connectorBuilder: EngineConnectorBuilder.() -> Unit, + ): RSocketServerTarget = target(EngineConnectorBuilder().apply(connectorBuilder), path, protocol) + + override fun target( + connector: EngineConnectorConfig, + path: String, + protocol: String?, + ): RSocketServerTarget = target(listOf(connector), path, protocol) +} + +@OptIn(RSocketTransportApi::class) +private class KtorWebSocketServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val factory: HttpServerFactory<*, *>, + private val webSocketsConfig: WebSockets.WebSocketOptions.() -> Unit, + private val connectors: List, + private val path: String, + private val protocol: String?, +) : RSocketServerTarget { + + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): KtorWebSocketServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val engine = createServerEngine(handler) + val resolvedConnectors = startServerEngine(engine) + + return KtorWebSocketServerInstanceImpl( + coroutineContext = engine.environment.parentCoroutineContext, + connectors = resolvedConnectors, + path = path, + protocol = protocol + ) + } + + // parentCoroutineContext is the context of server instance + @RSocketTransportApi + private fun createServerEngine(handler: RSocketConnectionHandler): ApplicationEngine = factory.createServer( + applicationEngineEnvironment { + val target = this@KtorWebSocketServerTargetImpl + parentCoroutineContext = target.coroutineContext.childContext() + connectors.addAll(target.connectors) + module { + install(WebSockets, webSocketsConfig) + routing { + webSocket(target.path, target.protocol) { + handler.handleKtorWebSocketConnection(this) + } + } + } + } + ) + + @OptIn(ExperimentalCoroutinesApi::class) + private suspend fun startServerEngine( + applicationEngine: ApplicationEngine, + ): List = launchCoroutine { cont -> + applicationEngine.start().stopServerOnCancellation() + cont.resume(applicationEngine.resolvedConnectors()) { + // will cause stopping of the server + applicationEngine.environment.parentCoroutineContext.job.cancel("Cancelled", it) + } + } +} + +private class KtorWebSocketServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val connectors: List, + override val path: String, + override val protocol: String?, +) : KtorWebSocketServerInstance + +private class HttpServerFactory( + private val factory: ApplicationEngineFactory, + private val configure: T.() -> Unit = {}, +) { + @RSocketTransportApi + fun createServer(environment: ApplicationEngineEnvironment): ApplicationEngine { + return factory.create(environment, configure) + } +} diff --git a/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransport.kt b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransport.kt index 24c2278c..28ccb4c7 100644 --- a/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransport.kt +++ b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransport.kt @@ -24,8 +24,6 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.ktor.websocket.internal.* -//TODO: will be reworked later with transport API rework - @Suppress("FunctionName") public fun WebSocketServerTransport( engineFactory: ApplicationEngineFactory, diff --git a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt index 850b3cd0..82b880da 100644 --- a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt +++ b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt @@ -20,3 +20,5 @@ import io.ktor.client.engine.cio.CIO as ClientCIO import io.ktor.server.cio.CIO as ServerCIO class CIOWebSocketTransportTest : WebSocketTransportTest(ClientCIO, ServerCIO) + +class CIOKtorWebSocketTransportTest : KtorWebSocketTransportTest(ClientCIO, ServerCIO) diff --git a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt index ffd26b12..088f6cda 100644 --- a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt +++ b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt @@ -35,3 +35,22 @@ abstract class WebSocketTransportTest( ) } } + +abstract class KtorWebSocketTransportTest( + private val clientEngine: HttpClientEngineFactory<*>, + private val serverEngine: ApplicationEngineFactory<*, *>, +) : TransportTest() { + override suspend fun before() { + val server = startServer( + KtorWebSocketServerTransport(testContext) { + httpEngine(serverEngine) + }.target(port = 0) + ) + val port = server.connectors.single().port + client = connectClient( + KtorWebSocketClientTransport(testContext) { + httpEngine(clientEngine) + }.target(port = port) + ) + } +} diff --git a/rsocket-transports/ktor-websocket-server/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTests.kt b/rsocket-transports/ktor-websocket-server/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTests.kt index bb17b9a2..3afca1bf 100644 --- a/rsocket-transports/ktor-websocket-server/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTests.kt +++ b/rsocket-transports/ktor-websocket-server/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTests.kt @@ -27,3 +27,11 @@ class OkHttpClientWebSocketTransportTest : WebSocketTransportTest(OkHttp, Server class NettyServerWebSocketTransportTest : WebSocketTransportTest(ClientCIO, Netty) class JettyServerWebSocketTransportTest : WebSocketTransportTest(ClientCIO, Jetty) + + + +class OkHttpClientKtorWebSocketTransportTest : KtorWebSocketTransportTest(OkHttp, ServerCIO) + +class NettyServerKtorWebSocketTransportTest : KtorWebSocketTransportTest(ClientCIO, Netty) + +class JettyServerKtorWebSocketTransportTest : KtorWebSocketTransportTest(ClientCIO, Jetty) diff --git a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/FrameWithLengthAssembler.kt b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/FrameWithLengthAssembler.kt index b141081e..d91c45cc 100644 --- a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/FrameWithLengthAssembler.kt +++ b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/FrameWithLengthAssembler.kt @@ -24,10 +24,18 @@ internal fun ByteReadPacket.withLength(): ByteReadPacket = buildPacket { writePacket(this@withLength) } -internal class FrameWithLengthAssembler(private val onFrame: (frame: ByteReadPacket) -> Unit) { - private var expectedFrameLength = 0 //TODO atomic for native +internal class FrameWithLengthAssembler(private val onFrame: (frame: ByteReadPacket) -> Unit) : Closeable { + private var closed = false + private var expectedFrameLength = 0 private val packetBuilder: BytePacketBuilder = BytePacketBuilder() + + override fun close() { + packetBuilder.close() + closed = true + } + inline fun write(write: BytePacketBuilder.() -> Unit) { + if (closed) return packetBuilder.write() loop() } @@ -39,6 +47,7 @@ internal class FrameWithLengthAssembler(private val onFrame: (frame: ByteReadPac expectedFrameLength = it.readInt24() if (it.remaining >= expectedFrameLength) build(it) // if has length and frame } + packetBuilder.size < expectedFrameLength -> return // not enough bytes to read frame else -> withTemp { build(it) } // enough bytes to read frame } diff --git a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpClientTransport.kt b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpClientTransport.kt new file mode 100644 index 00000000..498dbc8a --- /dev/null +++ b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpClientTransport.kt @@ -0,0 +1,74 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.nodejs.tcp + +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.nodejs.tcp.internal.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@OptIn(RSocketTransportApi::class) +public sealed interface NodejsTcpClientTransport : RSocketTransport { + public fun target(host: String, port: Int): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::NodejsTcpClientTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NodejsTcpClientTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) +} + +private class NodejsTcpClientTransportBuilderImpl : NodejsTcpClientTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.Default + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NodejsTcpClientTransport = NodejsTcpClientTransportImpl( + coroutineContext = context.supervisorContext() + dispatcher, + ) +} + +private class NodejsTcpClientTransportImpl( + override val coroutineContext: CoroutineContext, +) : NodejsTcpClientTransport { + override fun target(host: String, port: Int): RSocketClientTarget = NodejsTcpClientTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + host = host, + port = port + ) +} + +@OptIn(RSocketTransportApi::class) +private class NodejsTcpClientTargetImpl( + override val coroutineContext: CoroutineContext, + private val host: String, + private val port: Int, +) : RSocketClientTarget { + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + val socket = connect(port, host) + handler.handleNodejsTcpConnection(socket) + } +} diff --git a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpConnection.kt b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpConnection.kt new file mode 100644 index 00000000..7846f0f4 --- /dev/null +++ b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpConnection.kt @@ -0,0 +1,85 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.nodejs.tcp + +import io.ktor.utils.io.core.* +import io.ktor.utils.io.js.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import io.rsocket.kotlin.transport.nodejs.tcp.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import org.khronos.webgl.* + +@RSocketTransportApi +internal suspend fun RSocketConnectionHandler.handleNodejsTcpConnection(socket: Socket): Unit = coroutineScope { + val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) + val inbound = channelForCloseable(Channel.UNLIMITED) + + val closed = CompletableDeferred() + val frameAssembler = FrameWithLengthAssembler { inbound.trySend(it) } + socket.on( + onData = { frameAssembler.write { writeFully(it.buffer) } }, + onError = { closed.completeExceptionally(it) }, + onClose = { + frameAssembler.close() + if (!it) closed.complete(Unit) + } + ) + + val writerJob = launch { + while (true) socket.writeFrame(outboundQueue.dequeueFrame() ?: break) + }.onCompletion { outboundQueue.cancel() } + + try { + handleConnection(NodejsTcpConnection(outboundQueue, inbound)) + } finally { + inbound.cancel() + outboundQueue.close() // will cause `writerJob` completion + // even if it was cancelled, we still need to close socket and await it closure + withContext(NonCancellable) { + writerJob.join() + // close socket + socket.destroy() + closed.join() + } + } +} + +@RSocketTransportApi +private class NodejsTcpConnection( + private val outboundQueue: PrioritizationFrameQueue, + private val inbound: ReceiveChannel, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return inbound.receiveCatching().getOrNull() + } +} + +private fun Socket.writeFrame(frame: ByteReadPacket) { + val packet = buildPacket { + writeInt24(frame.remaining.toInt()) + writePacket(frame) + } + write(Uint8Array(packet.readArrayBuffer())) +} diff --git a/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpServerTransport.kt b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpServerTransport.kt new file mode 100644 index 00000000..f30d16e0 --- /dev/null +++ b/rsocket-transports/nodejs-tcp/src/jsMain/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/NodejsTcpServerTransport.kt @@ -0,0 +1,108 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.nodejs.tcp + +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.nodejs.tcp.internal.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@OptIn(RSocketTransportApi::class) +public sealed interface NodejsTcpServerInstance : RSocketServerInstance { + public val host: String + public val port: Int +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NodejsTcpServerTransport : RSocketTransport { + public fun target(host: String, port: Int): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory({ NodejsTcpServerTransportBuilderImpl }) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NodejsTcpServerTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) +} + +private object NodejsTcpServerTransportBuilderImpl : NodejsTcpServerTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.Default + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NodejsTcpServerTransport = NodejsTcpServerTransportImpl( + coroutineContext = context.supervisorContext() + dispatcher, + ) +} + +private class NodejsTcpServerTransportImpl( + override val coroutineContext: CoroutineContext, +) : NodejsTcpServerTransport { + override fun target(host: String, port: Int): RSocketServerTarget = NodejsTcpServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + host = host, + port = port + ) +} + +@OptIn(RSocketTransportApi::class) +private class NodejsTcpServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val host: String, + private val port: Int, +) : RSocketServerTarget { + + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): NodejsTcpServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val serverJob = launch { + val handlerScope = CoroutineScope(coroutineContext.supervisorContext()) + val server = createServer(port, host, { + coroutineContext.job.cancel("Server closed") + }) { + handlerScope.launch { handler.handleNodejsTcpConnection(it) } + } + try { + awaitCancellation() + } finally { + suspendCoroutine { cont -> server.close { cont.resume(Unit) } } + } + } + + return NodejsTcpServerInstanceImpl( + coroutineContext = coroutineContext + serverJob, + host = host, + port = port + ) + } +} + +@RSocketTransportApi +private class NodejsTcpServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val host: String, + override val port: Int, +) : NodejsTcpServerInstance diff --git a/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt b/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt index 0fd9a497..c2fc9163 100644 --- a/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt +++ b/rsocket-transports/nodejs-tcp/src/jsTest/kotlin/io/rsocket/kotlin/transport/nodejs/tcp/TcpTransportTest.kt @@ -34,3 +34,11 @@ class TcpTransportTest : TransportTest() { server.close() } } + +class NodejsTcpTransportTest : TransportTest() { + override suspend fun before() { + val port = PortProvider.next() + startServer(NodejsTcpServerTransport(testContext).target("127.0.0.1", port)) + client = connectClient(NodejsTcpClientTransport(testContext).target("127.0.0.1", port)) + } +}