diff --git a/internal/backend/basicstation/backend.go b/internal/backend/basicstation/backend.go index 1faee8cc..f970a3ca 100644 --- a/internal/backend/basicstation/backend.go +++ b/internal/backend/basicstation/backend.go @@ -84,7 +84,7 @@ func NewBackend(conf config.Config) (*Backend, error) { scheme: "ws", gateways: gateways{ - gateways: make(map[lorawan.EUI64]gateway), + gateways: make(map[lorawan.EUI64]*connection), }, caCert: conf.Backend.BasicStation.CACert, @@ -312,11 +312,11 @@ func (b *Backend) Stop() error { return b.ln.Close() } -func (b *Backend) handleRouterInfo(r *http.Request, c *websocket.Conn) { +func (b *Backend) handleRouterInfo(r *http.Request, conn *connection) { websocketReceiveCounter("router_info").Inc() var req structs.RouterInfoRequest - if err := c.ReadJSON(&req); err != nil { + if err := conn.conn.ReadJSON(&req); err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.WithError(err).Error("backend/basicstation: read message error") } @@ -345,8 +345,11 @@ func (b *Backend) handleRouterInfo(r *http.Request, c *websocket.Conn) { return } - c.SetWriteDeadline(time.Now().Add(b.writeTimeout)) - if err := c.WriteMessage(websocket.TextMessage, bb); err != nil { + conn.Lock() + defer conn.Unlock() + + conn.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) + if err := conn.conn.WriteMessage(websocket.TextMessage, bb); err != nil { log.WithError(err).Error("backend/basicstation: websocket send message error") return } @@ -358,7 +361,7 @@ func (b *Backend) handleRouterInfo(r *http.Request, c *websocket.Conn) { }).Info("backend/basicstation: router-info request received") } -func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) { +func (b *Backend) handleGateway(r *http.Request, conn *connection) { // get the gateway id from the url urlParts := strings.Split(r.URL.Path, "/") if len(urlParts) < 2 { @@ -391,7 +394,7 @@ func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) { } // set the gateway connection - if err := b.gateways.set(gatewayID, gateway{conn: c}); err != nil { + if err := b.gateways.set(gatewayID, conn); err != nil { log.WithError(err).WithField("gateway_id", gatewayID).Error("backend/basicstation: set gateway error") } log.WithFields(log.Fields{ @@ -466,7 +469,7 @@ func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) { // receive data for { - mt, msg, err := c.ReadMessage() + mt, msg, err := conn.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.WithField("gateway_id", gatewayID).WithError(err).Error("backend/basicstation: read message error") @@ -475,7 +478,7 @@ func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) { } // reset the read deadline as the Basic Station doesn't respond to PONG messages (yet) - c.SetReadDeadline(time.Now().Add(b.readTimeout)) + conn.conn.SetReadDeadline(time.Now().Add(b.readTimeout)) if mt == websocket.BinaryMessage { log.WithFields(log.Fields{ @@ -768,11 +771,14 @@ func (b *Backend) handleTimeSync(gatewayID lorawan.EUI64, v structs.TimeSyncRequ } func (b *Backend) sendToGateway(gatewayID lorawan.EUI64, v interface{}) error { - gw, err := b.gateways.get(gatewayID) + conn, err := b.gateways.get(gatewayID) if err != nil { return errors.Wrap(err, "get gateway error") } + conn.Lock() + defer conn.Unlock() + bb, err := json.Marshal(v) if err != nil { return errors.Wrap(err, "marshal json error") @@ -783,8 +789,8 @@ func (b *Backend) sendToGateway(gatewayID lorawan.EUI64, v interface{}) error { "message": string(bb), }).Debug("sending message to gateway") - gw.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) - if err := gw.conn.WriteMessage(websocket.TextMessage, bb); err != nil { + conn.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) + if err := conn.conn.WriteMessage(websocket.TextMessage, bb); err != nil { return errors.Wrap(err, "send message to gateway error") } @@ -792,20 +798,23 @@ func (b *Backend) sendToGateway(gatewayID lorawan.EUI64, v interface{}) error { } func (b *Backend) sendRawToGateway(gatewayID lorawan.EUI64, messageType int, data []byte) error { - gw, err := b.gateways.get(gatewayID) + conn, err := b.gateways.get(gatewayID) if err != nil { return errors.Wrap(err, "get gateway error") } - gw.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) - if err := gw.conn.WriteMessage(messageType, data); err != nil { + conn.Lock() + defer conn.Unlock() + + conn.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) + if err := conn.conn.WriteMessage(messageType, data); err != nil { return errors.Wrap(err, "send message to gateway error") } return nil } -func (b *Backend) websocketWrap(handler func(*http.Request, *websocket.Conn), w http.ResponseWriter, r *http.Request) { +func (b *Backend) websocketWrap(handler func(*http.Request, *connection), w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.WithError(err).Error("backend/basicstation: websocket upgrade error") @@ -824,23 +833,29 @@ func (b *Backend) websocketWrap(handler func(*http.Request, *websocket.Conn), w defer ticker.Stop() done := make(chan struct{}) + // Wrap the conn inside a gateway struct, so that we can lock it when writing + // data. + c := connection{conn: conn} + go func() { for { select { case <-ticker.C: + c.Lock() websocketPingPongCounter("ping").Inc() - conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) + c.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { log.WithError(err).Error("backend/basicstation: send ping message error") - conn.Close() + c.conn.Close() } + c.Unlock() case <-done: return } } }() - handler(r, conn) + handler(r, &c) done <- struct{}{} } diff --git a/internal/backend/basicstation/gateway.go b/internal/backend/basicstation/gateway.go index 9a7b65d8..40dfe533 100644 --- a/internal/backend/basicstation/gateway.go +++ b/internal/backend/basicstation/gateway.go @@ -14,19 +14,19 @@ var ( errGatewayDoesNotExist = errors.New("gateway does not exist") ) -type gateway struct { - conn *websocket.Conn - configVersion string +type connection struct { + sync.Mutex + conn *websocket.Conn } type gateways struct { sync.RWMutex - gateways map[lorawan.EUI64]gateway + gateways map[lorawan.EUI64]*connection subscribeEventFunc func(events.Subscribe) } -func (g *gateways) get(id lorawan.EUI64) (gateway, error) { +func (g *gateways) get(id lorawan.EUI64) (*connection, error) { g.RLock() defer g.RUnlock() @@ -37,11 +37,11 @@ func (g *gateways) get(id lorawan.EUI64) (gateway, error) { return gw, nil } -func (g *gateways) set(id lorawan.EUI64, gw gateway) error { +func (g *gateways) set(id lorawan.EUI64, c *connection) error { g.Lock() defer g.Unlock() - g.gateways[id] = gw + g.gateways[id] = c if g.subscribeEventFunc != nil { g.subscribeEventFunc(events.Subscribe{Subscribe: true, GatewayID: id})