diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcherFactory.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcherFactory.scala index 3c819570..170a3350 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcherFactory.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/clients/WsRpcDispatcherFactory.scala @@ -34,7 +34,7 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu ): Lifecycle[F[Throwable, _], WsRpcClientConnection[F]] = { for { client <- WsRpcDispatcherFactory.asyncHttpClient[F] - requestState <- Lifecycle.liftF(F.syncThrowable(new WsRequestState[F])) + requestState <- Lifecycle.liftF(F.syncThrowable(WsRequestState.create[F])) listener <- Lifecycle.liftF(F.syncThrowable(createListener(muxer, contextProvider, requestState, dispatcherLogger(uri, logger)))) handler <- Lifecycle.liftF(F.syncThrowable(new WebSocketUpgradeHandler(List(listener).asJava))) nettyWebSocket <- Lifecycle.make( diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala index d5c74583..636c41ef 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsClientSession.scala @@ -38,10 +38,10 @@ object WsClientSession { printer: Printer, logger: LogIO2[F], ) extends WsClientSession[F, RequestCtx, ClientId] { - private val openingTime: ZonedDateTime = IzTime.utcNow - private val sessionId = WsSessionId(UUIDGen.getTimeUUID()) - private val clientId = new AtomicReference[Option[ClientId]](None) - private val requestState = new WsRequestState() + private val openingTime: ZonedDateTime = IzTime.utcNow + private val sessionId: WsSessionId = WsSessionId(UUIDGen.getTimeUUID()) + private val clientId: AtomicReference[Option[ClientId]] = new AtomicReference[Option[ClientId]](None) + private val requestState: WsRequestState[F] = WsRequestState.create[F] def id: WsClientId[ClientId] = WsClientId(sessionId, clientId.get()) diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRequestState.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRequestState.scala index 73eca291..33d2352d 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRequestState.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/main/scala/izumi/idealingua/runtime/rpc/http4s/ws/WsRequestState.scala @@ -5,7 +5,6 @@ import izumi.functional.bio.{Clock1, Clock2, F, IO2, Primitives2, Promise2, Temp import izumi.fundamentals.platform.language.Quirks.* import izumi.idealingua.runtime.rpc.* import izumi.idealingua.runtime.rpc.http4s.ws.RawResponse.BadRawResponse -import izumi.idealingua.runtime.rpc.http4s.ws.WsRequestState.RequestHandler import izumi.idealingua.runtime.rpc.http4s.ws.WsRpcHandler.WsClientResponder import java.time.OffsetDateTime @@ -15,94 +14,23 @@ import scala.collection.mutable import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters.* -class WsRequestState[F[+_, +_]: IO2: Temporal2: Primitives2] extends WsClientResponder[F] { - // using custom clock to no allow to override it - private[this] val clock: Clock2[F] = Clock1.fromImpure(Clock1.Standard) - private[this] val requests: ConcurrentHashMap[RpcPacketId, IRTMethodId] = new ConcurrentHashMap[RpcPacketId, IRTMethodId]() - private[this] val responses: ConcurrentHashMap[RpcPacketId, RequestHandler[F]] = new ConcurrentHashMap[RpcPacketId, RequestHandler[F]]() - - def requestAndAwait[A](id: RpcPacketId, methodId: Option[IRTMethodId], timeout: FiniteDuration)(request: => F[Throwable, A]): F[Throwable, Option[RawResponse]] = { - (for { - handler <- registerRequest(id, methodId, timeout) - // request should be performed after handler created - _ <- request - res <- handler.promise.await.timeout(timeout) - } yield res).guarantee(forget(id)) - } - - def registerRequest(id: RpcPacketId, methodId: Option[IRTMethodId], timeout: FiniteDuration): F[Nothing, RequestHandler[F]] = { - for { - now <- clock.nowOffset() - _ <- forgetExpired(now) - promise <- F.mkPromise[Nothing, RawResponse] - ttl = timeout * 3 - handler = RequestHandler(id, promise, ttl, now) - _ <- F.sync(responses.put(id, handler)) - _ <- F.traverse(methodId)(m => F.sync(requests.put(id, m))) - } yield handler - } - - def awaitResponse(id: RpcPacketId, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] = { - F.fromOption(new IRTMissingHandlerException(s"Can not await for async response: $id. Missing handler.", null)) { - Option(responses.get(id)) - }.flatMap(_.promise.await.timeout(timeout)) - } - - def forget(id: RpcPacketId): F[Nothing, Unit] = F.sync { - requests.remove(id) - responses.remove(id).discard() - } - - def clear(): F[Nothing, Unit] = { - for { - _ <- F.sync(requests.clear()) - _ <- F.traverse(responses.values().asScala)(h => h.promise.succeed(BadRawResponse(None))) - _ <- F.sync(responses.clear()) - } yield () - } - - def responseWith(packetId: RpcPacketId, response: RawResponse): F[Throwable, Unit] = { - F.sync(Option(responses.get(packetId))).flatMap { - case Some(handler) => handler.promise.succeed(response).void - case None => F.unit - } - } +trait WsRequestState[F[_, _]] extends WsClientResponder[F] { + def requestAndAwait[A]( + id: RpcPacketId, + methodId: Option[IRTMethodId], + timeout: FiniteDuration, + )(request: => F[Throwable, A] + ): F[Throwable, Option[RawResponse]] - def responseWithData(packetId: RpcPacketId, data: Json): F[Throwable, Unit] = { - for { - method <- F.fromOption(new IRTMissingHandlerException(s"Cannot handle response for async request $packetId: no service handler", data)) { - Option(requests.get(packetId)) - } - _ <- responseWith(packetId, RawResponse.GoodRawResponse(data, method)) - } yield () - } + def responseWith(packetId: RpcPacketId, response: RawResponse): F[Throwable, Unit] + def responseWithData(packetId: RpcPacketId, data: Json): F[Throwable, Unit] - private[this] def forgetExpired(now: OffsetDateTime): F[Nothing, Unit] = { - for { - removed <- F.sync { - val removed = mutable.ArrayBuffer.empty[RequestHandler[F]] - // it should be synchronized on node remove - responses.values().removeIf { - handler => - val isExpired = handler.expired(now) - if (isExpired) removed.append(handler) - isExpired - } - removed.toList - } - _ <- F.traverse(removed) { - handler => - requests.remove(handler.id) - handler.promise.poll.flatMap { - case Some(_) => F.unit - case None => handler.promise.succeed(BadRawResponse(Some(Json.obj("error" -> Json.fromString(s"Request expired within ${handler.ttl}."))))) - } - } - } yield () - } + def clear(): F[Nothing, Unit] } object WsRequestState { + def create[F[+_, +_]: IO2: Temporal2: Primitives2]: WsRequestState[F] = new Default[F] + final case class RequestHandler[F[+_, +_]]( id: RpcPacketId, promise: Promise2[F, Nothing, RawResponse], @@ -113,4 +41,96 @@ object WsRequestState { ChronoUnit.MILLIS.between(firedAt, now) >= ttl.toMillis } } + + class Default[F[+_, +_]: IO2: Temporal2: Primitives2] extends WsRequestState[F] { + // using custom clock to no allow to override it + private[this] val clock: Clock2[F] = Clock1.fromImpure(Clock1.Standard) + private[this] val requests: ConcurrentHashMap[RpcPacketId, IRTMethodId] = new ConcurrentHashMap[RpcPacketId, IRTMethodId]() + private[this] val responses: ConcurrentHashMap[RpcPacketId, RequestHandler[F]] = new ConcurrentHashMap[RpcPacketId, RequestHandler[F]]() + + override def requestAndAwait[A]( + id: RpcPacketId, + methodId: Option[IRTMethodId], + timeout: FiniteDuration, + )(request: => F[Throwable, A] + ): F[Throwable, Option[RawResponse]] = { + (for { + handler <- registerRequest(id, methodId, timeout) + // request should be performed after handler created + _ <- request + res <- handler.promise.await.timeout(timeout) + } yield res).guarantee(forget(id)) + } + + override def responseWith(packetId: RpcPacketId, response: RawResponse): F[Throwable, Unit] = { + F.sync(Option(responses.get(packetId))).flatMap { + case Some(handler) => handler.promise.succeed(response).void + case None => F.unit + } + } + + override def responseWithData(packetId: RpcPacketId, data: Json): F[Throwable, Unit] = { + for { + method <- F.fromOption(new IRTMissingHandlerException(s"Cannot handle response for async request $packetId: no service handler", data)) { + Option(requests.get(packetId)) + } + _ <- responseWith(packetId, RawResponse.GoodRawResponse(data, method)) + } yield () + } + + override def clear(): F[Nothing, Unit] = { + for { + _ <- F.sync(requests.clear()) + _ <- F.traverse(responses.values().asScala)(h => h.promise.succeed(BadRawResponse(None))) + _ <- F.sync(responses.clear()) + } yield () + } + + def registerRequest(id: RpcPacketId, methodId: Option[IRTMethodId], timeout: FiniteDuration): F[Nothing, RequestHandler[F]] = { + for { + now <- clock.nowOffset() + _ <- forgetExpired(now) + promise <- F.mkPromise[Nothing, RawResponse] + ttl = timeout * 3 + handler = RequestHandler(id, promise, ttl, now) + _ <- F.sync(responses.put(id, handler)) + _ <- F.traverse(methodId)(m => F.sync(requests.put(id, m))) + } yield handler + } + + def awaitResponse(id: RpcPacketId, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] = { + F.fromOption(new IRTMissingHandlerException(s"Can not await for async response: $id. Missing handler.", null)) { + Option(responses.get(id)) + }.flatMap(_.promise.await.timeout(timeout)) + } + + private[this] def forget(id: RpcPacketId): F[Nothing, Unit] = F.sync { + requests.remove(id) + responses.remove(id).discard() + } + + private[this] def forgetExpired(now: OffsetDateTime): F[Nothing, Unit] = { + for { + removed <- F.sync { + val removed = mutable.ArrayBuffer.empty[RequestHandler[F]] + // it should be synchronized on node remove + responses.values().removeIf { + handler => + val isExpired = handler.expired(now) + if (isExpired) removed.append(handler) + isExpired + } + removed.toList + } + _ <- F.traverse(removed) { + handler => + requests.remove(handler.id) + handler.promise.poll.flatMap { + case Some(_) => F.unit + case None => handler.promise.succeed(BadRawResponse(Some(Json.obj("error" -> Json.fromString(s"Request expired within ${handler.ttl}."))))) + } + } + } yield () + } + } } diff --git a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/Http4sTransportTest.scala b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/Http4sTransportTest.scala index 91862311..c0d1f0ca 100644 --- a/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/Http4sTransportTest.scala +++ b/idealingua-v1/idealingua-v1-runtime-rpc-http4s/src/test/scala/izumi/idealingua/runtime/rpc/http4s/Http4sTransportTest.scala @@ -83,7 +83,7 @@ class Http4sTransportTest extends AnyWordSpec { "support request state clean" in { executeIO { - val rs = new WsRequestState[IO]() + val rs = new WsRequestState.Default[IO]() for { id1 <- ZIO.succeed(RpcPacketId.random()) id2 <- ZIO.succeed(RpcPacketId.random())