Skip to content

Commit

Permalink
Add extra environment to onMessage (#71)
Browse files Browse the repository at this point in the history
* Implement support for adding extra environment to onMessage

* Update README

* Add timeout for Sharding tests

Co-authored-by: Pierre Ricadat <[email protected]>
  • Loading branch information
octavz and ghostdogpr authored May 16, 2020
1 parent 5847b5d commit 7dda318
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 42 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ See [Akka Documentation](https://doc.akka.io/docs/akka/current/cluster-sharding.
To start sharding a given entity type on a node, use `Sharding.start`. It returns a `Sharding` object which can be used to send messages, stop or passivate sharded entities.

```scala
def start[Msg, State](
def start[R, Msg, State](
name: String,
onMessage: Msg => ZIO[Entity[State], Nothing, Unit],
onMessage: Msg => ZIO[Entity[State] with R, Nothing, Unit],
numberOfShards: Int = 100
): ZIO[Has[ActorSystem], Throwable, Sharding[Msg]]
): ZIO[Has[ActorSystem] with R, Throwable, Sharding[Msg]]
```

It requires:
Expand Down Expand Up @@ -178,8 +178,8 @@ val actorSystem: ZLayer[Any, Throwable, Has[ActorSystem]] =
ZLayer.fromManaged(Managed.make(Task(ActorSystem("Test")))(sys => Task.fromFuture(_ => sys.terminate()).either))

val behavior: String => ZIO[Entity[Int], Nothing, Unit] = {
case "+" => ZIO.accessM[Entity[Int]](_.state.update(x => Some(x.getOrElse(0) + 1)))
case "-" => ZIO.accessM[Entity[Int]](_.state.update(x => Some(x.getOrElse(0) - 1)))
case "+" => ZIO.accessM[Entity[Int]](_.get.state.update(x => Some(x.getOrElse(0) + 1)))
case "-" => ZIO.accessM[Entity[Int]](_.get.state.update(x => Some(x.getOrElse(0) - 1)))
case _ => ZIO.unit
}

Expand Down
13 changes: 0 additions & 13 deletions src/main/scala/zio/akka/cluster/sharding/Entity.scala

This file was deleted.

24 changes: 12 additions & 12 deletions src/main/scala/zio/akka/cluster/sharding/Sharding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import akka.pattern.{ ask => askPattern }
import akka.util.Timeout
import zio.akka.cluster.sharding
import zio.akka.cluster.sharding.MessageEnvelope.{ MessagePayload, PassivatePayload, PoisonPillPayload }
import zio.{ =!=, Has, Ref, Runtime, Task, UIO, ZIO }
import zio.{ =!=, Has, Ref, Runtime, Tagged, Task, UIO, ZIO, ZLayer }

/**
* A `Sharding[M]` is able to send messages of type `M` to a sharded entity or to stop one.
Expand Down Expand Up @@ -37,19 +37,19 @@ object Sharding {
* @param askTimeout a finite duration specifying how long an ask is allowed to wait for an entity to respond
* @return a [[Sharding]] object that can be used to send messages to sharded entities
*/
def start[Msg, State](
def start[R <: Has[_], Msg, State: Tagged](
name: String,
onMessage: Msg => ZIO[Entity[State], Nothing, Unit],
onMessage: Msg => ZIO[Entity[State] with R, Nothing, Unit],
numberOfShards: Int = 100,
askTimeout: FiniteDuration = 10.seconds
): ZIO[Has[ActorSystem], Throwable, Sharding[Msg]] =
): ZIO[Has[ActorSystem] with R, Throwable, Sharding[Msg]] =
for {
rts <- ZIO.runtime[Has[ActorSystem]]
actorSystem = rts.environment.get
rts <- ZIO.runtime[Has[ActorSystem] with R]
actorSystem = rts.environment.get[ActorSystem]
shardingRegion <- Task(
ClusterSharding(actorSystem).start(
typeName = name,
entityProps = Props(new ShardEntity(rts)(onMessage)),
entityProps = Props(new ShardEntity[R, Msg, State](rts)(onMessage)),
settings = ClusterShardingSettings(actorSystem),
extractEntityId = {
case MessageEnvelope(entityId, payload) =>
Expand Down Expand Up @@ -129,20 +129,20 @@ object Sharding {
)
}

private[sharding] class ShardEntity[Msg, State](rts: Runtime[Any])(
onMessage: Msg => ZIO[Entity[State], Nothing, Unit]
private[sharding] class ShardEntity[R <: Has[_], Msg, State: Tagged](rts: Runtime[R])(
onMessage: Msg => ZIO[Entity[State] with R, Nothing, Unit]
) extends Actor {

val ref: Ref[Option[State]] = rts.unsafeRun(Ref.make[Option[State]](None))
val actorContext: ActorContext = context
val entity: Entity[State] = new Entity[State] {
val entity: ZLayer[Any, Nothing, Entity[State]] = ZLayer.succeed(new Entity.Service[State] {
override def id: String = actorContext.self.path.name
override def state: Ref[Option[State]] = ref
override def stop: UIO[Unit] = UIO(actorContext.stop(self))
override def passivate: UIO[Unit] = UIO(actorContext.parent ! Passivate(PoisonPill))
override def passivateAfter(duration: Duration): UIO[Unit] = UIO(actorContext.self ! SetTimeout(duration))
override def replyToSender[R](msg: R): Task[Unit] = Task(actorContext.sender() ! msg)
}
})

def receive: Receive = {
case SetTimeout(duration) =>
Expand All @@ -152,7 +152,7 @@ object Sharding {
case p: Passivate =>
actorContext.parent ! p
case MessagePayload(msg) =>
rts.unsafeRunSync(onMessage(msg.asInstanceOf[Msg]).provide(entity))
rts.unsafeRunSync(onMessage(msg.asInstanceOf[Msg]).provideSomeLayer[R](entity))
()
case _ =>
}
Expand Down
35 changes: 35 additions & 0 deletions src/main/scala/zio/akka/cluster/sharding/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package zio.akka.cluster

import zio.{ Has, Ref, Tagged, Task, UIO, ZIO }

import scala.concurrent.duration.Duration

package object sharding {
type Entity[State] = Has[Entity.Service[State]]

object Entity {

trait Service[State] {
def replyToSender[R](msg: R): Task[Unit]
def id: String
def state: Ref[Option[State]]
def stop: UIO[Unit]
def passivate: UIO[Unit]
def passivateAfter(duration: Duration): UIO[Unit]
}

def replyToSender[State: Tagged, R](msg: R) =
ZIO.accessM[Entity[State]](_.get.replyToSender(msg))
def id[State: Tagged] =
ZIO.access[Entity[State]](_.get.id)
def state[State: Tagged] =
ZIO.access[Entity[State]](_.get.state)
def stop[State: Tagged] =
ZIO.accessM[Entity[State]](_.get.stop)
def passivate[State: Tagged] =
ZIO.accessM[Entity[State]](_.get.passivate)
def passivateAfter[State: Tagged](duration: Duration) =
ZIO.accessM[Entity[State]](_.get.passivateAfter(duration))

}
}
45 changes: 33 additions & 12 deletions src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import zio.duration._
import zio.test.Assertion._
import zio.test._
import zio.test.environment.TestEnvironment
import zio.{ ExecutionStrategy, Has, Managed, Promise, Task, ZIO, ZLayer }
import zio.{ ExecutionStrategy, Has, Managed, Promise, Task, UIO, ZIO, ZLayer }

object ShardingSpec extends DefaultRunnableSpec {

Expand Down Expand Up @@ -77,7 +77,7 @@ object ShardingSpec extends DefaultRunnableSpec {
},
testM("send and receive a message using ask") {
val onMessage: String => ZIO[Entity[Any], Nothing, Unit] =
incomingMsg => ZIO.accessM[Entity[Any]](r => r.replyToSender(incomingMsg).orDie)
incomingMsg => ZIO.accessM[Entity[Any]](r => r.get.replyToSender(incomingMsg).orDie)
assertM(
for {
sharding <- Sharding.start(shardName, onMessage)
Expand All @@ -91,7 +91,7 @@ object ShardingSpec extends DefaultRunnableSpec {
p <- Promise.make[Nothing, Boolean]
onMessage = (_: String) =>
for {
state <- ZIO.access[Entity[Int]](_.state)
state <- ZIO.access[Entity[Int]](_.get.state)
newState <- state.updateAndGet {
case None => Some(1)
case Some(x) => Some(x + 1)
Expand All @@ -113,9 +113,9 @@ object ShardingSpec extends DefaultRunnableSpec {
p <- Promise.make[Nothing, Option[Unit]]
onMessage = (msg: String) =>
msg match {
case "set" => ZIO.accessM[Entity[Unit]](_.state.set(Some(())))
case "get" => ZIO.accessM[Entity[Unit]](_.state.get.flatMap(s => p.succeed(s).unit))
case "die" => ZIO.accessM[Entity[Unit]](_.stop)
case "set" => ZIO.accessM[Entity[Unit]](_.get.state.set(Some(())))
case "get" => ZIO.accessM[Entity[Unit]](_.get.state.get.flatMap(s => p.succeed(s).unit))
case "die" => ZIO.accessM[Entity[Unit]](_.get.stop)
}
sharding <- Sharding.start(shardName, onMessage)
_ <- sharding.send(shardId, "set")
Expand All @@ -136,8 +136,8 @@ object ShardingSpec extends DefaultRunnableSpec {
p <- Promise.make[Nothing, Option[Unit]]
onMessage = (msg: String) =>
msg match {
case "set" => ZIO.accessM[Entity[Unit]](_.state.set(Some(())))
case "get" => ZIO.accessM[Entity[Unit]](_.state.get.flatMap(s => p.succeed(s).unit))
case "set" => ZIO.accessM[Entity[Unit]](_.get.state.set(Some(())))
case "get" => ZIO.accessM[Entity[Unit]](_.get.state.get.flatMap(s => p.succeed(s).unit))
}
sharding <- Sharding.start(shardName, onMessage)
_ <- sharding.send(shardId, "set")
Expand All @@ -158,9 +158,9 @@ object ShardingSpec extends DefaultRunnableSpec {
p <- Promise.make[Nothing, Option[Unit]]
onMessage = (msg: String) =>
msg match {
case "set" => ZIO.accessM[Entity[Unit]](_.state.set(Some(())))
case "get" => ZIO.accessM[Entity[Unit]](_.state.get.flatMap(s => p.succeed(s).unit))
case "timeout" => ZIO.accessM[Entity[Unit]](_.passivateAfter((1 millisecond).asScala))
case "set" => ZIO.accessM[Entity[Unit]](_.get.state.set(Some(())))
case "get" => ZIO.accessM[Entity[Unit]](_.get.state.get.flatMap(s => p.succeed(s).unit))
case "timeout" => ZIO.accessM[Entity[Unit]](_.get.passivateAfter((1 millisecond).asScala))
}
sharding <- Sharding.start(shardName, onMessage)
_ <- sharding.send(shardId, "set")
Expand Down Expand Up @@ -194,9 +194,30 @@ object ShardingSpec extends DefaultRunnableSpec {
)
)
)(isUnit)
},
testM("provide proper environment to onMessage") {
trait TestService {
def doSomething(): UIO[String]
}
def doSomething =
ZIO.accessM[Has[TestService]](_.get.doSomething())

val l = ZLayer.succeed(new TestService {
override def doSomething(): UIO[String] = UIO("test")
})

assertM(
for {
p <- Promise.make[Nothing, String]
onMessage = (_: String) => (doSomething >>= p.succeed).unit
sharding <- Sharding.start(shardName, onMessage)
_ <- sharding.send(shardId, msg)
res <- p.await
} yield res
)(equalTo("test")).provideLayer(actorSystem ++ l)
}
)

override def aspects: List[TestAspect[Nothing, TestEnvironment, Nothing, Any]] =
List(TestAspect.executionStrategy(ExecutionStrategy.Sequential))
List(TestAspect.executionStrategy(ExecutionStrategy.Sequential), TestAspect.timeout(30.seconds))
}

0 comments on commit 7dda318

Please sign in to comment.