From 0ed54c1472e7706603368f64e9c4362ce02c4a55 Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Thu, 25 Apr 2024 20:07:09 +0800 Subject: [PATCH] fix: channel consume left data after close --- lang/channel/channel.go | 11 ++++++----- lang/channel/channel_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/lang/channel/channel.go b/lang/channel/channel.go index b9c75f43..4a811b50 100644 --- a/lang/channel/channel.go +++ b/lang/channel/channel.go @@ -216,10 +216,7 @@ func (c *channel) Close() { if !atomic.CompareAndSwapInt32(&c.state, 0, -1) { return } - // stop consumer - c.bufferLock.Lock() - c.buffer.Init() // clear buffer - c.bufferLock.Unlock() + // Close function only notify Input/consume goroutine to close gracefully c.bufferCond.Broadcast() } @@ -253,6 +250,10 @@ func (c *channel) Input(v interface{}) { for c.buffer.Len() >= c.size { // wait for consuming c.bufferCond.Wait() + if c.isClosed() { + // blocking send a closed channel should return directly + return + } } } c.enqueueBuffer(it) @@ -289,7 +290,7 @@ func (c *channel) consume() { c.bufferLock.Lock() for c.buffer.Len() == 0 { if c.isClosed() { - close(c.consumer) + close(c.consumer) // close consumer atomic.StoreInt32(&c.state, -2) // -2 means closed totally c.bufferLock.Unlock() return diff --git a/lang/channel/channel_test.go b/lang/channel/channel_test.go index b0258bb5..196b39e3 100644 --- a/lang/channel/channel_test.go +++ b/lang/channel/channel_test.go @@ -462,3 +462,32 @@ func TestFastRecoverConsumer(t *testing.T) { } // all consumed } + +func TestChannelCloseThenConsume(t *testing.T) { + size := 10 + ch := New(WithNonBlock(), WithSize(size)) + for i := 0; i < size; i++ { + ch.Input(i) + } + ch.Close() + for i := 0; i < size; i++ { + x := <-ch.Output() + assert.NotNil(t, x) + n := x.(int) + assert.Equal(t, n, x) + } +} + +func TestChannelInputAndClose(t *testing.T) { + ch := New(WithSize(1)) + go func() { + time.Sleep(time.Millisecond * 100) + ch.Close() + }() + begin := time.Now() + for i := 0; i < 10; i++ { + ch.Input(1) + } + cost := time.Now().Sub(begin) + assert.True(t, cost.Milliseconds() >= 100) +}