From 53d93ee6abc7b6ec9c689465bf431fff37cfd2e2 Mon Sep 17 00:00:00 2001 From: "Ilya.Usov" Date: Thu, 5 Oct 2023 14:36:53 +0200 Subject: [PATCH] ensure counterpart is not called with terminated lifetime guarantee cancellation on the counterpart if lifetime was terminated during send --- .../com/jetbrains/rd/framework/impl/RdTask.kt | 15 +++-- .../rd/framework/test/cases/RdTaskTest.kt | 57 +++++++++++++++++-- rd-net/RdFramework/Tasks/RdCall.cs | 20 ++++--- rd-net/Test.RdFramework/RdTaskTest.cs | 53 +++++++++++++++++ 4 files changed, 128 insertions(+), 17 deletions(-) diff --git a/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/impl/RdTask.kt b/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/impl/RdTask.kt index f7157b690..30ece6630 100644 --- a/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/impl/RdTask.kt +++ b/rd-kt/rd-framework/src/main/kotlin/com/jetbrains/rd/framework/impl/RdTask.kt @@ -365,12 +365,15 @@ class RdCall(internal val requestSzr: ISerializer = Polymorphi val taskId = proto.identity.next(RdId.Null) val bindLifetime = bindLifetime - val task = CallSiteWiredRdTask(lifetime.intersect(bindLifetime), this, taskId, scheduler ?: proto.scheduler) - - proto.wire.send(rdid) { buffer -> - logSend.trace { "call `$location`::($rdid) send${sync.condstr {" SYNC"}} request '$taskId' : ${request.printToString()} " } - taskId.write(buffer) - requestSzr.write(ctx, buffer, request) + val taskLifetime = lifetime.intersect(bindLifetime) + + val task = CallSiteWiredRdTask(taskLifetime, this, taskId, scheduler ?: proto.scheduler) + taskLifetime.executeIfAlive { + proto.wire.send(rdid) { buffer -> + logSend.trace { "call `$location`::($rdid) send${sync.condstr {" SYNC"}} request '$taskId' : ${request.printToString()} " } + taskId.write(buffer) + requestSzr.write(ctx, buffer, request) + } } return task diff --git a/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/RdTaskTest.kt b/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/RdTaskTest.kt index 9f3c498f0..19c6cb927 100644 --- a/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/RdTaskTest.kt +++ b/rd-kt/rd-framework/src/test/kotlin/com/jetbrains/rd/framework/test/cases/RdTaskTest.kt @@ -3,13 +3,14 @@ package com.jetbrains.rd.framework.test.cases import com.jetbrains.rd.framework.RdTaskResult import com.jetbrains.rd.framework.base.static import com.jetbrains.rd.framework.impl.RdCall +import com.jetbrains.rd.framework.impl.RdTask import com.jetbrains.rd.framework.isFaulted import com.jetbrains.rd.framework.test.util.RdFrameworkTestBase +import com.jetbrains.rd.util.lifetime.Lifetime +import com.jetbrains.rd.util.lifetime.isAlive +import com.jetbrains.rd.util.lifetime.waitTermination import com.jetbrains.rd.util.reactive.valueOrThrow -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertFails -import kotlin.test.assertTrue +import kotlin.test.* class RdTaskTest : RdFrameworkTestBase() { @Test @@ -59,6 +60,54 @@ class RdTaskTest : RdFrameworkTestBase() { assertEquals("Cancelled", RdTaskResult.Cancelled().toString()) assertEquals("Fault :: com.jetbrains.rd.util.reactive.RdFault: error", RdTaskResult.Fault(Error("error")).toString()) } + + @Test + fun startWithTerminatedLifetime() { + val entity_id = 1 + + val client_entity = RdCall().static(entity_id) + val server_entity = RdCall().static(entity_id) + + clientProtocol.bindStatic(client_entity, "top") + serverProtocol.bindStatic(server_entity, "top") + + var called = false + server_entity.set { lf, value -> + called = true + error("Must not be reached") + } + + val task = client_entity.start(Lifetime.Terminated, 1) + val result = task.result.valueOrThrow + assertIs>(result) + assertFalse(called) + } + + + @Test + fun startWithTerminatingDuringSet() { + val entity_id = 1 + + val client_entity = RdCall().static(entity_id) + val server_entity = RdCall().static(entity_id) + + clientProtocol.bindStatic(client_entity, "top") + serverProtocol.bindStatic(server_entity, "top") + + val def = clientLifetime.createNested() + var callLifetime: Lifetime? = null + server_entity.set { lf, value -> + def.terminate(true) + + callLifetime = lf + RdTask.fromResult(value.toString()) + } + + val task = client_entity.start(def, 1) + val result = task.result.valueOrThrow + assertIs>(result) + assertFalse(callLifetime!!.isAlive) + } } open class A(open val a:Any) {} diff --git a/rd-net/RdFramework/Tasks/RdCall.cs b/rd-net/RdFramework/Tasks/RdCall.cs index cac238f13..a1e2ac3ba 100644 --- a/rd-net/RdFramework/Tasks/RdCall.cs +++ b/rd-net/RdFramework/Tasks/RdCall.cs @@ -199,15 +199,21 @@ private IRdTask StartInternal(Lifetime requestLifetime, TReq request, ISch return RunHandler(request, requestLifetime, moniker: this); var taskId = proto.Identities.Next(RdId.Nil); - var task = new WiredRdTask.CallSite(Lifetime.Intersect(requestLifetime, myBindLifetime), this, taskId, scheduler ?? proto.Scheduler); - - proto.Wire.Send(RdId, (writer) => + + var taskLifetime = Lifetime.Intersect(requestLifetime, myBindLifetime); + var task = new WiredRdTask.CallSite(taskLifetime, this, taskId, scheduler ?? proto.Scheduler); + + using var cookie = taskLifetime.UsingExecuteIfAlive(); + if (cookie.Succeed) { - SendTrace?.Log($"{task} :: send request: {request.PrintToString()}"); + proto.Wire.Send(RdId, (writer) => + { + SendTrace?.Log($"{task} :: send request: {request.PrintToString()}"); - taskId.Write(writer); - WriteRequestDelegate(serializationContext, writer, request); - }); + taskId.Write(writer); + WriteRequestDelegate(serializationContext, writer, request); + }); + } return task; } diff --git a/rd-net/Test.RdFramework/RdTaskTest.cs b/rd-net/Test.RdFramework/RdTaskTest.cs index 7a06de5b9..66f09802f 100644 --- a/rd-net/Test.RdFramework/RdTaskTest.cs +++ b/rd-net/Test.RdFramework/RdTaskTest.cs @@ -241,5 +241,58 @@ public void TestOverriddenHandlerScheduler() Assert.AreEqual("0", result.Value.Result); } + + [Test] + public void StartWithTerminatedLifetime() + { + ClientWire.AutoTransmitMode = true; + ServerWire.AutoTransmitMode = true; + var entity_id = 1; + + var serverEntity = BindToServer(LifetimeDefinition.Lifetime, NewRdCall(), ourKey); + var clientEntity = BindToClient(LifetimeDefinition.Lifetime, NewRdCall(), ourKey); + + var called = false; + serverEntity.Set((lf, value) => + { + called = true; + throw new InvalidOperationException("Must not be reached"); + }); + + var task = clientEntity.Start(Lifetime.Terminated, 1); + var result = task.Result.Value; + Assert.IsTrue(result.Status == RdTaskStatus.Canceled); + Assert.IsFalse(called); + } + + + [Test] + public void StartWithTerminatingDuringSet() + { + ClientWire.AutoTransmitMode = true; + ServerWire.AutoTransmitMode = true; + var entity_id = 1; + + var serverEntity = BindToServer(LifetimeDefinition.Lifetime, NewRdCall(), ourKey); + var clientEntity = BindToClient(LifetimeDefinition.Lifetime, NewRdCall(), ourKey); + + var def = TestLifetime.CreateNested(); + Lifetime callLifetime = default; + serverEntity.Set((lf, value) => + { + using (new LifetimeDefinition.AllowTerminationUnderExecutionCookie(Thread.CurrentThread)) + { + def.Terminate(); + } + + callLifetime = lf; + return RdTask.Successful(value.ToString()); + }); + + var task = clientEntity.Start(def.Lifetime, 1); + var result = task.Result.Value; + Assert.IsTrue(result.Status == RdTaskStatus.Canceled); + Assert.IsFalse(callLifetime.IsAlive); + } } } \ No newline at end of file