diff --git a/README.md b/README.md index d7dd348..59c4e09 100644 --- a/README.md +++ b/README.md @@ -135,11 +135,11 @@ See [Akka Documentation](https://doc.akka.io/docs/akka/current/cluster-sharding. To start sharding a given entity type on a node, use `Sharding.start`. It returns a `Sharding` object which can be used to send messages, stop or passivate sharded entities. ```scala -def start[Msg, State]( +def start[R, Msg, State]( name: String, - onMessage: Msg => ZIO[Entity[State], Nothing, Unit], + onMessage: Msg => ZIO[Entity[State] with R, Nothing, Unit], numberOfShards: Int = 100 - ): ZIO[Has[ActorSystem], Throwable, Sharding[Msg]] + ): ZIO[Has[ActorSystem] with R, Throwable, Sharding[Msg]] ``` It requires: @@ -178,8 +178,8 @@ val actorSystem: ZLayer[Any, Throwable, Has[ActorSystem]] = ZLayer.fromManaged(Managed.make(Task(ActorSystem("Test")))(sys => Task.fromFuture(_ => sys.terminate()).either)) val behavior: String => ZIO[Entity[Int], Nothing, Unit] = { - case "+" => ZIO.accessM[Entity[Int]](_.state.update(x => Some(x.getOrElse(0) + 1))) - case "-" => ZIO.accessM[Entity[Int]](_.state.update(x => Some(x.getOrElse(0) - 1))) + case "+" => ZIO.accessM[Entity[Int]](_.get.state.update(x => Some(x.getOrElse(0) + 1))) + case "-" => ZIO.accessM[Entity[Int]](_.get.state.update(x => Some(x.getOrElse(0) - 1))) case _ => ZIO.unit } diff --git a/src/main/scala/zio/akka/cluster/sharding/Entity.scala b/src/main/scala/zio/akka/cluster/sharding/Entity.scala deleted file mode 100644 index 6186947..0000000 --- a/src/main/scala/zio/akka/cluster/sharding/Entity.scala +++ /dev/null @@ -1,13 +0,0 @@ -package zio.akka.cluster.sharding - -import scala.concurrent.duration.Duration -import zio.{ Ref, Task, UIO } - -trait Entity[State] { - def replyToSender[R](msg: R): Task[Unit] - def id: String - def state: Ref[Option[State]] - def stop: UIO[Unit] - def passivate: UIO[Unit] - def passivateAfter(duration: Duration): UIO[Unit] -} diff --git a/src/main/scala/zio/akka/cluster/sharding/Sharding.scala b/src/main/scala/zio/akka/cluster/sharding/Sharding.scala index b8d5198..bdcf0f4 100644 --- a/src/main/scala/zio/akka/cluster/sharding/Sharding.scala +++ b/src/main/scala/zio/akka/cluster/sharding/Sharding.scala @@ -9,7 +9,7 @@ import akka.pattern.{ ask => askPattern } import akka.util.Timeout import zio.akka.cluster.sharding import zio.akka.cluster.sharding.MessageEnvelope.{ MessagePayload, PassivatePayload, PoisonPillPayload } -import zio.{ =!=, Has, Ref, Runtime, Task, UIO, ZIO } +import zio.{ =!=, Has, Ref, Runtime, Tagged, Task, UIO, ZIO, ZLayer } /** * A `Sharding[M]` is able to send messages of type `M` to a sharded entity or to stop one. @@ -37,19 +37,19 @@ object Sharding { * @param askTimeout a finite duration specifying how long an ask is allowed to wait for an entity to respond * @return a [[Sharding]] object that can be used to send messages to sharded entities */ - def start[Msg, State]( + def start[R <: Has[_], Msg, State: Tagged]( name: String, - onMessage: Msg => ZIO[Entity[State], Nothing, Unit], + onMessage: Msg => ZIO[Entity[State] with R, Nothing, Unit], numberOfShards: Int = 100, askTimeout: FiniteDuration = 10.seconds - ): ZIO[Has[ActorSystem], Throwable, Sharding[Msg]] = + ): ZIO[Has[ActorSystem] with R, Throwable, Sharding[Msg]] = for { - rts <- ZIO.runtime[Has[ActorSystem]] - actorSystem = rts.environment.get + rts <- ZIO.runtime[Has[ActorSystem] with R] + actorSystem = rts.environment.get[ActorSystem] shardingRegion <- Task( ClusterSharding(actorSystem).start( typeName = name, - entityProps = Props(new ShardEntity(rts)(onMessage)), + entityProps = Props(new ShardEntity[R, Msg, State](rts)(onMessage)), settings = ClusterShardingSettings(actorSystem), extractEntityId = { case MessageEnvelope(entityId, payload) => @@ -129,20 +129,20 @@ object Sharding { ) } - private[sharding] class ShardEntity[Msg, State](rts: Runtime[Any])( - onMessage: Msg => ZIO[Entity[State], Nothing, Unit] + private[sharding] class ShardEntity[R <: Has[_], Msg, State: Tagged](rts: Runtime[R])( + onMessage: Msg => ZIO[Entity[State] with R, Nothing, Unit] ) extends Actor { val ref: Ref[Option[State]] = rts.unsafeRun(Ref.make[Option[State]](None)) val actorContext: ActorContext = context - val entity: Entity[State] = new Entity[State] { + val entity: ZLayer[Any, Nothing, Entity[State]] = ZLayer.succeed(new Entity.Service[State] { override def id: String = actorContext.self.path.name override def state: Ref[Option[State]] = ref override def stop: UIO[Unit] = UIO(actorContext.stop(self)) override def passivate: UIO[Unit] = UIO(actorContext.parent ! Passivate(PoisonPill)) override def passivateAfter(duration: Duration): UIO[Unit] = UIO(actorContext.self ! SetTimeout(duration)) override def replyToSender[R](msg: R): Task[Unit] = Task(actorContext.sender() ! msg) - } + }) def receive: Receive = { case SetTimeout(duration) => @@ -152,7 +152,7 @@ object Sharding { case p: Passivate => actorContext.parent ! p case MessagePayload(msg) => - rts.unsafeRunSync(onMessage(msg.asInstanceOf[Msg]).provide(entity)) + rts.unsafeRunSync(onMessage(msg.asInstanceOf[Msg]).provideSomeLayer[R](entity)) () case _ => } diff --git a/src/main/scala/zio/akka/cluster/sharding/package.scala b/src/main/scala/zio/akka/cluster/sharding/package.scala new file mode 100644 index 0000000..cabbf48 --- /dev/null +++ b/src/main/scala/zio/akka/cluster/sharding/package.scala @@ -0,0 +1,35 @@ +package zio.akka.cluster + +import zio.{ Has, Ref, Tagged, Task, UIO, ZIO } + +import scala.concurrent.duration.Duration + +package object sharding { + type Entity[State] = Has[Entity.Service[State]] + + object Entity { + + trait Service[State] { + def replyToSender[R](msg: R): Task[Unit] + def id: String + def state: Ref[Option[State]] + def stop: UIO[Unit] + def passivate: UIO[Unit] + def passivateAfter(duration: Duration): UIO[Unit] + } + + def replyToSender[State: Tagged, R](msg: R) = + ZIO.accessM[Entity[State]](_.get.replyToSender(msg)) + def id[State: Tagged] = + ZIO.access[Entity[State]](_.get.id) + def state[State: Tagged] = + ZIO.access[Entity[State]](_.get.state) + def stop[State: Tagged] = + ZIO.accessM[Entity[State]](_.get.stop) + def passivate[State: Tagged] = + ZIO.accessM[Entity[State]](_.get.passivate) + def passivateAfter[State: Tagged](duration: Duration) = + ZIO.accessM[Entity[State]](_.get.passivateAfter(duration)) + + } +} diff --git a/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala b/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala index 1b118a7..46c93b8 100644 --- a/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala +++ b/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala @@ -8,7 +8,7 @@ import zio.duration._ import zio.test.Assertion._ import zio.test._ import zio.test.environment.TestEnvironment -import zio.{ ExecutionStrategy, Has, Managed, Promise, Task, ZIO, ZLayer } +import zio.{ ExecutionStrategy, Has, Managed, Promise, Task, UIO, ZIO, ZLayer } object ShardingSpec extends DefaultRunnableSpec { @@ -77,7 +77,7 @@ object ShardingSpec extends DefaultRunnableSpec { }, testM("send and receive a message using ask") { val onMessage: String => ZIO[Entity[Any], Nothing, Unit] = - incomingMsg => ZIO.accessM[Entity[Any]](r => r.replyToSender(incomingMsg).orDie) + incomingMsg => ZIO.accessM[Entity[Any]](r => r.get.replyToSender(incomingMsg).orDie) assertM( for { sharding <- Sharding.start(shardName, onMessage) @@ -91,7 +91,7 @@ object ShardingSpec extends DefaultRunnableSpec { p <- Promise.make[Nothing, Boolean] onMessage = (_: String) => for { - state <- ZIO.access[Entity[Int]](_.state) + state <- ZIO.access[Entity[Int]](_.get.state) newState <- state.updateAndGet { case None => Some(1) case Some(x) => Some(x + 1) @@ -113,9 +113,9 @@ object ShardingSpec extends DefaultRunnableSpec { p <- Promise.make[Nothing, Option[Unit]] onMessage = (msg: String) => msg match { - case "set" => ZIO.accessM[Entity[Unit]](_.state.set(Some(()))) - case "get" => ZIO.accessM[Entity[Unit]](_.state.get.flatMap(s => p.succeed(s).unit)) - case "die" => ZIO.accessM[Entity[Unit]](_.stop) + case "set" => ZIO.accessM[Entity[Unit]](_.get.state.set(Some(()))) + case "get" => ZIO.accessM[Entity[Unit]](_.get.state.get.flatMap(s => p.succeed(s).unit)) + case "die" => ZIO.accessM[Entity[Unit]](_.get.stop) } sharding <- Sharding.start(shardName, onMessage) _ <- sharding.send(shardId, "set") @@ -136,8 +136,8 @@ object ShardingSpec extends DefaultRunnableSpec { p <- Promise.make[Nothing, Option[Unit]] onMessage = (msg: String) => msg match { - case "set" => ZIO.accessM[Entity[Unit]](_.state.set(Some(()))) - case "get" => ZIO.accessM[Entity[Unit]](_.state.get.flatMap(s => p.succeed(s).unit)) + case "set" => ZIO.accessM[Entity[Unit]](_.get.state.set(Some(()))) + case "get" => ZIO.accessM[Entity[Unit]](_.get.state.get.flatMap(s => p.succeed(s).unit)) } sharding <- Sharding.start(shardName, onMessage) _ <- sharding.send(shardId, "set") @@ -158,9 +158,9 @@ object ShardingSpec extends DefaultRunnableSpec { p <- Promise.make[Nothing, Option[Unit]] onMessage = (msg: String) => msg match { - case "set" => ZIO.accessM[Entity[Unit]](_.state.set(Some(()))) - case "get" => ZIO.accessM[Entity[Unit]](_.state.get.flatMap(s => p.succeed(s).unit)) - case "timeout" => ZIO.accessM[Entity[Unit]](_.passivateAfter((1 millisecond).asScala)) + case "set" => ZIO.accessM[Entity[Unit]](_.get.state.set(Some(()))) + case "get" => ZIO.accessM[Entity[Unit]](_.get.state.get.flatMap(s => p.succeed(s).unit)) + case "timeout" => ZIO.accessM[Entity[Unit]](_.get.passivateAfter((1 millisecond).asScala)) } sharding <- Sharding.start(shardName, onMessage) _ <- sharding.send(shardId, "set") @@ -194,9 +194,30 @@ object ShardingSpec extends DefaultRunnableSpec { ) ) )(isUnit) + }, + testM("provide proper environment to onMessage") { + trait TestService { + def doSomething(): UIO[String] + } + def doSomething = + ZIO.accessM[Has[TestService]](_.get.doSomething()) + + val l = ZLayer.succeed(new TestService { + override def doSomething(): UIO[String] = UIO("test") + }) + + assertM( + for { + p <- Promise.make[Nothing, String] + onMessage = (_: String) => (doSomething >>= p.succeed).unit + sharding <- Sharding.start(shardName, onMessage) + _ <- sharding.send(shardId, msg) + res <- p.await + } yield res + )(equalTo("test")).provideLayer(actorSystem ++ l) } ) override def aspects: List[TestAspect[Nothing, TestEnvironment, Nothing, Any]] = - List(TestAspect.executionStrategy(ExecutionStrategy.Sequential)) + List(TestAspect.executionStrategy(ExecutionStrategy.Sequential), TestAspect.timeout(30.seconds)) }