diff --git a/shared/src/main/scala/async/Semaphore.scala b/shared/src/main/scala/async/Semaphore.scala index 26e801d..e509aaf 100644 --- a/shared/src/main/scala/async/Semaphore.scala +++ b/shared/src/main/scala/async/Semaphore.scala @@ -9,53 +9,65 @@ import java.util.concurrent.atomic.AtomicInteger * @param initialValue * the initial counter of this semaphore */ -class Semaphore(initialValue: Int) extends Async.Source[Unit]: +class Semaphore(initialValue: Int) extends Async.Source[Semaphore.Guard]: + self => private val value = AtomicInteger(initialValue) - private val waiting = ConcurrentLinkedQueue[Listener[Unit]]() + private val waiting = ConcurrentLinkedQueue[Listener[Semaphore.Guard]]() - override def onComplete(k: Listener[Unit]): Unit = + override def onComplete(k: Listener[Semaphore.Guard]): Unit = if k.acquireLock() then // if k is gone, we are done if value.getAndDecrement() > 0 then // we got a ticket - k.complete((), this) + k.complete(guard, this) else // no ticket -> add to queue and reset value (was now negative - unless concurrently increased) k.releaseLock() waiting.add(k) - release() + guard.release() - override def dropListener(k: Listener[Unit]): Unit = waiting.remove(k) + override def dropListener(k: Listener[Semaphore.Guard]): Unit = waiting.remove(k) - override def poll(k: Listener[Unit]): Boolean = + override def poll(k: Listener[Semaphore.Guard]): Boolean = if !k.acquireLock() then return true val success = value.getAndUpdate(i => if i > 0 then i - 1 else i) > 0 - if success then k.complete((), this) else k.releaseLock() + if success then k.complete(guard, self) else k.releaseLock() success - override def poll(): Option[Unit] = - if value.getAndUpdate(i => if i > 0 then i - 1 else i) > 0 then Some(()) else None + override def poll(): Option[Semaphore.Guard] = + if value.getAndUpdate(i => if i > 0 then i - 1 else i) > 0 then Some(guard) else None /** Decrease the number of grants available from this semaphore, possibly waiting if none is available. * * @param a * the async capability used for waiting */ - inline def acquire()(using a: Async): Unit = - a.await(this) // do not short-circuit because cancellation should be considered first + inline def acquire()(using Async): Semaphore.Guard = + this.awaitResult // do not short-circuit because cancellation should be considered first - /** Increase the number of grants available to this semaphore, possibly waking up a waiting [[acquire]]. + private object guard extends Semaphore.Guard: + /** Increase the number of grants available to this semaphore, possibly waking up a waiting [[acquire]]. + */ + def release(): Unit = + // if value is < 0, a ticket is missing anyway -> do nothing now + if value.getAndUpdate(i => if i < 0 then i + 1 else i) >= 0 then + // we kept the ticket for now + + var listener = waiting.poll() + while listener != null && !listener.completeNow(guard, self) do listener = waiting.poll() + // if listener not null, then we quit because listener was completed -> ticket is reused -> we are done + + // if listener is null, return the ticket by incrementing, then recheck waiting queue (if incremented to >0) + if listener == null && value.getAndIncrement() >= 0 then + listener = waiting.poll() + if listener != null then // if null now, we are done + onComplete(listener) + +object Semaphore: + /** A guard that marks a single usage of the [[Semaphore]]. Implements [[java.lang.AutoCloseable]] so it can be used + * as a try-with-resource (e.g. with [[scala.util.Using]]). */ - def release(): Unit = - // if value is < 0, a ticket is missing anyway -> do nothing now - if value.getAndUpdate(i => if i < 0 then i + 1 else i) >= 0 then - // we kept the ticket for now - - var listener = waiting.poll() - while listener != null && !listener.completeNow((), this) do listener = waiting.poll() - // if listener not null, then we quit because listener was completed -> ticket is reused -> we are done - - // if listener is null, return the ticket by incrementing, then recheck waiting queue (if incremented to >0) - if listener == null && value.getAndIncrement() >= 0 then - listener = waiting.poll() - if listener != null then // if null now, we are done - onComplete(listener) + trait Guard extends java.lang.AutoCloseable: + /** Release the semaphore, must be called exactly once. */ + def release(): Unit + + final def close() = release() diff --git a/shared/src/test/scala/SemaphoreBehavior.scala b/shared/src/test/scala/SemaphoreBehavior.scala index ff32fbc..9496837 100644 --- a/shared/src/test/scala/SemaphoreBehavior.scala +++ b/shared/src/test/scala/SemaphoreBehavior.scala @@ -14,8 +14,7 @@ class SemaphoreBehavior extends munit.FunSuite { test("single threaded semaphore") { Async.blocking: val sem = Semaphore(2) - sem.acquire() - sem.release() + sem.acquire().release() sem.acquire() sem.acquire() } @@ -23,11 +22,11 @@ class SemaphoreBehavior extends munit.FunSuite { test("single threaded semaphore blocked") { Async.blocking: val sem = Semaphore(2) - sem.acquire() + val guard = sem.acquire() sem.acquire() val res = withTimeoutOption(100.millis)(sem.acquire()) assertEquals(res, None) - sem.release() + guard.release() sem.acquire() } @@ -39,9 +38,8 @@ class SemaphoreBehavior extends munit.FunSuite { Seq .fill(100)(Future { for i <- 0 until 1_000 do - sem.acquire() - count += 1 - sem.release() + scala.util.Using(sem.acquire()): _ => + count += 1 }) .awaitAll assertEquals(count, 100_000)