Skip to content

Commit

Permalink
Introduce WsRequestState trait to hide unsafe methods. (#471)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caparow authored Nov 9, 2023
1 parent 7fb5812 commit f840362
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRu
): Lifecycle[F[Throwable, _], WsRpcClientConnection[F]] = {
for {
client <- WsRpcDispatcherFactory.asyncHttpClient[F]
requestState <- Lifecycle.liftF(F.syncThrowable(new WsRequestState[F]))
requestState <- Lifecycle.liftF(F.syncThrowable(WsRequestState.create[F]))
listener <- Lifecycle.liftF(F.syncThrowable(createListener(muxer, contextProvider, requestState, dispatcherLogger(uri, logger))))
handler <- Lifecycle.liftF(F.syncThrowable(new WebSocketUpgradeHandler(List(listener).asJava)))
nettyWebSocket <- Lifecycle.make(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ object WsClientSession {
printer: Printer,
logger: LogIO2[F],
) extends WsClientSession[F, RequestCtx, ClientId] {
private val openingTime: ZonedDateTime = IzTime.utcNow
private val sessionId = WsSessionId(UUIDGen.getTimeUUID())
private val clientId = new AtomicReference[Option[ClientId]](None)
private val requestState = new WsRequestState()
private val openingTime: ZonedDateTime = IzTime.utcNow
private val sessionId: WsSessionId = WsSessionId(UUIDGen.getTimeUUID())
private val clientId: AtomicReference[Option[ClientId]] = new AtomicReference[Option[ClientId]](None)
private val requestState: WsRequestState[F] = WsRequestState.create[F]

def id: WsClientId[ClientId] = WsClientId(sessionId, clientId.get())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import izumi.functional.bio.{Clock1, Clock2, F, IO2, Primitives2, Promise2, Temp
import izumi.fundamentals.platform.language.Quirks.*
import izumi.idealingua.runtime.rpc.*
import izumi.idealingua.runtime.rpc.http4s.ws.RawResponse.BadRawResponse
import izumi.idealingua.runtime.rpc.http4s.ws.WsRequestState.RequestHandler
import izumi.idealingua.runtime.rpc.http4s.ws.WsRpcHandler.WsClientResponder

import java.time.OffsetDateTime
Expand All @@ -15,94 +14,23 @@ import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
import scala.jdk.CollectionConverters.*

class WsRequestState[F[+_, +_]: IO2: Temporal2: Primitives2] extends WsClientResponder[F] {
// using custom clock to no allow to override it
private[this] val clock: Clock2[F] = Clock1.fromImpure(Clock1.Standard)
private[this] val requests: ConcurrentHashMap[RpcPacketId, IRTMethodId] = new ConcurrentHashMap[RpcPacketId, IRTMethodId]()
private[this] val responses: ConcurrentHashMap[RpcPacketId, RequestHandler[F]] = new ConcurrentHashMap[RpcPacketId, RequestHandler[F]]()

def requestAndAwait[A](id: RpcPacketId, methodId: Option[IRTMethodId], timeout: FiniteDuration)(request: => F[Throwable, A]): F[Throwable, Option[RawResponse]] = {
(for {
handler <- registerRequest(id, methodId, timeout)
// request should be performed after handler created
_ <- request
res <- handler.promise.await.timeout(timeout)
} yield res).guarantee(forget(id))
}

def registerRequest(id: RpcPacketId, methodId: Option[IRTMethodId], timeout: FiniteDuration): F[Nothing, RequestHandler[F]] = {
for {
now <- clock.nowOffset()
_ <- forgetExpired(now)
promise <- F.mkPromise[Nothing, RawResponse]
ttl = timeout * 3
handler = RequestHandler(id, promise, ttl, now)
_ <- F.sync(responses.put(id, handler))
_ <- F.traverse(methodId)(m => F.sync(requests.put(id, m)))
} yield handler
}

def awaitResponse(id: RpcPacketId, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] = {
F.fromOption(new IRTMissingHandlerException(s"Can not await for async response: $id. Missing handler.", null)) {
Option(responses.get(id))
}.flatMap(_.promise.await.timeout(timeout))
}

def forget(id: RpcPacketId): F[Nothing, Unit] = F.sync {
requests.remove(id)
responses.remove(id).discard()
}

def clear(): F[Nothing, Unit] = {
for {
_ <- F.sync(requests.clear())
_ <- F.traverse(responses.values().asScala)(h => h.promise.succeed(BadRawResponse(None)))
_ <- F.sync(responses.clear())
} yield ()
}

def responseWith(packetId: RpcPacketId, response: RawResponse): F[Throwable, Unit] = {
F.sync(Option(responses.get(packetId))).flatMap {
case Some(handler) => handler.promise.succeed(response).void
case None => F.unit
}
}
trait WsRequestState[F[_, _]] extends WsClientResponder[F] {
def requestAndAwait[A](
id: RpcPacketId,
methodId: Option[IRTMethodId],
timeout: FiniteDuration,
)(request: => F[Throwable, A]
): F[Throwable, Option[RawResponse]]

def responseWithData(packetId: RpcPacketId, data: Json): F[Throwable, Unit] = {
for {
method <- F.fromOption(new IRTMissingHandlerException(s"Cannot handle response for async request $packetId: no service handler", data)) {
Option(requests.get(packetId))
}
_ <- responseWith(packetId, RawResponse.GoodRawResponse(data, method))
} yield ()
}
def responseWith(packetId: RpcPacketId, response: RawResponse): F[Throwable, Unit]
def responseWithData(packetId: RpcPacketId, data: Json): F[Throwable, Unit]

private[this] def forgetExpired(now: OffsetDateTime): F[Nothing, Unit] = {
for {
removed <- F.sync {
val removed = mutable.ArrayBuffer.empty[RequestHandler[F]]
// it should be synchronized on node remove
responses.values().removeIf {
handler =>
val isExpired = handler.expired(now)
if (isExpired) removed.append(handler)
isExpired
}
removed.toList
}
_ <- F.traverse(removed) {
handler =>
requests.remove(handler.id)
handler.promise.poll.flatMap {
case Some(_) => F.unit
case None => handler.promise.succeed(BadRawResponse(Some(Json.obj("error" -> Json.fromString(s"Request expired within ${handler.ttl}.")))))
}
}
} yield ()
}
def clear(): F[Nothing, Unit]
}

object WsRequestState {
def create[F[+_, +_]: IO2: Temporal2: Primitives2]: WsRequestState[F] = new Default[F]

final case class RequestHandler[F[+_, +_]](
id: RpcPacketId,
promise: Promise2[F, Nothing, RawResponse],
Expand All @@ -113,4 +41,96 @@ object WsRequestState {
ChronoUnit.MILLIS.between(firedAt, now) >= ttl.toMillis
}
}

class Default[F[+_, +_]: IO2: Temporal2: Primitives2] extends WsRequestState[F] {
// using custom clock to no allow to override it
private[this] val clock: Clock2[F] = Clock1.fromImpure(Clock1.Standard)
private[this] val requests: ConcurrentHashMap[RpcPacketId, IRTMethodId] = new ConcurrentHashMap[RpcPacketId, IRTMethodId]()
private[this] val responses: ConcurrentHashMap[RpcPacketId, RequestHandler[F]] = new ConcurrentHashMap[RpcPacketId, RequestHandler[F]]()

override def requestAndAwait[A](
id: RpcPacketId,
methodId: Option[IRTMethodId],
timeout: FiniteDuration,
)(request: => F[Throwable, A]
): F[Throwable, Option[RawResponse]] = {
(for {
handler <- registerRequest(id, methodId, timeout)
// request should be performed after handler created
_ <- request
res <- handler.promise.await.timeout(timeout)
} yield res).guarantee(forget(id))
}

override def responseWith(packetId: RpcPacketId, response: RawResponse): F[Throwable, Unit] = {
F.sync(Option(responses.get(packetId))).flatMap {
case Some(handler) => handler.promise.succeed(response).void
case None => F.unit
}
}

override def responseWithData(packetId: RpcPacketId, data: Json): F[Throwable, Unit] = {
for {
method <- F.fromOption(new IRTMissingHandlerException(s"Cannot handle response for async request $packetId: no service handler", data)) {
Option(requests.get(packetId))
}
_ <- responseWith(packetId, RawResponse.GoodRawResponse(data, method))
} yield ()
}

override def clear(): F[Nothing, Unit] = {
for {
_ <- F.sync(requests.clear())
_ <- F.traverse(responses.values().asScala)(h => h.promise.succeed(BadRawResponse(None)))
_ <- F.sync(responses.clear())
} yield ()
}

def registerRequest(id: RpcPacketId, methodId: Option[IRTMethodId], timeout: FiniteDuration): F[Nothing, RequestHandler[F]] = {
for {
now <- clock.nowOffset()
_ <- forgetExpired(now)
promise <- F.mkPromise[Nothing, RawResponse]
ttl = timeout * 3
handler = RequestHandler(id, promise, ttl, now)
_ <- F.sync(responses.put(id, handler))
_ <- F.traverse(methodId)(m => F.sync(requests.put(id, m)))
} yield handler
}

def awaitResponse(id: RpcPacketId, timeout: FiniteDuration): F[Throwable, Option[RawResponse]] = {
F.fromOption(new IRTMissingHandlerException(s"Can not await for async response: $id. Missing handler.", null)) {
Option(responses.get(id))
}.flatMap(_.promise.await.timeout(timeout))
}

private[this] def forget(id: RpcPacketId): F[Nothing, Unit] = F.sync {
requests.remove(id)
responses.remove(id).discard()
}

private[this] def forgetExpired(now: OffsetDateTime): F[Nothing, Unit] = {
for {
removed <- F.sync {
val removed = mutable.ArrayBuffer.empty[RequestHandler[F]]
// it should be synchronized on node remove
responses.values().removeIf {
handler =>
val isExpired = handler.expired(now)
if (isExpired) removed.append(handler)
isExpired
}
removed.toList
}
_ <- F.traverse(removed) {
handler =>
requests.remove(handler.id)
handler.promise.poll.flatMap {
case Some(_) => F.unit
case None => handler.promise.succeed(BadRawResponse(Some(Json.obj("error" -> Json.fromString(s"Request expired within ${handler.ttl}.")))))
}
}
} yield ()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Http4sTransportTest extends AnyWordSpec {

"support request state clean" in {
executeIO {
val rs = new WsRequestState[IO]()
val rs = new WsRequestState.Default[IO]()
for {
id1 <- ZIO.succeed(RpcPacketId.random())
id2 <- ZIO.succeed(RpcPacketId.random())
Expand Down

0 comments on commit f840362

Please sign in to comment.