diff --git a/src/main/scala/zio/akka/cluster/sharding/Entity.scala b/src/main/scala/zio/akka/cluster/sharding/Entity.scala index 540c3e5..6186947 100644 --- a/src/main/scala/zio/akka/cluster/sharding/Entity.scala +++ b/src/main/scala/zio/akka/cluster/sharding/Entity.scala @@ -1,5 +1,6 @@ package zio.akka.cluster.sharding +import scala.concurrent.duration.Duration import zio.{ Ref, Task, UIO } trait Entity[State] { @@ -7,4 +8,6 @@ trait Entity[State] { def id: String def state: Ref[Option[State]] def stop: UIO[Unit] + def passivate: UIO[Unit] + def passivateAfter(duration: Duration): UIO[Unit] } diff --git a/src/main/scala/zio/akka/cluster/sharding/SetTimeout.scala b/src/main/scala/zio/akka/cluster/sharding/SetTimeout.scala new file mode 100644 index 0000000..07698ca --- /dev/null +++ b/src/main/scala/zio/akka/cluster/sharding/SetTimeout.scala @@ -0,0 +1,5 @@ +package zio.akka.cluster.sharding + +import scala.concurrent.duration.Duration + +case class SetTimeout(duration: Duration) diff --git a/src/main/scala/zio/akka/cluster/sharding/Sharding.scala b/src/main/scala/zio/akka/cluster/sharding/Sharding.scala index e23bf88..b8d5198 100644 --- a/src/main/scala/zio/akka/cluster/sharding/Sharding.scala +++ b/src/main/scala/zio/akka/cluster/sharding/Sharding.scala @@ -2,7 +2,7 @@ package zio.akka.cluster.sharding import scala.concurrent.duration._ import scala.reflect.ClassTag -import akka.actor.{ Actor, ActorContext, ActorRef, ActorSystem, PoisonPill, Props } +import akka.actor.{ Actor, ActorContext, ActorRef, ActorSystem, PoisonPill, Props, ReceiveTimeout } import akka.cluster.sharding.ShardRegion.Passivate import akka.cluster.sharding.{ ClusterSharding, ClusterShardingSettings } import akka.pattern.{ ask => askPattern } @@ -136,13 +136,19 @@ object Sharding { val ref: Ref[Option[State]] = rts.unsafeRun(Ref.make[Option[State]](None)) val actorContext: ActorContext = context val entity: Entity[State] = new Entity[State] { - override def id: String = context.self.path.name - override def state: Ref[Option[State]] = ref - override def stop: UIO[Unit] = UIO(actorContext.stop(self)) - override def replyToSender[R](msg: R): Task[Unit] = Task(context.sender() ! msg) + override def id: String = actorContext.self.path.name + override def state: Ref[Option[State]] = ref + override def stop: UIO[Unit] = UIO(actorContext.stop(self)) + override def passivate: UIO[Unit] = UIO(actorContext.parent ! Passivate(PoisonPill)) + override def passivateAfter(duration: Duration): UIO[Unit] = UIO(actorContext.self ! SetTimeout(duration)) + override def replyToSender[R](msg: R): Task[Unit] = Task(actorContext.sender() ! msg) } def receive: Receive = { + case SetTimeout(duration) => + actorContext.setReceiveTimeout(duration) + case ReceiveTimeout => + actorContext.parent ! Passivate(PoisonPill) case p: Passivate => actorContext.parent ! p case MessagePayload(msg) => diff --git a/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala b/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala index a4e9239..c36d73e 100644 --- a/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala +++ b/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala @@ -152,6 +152,29 @@ object ShardingSpec extends DefaultRunnableSpec { } yield res )(isNone).provideLayer(actorSystem) }, + testM("passivateAfter") { + assertM( + for { + p <- Promise.make[Nothing, Option[Unit]] + onMessage = (msg: String) => + msg match { + case "set" => ZIO.accessM[Entity[Unit]](_.state.set(Some(()))) + case "get" => ZIO.accessM[Entity[Unit]](_.state.get.flatMap(s => p.succeed(s).unit)) + case "timeout" => ZIO.accessM[Entity[Unit]](_.passivateAfter((1 millisecond).asScala)) + } + sharding <- Sharding.start(shardName, onMessage) + _ <- sharding.send(shardId, "set") + _ <- sharding.send(shardId, "timeout") + _ <- ZIO + .sleep(3 seconds) + .provideLayer( + Clock.live + ) // give time to the ShardCoordinator to notice the death of the actor and recreate one + _ <- sharding.send(shardId, "get") + res <- p.await + } yield res + )(isNone).provideLayer(actorSystem) + }, testM("work with 2 actor systems") { assertM( actorSystem.build.use(a1 =>