Skip to content

Commit

Permalink
Add websocket connection mutex to avoid concurrent writes.
Browse files Browse the repository at this point in the history
Fixes #191.
  • Loading branch information
brocaar committed Apr 6, 2021
1 parent 2097398 commit 5f22f6b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 26 deletions.
53 changes: 34 additions & 19 deletions internal/backend/basicstation/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func NewBackend(conf config.Config) (*Backend, error) {
scheme: "ws",

gateways: gateways{
gateways: make(map[lorawan.EUI64]gateway),
gateways: make(map[lorawan.EUI64]*connection),
},

caCert: conf.Backend.BasicStation.CACert,
Expand Down Expand Up @@ -312,11 +312,11 @@ func (b *Backend) Stop() error {
return b.ln.Close()
}

func (b *Backend) handleRouterInfo(r *http.Request, c *websocket.Conn) {
func (b *Backend) handleRouterInfo(r *http.Request, conn *connection) {
websocketReceiveCounter("router_info").Inc()
var req structs.RouterInfoRequest

if err := c.ReadJSON(&req); err != nil {
if err := conn.conn.ReadJSON(&req); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.WithError(err).Error("backend/basicstation: read message error")
}
Expand Down Expand Up @@ -345,8 +345,11 @@ func (b *Backend) handleRouterInfo(r *http.Request, c *websocket.Conn) {
return
}

c.SetWriteDeadline(time.Now().Add(b.writeTimeout))
if err := c.WriteMessage(websocket.TextMessage, bb); err != nil {
conn.Lock()
defer conn.Unlock()

conn.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
if err := conn.conn.WriteMessage(websocket.TextMessage, bb); err != nil {
log.WithError(err).Error("backend/basicstation: websocket send message error")
return
}
Expand All @@ -358,7 +361,7 @@ func (b *Backend) handleRouterInfo(r *http.Request, c *websocket.Conn) {
}).Info("backend/basicstation: router-info request received")
}

func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) {
func (b *Backend) handleGateway(r *http.Request, conn *connection) {
// get the gateway id from the url
urlParts := strings.Split(r.URL.Path, "/")
if len(urlParts) < 2 {
Expand Down Expand Up @@ -391,7 +394,7 @@ func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) {
}

// set the gateway connection
if err := b.gateways.set(gatewayID, gateway{conn: c}); err != nil {
if err := b.gateways.set(gatewayID, conn); err != nil {
log.WithError(err).WithField("gateway_id", gatewayID).Error("backend/basicstation: set gateway error")
}
log.WithFields(log.Fields{
Expand Down Expand Up @@ -466,7 +469,7 @@ func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) {

// receive data
for {
mt, msg, err := c.ReadMessage()
mt, msg, err := conn.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.WithField("gateway_id", gatewayID).WithError(err).Error("backend/basicstation: read message error")
Expand All @@ -475,7 +478,7 @@ func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) {
}

// reset the read deadline as the Basic Station doesn't respond to PONG messages (yet)
c.SetReadDeadline(time.Now().Add(b.readTimeout))
conn.conn.SetReadDeadline(time.Now().Add(b.readTimeout))

if mt == websocket.BinaryMessage {
log.WithFields(log.Fields{
Expand Down Expand Up @@ -768,11 +771,14 @@ func (b *Backend) handleTimeSync(gatewayID lorawan.EUI64, v structs.TimeSyncRequ
}

func (b *Backend) sendToGateway(gatewayID lorawan.EUI64, v interface{}) error {
gw, err := b.gateways.get(gatewayID)
conn, err := b.gateways.get(gatewayID)
if err != nil {
return errors.Wrap(err, "get gateway error")
}

conn.Lock()
defer conn.Unlock()

bb, err := json.Marshal(v)
if err != nil {
return errors.Wrap(err, "marshal json error")
Expand All @@ -783,29 +789,32 @@ func (b *Backend) sendToGateway(gatewayID lorawan.EUI64, v interface{}) error {
"message": string(bb),
}).Debug("sending message to gateway")

gw.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
if err := gw.conn.WriteMessage(websocket.TextMessage, bb); err != nil {
conn.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
if err := conn.conn.WriteMessage(websocket.TextMessage, bb); err != nil {
return errors.Wrap(err, "send message to gateway error")
}

return nil
}

func (b *Backend) sendRawToGateway(gatewayID lorawan.EUI64, messageType int, data []byte) error {
gw, err := b.gateways.get(gatewayID)
conn, err := b.gateways.get(gatewayID)
if err != nil {
return errors.Wrap(err, "get gateway error")
}

gw.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
if err := gw.conn.WriteMessage(messageType, data); err != nil {
conn.Lock()
defer conn.Unlock()

conn.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
if err := conn.conn.WriteMessage(messageType, data); err != nil {
return errors.Wrap(err, "send message to gateway error")
}

return nil
}

func (b *Backend) websocketWrap(handler func(*http.Request, *websocket.Conn), w http.ResponseWriter, r *http.Request) {
func (b *Backend) websocketWrap(handler func(*http.Request, *connection), w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.WithError(err).Error("backend/basicstation: websocket upgrade error")
Expand All @@ -824,23 +833,29 @@ func (b *Backend) websocketWrap(handler func(*http.Request, *websocket.Conn), w
defer ticker.Stop()
done := make(chan struct{})

// Wrap the conn inside a gateway struct, so that we can lock it when writing
// data.
c := connection{conn: conn}

go func() {
for {
select {
case <-ticker.C:
c.Lock()
websocketPingPongCounter("ping").Inc()
conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
c.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
log.WithError(err).Error("backend/basicstation: send ping message error")
conn.Close()
c.conn.Close()
}
c.Unlock()
case <-done:
return
}
}
}()

handler(r, conn)
handler(r, &c)
done <- struct{}{}
}

Expand Down
14 changes: 7 additions & 7 deletions internal/backend/basicstation/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ var (
errGatewayDoesNotExist = errors.New("gateway does not exist")
)

type gateway struct {
conn *websocket.Conn
configVersion string
type connection struct {
sync.Mutex
conn *websocket.Conn
}

type gateways struct {
sync.RWMutex
gateways map[lorawan.EUI64]gateway
gateways map[lorawan.EUI64]*connection

subscribeEventFunc func(events.Subscribe)
}

func (g *gateways) get(id lorawan.EUI64) (gateway, error) {
func (g *gateways) get(id lorawan.EUI64) (*connection, error) {
g.RLock()
defer g.RUnlock()

Expand All @@ -37,11 +37,11 @@ func (g *gateways) get(id lorawan.EUI64) (gateway, error) {
return gw, nil
}

func (g *gateways) set(id lorawan.EUI64, gw gateway) error {
func (g *gateways) set(id lorawan.EUI64, c *connection) error {
g.Lock()
defer g.Unlock()

g.gateways[id] = gw
g.gateways[id] = c

if g.subscribeEventFunc != nil {
g.subscribeEventFunc(events.Subscribe{Subscribe: true, GatewayID: id})
Expand Down

0 comments on commit 5f22f6b

Please sign in to comment.