diff --git a/api.go b/api.go index fb66544..109dc8c 100644 --- a/api.go +++ b/api.go @@ -22,22 +22,25 @@ const ( defaultTimeout = 30 * time.Second ) -// DerivAPI is the main struct for the DerivAPI client. -type DerivAPI struct { //nolint:revive // don't want to change the name for now - reqChan chan APIReqest - Endpoint *url.URL - keepAliveOnDisconnect chan bool - Origin *url.URL - ws *websocket.Conn - closingChan chan int - Lang string - TimeOut time.Duration - lastRequestID int64 - AppID int - keepAliveInterval time.Duration - connectionLock sync.Mutex - keepAlive bool - debugEnabled bool +// DerivAPI is the main struct for the DerivAPI client +// +//nolint:revive // don't want to break backward compatibility for now +type DerivAPI struct { + ctx context.Context + Endpoint *url.URL + Origin *url.URL + ws *websocket.Conn + closingChan chan int + reqChan chan APIReqest + cancel context.CancelFunc + Lang string + TimeOut time.Duration + keepAliveInterval time.Duration + AppID int + lastRequestID int64 + connectionLock sync.Mutex + keepAlive bool + debugEnabled bool } // APIReqest is an interface for all API requests. @@ -116,12 +119,15 @@ func NewDerivAPI(endpoint string, appID int, lang, origin string, opts ...APIOpt connectionLock: sync.Mutex{}, closingChan: make(chan int), keepAliveInterval: keepAliveInterval, + ctx: context.Background(), } for _, opt := range opts { opt(&api) } + api.ctx, api.cancel = context.WithCancel(api.ctx) + return &api, nil } @@ -182,9 +188,7 @@ func (api *DerivAPI) Connect() error { go api.requestMapper(respChan, outputChan, api.reqChan, api.closingChan) if api.keepAlive { - api.keepAliveOnDisconnect = make(chan bool, 1) - - go func(interval time.Duration, onDisconnect chan bool) { + go func(interval time.Duration) { for { select { case <-time.After(interval): @@ -192,11 +196,11 @@ func (api *DerivAPI) Connect() error { if err != nil { return } - case <-onDisconnect: + case <-api.ctx.Done(): return } } - }(api.keepAliveInterval, api.keepAliveOnDisconnect) + }(api.keepAliveInterval) } return nil @@ -216,12 +220,7 @@ func (api *DerivAPI) Disconnect() { api.logDebugf("Disconnecting from %s", api.Endpoint.String()) - close(api.reqChan) - - if api.keepAlive { - api.keepAliveOnDisconnect <- true - close(api.keepAliveOnDisconnect) - } + api.cancel() api.ws.Close(websocket.StatusNormalClosure, "disconnecting") api.ws = nil @@ -231,14 +230,26 @@ func (api *DerivAPI) Disconnect() { // It reads requests from the reqChan channel and sends them using the websocket.Message.Send method. // If an error occurs while sending a request, it calls the Disconnect method to gracefully disconnect from the WebSocket server. func (api *DerivAPI) requestSender(wsConn *websocket.Conn, reqChan chan []byte) { - for req := range reqChan { - api.logDebugf("Sending request: %s", req) - - if err := wsConn.Write(context.TODO(), websocket.MessageText, req); err != nil { - api.logDebugf("Failed to send request: %s", err.Error()) - api.Disconnect() + defer func() { + api.Disconnect() + }() + for { + select { + case <-api.ctx.Done(): return + case req, ok := <-reqChan: + if !ok { + return + } + + api.logDebugf("Sending request: %s", req) + + err := wsConn.Write(api.ctx, websocket.MessageText, req) + if err != nil { + api.logDebugf("Failed to send request: %s", err.Error()) + return + } } } } @@ -254,8 +265,8 @@ func (api *DerivAPI) handleResponses(wsConn *websocket.Conn, respChan chan []byt api.Disconnect() }() - for { - msgType, reader, err := wsConn.Reader(context.TODO()) + for api.ctx.Err() == nil { + msgType, reader, err := wsConn.Reader(api.ctx) if err != nil { api.logDebugf("Failed to receive response: %s", err.Error()) return @@ -275,7 +286,12 @@ func (api *DerivAPI) handleResponses(wsConn *websocket.Conn, respChan chan []byt } api.logDebugf("Received response: %s", buffer.String()) - respChan <- buffer.Bytes() + + select { + case <-api.ctx.Done(): + return + case respChan <- buffer.Bytes(): + } } } @@ -315,6 +331,8 @@ func (api *DerivAPI) requestMapper(respChan, outputChan chan []byte, reqChan cha close(channel) delete(responseMap, reqID) } + case <-api.ctx.Done(): + return } } } @@ -356,6 +374,8 @@ func (api *DerivAPI) SendRequest(reqID int, request, response any) error { defer api.closeRequestChannel(reqID) select { + case <-api.ctx.Done(): + return fmt.Errorf("connection closed") case <-time.After(api.TimeOut): api.logDebugf("Timeout waiting for response for request %d", reqID) return fmt.Errorf("timeout") @@ -385,5 +405,8 @@ func (api *DerivAPI) getNextRequestID() int { // closeRequestChannel closes the channel that receives the response for a request func (api *DerivAPI) closeRequestChannel(reqID int) { - api.closingChan <- reqID + select { + case api.closingChan <- reqID: + case <-api.ctx.Done(): + } }