Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pool: Correct endpoint waitgroup logic. #359

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions pool/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ type connection struct {
type Endpoint struct {
listenAddr string
connCh chan *connection
discCh chan struct{}
listener net.Listener
cfg *EndpointConfig
clients map[string]*Client
Expand All @@ -88,7 +87,6 @@ func NewEndpoint(eCfg *EndpointConfig, listenAddr string) (*Endpoint, error) {
cfg: eCfg,
clients: make(map[string]*Client),
connCh: make(chan *connection, bufferSize),
discCh: make(chan struct{}, bufferSize),
}
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
Expand All @@ -109,7 +107,9 @@ func (e *Endpoint) removeClient(c *Client) {

// listen accepts incoming client connections on the endpoint.
// It must be run as a goroutine.
func (e *Endpoint) listen() {
func (e *Endpoint) listen(ctx context.Context) {
defer e.wg.Done()

log.Infof("listening on %s", e.listenAddr)
for {
conn, err := e.listener.Accept()
Expand All @@ -126,16 +126,19 @@ func (e *Endpoint) listen() {
log.Errorf("unable to accept client connection: %v", err)
return
}
e.connCh <- &connection{
Conn: conn,
Done: make(chan bool),
select {
case <-ctx.Done():
return
case e.connCh <- &connection{Conn: conn, Done: make(chan bool)}:
}
}
}

// connect creates new pool clients from established connections.
// It must be run as a goroutine.
func (e *Endpoint) connect(ctx context.Context) {
defer e.wg.Done()

for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -203,22 +206,13 @@ func (e *Endpoint) connect(ctx context.Context) {
// disconnect relays client disconnections to the endpoint for processing.
// It must be run as a goroutine.
func (e *Endpoint) disconnect(ctx context.Context) {
for {
select {
case <-ctx.Done():
e.clientsMtx.Lock()
for _, client := range e.clients {
client.cancel()
}
e.clientsMtx.Unlock()

e.wg.Done()
return

case <-e.discCh:
e.wg.Done()
}
<-ctx.Done()
e.clientsMtx.Lock()
for _, client := range e.clients {
client.cancel()
}
e.clientsMtx.Unlock()
e.wg.Done()
}

// generateHashIDs generates hash ids of all client connections to the pool.
Expand All @@ -238,8 +232,8 @@ func (e *Endpoint) generateHashIDs() map[string]struct{} {
// run handles the lifecycle of all endpoint related processes.
// This should be run as a goroutine.
func (e *Endpoint) run(ctx context.Context) {
e.wg.Add(1)
go e.listen()
e.wg.Add(3)
go e.listen(ctx)
go e.connect(ctx)
go e.disconnect(ctx)
e.wg.Wait()
Expand Down
5 changes: 1 addition & 4 deletions pool/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ func testEndpoint(t *testing.T) {
t.Fatalf("[NewEndpoint] unexpected error: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
endpoint.wg.Add(1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
Expand Down Expand Up @@ -235,7 +234,5 @@ func testEndpoint(t *testing.T) {
defer conn.Close()

cancel()
// TODO: This never finishes because endpoint.run never actually finishes
// due to the internal waitgroup not being handled properly.
// wg.Wait()
wg.Wait()
}
Loading