Skip to content

Commit

Permalink
Merge pull request #14 from ksysoev/refactor_closing_connection
Browse files Browse the repository at this point in the history
Add context handling to DerivAPI client
  • Loading branch information
ksysoev authored Aug 18, 2024
2 parents 854eb1f + c915b11 commit e800a6c
Showing 1 changed file with 60 additions and 37 deletions.
97 changes: 60 additions & 37 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,25 @@ const (
defaultTimeout = 30 * time.Second
)

// DerivAPI is the main struct for the DerivAPI client.
type DerivAPI struct { //nolint:revive // don't want to change the name for now
reqChan chan APIReqest
Endpoint *url.URL
keepAliveOnDisconnect chan bool
Origin *url.URL
ws *websocket.Conn
closingChan chan int
Lang string
TimeOut time.Duration
lastRequestID int64
AppID int
keepAliveInterval time.Duration
connectionLock sync.Mutex
keepAlive bool
debugEnabled bool
// DerivAPI is the main struct for the DerivAPI client
//
//nolint:revive // don't want to break backward compatibility for now
type DerivAPI struct {
ctx context.Context
Endpoint *url.URL
Origin *url.URL
ws *websocket.Conn
closingChan chan int
reqChan chan APIReqest
cancel context.CancelFunc
Lang string
TimeOut time.Duration
keepAliveInterval time.Duration
AppID int
lastRequestID int64
connectionLock sync.Mutex
keepAlive bool
debugEnabled bool
}

// APIReqest is an interface for all API requests.
Expand Down Expand Up @@ -116,12 +119,15 @@ func NewDerivAPI(endpoint string, appID int, lang, origin string, opts ...APIOpt
connectionLock: sync.Mutex{},
closingChan: make(chan int),
keepAliveInterval: keepAliveInterval,
ctx: context.Background(),
}

for _, opt := range opts {
opt(&api)
}

api.ctx, api.cancel = context.WithCancel(api.ctx)

return &api, nil
}

Expand Down Expand Up @@ -182,21 +188,19 @@ func (api *DerivAPI) Connect() error {
go api.requestMapper(respChan, outputChan, api.reqChan, api.closingChan)

if api.keepAlive {
api.keepAliveOnDisconnect = make(chan bool, 1)

go func(interval time.Duration, onDisconnect chan bool) {
go func(interval time.Duration) {
for {
select {
case <-time.After(interval):
_, err := api.Ping(schema.Ping{Ping: 1})
if err != nil {
return
}
case <-onDisconnect:
case <-api.ctx.Done():
return
}
}
}(api.keepAliveInterval, api.keepAliveOnDisconnect)
}(api.keepAliveInterval)
}

return nil
Expand All @@ -216,12 +220,7 @@ func (api *DerivAPI) Disconnect() {

api.logDebugf("Disconnecting from %s", api.Endpoint.String())

close(api.reqChan)

if api.keepAlive {
api.keepAliveOnDisconnect <- true
close(api.keepAliveOnDisconnect)
}
api.cancel()

api.ws.Close(websocket.StatusNormalClosure, "disconnecting")
api.ws = nil
Expand All @@ -231,14 +230,26 @@ func (api *DerivAPI) Disconnect() {
// It reads requests from the reqChan channel and sends them using the websocket.Message.Send method.
// If an error occurs while sending a request, it calls the Disconnect method to gracefully disconnect from the WebSocket server.
func (api *DerivAPI) requestSender(wsConn *websocket.Conn, reqChan chan []byte) {
for req := range reqChan {
api.logDebugf("Sending request: %s", req)

if err := wsConn.Write(context.TODO(), websocket.MessageText, req); err != nil {
api.logDebugf("Failed to send request: %s", err.Error())
api.Disconnect()
defer func() {
api.Disconnect()
}()

for {
select {
case <-api.ctx.Done():
return
case req, ok := <-reqChan:
if !ok {
return
}

api.logDebugf("Sending request: %s", req)

err := wsConn.Write(api.ctx, websocket.MessageText, req)
if err != nil {
api.logDebugf("Failed to send request: %s", err.Error())
return
}
}
}
}
Expand All @@ -254,8 +265,8 @@ func (api *DerivAPI) handleResponses(wsConn *websocket.Conn, respChan chan []byt
api.Disconnect()
}()

for {
msgType, reader, err := wsConn.Reader(context.TODO())
for api.ctx.Err() == nil {
msgType, reader, err := wsConn.Reader(api.ctx)
if err != nil {
api.logDebugf("Failed to receive response: %s", err.Error())
return
Expand All @@ -275,7 +286,12 @@ func (api *DerivAPI) handleResponses(wsConn *websocket.Conn, respChan chan []byt
}

api.logDebugf("Received response: %s", buffer.String())
respChan <- buffer.Bytes()

select {
case <-api.ctx.Done():
return
case respChan <- buffer.Bytes():
}
}
}

Expand Down Expand Up @@ -315,6 +331,8 @@ func (api *DerivAPI) requestMapper(respChan, outputChan chan []byte, reqChan cha
close(channel)
delete(responseMap, reqID)
}
case <-api.ctx.Done():
return
}
}
}
Expand Down Expand Up @@ -356,6 +374,8 @@ func (api *DerivAPI) SendRequest(reqID int, request, response any) error {
defer api.closeRequestChannel(reqID)

select {
case <-api.ctx.Done():
return fmt.Errorf("connection closed")
case <-time.After(api.TimeOut):
api.logDebugf("Timeout waiting for response for request %d", reqID)
return fmt.Errorf("timeout")
Expand Down Expand Up @@ -385,5 +405,8 @@ func (api *DerivAPI) getNextRequestID() int {

// closeRequestChannel closes the channel that receives the response for a request
func (api *DerivAPI) closeRequestChannel(reqID int) {
api.closingChan <- reqID
select {
case api.closingChan <- reqID:
case <-api.ctx.Done():
}
}

0 comments on commit e800a6c

Please sign in to comment.