Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ws package and fix reading from ws connection #10

Merged
merged 3 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package deriv

import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
Expand All @@ -10,8 +12,8 @@ import (
"sync/atomic"
"time"

"github.com/coder/websocket"
"github.com/ksysoev/deriv-api/schema"
"golang.org/x/net/websocket"
)

// DerivAPI is the main struct for the DerivAPI client.
Expand All @@ -36,7 +38,7 @@ type DerivAPI struct {
type ApiReqest struct {
id int
msg []byte
respChan chan string
respChan chan []byte
}

// ApiObjectRequest is an interface for all API requests that return an object.
Expand Down Expand Up @@ -152,17 +154,26 @@ func (api *DerivAPI) Connect() error {

api.logDebugf("Connecting to %s", api.Endpoint.String())

ws, err := websocket.Dial(api.Endpoint.String(), "", api.Origin.String())
ws, resp, err := websocket.Dial(context.TODO(), api.Endpoint.String(), &websocket.DialOptions{
HTTPHeader: map[string][]string{
"Origin": {api.Origin.String()},
},
})
if err != nil {
api.logDebugf("Failed to establish WS connection: %s", err.Error())
return err
}

if resp.Body != nil {
defer resp.Body.Close()
}

api.logDebugf("Connected to %s", api.Endpoint.String())

api.ws = ws

api.reqChan = make(chan ApiReqest)
respChan := make(chan string)
respChan := make(chan []byte)
outputChan := make(chan []byte)

go api.handleResponses(ws, respChan)
Expand Down Expand Up @@ -212,7 +223,7 @@ func (api *DerivAPI) Disconnect() {
close(api.keepAliveOnDisconnect)
}

api.ws.Close()
api.ws.Close(websocket.StatusNormalClosure, "disconnecting")
api.ws = nil
}

Expand All @@ -222,7 +233,8 @@ func (api *DerivAPI) Disconnect() {
func (api *DerivAPI) requestSender(wsConn *websocket.Conn, reqChan chan []byte) {
for req := range reqChan {
api.logDebugf("Sending request: %s", req)
err := websocket.Message.Send(wsConn, req)

err := wsConn.Write(context.TODO(), websocket.MessageText, req)

if err != nil {
api.logDebugf("Failed to send request: %s", err.Error())
Expand All @@ -236,27 +248,42 @@ func (api *DerivAPI) requestSender(wsConn *websocket.Conn, reqChan chan []byte)
// It reads responses using the websocket.Message.Receive method and sends them to the respChan channel.
// If an error occurs while receiving a response, it calls the Disconnect method to gracefully disconnect from the WebSocket server.
// The function returns when the WebSocket connection is closed or when an error occurs.
func (api *DerivAPI) handleResponses(wsConn *websocket.Conn, respChan chan string) {
defer close(respChan)
func (api *DerivAPI) handleResponses(wsConn *websocket.Conn, respChan chan []byte) {
defer func() {
close(respChan)

api.Disconnect()
}()

for {
var msg string
err := websocket.Message.Receive(wsConn, &msg)
msgType, reader, err := wsConn.Reader(context.TODO())
if err != nil {
api.logDebugf("Failed to receive response: %s", err.Error())
api.Disconnect()
return
}

api.logDebugf("Received response: %s", msg)
respChan <- msg
if msgType != websocket.MessageText {
api.logDebugf("Unexpected message type: %d", msgType)
continue
}

buffer := bytes.NewBuffer(make([]byte, 0))
_, err = buffer.ReadFrom(reader)

if err != nil {
api.logDebugf("Failed to read response: %s", err.Error())
return
}

api.logDebugf("Received response: %s", buffer.String())
respChan <- buffer.Bytes()
}
}

// requestMapper forward requests to the Deriv API server and
// responses from the WebSocket server to the appropriate channels.
func (api *DerivAPI) requestMapper(respChan chan string, outputChan chan []byte, reqChan chan ApiReqest, closingChan chan int) {
responseMap := make(map[int]chan string)
func (api *DerivAPI) requestMapper(respChan chan []byte, outputChan chan []byte, reqChan chan ApiReqest, closingChan chan int) {
responseMap := make(map[int]chan []byte)

defer func() {
for reqID, channel := range responseMap {
Expand All @@ -269,7 +296,7 @@ func (api *DerivAPI) requestMapper(respChan chan string, outputChan chan []byte,
select {
case rawResp := <-respChan:
var response APIResponseReqID
err := json.Unmarshal([]byte(rawResp), &response)
err := json.Unmarshal(rawResp, &response)
if err != nil {
continue
}
Expand Down Expand Up @@ -297,7 +324,7 @@ func (api *DerivAPI) requestMapper(respChan chan string, outputChan chan []byte,
}

// Send sends a request to the Deriv API and returns a channel that will receive the response
func (api *DerivAPI) Send(reqID int, request any) (chan string, error) {
func (api *DerivAPI) Send(reqID int, request any) (chan []byte, error) {
err := api.Connect()

if err != nil {
Expand All @@ -309,7 +336,7 @@ func (api *DerivAPI) Send(reqID int, request any) (chan string, error) {
return nil, err
}

respChan := make(chan string)
respChan := make(chan []byte)

ApiReqest := ApiReqest{
id: reqID,
Expand Down
2 changes: 1 addition & 1 deletion api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func TestSend(t *testing.T) {

msg := <-respChan
testMsg := "{\"req_id\":1}"
if msg != testMsg {
if string(msg) != testMsg {
t.Errorf("Expected message to be %s, but got %s", testMsg, msg)
}
}
Expand Down
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type APIErrorResponse struct {
// If the response is not a valid JSON-encoded APIErrorResponse, an error is returned.
// If the APIErrorResponse contains a non-empty APIError, it is returned as an error.
// Otherwise, nil is returned.
func parseError(rawResponse string) error {
func parseError(rawResponse []byte) error {
var errorResponse APIErrorResponse

err := json.Unmarshal([]byte(rawResponse), &errorResponse)
Expand Down
8 changes: 4 additions & 4 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ func TestParseError_ValidResponse(t *testing.T) {
}

expected := &errorResponse.Error
actual := parseError(string(rawResponse))
actual := parseError(rawResponse)

if actual.Error() != expected.Error() {
t.Errorf("parseError() returned %v, expected %v", actual, expected)
}
}

func TestParseError_InvalidResponse(t *testing.T) {
rawResponse := "invalid-json"
rawResponse := []byte("invalid-json")

expected := errors.New("invalid character 'i' looking for beginning of value")
actual := parseError(rawResponse)
Expand All @@ -55,7 +55,7 @@ func TestParseError_InvalidResponse(t *testing.T) {
}

func TestParseError_EmptyErrorResponse(t *testing.T) {
rawResponse := "{}"
rawResponse := []byte("{}")

expected := (error)(nil)
actual := parseError(rawResponse)
Expand All @@ -76,7 +76,7 @@ func TestParseError_EmptyAPIError(t *testing.T) {
}

expected := (error)(nil)
actual := parseError(string(rawResponse))
actual := parseError(rawResponse)

if actual != expected {
t.Errorf("parseError() returned %v, expected %v", actual, expected)
Expand Down
2 changes: 1 addition & 1 deletion examples/candles/candles.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func main() {
api, err := deriv.NewDerivAPI("wss://ws.binaryws.com/websockets/v3", 1089, "en", "https://localhost/")
api, err := deriv.NewDerivAPI("wss://ws.derivws.com/websockets/v3", 1089, "en", "https://localhost/")

if err != nil {
log.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion examples/creating_app/create_app.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
const ApiToken = "YOUR_API_TOKEN_HERE" // Replace with your API token

func main() {
api, err := deriv.NewDerivAPI("wss://ws.binaryws.com/websockets/v3", 1089, "en", "https://localhost/")
api, err := deriv.NewDerivAPI("wss://ws.derivws.com/websockets/v3", 1089, "en", "https://localhost/")

if err != nil {
log.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion examples/tick_history/tick_history.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func main() {
api, err := deriv.NewDerivAPI("wss://ws.binaryws.com/websockets/v3", 1089, "en", "https://localhost/")
api, err := deriv.NewDerivAPI("wss://ws.derivws.com/websockets/v3", 1089, "en", "https://localhost/")

if err != nil {
log.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion examples/tick_stream/tick_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func main() {
api, err := deriv.NewDerivAPI("wss://ws.binaryws.com/websockets/v3", 1089, "en", "https://localhost/", deriv.Debug)
api, err := deriv.NewDerivAPI("wss://ws.derivws.com/websockets/v3", 36544, "en", "https://localhost/", deriv.Debug)

if err != nil {
log.Fatal(err)
Expand Down
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@ module github.com/ksysoev/deriv-api
go 1.20

require golang.org/x/net v0.24.0

require (
github.com/coder/websocket v1.8.12 // indirect
nhooyr.io/websocket v1.8.17 // indirect
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y=
nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
8 changes: 4 additions & 4 deletions subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ type SubscriptionIDResponse struct {
// returns it along with any error that occurs during the deserialization process.
// If the response contains an error code, it returns a SubscriptionResponse
// struct and an error that wraps the Error field of the struct.
func parseSubsciption(rawResponse string) (SubscriptionResponse, error) {
func parseSubsciption(rawResponse []byte) (SubscriptionResponse, error) {
var sub SubscriptionResponse

err := json.Unmarshal([]byte(rawResponse), &sub)
err := json.Unmarshal(rawResponse, &sub)
if err != nil {
return sub, err
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func (s *Subsciption[initResp, Resp]) Start(reqID int, request any) (initResp, e
panic("Response object must implement ApiResponse interface")
}

err = apiResp.UnmarshalJSON([]byte(initResponse))
err = apiResp.UnmarshalJSON(initResponse)
if err != nil {
s.API.logDebugf("Failed to parse response for request %d: %s", reqID, err.Error())
s.API.closeRequestChannel(reqID)
Expand All @@ -151,7 +151,7 @@ func (s *Subsciption[initResp, Resp]) Start(reqID int, request any) (initResp, e
}

// messageHandler is a goroutine that handles subscription updates received on the channel passed to it.
func (s *Subsciption[initResp, Resp]) messageHandler(inChan chan string) {
func (s *Subsciption[initResp, Resp]) messageHandler(inChan chan []byte) {
defer func() {
s.statusLock.Lock()
if s.isActive {
Expand Down
10 changes: 5 additions & 5 deletions subscriptions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func TestParseSubscription_ValidInput(t *testing.T) {
input := `{"subscription": {"id": "123"}}`
input := []byte(`{"subscription": {"id": "123"}}`)
expected := SubscriptionResponse{Subscription: SubscriptionIDResponse{ID: "123"}}
result, err := parseSubsciption(input)
if err != nil {
Expand All @@ -25,15 +25,15 @@ func TestParseSubscription_ValidInput(t *testing.T) {
}

func TestParseSubscription_InvalidJSONInput(t *testing.T) {
input := `{"subscription": {"id": "123", "status": "active"`
input := []byte(`{"subscription": {"id": "123", "status": "active"`)
_, err := parseSubsciption(input)
if err == nil {
t.Errorf("Expected an error, but got nil")
}
}

func TestParseSubscription_InvalidSubscriptionData(t *testing.T) {
input := `{"subscription": {"id": "123", "status": "active"}, "error": {"code": "invalid_subscription"}}`
input := []byte(`{"subscription": {"id": "123", "status": "active"}, "error": {"code": "invalid_subscription"}}`)
expectedErr := &APIError{Code: "invalid_subscription"}
_, err := parseSubsciption(input)
if err == nil {
Expand All @@ -45,14 +45,14 @@ func TestParseSubscription_InvalidSubscriptionData(t *testing.T) {
}

func TestParseSubscription_EmptyInput(t *testing.T) {
_, err := parseSubsciption("")
_, err := parseSubsciption([]byte(""))
if err == nil {
t.Errorf("Expected an error, but got nil")
}
}

func TestParseSubscription_EmptySubscriptionData(t *testing.T) {
input := `{}`
input := []byte(`{}`)
expectedErr := fmt.Errorf("subscription ID is empty")
_, err := parseSubsciption(input)
if err == nil {
Expand Down
Loading