Skip to content

Commit

Permalink
Have the semaphore return a guard with a .release()
Browse files Browse the repository at this point in the history
  • Loading branch information
natsukagami authored and m8nmueller committed Aug 13, 2024
1 parent 08fb3a0 commit fe2dc98
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
66 changes: 39 additions & 27 deletions shared/src/main/scala/async/Semaphore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,65 @@ import java.util.concurrent.atomic.AtomicInteger
* @param initialValue
* the initial counter of this semaphore
*/
class Semaphore(initialValue: Int) extends Async.Source[Unit]:
class Semaphore(initialValue: Int) extends Async.Source[Semaphore.Guard]:
self =>
private val value = AtomicInteger(initialValue)
private val waiting = ConcurrentLinkedQueue[Listener[Unit]]()
private val waiting = ConcurrentLinkedQueue[Listener[Semaphore.Guard]]()

override def onComplete(k: Listener[Unit]): Unit =
override def onComplete(k: Listener[Semaphore.Guard]): Unit =
if k.acquireLock() then // if k is gone, we are done
if value.getAndDecrement() > 0 then
// we got a ticket
k.complete((), this)
k.complete(guard, this)
else
// no ticket -> add to queue and reset value (was now negative - unless concurrently increased)
k.releaseLock()
waiting.add(k)
release()
guard.release()

override def dropListener(k: Listener[Unit]): Unit = waiting.remove(k)
override def dropListener(k: Listener[Semaphore.Guard]): Unit = waiting.remove(k)

override def poll(k: Listener[Unit]): Boolean =
override def poll(k: Listener[Semaphore.Guard]): Boolean =
if !k.acquireLock() then return true
val success = value.getAndUpdate(i => if i > 0 then i - 1 else i) > 0
if success then k.complete((), this) else k.releaseLock()
if success then k.complete(guard, self) else k.releaseLock()
success

override def poll(): Option[Unit] =
if value.getAndUpdate(i => if i > 0 then i - 1 else i) > 0 then Some(()) else None
override def poll(): Option[Semaphore.Guard] =
if value.getAndUpdate(i => if i > 0 then i - 1 else i) > 0 then Some(guard) else None

/** Decrease the number of grants available from this semaphore, possibly waiting if none is available.
*
* @param a
* the async capability used for waiting
*/
inline def acquire()(using a: Async): Unit =
a.await(this) // do not short-circuit because cancellation should be considered first
inline def acquire()(using Async): Semaphore.Guard =
this.awaitResult // do not short-circuit because cancellation should be considered first

/** Increase the number of grants available to this semaphore, possibly waking up a waiting [[acquire]].
private object guard extends Semaphore.Guard:
/** Increase the number of grants available to this semaphore, possibly waking up a waiting [[acquire]].
*/
def release(): Unit =
// if value is < 0, a ticket is missing anyway -> do nothing now
if value.getAndUpdate(i => if i < 0 then i + 1 else i) >= 0 then
// we kept the ticket for now

var listener = waiting.poll()
while listener != null && !listener.completeNow(guard, self) do listener = waiting.poll()
// if listener not null, then we quit because listener was completed -> ticket is reused -> we are done

// if listener is null, return the ticket by incrementing, then recheck waiting queue (if incremented to >0)
if listener == null && value.getAndIncrement() >= 0 then
listener = waiting.poll()
if listener != null then // if null now, we are done
onComplete(listener)

object Semaphore:
/** A guard that marks a single usage of the [[Semaphore]]. Implements [[java.lang.AutoCloseable]] so it can be used
* as a try-with-resource (e.g. with [[scala.util.Using]]).
*/
def release(): Unit =
// if value is < 0, a ticket is missing anyway -> do nothing now
if value.getAndUpdate(i => if i < 0 then i + 1 else i) >= 0 then
// we kept the ticket for now

var listener = waiting.poll()
while listener != null && !listener.completeNow((), this) do listener = waiting.poll()
// if listener not null, then we quit because listener was completed -> ticket is reused -> we are done

// if listener is null, return the ticket by incrementing, then recheck waiting queue (if incremented to >0)
if listener == null && value.getAndIncrement() >= 0 then
listener = waiting.poll()
if listener != null then // if null now, we are done
onComplete(listener)
trait Guard extends java.lang.AutoCloseable:
/** Release the semaphore, must be called exactly once. */
def release(): Unit

final def close() = release()
12 changes: 5 additions & 7 deletions shared/src/test/scala/SemaphoreBehavior.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,19 @@ class SemaphoreBehavior extends munit.FunSuite {
test("single threaded semaphore") {
Async.blocking:
val sem = Semaphore(2)
sem.acquire()
sem.release()
sem.acquire().release()
sem.acquire()
sem.acquire()
}

test("single threaded semaphore blocked") {
Async.blocking:
val sem = Semaphore(2)
sem.acquire()
val guard = sem.acquire()
sem.acquire()
val res = withTimeoutOption(100.millis)(sem.acquire())
assertEquals(res, None)
sem.release()
guard.release()
sem.acquire()
}

Expand All @@ -39,9 +38,8 @@ class SemaphoreBehavior extends munit.FunSuite {
Seq
.fill(100)(Future {
for i <- 0 until 1_000 do
sem.acquire()
count += 1
sem.release()
scala.util.Using(sem.acquire()): _ =>
count += 1
})
.awaitAll
assertEquals(count, 100_000)
Expand Down

0 comments on commit fe2dc98

Please sign in to comment.