From 1a5031a9ca06d5c515a2c12dcdbd94b892f4ca13 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Wed, 10 May 2023 17:40:21 +0200 Subject: [PATCH 1/2] gs: Allow compare-and-swap semantics for gateway connection stats --- pkg/gatewayserver/gatewayserver.go | 25 +++++- pkg/gatewayserver/gatewayserver_test.go | 18 ++-- pkg/gatewayserver/grpc.go | 2 +- pkg/gatewayserver/redis/registry.go | 79 ++++++++++++++---- pkg/gatewayserver/redis/registry_test.go | 101 +++++++++++++++++------ pkg/gatewayserver/registry.go | 10 ++- 6 files changed, 174 insertions(+), 61 deletions(-) diff --git a/pkg/gatewayserver/gatewayserver.go b/pkg/gatewayserver/gatewayserver.go index d1a43b4cd1..270792c3cc 100644 --- a/pkg/gatewayserver/gatewayserver.go +++ b/pkg/gatewayserver/gatewayserver.go @@ -970,7 +970,14 @@ func (gs *GatewayServer) updateConnStats(ctx context.Context, conn connectionEnt } registerGatewayConnectionStats(ctx, ids, stats) if gs.statsRegistry != nil { - if err := gs.statsRegistry.Set(decoupledCtx, ids, stats, ttnpb.GatewayConnectionStatsFieldPathsTopLevel, gs.config.ConnectionStatsTTL); err != nil { + if err := gs.statsRegistry.Set( + decoupledCtx, + ids, + func(*ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + return stats, ttnpb.GatewayConnectionStatsFieldPathsTopLevel, nil + }, + gs.config.ConnectionStatsTTL, + ); err != nil { logger.WithError(err).Warn("Failed to initialize connection stats") } } @@ -986,8 +993,11 @@ func (gs *GatewayServer) updateConnStats(ctx context.Context, conn connectionEnt return } if err := gs.statsRegistry.Set( - decoupledCtx, ids, stats, - []string{"connected_at", "disconnected_at"}, + decoupledCtx, + ids, + func(*ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + return stats, []string{"connected_at", "disconnected_at"}, nil + }, gs.config.ConnectionStatsDisconnectTTL, ); err != nil { logger.WithError(err).Warn("Failed to clear connection stats") @@ -1028,7 +1038,14 @@ func (gs *GatewayServer) updateConnStats(ctx context.Context, conn connectionEnt if gs.statsRegistry == nil { continue } - if err := gs.statsRegistry.Set(decoupledCtx, ids, stats, paths, gs.config.ConnectionStatsTTL); err != nil { + if err := gs.statsRegistry.Set( + decoupledCtx, + ids, + func(*ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + return stats, paths, nil + }, + gs.config.ConnectionStatsTTL, + ); err != nil { logger.WithError(err).Warn("Failed to update connection stats") } } diff --git a/pkg/gatewayserver/gatewayserver_test.go b/pkg/gatewayserver/gatewayserver_test.go index 141f500429..a1b9dcb032 100644 --- a/pkg/gatewayserver/gatewayserver_test.go +++ b/pkg/gatewayserver/gatewayserver_test.go @@ -103,9 +103,9 @@ func TestGatewayServer(t *testing.T) { gsConfig := &gatewayserver.Config{ RequireRegisteredGateways: false, UpdateGatewayLocationDebounceTime: 0, - UpdateConnectionStatsInterval: time.Second, - ConnectionStatsTTL: (1 << 3) * test.Delay, - ConnectionStatsDisconnectTTL: (1 << 6) * test.Delay, + UpdateConnectionStatsInterval: (1 << 5) * test.Delay, + ConnectionStatsTTL: (1 << 6) * test.Delay, + ConnectionStatsDisconnectTTL: (1 << 7) * test.Delay, Stats: statsRegistry, FetchGatewayInterval: time.Minute, FetchGatewayJitter: 1, @@ -734,7 +734,7 @@ func TestGatewayServer(t *testing.T) { }) // Wait for gateway disconnection to be processed. - time.Sleep(6 * timeout) + time.Sleep(2 * config.ConnectionStatsDisconnectTTL) t.Run(fmt.Sprintf("Traffic/%v", ptc.Protocol), func(t *testing.T) { a := assertions.New(t) @@ -1472,9 +1472,7 @@ func TestGatewayServer(t *testing.T) { stats, paths := conn.Stats() a.So(stats, should.NotBeNil) - if statsRegistry != nil { - a.So(statsRegistry.Set(conn.Context(), ids, stats, paths, 0), should.BeNil) - } + a.So(paths, should.NotBeEmpty) stats, err := statsClient.GetGatewayConnectionStats(statsCtx, ids) if !a.So(err, should.BeNil) { @@ -1803,9 +1801,7 @@ func TestGatewayServer(t *testing.T) { stats, paths := conn.Stats() a.So(stats, should.NotBeNil) - if config.Stats != nil { - a.So(config.Stats.Set(conn.Context(), ids, stats, paths, 0), should.BeNil) - } + a.So(paths, should.NotBeEmpty) stats, err = statsClient.GetGatewayConnectionStats(statsCtx, ids) if !a.So(err, should.BeNil) { @@ -1823,7 +1819,7 @@ func TestGatewayServer(t *testing.T) { } // Wait for disconnection to be processed. - time.Sleep(4 * config.ConnectionStatsDisconnectTTL) + time.Sleep(2 * config.ConnectionStatsDisconnectTTL) // After canceling the context and awaiting the link, the connection should be gone. t.Run("Disconnected", func(t *testing.T) { diff --git a/pkg/gatewayserver/grpc.go b/pkg/gatewayserver/grpc.go index 34f8ba8c10..03a6cbaef4 100644 --- a/pkg/gatewayserver/grpc.go +++ b/pkg/gatewayserver/grpc.go @@ -80,7 +80,7 @@ func (gs *GatewayServer) BatchGetGatewayConnectionStats( } if gs.statsRegistry != nil { - entries, err := gs.statsRegistry.BatchGet(ctx, req.GatewayIds, req.FieldMask.GetPaths()) + entries, err := gs.statsRegistry.BatchGet(ctx, req.GatewayIds, req.FieldMask.GetPaths()...) if err != nil { return nil, err } diff --git a/pkg/gatewayserver/redis/registry.go b/pkg/gatewayserver/redis/registry.go index ead2e2280c..cf994f9cc5 100644 --- a/pkg/gatewayserver/redis/registry.go +++ b/pkg/gatewayserver/redis/registry.go @@ -35,10 +35,7 @@ type GatewayConnectionStatsRegistry struct { // Init initializes the GatewayConnectionStatsRegistry. func (r *GatewayConnectionStatsRegistry) Init(ctx context.Context) error { - if err := ttnredis.InitMutex(ctx, r.Redis); err != nil { - return err - } - return nil + return ttnredis.InitMutex(ctx, r.Redis) } func (r *GatewayConnectionStatsRegistry) key(uid string) string { @@ -46,7 +43,13 @@ func (r *GatewayConnectionStatsRegistry) key(uid string) string { } // Set sets or clears the connection stats for a gateway. -func (r *GatewayConnectionStatsRegistry) Set(ctx context.Context, ids *ttnpb.GatewayIdentifiers, stats *ttnpb.GatewayConnectionStats, paths []string, ttl time.Duration) error { +func (r *GatewayConnectionStatsRegistry) Set( + ctx context.Context, + ids *ttnpb.GatewayIdentifiers, + f func(*ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error), + ttl time.Duration, + gets ...string, +) error { uid := unique.ID(ctx, ids) lockerID, err := ttnredis.GenerateLockerID() @@ -57,21 +60,61 @@ func (r *GatewayConnectionStatsRegistry) Set(ctx context.Context, ids *ttnpb.Gat defer trace.StartRegion(ctx, "set gateway connection stats").End() uk := r.key(uid) - if stats == nil { - err = r.Redis.Del(ctx, uk).Err() - } else { - err = ttnredis.LockedWatch(ctx, r.Redis, uk, lockerID, r.LockTTL, func(tx *redis.Tx) error { - pb := &ttnpb.GatewayConnectionStats{} - if err := ttnredis.GetProto(ctx, tx, uk).ScanProto(pb); err != nil && !errors.IsNotFound(err) { + err = ttnredis.LockedWatch(ctx, r.Redis, uk, lockerID, r.LockTTL, func(tx *redis.Tx) error { + stored := &ttnpb.GatewayConnectionStats{} + cmd := ttnredis.GetProto(ctx, tx, uk) + if err := cmd.ScanProto(stored); errors.IsNotFound(err) { + stored = nil + } else if err != nil { + return err + } + + var pb *ttnpb.GatewayConnectionStats + if stored != nil { + pb = &ttnpb.GatewayConnectionStats{} + if err := cmd.ScanProto(pb); err != nil { return err } - if err := pb.SetFields(stats, paths...); err != nil { + if pb, err = applyGatewayConnectionStatsFieldMask(nil, pb, gets...); err != nil { return err } - _, err := ttnredis.SetProto(ctx, tx, uk, pb, ttl) + } + + var sets []string + pb, sets, err = f(pb) + if err != nil { return err - }) - } + } + if stored == nil && pb == nil { + return nil + } + var pipelined func(redis.Pipeliner) error + if pb == nil { + pipelined = func(p redis.Pipeliner) error { + p.Del(ctx, uk) + return nil + } + } else { + updated := &ttnpb.GatewayConnectionStats{} + if stored != nil { + if err := cmd.ScanProto(updated); err != nil { + return err + } + } + if updated, err = applyGatewayConnectionStatsFieldMask(updated, pb, sets...); err != nil { + return err + } + if err := updated.ValidateFields(); err != nil { + return err + } + pipelined = func(p redis.Pipeliner) error { + _, err = ttnredis.SetProto(ctx, p, uk, updated, ttl) + return err + } + } + _, err = tx.TxPipelined(ctx, pipelined) + return err + }) if err != nil { return ttnredis.ConvertError(err) } @@ -79,7 +122,9 @@ func (r *GatewayConnectionStatsRegistry) Set(ctx context.Context, ids *ttnpb.Gat } // Get returns the connection stats for a gateway. -func (r *GatewayConnectionStatsRegistry) Get(ctx context.Context, ids *ttnpb.GatewayIdentifiers) (*ttnpb.GatewayConnectionStats, error) { +func (r *GatewayConnectionStatsRegistry) Get( + ctx context.Context, ids *ttnpb.GatewayIdentifiers, +) (*ttnpb.GatewayConnectionStats, error) { uid := unique.ID(ctx, ids) result := &ttnpb.GatewayConnectionStats{} if err := ttnredis.GetProto(ctx, r.Redis, r.key(uid)).ScanProto(result); err != nil { @@ -104,7 +149,7 @@ func applyGatewayConnectionStatsFieldMask( func (r *GatewayConnectionStatsRegistry) BatchGet( ctx context.Context, ids []*ttnpb.GatewayIdentifiers, - paths []string, + paths ...string, ) (map[string]*ttnpb.GatewayConnectionStats, error) { ret := make(map[string]*ttnpb.GatewayConnectionStats, len(ids)) keys := make([]string, 0, len(ids)) diff --git a/pkg/gatewayserver/redis/registry_test.go b/pkg/gatewayserver/redis/registry_test.go index 39622749be..7fc4af9e7d 100644 --- a/pkg/gatewayserver/redis/registry_test.go +++ b/pkg/gatewayserver/redis/registry_test.go @@ -71,35 +71,52 @@ func TestRegistry(t *testing.T) { a.So(errors.IsNotFound(err), should.BeTrue) batchStats, err := registry.BatchGet(ctx, []*ttnpb.GatewayIdentifiers{ ids, - }, nil) + }) a.So(err, should.BeNil) a.So(len(batchStats), should.Equal, 0) }) + emptyStatsClearUpdate := func(pb *ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + a.So(pb, should.BeNil) + return nil, nil, nil + } + nonEmptyStatsCleanUpdate := func(pb *ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + a.So(pb, should.NotBeNil) + return nil, nil, nil + } + t.Run("EmptyStats", func(t *testing.T) { a, ctx := test.New(t) - err := registry.Set(ctx, ids3, nil, []string{}, 0) + err := registry.Set(ctx, ids3, emptyStatsClearUpdate, 0) a.So(err, should.BeNil) retrieved, err := registry.Get(ctx, ids3) a.So(retrieved, should.BeNil) a.So(errors.IsNotFound(err), should.BeTrue) batchStats, err := registry.BatchGet(ctx, []*ttnpb.GatewayIdentifiers{ ids3, - }, nil) + }) a.So(err, should.BeNil) a.So(len(batchStats), should.Equal, 0) }) t.Run("SetAndClear", func(t *testing.T) { a, ctx := test.New(t) - err := registry.Set(ctx, ids, initialStats, []string{ - "connected_at", - "protocol", - "last_downlink_received_at", - "downlink_count", - "last_uplink_received_at", - "uplink_count", - }, 0) + err := registry.Set( + ctx, + ids, + func(pb *ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + a.So(pb, should.BeNil) + return initialStats, []string{ + "connected_at", + "protocol", + "last_downlink_received_at", + "downlink_count", + "last_uplink_received_at", + "uplink_count", + }, nil + }, + 0, + ) a.So(err, should.BeNil) retrieved, err := registry.Get(ctx, ids) a.So(err, should.BeNil) @@ -115,12 +132,12 @@ func TestRegistry(t *testing.T) { ids, ids2, ids3, - }, nil) + }) a.So(err, should.BeNil) a.So(len(batchStats), should.Equal, 1) // Unset - err = registry.Set(ctx, ids, nil, nil, 0) + err = registry.Set(ctx, ids, nonEmptyStatsCleanUpdate, 0) a.So(err, should.BeNil) retrieved, err = registry.Get(ctx, ids) a.So(errors.IsNotFound(err), should.BeTrue) @@ -129,8 +146,8 @@ func TestRegistry(t *testing.T) { t.Run("ClearManyTimes", func(t *testing.T) { a, ctx := test.New(t) - a.So(registry.Set(ctx, ids, nil, nil, 0), should.BeNil) - a.So(registry.Set(ctx, ids, nil, nil, 0), should.BeNil) + a.So(registry.Set(ctx, ids, emptyStatsClearUpdate, 0), should.BeNil) + a.So(registry.Set(ctx, ids, emptyStatsClearUpdate, 0), should.BeNil) }) t.Run("SetWithTTL", func(t *testing.T) { @@ -139,7 +156,15 @@ func TestRegistry(t *testing.T) { DisconnectedAt: timestamppb.New(time.Date(2021, 12, 2, 11, 24, 58, 0, time.UTC)), } - err := registry.Set(ctx, ids, stats, []string{"disconnected_at"}, Timeout) + err := registry.Set( + ctx, + ids, + func(pb *ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + a.So(pb, should.BeNil) + return stats, []string{"disconnected_at"}, nil + }, + Timeout, + ) a.So(err, should.BeNil) // all data should exist @@ -164,10 +189,18 @@ func TestRegistry(t *testing.T) { DownlinkCount: 1, } - err := registry.Set(ctx, ids, stats, []string{ - "uplink_count", - "last_uplink_received_at", - }, 0) + err := registry.Set( + ctx, + ids, + func(pb *ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + a.So(pb, should.BeNil) + return stats, []string{ + "uplink_count", + "last_uplink_received_at", + }, nil + }, + 0, + ) a.So(err, should.BeNil) retrieved, err := registry.Get(ctx, ids) a.So(err, should.BeNil) @@ -177,7 +210,15 @@ func TestRegistry(t *testing.T) { }) // Now update downlink also - err = registry.Set(ctx, ids, stats, []string{"downlink_count"}, 0) + err = registry.Set( + ctx, + ids, + func(pb *ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + a.So(pb, should.NotBeNil) + return stats, []string{"downlink_count"}, nil + }, + 0, + ) a.So(err, should.BeNil) retrieved, err = registry.Get(ctx, ids) a.So(err, should.BeNil) @@ -191,11 +232,19 @@ func TestRegistry(t *testing.T) { stats.LastUplinkReceivedAt = nil stats.UplinkCount = 0 stats.DownlinkCount = 2 - err = registry.Set(ctx, ids, stats, []string{ - "uplink_count", - "last_uplink_received_at", - "downlink_count", - }, 0) + err = registry.Set( + ctx, + ids, + func(pb *ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + a.So(pb, should.NotBeNil) + return stats, []string{ + "uplink_count", + "last_uplink_received_at", + "downlink_count", + }, nil + }, + 0, + ) a.So(err, should.BeNil) retrieved, err = registry.Get(ctx, ids) a.So(err, should.BeNil) diff --git a/pkg/gatewayserver/registry.go b/pkg/gatewayserver/registry.go index a469c93b7a..48d8343f05 100644 --- a/pkg/gatewayserver/registry.go +++ b/pkg/gatewayserver/registry.go @@ -29,10 +29,16 @@ type GatewayConnectionStatsRegistry interface { BatchGet( ctx context.Context, ids []*ttnpb.GatewayIdentifiers, - paths []string, + paths ...string, ) (map[string]*ttnpb.GatewayConnectionStats, error) // Set sets, updates or clears the connection stats for a gateway. Only fields specified in the field mask paths are set. - Set(ctx context.Context, ids *ttnpb.GatewayIdentifiers, stats *ttnpb.GatewayConnectionStats, paths []string, ttl time.Duration) error + Set( + ctx context.Context, + ids *ttnpb.GatewayIdentifiers, + f func(*ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error), + ttl time.Duration, + gets ...string, + ) error } // EntityRegistry abstracts the Identity server gateway functions. From 46bfb910c8f9c140682934677b22f351d79785e2 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Wed, 10 May 2023 17:53:43 +0200 Subject: [PATCH 2/2] gs: Always store the earliest connection time --- pkg/gatewayserver/gatewayserver.go | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/pkg/gatewayserver/gatewayserver.go b/pkg/gatewayserver/gatewayserver.go index 270792c3cc..af428a9a6b 100644 --- a/pkg/gatewayserver/gatewayserver.go +++ b/pkg/gatewayserver/gatewayserver.go @@ -955,16 +955,31 @@ func (gs *GatewayServer) handleUpstream(ctx context.Context, conn connectionEntr } } +func earliestTimestamp(a, b *timestamppb.Timestamp) *timestamppb.Timestamp { + switch { + case a == nil && b == nil: + return nil + case a == nil: + return b + case b == nil: + return a + default: + if aT, bT := a.AsTime(), b.AsTime(); aT.Before(bT) { + return a + } + return b + } +} + func (gs *GatewayServer) updateConnStats(ctx context.Context, conn connectionEntry) { decoupledCtx := gs.FromRequestContext(ctx) logger := log.FromContext(ctx) ids := conn.Connection.Gateway().GetIds() - connectTime := conn.Connection.ConnectTime() // Initial update, so that the gateway appears connected. stats := &ttnpb.GatewayConnectionStats{ - ConnectedAt: timestamppb.New(connectTime), + ConnectedAt: timestamppb.New(conn.Connection.ConnectTime()), Protocol: conn.Connection.Frontend().Protocol(), GatewayRemoteAddress: conn.Connection.GatewayRemoteAddress(), } @@ -1033,7 +1048,6 @@ func (gs *GatewayServer) updateConnStats(ctx context.Context, conn connectionEnt lastUpdate = time.Now() stats, paths := conn.Stats() - paths = ttnpb.ExcludeFields(paths, "connected_at", "disconnected_at", "protocol", "gateway_remote_address") registerGatewayConnectionStats(decoupledCtx, ids, stats) if gs.statsRegistry == nil { continue @@ -1041,10 +1055,12 @@ func (gs *GatewayServer) updateConnStats(ctx context.Context, conn connectionEnt if err := gs.statsRegistry.Set( decoupledCtx, ids, - func(*ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + func(pb *ttnpb.GatewayConnectionStats) (*ttnpb.GatewayConnectionStats, []string, error) { + stats.ConnectedAt = earliestTimestamp(stats.ConnectedAt, pb.ConnectedAt) return stats, paths, nil }, gs.config.ConnectionStatsTTL, + "connected_at", ); err != nil { logger.WithError(err).Warn("Failed to update connection stats") }