Skip to content

Commit

Permalink
feat: cond with timeout control
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed Mar 19, 2024
1 parent 21fc7a1 commit 1ec1731
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 0 deletions.
5 changes: 5 additions & 0 deletions lang/channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package channel

import (
"container/list"
"context"
"runtime"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -265,6 +266,10 @@ func (c *channel) Output() <-chan interface{} {
return c.consumer
}

func (c *channel) OutputCtx(ctx context.Context) <-chan interface{} {
return c.consumer
}

func (c *channel) Len() int {
produced, consumed := c.Stats()
l := produced - consumed
Expand Down
86 changes: 86 additions & 0 deletions lang/channel/cond.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package channel

import (
"context"
"sync/atomic"
"time"
)

var (
_ Cond = (*cond)(nil)
)

type Cond interface {
Signal() bool
Broadcast() bool
Wait(ctx context.Context) bool
}

func NewCond() Cond {
return new(cond)
}

type condSignal = chan struct{}

type cond struct {
signal atomic.Value
timeout time.Duration
}

func (c *cond) Signal() bool {
sv := c.signal.Load()
if sv == nil {
return false
}
signal := sv.(condSignal)
select {
case signal <- struct{}{}:
return true
default:
return false
}
}

func (c *cond) Broadcast() bool {
BROADCAST:
sv := c.signal.Load()
if sv == nil {
return false
}
var signal condSignal = nil
if !c.signal.CompareAndSwap(sv, signal) {
goto BROADCAST
}
signal = sv.(condSignal)
select {
case <-signal:
return false
default:
close(signal)
return true
}
}

func (c *cond) Wait(ctx context.Context) bool {
WAIT:
sv := c.signal.Load()
var signal condSignal
if sv == nil {
signal = make(condSignal)
if !c.signal.CompareAndSwap(nil, signal) {
goto WAIT
}
} else {
signal = sv.(condSignal)
}
if ctx == nil || ctx.Done() == nil {
<-signal
return true
}
select {
case <-signal:
return true
case <-ctx.Done():
return false
}
}
75 changes: 75 additions & 0 deletions lang/channel/cond_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package channel

import (
"context"
"runtime"
"sync/atomic"
"testing"
"time"
)

func TestCond(t *testing.T) {
cd := NewCond()
var finished int32
emptyCtx := context.Background()
cancelCtx, cancelFunc := context.WithCancel(emptyCtx)
for i := 0; i < 10; i++ {
go func(i int) {
if i%2 == 0 {
cd.Wait(emptyCtx)
} else {
cd.Wait(cancelCtx)
}
atomic.AddInt32(&finished, 1)
}(i)
}
time.Sleep(time.Millisecond * 100)
cancelFunc()
for atomic.LoadInt32(&finished) != int32(5) {
runtime.Gosched()
}
cd.Signal()
for atomic.LoadInt32(&finished) != int32(6) {
runtime.Gosched()
}
cd.Signal()
for atomic.LoadInt32(&finished) != int32(7) {
runtime.Gosched()
}
cd.Broadcast()
cd.Signal()
for atomic.LoadInt32(&finished) != int32(10) {
runtime.Gosched()
}
}

func BenchmarkChanCond(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ch := make(chan struct{})
go func() {
time.Sleep(time.Millisecond)
close(ch)
}()
select {
case <-ch:
case <-time.After(10 * time.Millisecond):
}
}
}

func BenchmarkCond(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cd := NewCond()
go func() {
time.Sleep(time.Millisecond)
cd.Signal()
}()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
cd.Wait(ctx)
cancel()
}
}

0 comments on commit 1ec1731

Please sign in to comment.