From d9947ef033c44a7e3ca43be03340d7c1391a1661 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Tue, 3 Sep 2024 14:29:33 +0800 Subject: [PATCH] fix: close consume chan when close in throttle waiting --- lang/channel/channel.go | 2 ++ lang/channel/channel_test.go | 60 ++++++++++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/lang/channel/channel.go b/lang/channel/channel.go index 4a811b5..2b3eec4 100644 --- a/lang/channel/channel.go +++ b/lang/channel/channel.go @@ -283,6 +283,8 @@ func (c *channel) consume() { // check throttle if c.throttling(c.consumerThrottle) { // closed + close(c.consumer) // close consumer + atomic.StoreInt32(&c.state, -2) // -2 means closed totally return } diff --git a/lang/channel/channel_test.go b/lang/channel/channel_test.go index 438a0d5..52955d2 100644 --- a/lang/channel/channel_test.go +++ b/lang/channel/channel_test.go @@ -373,9 +373,12 @@ func TestChannelProduceRateControl(t *testing.T) { ch := New( WithRateThrottle(produceMaxRate, 0), ) - defer ch.Close() + var wg sync.WaitGroup + const total = 300 + wg.Add(1) go func() { + defer wg.Done() for c := range ch.Output() { id := c.(int) //tlogf(t, "consumed: %d", id) @@ -383,30 +386,75 @@ func TestChannelProduceRateControl(t *testing.T) { } }() begin := time.Now() - for i := 1; i <= 500; i++ { + for i := 1; i <= total; i++ { ch.Input(i) } + ch.Close() // when channel closed, ch.Output() should return + wg.Wait() cost := time.Now().Sub(begin) tlogf(t, "Cost %dms", cost.Milliseconds()) } func TestChannelConsumeRateControl(t *testing.T) { + consumeRate := 100 ch := New( - WithRateThrottle(0, 100), + WithRateThrottle(0, consumeRate), ) - defer ch.Close() + var wg sync.WaitGroup + const total = 300 + var counter int32 + wg.Add(1) go func() { + defer wg.Done() for c := range ch.Output() { id := c.(int) //tlogf(t, "consumed: %d", id) - _ = id + if id == 0 { + t.Errorf("get zero output") + } + atomic.AddInt32(&counter, 1) + } + }() + begin := time.Now() + for i := 1; i <= total; i++ { + ch.Input(i) + } + ch.Close() // when channel closed, ch.Output() should return + wg.Wait() + assert.Equal(t, int32(total), atomic.LoadInt32(&counter)) + cost := time.Now().Sub(begin) + tlogf(t, "Cost %dms", cost.Milliseconds()) +} + +func TestChannelProduceAndConsumeRateControl(t *testing.T) { + produceRate, consumeRate := 100, 50 + ch := New( + WithRateThrottle(produceRate, consumeRate), + ) + + var wg sync.WaitGroup + const total = 300 + var counter int32 + wg.Add(1) + go func() { + defer wg.Done() + for c := range ch.Output() { + id := c.(int) + //tlogf(t, "consumed: %d", id) + if id == 0 { + t.Errorf("get zero output") + } + atomic.AddInt32(&counter, 1) } }() begin := time.Now() - for i := 1; i <= 500; i++ { + for i := 1; i <= total; i++ { ch.Input(i) } + ch.Close() // when channel closed, ch.Output() should return + wg.Wait() + assert.Equal(t, int32(total), atomic.LoadInt32(&counter)) cost := time.Now().Sub(begin) tlogf(t, "Cost %dms", cost.Milliseconds()) }