diff --git a/internal/integration/mqtt/backend.go b/internal/integration/mqtt/backend.go index 5266bff2..e1915d32 100644 --- a/internal/integration/mqtt/backend.go +++ b/internal/integration/mqtt/backend.go @@ -23,11 +23,11 @@ import ( // Backend implements a MQTT backend. type Backend struct { - sync.RWMutex + auth auth.Authentication - auth auth.Authentication conn paho.Client - closed bool + connMux sync.RWMutex + connClosed bool clientOpts *paho.ClientOptions downlinkFrameFunc func(gw.DownlinkFrame) @@ -35,7 +35,10 @@ type Backend struct { gatewayCommandExecRequestFunc func(gw.GatewayCommandExecRequest) rawPacketForwarderCommandFunc func(gw.RawPacketForwarderCommand) + gatewaysMux sync.RWMutex gateways map[lorawan.EUI64]struct{} + gatewaysSubscribedMux sync.Mutex + gatewaysSubscribed map[lorawan.EUI64]struct{} terminateOnConnectError bool qos uint8 @@ -55,6 +58,7 @@ func NewBackend(conf config.Config) (*Backend, error) { terminateOnConnectError: conf.Integration.MQTT.TerminateOnConnectError, clientOpts: paho.NewClientOptions(), gateways: make(map[lorawan.EUI64]struct{}), + gatewaysSubscribed: make(map[lorawan.EUI64]struct{}), } switch conf.Integration.MQTT.Auth.Type { @@ -140,16 +144,17 @@ func NewBackend(conf config.Config) (*Backend, error) { func (b *Backend) Start() error { b.connectLoop() go b.reconnectLoop() + go b.subscribeLoop() return nil } // Stop stops the integration. func (b *Backend) Stop() error { - b.Lock() - b.closed = true - b.Unlock() + b.connMux.Lock() + defer b.connMux.Unlock() b.conn.Disconnect(250) + b.connClosed = true return nil } @@ -173,45 +178,23 @@ func (b *Backend) SetRawPacketForwarderCommandFunc(f func(gw.RawPacketForwarderC b.rawPacketForwarderCommandFunc = f } -// SetGatewaySubscription (un)subscribes the given gateway. +// SetGatewaySubscription sets or unsets the gateway. +// Note: the actual MQTT (un)subscribe happens in a separate function to avoid +// race conditions in case of connection issues. This way, the gateways map +// always reflect the desired state. func (b *Backend) SetGatewaySubscription(subscribe bool, gatewayID lorawan.EUI64) error { - b.Lock() - defer b.Unlock() - log.WithFields(log.Fields{ "gateway_id": gatewayID, "subscribe": subscribe, - }).Debug("integration/mqtt: set gateway subscription called") - - _, ok := b.gateways[gatewayID] - if ok == subscribe { - return nil - } - - for { - if subscribe { - if err := b.subscribeGateway(gatewayID); err != nil { - log.WithError(err).WithFields(log.Fields{ - "gateway_id": gatewayID, - }).Error("integration/mqtt: subscribe gateway error") - time.Sleep(time.Second) - continue - } - - b.gateways[gatewayID] = struct{}{} - } else { - if err := b.unsubscribeGateway(gatewayID); err != nil { - log.WithError(err).WithFields(log.Fields{ - "gateway_id": gatewayID, - }).Error("integration/mqtt: unsubscribe gateway error") - time.Sleep(time.Second) - continue - } + }).Debug("integration/mqtt: set gateway subscription") - delete(b.gateways, gatewayID) - } + b.gatewaysMux.Lock() + defer b.gatewaysMux.Unlock() - break + if subscribe { + b.gateways[gatewayID] = struct{}{} + } else { + delete(b.gateways, gatewayID) } return nil @@ -265,8 +248,8 @@ func (b *Backend) PublishEvent(gatewayID lorawan.EUI64, event string, id uuid.UU } func (b *Backend) connect() error { - b.Lock() - defer b.Unlock() + b.connMux.Lock() + defer b.connMux.Unlock() if err := b.auth.Update(b.clientOpts); err != nil { return errors.Wrap(err, "integration/mqtt: update authentication error") @@ -300,8 +283,8 @@ func (b *Backend) connectLoop() { func (b *Backend) disconnect() error { mqttDisconnectCounter().Inc() - b.Lock() - defer b.Unlock() + b.connMux.Lock() + defer b.connMux.Unlock() b.conn.Disconnect(250) return nil @@ -310,7 +293,11 @@ func (b *Backend) disconnect() error { func (b *Backend) reconnectLoop() { if b.auth.ReconnectAfter() > 0 { for { - if b.closed { + b.connMux.RLock() + closed := b.connClosed + b.connMux.RUnlock() + + if closed { break } time.Sleep(b.auth.ReconnectAfter()) @@ -326,22 +313,71 @@ func (b *Backend) reconnectLoop() { func (b *Backend) onConnected(c paho.Client) { mqttConnectCounter().Inc() + log.Info("integration/mqtt: connected to mqtt broker") - b.RLock() - defer b.RUnlock() + b.gatewaysSubscribedMux.Lock() + defer b.gatewaysSubscribedMux.Unlock() - log.Info("integration/mqtt: connected to mqtt broker") + // reset the subscriptions as we have a new connection + // note: this is done in the onConnected function because the subscribeLoop + // locks the gatewaysSubscribedMux and will only release it after all + // (un)subscribe operations have been completed. If it would be done in the + // onConnectionLost function, the function could block until the connection + // is restored because the (un)subscribe operations will block until then. + b.gatewaysSubscribed = make(map[lorawan.EUI64]struct{}) +} - for gatewayID := range b.gateways { - for { +func (b *Backend) subscribeLoop() { + for { + b.connMux.RLock() + closed := b.connClosed + b.connMux.RUnlock() + if closed { + break + } + + var subscribe []lorawan.EUI64 + var unsubscribe []lorawan.EUI64 + + b.gatewaysMux.RLock() + b.gatewaysSubscribedMux.Lock() + + // subscribe + for gatewayID := range b.gateways { + if _, ok := b.gatewaysSubscribed[gatewayID]; !ok { + subscribe = append(subscribe, gatewayID) + } + } + + // unsubscribe + for gatewayID := range b.gatewaysSubscribed { + if _, ok := b.gateways[gatewayID]; !ok { + unsubscribe = append(unsubscribe, gatewayID) + } + } + + // unlock gatewaysMux so that SetGatewaySubscription can write again + // to the map, in which case changes are picked up in the next run + b.gatewaysMux.RUnlock() + + for _, gatewayID := range subscribe { if err := b.subscribeGateway(gatewayID); err != nil { log.WithError(err).WithField("gateway_id", gatewayID).Error("integration/mqtt: subscribe gateway error") - time.Sleep(time.Second) - continue + } else { + b.gatewaysSubscribed[gatewayID] = struct{}{} } + } - break + for _, gatewayID := range unsubscribe { + if err := b.unsubscribeGateway(gatewayID); err != nil { + log.WithError(err).WithField("gateway_id", gatewayID).Error("integration/mqtt: unsubscribe gateway error") + } else { + delete(b.gatewaysSubscribed, gatewayID) + } } + + b.gatewaysSubscribedMux.Unlock() + time.Sleep(time.Millisecond * 100) } }