Skip to content

Commit

Permalink
defer cancellation until all replays have stopped
Browse files Browse the repository at this point in the history
  • Loading branch information
johanandren committed Sep 18, 2024
1 parent 361e8f9 commit 723b416
Showing 1 changed file with 43 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

package akka.projection.grpc.internal

import akka.Done

import scala.util.Failure
import scala.util.Success
import scala.util.Try
import scala.util.matching.Regex

import akka.NotUsed
import akka.actor.typed.scaladsl.LoggerOps
import akka.annotation.InternalApi
Expand All @@ -28,6 +29,7 @@ import akka.stream.Attributes
import akka.stream.BidiShape
import akka.stream.Inlet
import akka.stream.Outlet
import akka.stream.scaladsl.Keep
import akka.stream.scaladsl.Sink
import akka.stream.scaladsl.SinkQueueWithCancel
import akka.stream.scaladsl.Source
Expand All @@ -37,6 +39,8 @@ import akka.stream.stage.InHandler
import akka.stream.stage.OutHandler
import org.slf4j.LoggerFactory

import scala.concurrent.Future

/**
* INTERNAL API
*/
Expand Down Expand Up @@ -145,7 +149,8 @@ import org.slf4j.LoggerFactory
private case class ReplaySession(
fromSeqNr: Long,
filterAfterSeqNr: Long,
queue: SinkQueueWithCancel[EventEnvelope[Any]])
queue: SinkQueueWithCancel[EventEnvelope[Any]],
replayStreamCompletion: Future[Done])

}

Expand Down Expand Up @@ -356,11 +361,14 @@ import org.slf4j.LoggerFactory
replayInProgress(pid).copy(filterAfterSeqNr = replayPersistenceId.filterAfterSeqNr))
} else if (replayInProgress.size < replayParallelism) {
log.debugN("Stream [{}]: Starting replay of persistenceId [{}], from seqNr [{}]", logPrefix, pid, fromSeqNr)
val queue =
val (replayCompletion, queue) =
currentEventsByPersistenceId(pid, fromSeqNr)
.runWith(Sink.queue())(materializer)
replayInProgress =
replayInProgress.updated(pid, ReplaySession(fromSeqNr, replayPersistenceId.filterAfterSeqNr, queue))
.watchTermination()((_, done) => done)
.toMat(Sink.queue())(Keep.both)
.run()(materializer)
replayInProgress = replayInProgress.updated(
pid,
ReplaySession(fromSeqNr, replayPersistenceId.filterAfterSeqNr, queue, replayCompletion))
tryPullReplay(pid)
} else {
log.debugN("Stream [{}]: Queueing replay of persistenceId [{}], from seqNr [{}]", logPrefix, pid, fromSeqNr)
Expand Down Expand Up @@ -457,17 +465,36 @@ import org.slf4j.LoggerFactory
}
})

setHandler(outEnv, new OutHandler {
override def onPull(): Unit = {
log.trace("Stream [{}]: onPull outEnv", logPrefix)
pullInEnvOrReplay()
}
})
setHandler(
outEnv,
new OutHandler {
override def onPull(): Unit = {
log.trace("Stream [{}]: onPull outEnv", logPrefix)
pullInEnvOrReplay()
}

override def onDownstreamFinish(cause: Throwable): Unit = {
val runningSessions = replayInProgress.values.filterNot(_.replayStreamCompletion.isCompleted)
if (runningSessions.nonEmpty) {
// to avoid abrupt stage termination error logging,
// defer acting on cancel until any replays have completely cancelled
setKeepGoing(true)
val replayCompletedCallback = getAsyncCallback[Try[Done]] { _ =>
val stillRunning = replayInProgress.values.filterNot(_.replayStreamCompletion.isCompleted)
if (stillRunning.isEmpty) {
super.onDownstreamFinish(cause)
}
}.invoke _
runningSessions.foreach { runningSession =>
runningSession.queue.cancel()
runningSession.replayStreamCompletion.onComplete(replayCompletedCallback)(ExecutionContexts.parasitic)
}
} else {
super.onDownstreamFinish(cause)
}
}
})

override def postStop(): Unit = {
replayInProgress.values.foreach(_.queue.cancel())
replayInProgress = Map.empty
}
}

}

0 comments on commit 723b416

Please sign in to comment.