diff --git a/.dockerignore b/.dockerignore index 0370155b5..ad89aae4a 100644 --- a/.dockerignore +++ b/.dockerignore @@ -24,4 +24,9 @@ go.work.sum Dockerfile docker/docker-compose.* -cmd/**/Dockerfile \ No newline at end of file +cmd/**/Dockerfile + +.dockerignore +.gcloudignore + +Makefile \ No newline at end of file diff --git a/.github/workflows/lava.yml b/.github/workflows/lava.yml index 5b0dad8e3..6ee665857 100644 --- a/.github/workflows/lava.yml +++ b/.github/workflows/lava.yml @@ -393,7 +393,6 @@ jobs: contents: write packages: write id-token: write - needs: [test-consensus, test-protocol] runs-on: ubuntu-latest strategy: matrix: @@ -423,6 +422,9 @@ jobs: uses: docker/build-push-action@v5 continue-on-error: true with: + provenance: false + sbom: false + context: . tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} file: cmd/${{ matrix.binary }}/Dockerfile diff --git a/cmd/lavad/Dockerfile b/cmd/lavad/Dockerfile index 506990e30..00828f148 100644 --- a/cmd/lavad/Dockerfile +++ b/cmd/lavad/Dockerfile @@ -14,11 +14,12 @@ RUN apk add --no-cache \ WORKDIR /lava -COPY go.mod go.sum ./ - +ENV GOCACHE=/root/.cache/go-build RUN --mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/go/pkg/mod \ - go mod download + --mount=type=bind,source=go.sum,target=go.sum \ + --mount=type=bind,source=go.mod,target=go.mod \ + go mod download -x COPY . . diff --git a/cmd/lavad/Dockerfile.Cosmovisor b/cmd/lavad/Dockerfile.Cosmovisor index f0612bfe0..332b2b6f5 100644 --- a/cmd/lavad/Dockerfile.Cosmovisor +++ b/cmd/lavad/Dockerfile.Cosmovisor @@ -15,9 +15,12 @@ WORKDIR /lava COPY go.mod go.sum ./ +ENV GOCACHE=/root/.cache/go-build RUN --mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/go/pkg/mod \ - go mod download + --mount=type=bind,source=go.sum,target=go.sum \ + --mount=type=bind,source=go.mod,target=go.mod \ + go mod download -x COPY . . diff --git a/cmd/lavap/Dockerfile b/cmd/lavap/Dockerfile index 9da381095..bf83e92d5 100644 --- a/cmd/lavap/Dockerfile +++ b/cmd/lavap/Dockerfile @@ -14,12 +14,13 @@ RUN apk add --no-cache \ WORKDIR /lava -COPY go.mod go.sum ./ - +ENV GOCACHE=/root/.cache/go-build RUN --mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/go/pkg/mod \ - go mod download - + --mount=type=bind,source=go.sum,target=go.sum \ + --mount=type=bind,source=go.mod,target=go.mod \ + go mod download -x + COPY . . ARG GIT_VERSION diff --git a/protocol/chainlib/consumer_ws_subscription_manager.go b/protocol/chainlib/consumer_ws_subscription_manager.go index e19420812..10d8972e3 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager.go +++ b/protocol/chainlib/consumer_ws_subscription_manager.go @@ -201,7 +201,6 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), - utils.LogAttr("connectedDapps", cwsm.connectedDapps), ) websocketRepliesChan := make(chan *pairingtypes.RelayReply) diff --git a/protocol/chainlib/grpc.go b/protocol/chainlib/grpc.go index 68367b1ec..f2425c12c 100644 --- a/protocol/chainlib/grpc.go +++ b/protocol/chainlib/grpc.go @@ -9,7 +9,6 @@ import ( "net/http" "strconv" "strings" - "sync" "time" "github.com/goccy/go-json" @@ -50,25 +49,6 @@ type GrpcNodeErrorResponse struct { ErrorCode uint32 `json:"error_code"` } -type grpcDescriptorCache struct { - cachedDescriptors sync.Map // method name is the key, method descriptor is the value -} - -func (gdc *grpcDescriptorCache) getDescriptor(methodName string) *desc.MethodDescriptor { - if descriptor, ok := gdc.cachedDescriptors.Load(methodName); ok { - converted, success := descriptor.(*desc.MethodDescriptor) // convert to a descriptor - if success { - return converted - } - utils.LavaFormatError("Failed Converting method descriptor", nil, utils.Attribute{Key: "Method", Value: methodName}) - } - return nil -} - -func (gdc *grpcDescriptorCache) setDescriptor(methodName string, descriptor *desc.MethodDescriptor) { - gdc.cachedDescriptors.Store(methodName, descriptor) -} - type GrpcChainParser struct { BaseChainParser @@ -388,7 +368,7 @@ func (apil *GrpcChainListener) GetListeningAddress() string { type GrpcChainProxy struct { BaseChainProxy conn grpcConnectorInterface - descriptorsCache *grpcDescriptorCache + descriptorsCache *common.SafeSyncMap[string, *desc.MethodDescriptor] } type grpcConnectorInterface interface { Close() @@ -413,7 +393,7 @@ func NewGrpcChainProxy(ctx context.Context, nConns uint, rpcProviderEndpoint lav func newGrpcChainProxy(ctx context.Context, averageBlockTime time.Duration, parser ChainParser, conn grpcConnectorInterface, rpcProviderEndpoint lavasession.RPCProviderEndpoint) (ChainProxy, error) { cp := &GrpcChainProxy{ BaseChainProxy: BaseChainProxy{averageBlockTime: averageBlockTime, ErrorHandler: &GRPCErrorHandler{}, ChainID: rpcProviderEndpoint.ChainID, HashedNodeUrl: chainproxy.HashURL(rpcProviderEndpoint.NodeUrls[0].Url)}, - descriptorsCache: &grpcDescriptorCache{}, + descriptorsCache: &common.SafeSyncMap[string, *desc.MethodDescriptor]{}, } cp.conn = conn if cp.conn == nil { @@ -471,9 +451,12 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, descriptorSource := rpcInterfaceMessages.DescriptorSourceFromServer(cl) svc, methodName := rpcInterfaceMessages.ParseSymbol(nodeMessage.Path) - // check if we have method descriptor already cached. - methodDescriptor := cp.descriptorsCache.getDescriptor(methodName) - if methodDescriptor == nil { // method descriptor not cached yet, need to fetch it and add to cache + // Check if we have method descriptor already cached. + // The reason we do Load and then Store here, instead of LoadOrStore: + // On the worst case scenario, where 2 threads are accessing the map at the same time, the same descriptor will be stored twice. + // It is better than the alternative, which is always creating the descriptor, since the outcome is the same. + methodDescriptor, found, _ := cp.descriptorsCache.Load(methodName) + if !found { // method descriptor not cached yet, need to fetch it and add to cache var descriptor desc.Descriptor if descriptor, err = descriptorSource.FindSymbol(svc); err != nil { return nil, "", nil, utils.LavaFormatError("descriptorSource.FindSymbol", err, utils.Attribute{Key: "GUID", Value: ctx}) @@ -488,7 +471,7 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, } // add the descriptor to the chainProxy cache - cp.descriptorsCache.setDescriptor(methodName, methodDescriptor) + cp.descriptorsCache.Store(methodName, methodDescriptor) } msgFactory := dynamic.NewMessageFactoryWithDefaults() diff --git a/protocol/chainlib/grpcproxy/grpcproxy.go b/protocol/chainlib/grpcproxy/grpcproxy.go index e5a06dcf3..7ec2ec0d2 100644 --- a/protocol/chainlib/grpcproxy/grpcproxy.go +++ b/protocol/chainlib/grpcproxy/grpcproxy.go @@ -13,6 +13,8 @@ import ( "golang.org/x/net/http2/h2c" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/health" + "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -28,6 +30,7 @@ type HealthReporter interface { func NewGRPCProxy(cb ProxyCallBack, healthCheckPath string, cmdFlags common.ConsumerCmdFlags, healthReporter HealthReporter) (*grpc.Server, *http.Server, error) { serverReceiveMaxMessageSize := grpc.MaxRecvMsgSize(MaxCallRecvMsgSize) // setting receive size to 32mb instead of 4mb default s := grpc.NewServer(grpc.UnknownServiceHandler(makeProxyFunc(cb)), grpc.ForceServerCodec(RawBytesCodec{}), serverReceiveMaxMessageSize) + grpc_health_v1.RegisterHealthServer(s, health.NewServer()) wrappedServer := grpcweb.WrapServer(s) handler := func(resp http.ResponseWriter, req *http.Request) { // Set CORS headers diff --git a/protocol/chainlib/provider_node_subscription_manager_test.go b/protocol/chainlib/provider_node_subscription_manager_test.go index 38bb6283c..9a8e24fb6 100644 --- a/protocol/chainlib/provider_node_subscription_manager_test.go +++ b/protocol/chainlib/provider_node_subscription_manager_test.go @@ -344,7 +344,7 @@ func TestSubscriptionManager_MultipleParallelSubscriptionsWithTheSameParamsAndNo t.Run(play.name, func(t *testing.T) { ts := SetupForTests(t, 1, play.specId, "../../") - wg := sync.WaitGroup{} + sentMessageToNodeChannel := make(chan bool, 1) // msgCount := 0 upgrader := websocket.Upgrader{} first := true @@ -373,7 +373,12 @@ func TestSubscriptionManager_MultipleParallelSubscriptionsWithTheSameParamsAndNo return } utils.LavaFormatDebug("write message") - wg.Done() + select { + case sentMessageToNodeChannel <- true: + utils.LavaFormatDebug("sent message to node") + default: + utils.LavaFormatDebug("unable to communicate with the test") + } // Write the first reply err = conn.WriteMessage(messageType, play.subscriptionFirstReply) @@ -405,7 +410,6 @@ func TestSubscriptionManager_MultipleParallelSubscriptionsWithTheSameParamsAndNo mockRpcProvider := &RelayFinalizationBlocksHandlerMock{} pnsm := NewProviderNodeSubscriptionManager(chainRouter, chainParser, mockRpcProvider, ts.Providers[0].SK) - wg.Add(1) wgAllIds := sync.WaitGroup{} wgAllIds.Add(9) errors := []error{} @@ -429,7 +433,11 @@ func TestSubscriptionManager_MultipleParallelSubscriptionsWithTheSameParamsAndNo utils.LavaFormatDebug("Waiting wait group") wgAllIds.Wait() - wg.Wait() // Make sure the subscription manager sent a message to the node + select { + case <-sentMessageToNodeChannel: // Make sure the subscription manager sent a message to the node + case <-time.After(time.Second * 10): + require.Fail(t, "timeout waiting for message to node") + } // make sure we had only one error, on the first subscription attempt require.Len(t, errors, 1) diff --git a/protocol/chaintracker/chain_tracker.go b/protocol/chaintracker/chain_tracker.go index 9b70ba07c..29d6d390d 100644 --- a/protocol/chaintracker/chain_tracker.go +++ b/protocol/chaintracker/chain_tracker.go @@ -68,6 +68,10 @@ type ChainTracker struct { blockEventsGap []time.Duration blockTimeUpdatables map[blockTimeUpdatable]struct{} pmetrics *metrics.ProviderMetricsManager + + // initial config + averageBlockTime time.Duration + serverAddress string } // this function returns block hashes of the blocks: [from block - to block] inclusive. an additional specific block hash can be provided. order is sorted ascending @@ -570,6 +574,16 @@ func (ct *ChainTracker) serve(ctx context.Context, listenAddr string) error { return nil } +func (ct *ChainTracker) StartAndServe(ctx context.Context) error { + err := ct.start(ctx, ct.averageBlockTime) + if err != nil { + return err + } + + err = ct.serve(ctx, ct.serverAddress) + return err +} + func NewChainTracker(ctx context.Context, chainFetcher ChainFetcher, config ChainTrackerConfig) (chainTracker *ChainTracker, err error) { if !rand.Initialized() { utils.LavaFormatFatal("can't start chainTracker with nil rand source", nil) @@ -598,16 +612,13 @@ func NewChainTracker(ctx context.Context, chainFetcher ChainFetcher, config Chai startupTime: time.Now(), pmetrics: config.Pmetrics, pollingTimeMultiplier: time.Duration(pollingTime), + averageBlockTime: config.AverageBlockTime, + serverAddress: config.ServerAddress, } if chainFetcher == nil { return nil, utils.LavaFormatError("can't start chainTracker with nil chainFetcher argument", nil) } chainTracker.endpoint = chainFetcher.FetchEndpoint() - err = chainTracker.start(ctx, config.AverageBlockTime) - if err != nil { - return nil, err - } - err = chainTracker.serve(ctx, config.ServerAddress) return chainTracker, err } diff --git a/protocol/chaintracker/chain_tracker_test.go b/protocol/chaintracker/chain_tracker_test.go index c0140af61..1ebcf62a2 100644 --- a/protocol/chaintracker/chain_tracker_test.go +++ b/protocol/chaintracker/chain_tracker_test.go @@ -161,6 +161,7 @@ func TestChainTracker(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) + chainTracker.StartAndServe(context.Background()) require.NoError(t, err) for _, advancement := range tt.advancements { for i := 0; i < int(advancement); i++ { @@ -218,6 +219,7 @@ func TestChainTrackerRangeOnly(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) + chainTracker.StartAndServe(context.Background()) require.NoError(t, err) for _, advancement := range tt.advancements { for i := 0; i < int(advancement); i++ { @@ -302,6 +304,7 @@ func TestChainTrackerCallbacks(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback, NewLatestCallback: newBlockCallback} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + chainTracker.StartAndServe(context.Background()) totalAdvancement := 0 t.Run("one long test", func(t *testing.T) { for _, tt := range tests { @@ -368,6 +371,7 @@ func TestChainTrackerFetchSpreadAcrossPollingTime(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: localTimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)} tracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + tracker.StartAndServe(context.Background()) // fool the tracker so it thinks blocks will come every localTimeForPollingMock (ms), and not adjust it's polling timers for i := 0; i < 50; i++ { tracker.AddBlockGap(localTimeForPollingMock, 1) @@ -491,6 +495,7 @@ func TestChainTrackerPollingTimeUpdate(t *testing.T) { mockChainFetcher.AdvanceBlock() chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: play.localTimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)} tracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) + tracker.StartAndServe(context.Background()) tracker.RegisterForBlockTimeUpdates(&mockTimeUpdater) require.NoError(t, err) // initial delay @@ -555,6 +560,7 @@ func TestChainTrackerMaintainMemory(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + chainTracker.StartAndServe(context.Background()) t.Run("one long test", func(t *testing.T) { for _, tt := range tests { utils.LavaFormatInfo(startedTestStr + tt.name) @@ -607,6 +613,7 @@ func TestFindRequestedBlockHash(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + chainTracker.StartAndServe(context.Background()) latestBlock, onlyLatestBlockData, _, err := chainTracker.GetLatestBlockData(spectypes.LATEST_BLOCK, spectypes.LATEST_BLOCK, spectypes.NOT_APPLICABLE) require.NoError(t, err) require.Equal(t, currentLatestBlockInMock, latestBlock) diff --git a/protocol/common/cobra_common.go b/protocol/common/cobra_common.go index c8d4615d2..40cbffdce 100644 --- a/protocol/common/cobra_common.go +++ b/protocol/common/cobra_common.go @@ -33,6 +33,11 @@ const ( // This feature is suppose to help with successful relays in some chains that return node errors on rare race conditions on the serviced chains. SetRelayCountOnNodeErrorFlag = "set-retry-count-on-node-error" UseStaticSpecFlag = "use-static-spec" // allows the user to manually load a spec providing a path, this is useful to test spec changes before they hit the blockchain + + // optimizer flags + SetProviderOptimizerBestTierPickChance = "set-provider-optimizer-best-tier-pick-chance" + SetProviderOptimizerWorstTierPickChance = "set-provider-optimizer-worst-tier-pick-chance" + SetProviderOptimizerNumberOfTiersToCreate = "set-provider-optimizer-number-of-tiers-to-create" ) const ( diff --git a/protocol/common/safe_sync_map.go b/protocol/common/safe_sync_map.go new file mode 100644 index 000000000..b0e94a421 --- /dev/null +++ b/protocol/common/safe_sync_map.go @@ -0,0 +1,51 @@ +package common + +import ( + "sync" + + "github.com/lavanet/lava/v3/utils" +) + +type SafeSyncMap[K, V any] struct { + localMap sync.Map +} + +func (ssm *SafeSyncMap[K, V]) Store(key K, toSet V) { + ssm.localMap.Store(key, toSet) +} + +func (ssm *SafeSyncMap[K, V]) Load(key K) (ret V, ok bool, err error) { + value, ok := ssm.localMap.Load(key) + if !ok { + return ret, ok, nil + } + ret, ok = value.(V) + if !ok { + return ret, false, utils.LavaFormatError("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil) + } + return ret, true, nil +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +// The function returns the value that was loaded or stored. +func (ssm *SafeSyncMap[K, V]) LoadOrStore(key K, value V) (ret V, loaded bool, err error) { + actual, loaded := ssm.localMap.LoadOrStore(key, value) + if loaded { + // loaded from map + var ok bool + ret, ok = actual.(V) + if !ok { + return ret, false, utils.LavaFormatError("invalid usage of sync map, could not cast result into a PolicyUpdater", nil) + } + return ret, true, nil + } + + // stored in map + return value, false, nil +} + +func (ssm *SafeSyncMap[K, V]) Range(f func(key, value any) bool) { + ssm.localMap.Range(f) +} diff --git a/protocol/integration/protocol_test.go b/protocol/integration/protocol_test.go index 05b0273d1..048bafedd 100644 --- a/protocol/integration/protocol_test.go +++ b/protocol/integration/protocol_test.go @@ -293,6 +293,7 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string mockChainFetcher := NewMockChainFetcher(1000, int64(blocksToSaveChainTracker), nil) chainTracker, err := chaintracker.NewChainTracker(ctx, mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + chainTracker.StartAndServe(ctx) reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, account.Addr.String(), chainRouter, chainParser) mockReliabilityManager := NewMockReliabilityManager(reliabilityManager) rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil, false) diff --git a/protocol/lavasession/consumer_session_manager.go b/protocol/lavasession/consumer_session_manager.go index fca5f44af..7fef3af42 100644 --- a/protocol/lavasession/consumer_session_manager.go +++ b/protocol/lavasession/consumer_session_manager.go @@ -114,6 +114,7 @@ func (csm *ConsumerSessionManager) UpdateAllProviders(epoch uint64, pairingList csm.setValidAddressesToDefaultValue("", nil) // the starting point is that valid addresses are equal to pairing addresses. // reset session related metrics csm.consumerMetricsManager.ResetSessionRelatedMetrics() + csm.providerOptimizer.UpdateWeights(CalcWeightsByStake(pairingList)) utils.LavaFormatDebug("updated providers", utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "spec", Value: csm.rpcEndpoint.Key()}) return nil } @@ -638,7 +639,7 @@ func (csm *ConsumerSessionManager) getValidProviderAddresses(ignoredProvidersLis if stateful == common.CONSISTENCY_SELECT_ALL_PROVIDERS && csm.providerOptimizer.Strategy() != provideroptimizer.STRATEGY_COST { providers = csm.getTopTenProvidersForStatefulCalls(validAddresses, ignoredProvidersList) } else { - providers = csm.providerOptimizer.ChooseProvider(validAddresses, ignoredProvidersList, cu, requestedBlock, OptimizerPerturbation) + providers, _ = csm.providerOptimizer.ChooseProvider(validAddresses, ignoredProvidersList, cu, requestedBlock) } utils.LavaFormatTrace("Choosing providers", diff --git a/protocol/lavasession/consumer_types.go b/protocol/lavasession/consumer_types.go index ef9b61e82..67c8b7b25 100644 --- a/protocol/lavasession/consumer_types.go +++ b/protocol/lavasession/consumer_types.go @@ -42,6 +42,7 @@ const ( AllowInsecureConnectionToProvidersFlag = "allow-insecure-provider-dialing" AllowGRPCCompressionFlag = "allow-grpc-compression-for-consumer-provider-communication" maximumStreamsOverASingleConnection = 100 + WeightMultiplierForStaticProviders = 10 ) var ( @@ -72,9 +73,10 @@ type ProviderOptimizer interface { AppendProbeRelayData(providerAddress string, latency time.Duration, success bool) AppendRelayFailure(providerAddress string) AppendRelayData(providerAddress string, latency time.Duration, isHangingApi bool, cu, syncBlock uint64) - ChooseProvider(allAddresses []string, ignoredProviders map[string]struct{}, cu uint64, requestedBlock int64, perturbationPercentage float64) (addresses []string) + ChooseProvider(allAddresses []string, ignoredProviders map[string]struct{}, cu uint64, requestedBlock int64) (addresses []string, tier int) GetExcellenceQoSReportForProvider(string) (*pairingtypes.QualityOfServiceReport, *pairingtypes.QualityOfServiceReport) Strategy() provideroptimizer.Strategy + UpdateWeights(map[string]int64) } type ignoredProviders struct { @@ -595,3 +597,28 @@ func CalculateAvailabilityScore(qosReport *QoSReport) (downtimePercentageRet, sc scaledAvailabilityScore := sdk.MaxDec(sdk.ZeroDec(), AvailabilityPercentage.Sub(downtimePercentage).Quo(AvailabilityPercentage)) return downtimePercentage, scaledAvailabilityScore } + +func CalcWeightsByStake(providers map[uint64]*ConsumerSessionsWithProvider) (weights map[string]int64) { + weights = make(map[string]int64) + staticProviders := make([]*ConsumerSessionsWithProvider, 0) + maxWeight := int64(1) + for _, cswp := range providers { + if cswp.StaticProvider { + staticProviders = append(staticProviders, cswp) + continue + } + stakeAmount := cswp.getProviderStakeSize().Amount + stake := int64(10) // defaults to 10 if stake isn't set + if !stakeAmount.IsNil() && stakeAmount.IsInt64() { + stake = stakeAmount.Int64() + } + if stake > maxWeight { + maxWeight = stake + } + weights[cswp.PublicLavaAddress] = stake + } + for _, cswp := range staticProviders { + weights[cswp.PublicLavaAddress] = maxWeight * WeightMultiplierForStaticProviders + } + return weights +} diff --git a/protocol/provideroptimizer/provider_optimizer.go b/protocol/provideroptimizer/provider_optimizer.go index 6936c87fb..806818d97 100644 --- a/protocol/provideroptimizer/provider_optimizer.go +++ b/protocol/provideroptimizer/provider_optimizer.go @@ -30,6 +30,13 @@ const ( WANTED_PRECISION = int64(8) ) +var ( + OptimizerNumTiers = 4 + MinimumEntries = 5 + ATierChance = 0.75 + LastTierChance = 0.0 +) + type ConcurrentBlockStore struct { Lock sync.Mutex Time time.Time @@ -49,6 +56,13 @@ type ProviderOptimizer struct { baseWorldLatency time.Duration wantedNumProvidersInConcurrency uint latestSyncData ConcurrentBlockStore + selectionWeighter SelectionWeighter + OptimizerNumTiers int +} + +type Exploration struct { + address string + time time.Time } type ProviderData struct { @@ -72,6 +86,10 @@ const ( STRATEGY_DISTRIBUTED ) +func (po *ProviderOptimizer) UpdateWeights(weights map[string]int64) { + po.selectionWeighter.SetWeights(weights) +} + func (po *ProviderOptimizer) AppendRelayFailure(providerAddress string) { po.appendRelayData(providerAddress, 0, false, false, 0, 0, time.Now()) } @@ -131,16 +149,12 @@ func (po *ProviderOptimizer) AppendProbeRelayData(providerAddress string, latenc ) } -// returns a sub set of selected providers according to their scores, perturbation factor will be added to each score in order to randomly select providers that are not always on top -func (po *ProviderOptimizer) ChooseProvider(allAddresses []string, ignoredProviders map[string]struct{}, cu uint64, requestedBlock int64, perturbationPercentage float64) (addresses []string) { - returnedProviders := make([]string, 1) // location 0 is always the best score - latencyScore := math.MaxFloat64 // smaller = better i.e less latency - syncScore := math.MaxFloat64 // smaller = better i.e less sync lag - numProviders := len(allAddresses) - if po.strategy == STRATEGY_DISTRIBUTED { - // distribute relays across more providers - perturbationPercentage *= 2 - } +func (po *ProviderOptimizer) CalculateSelectionTiers(allAddresses []string, ignoredProviders map[string]struct{}, cu uint64, requestedBlock int64) (SelectionTier, Exploration) { + latencyScore := math.MaxFloat64 // smaller = better i.e less latency + syncScore := math.MaxFloat64 // smaller = better i.e less sync lag + + explorationCandidate := Exploration{address: "", time: time.Now().Add(time.Hour)} + selectionTier := NewSelectionTier() for _, providerAddress := range allAddresses { if _, ok := ignoredProviders[providerAddress]; ok { // ignored provider, skip it @@ -152,16 +166,12 @@ func (po *ProviderOptimizer) ChooseProvider(allAddresses []string, ignoredProvid } // latency score latencyScoreCurrent := po.calculateLatencyScore(providerData, cu, requestedBlock) // smaller == better i.e less latency - // latency perturbation - latencyScoreCurrent = pertrubWithNormalGaussian(latencyScoreCurrent, perturbationPercentage) // sync score syncScoreCurrent := float64(0) if requestedBlock < 0 { // means user didn't ask for a specific block and we want to give him the best syncScoreCurrent = po.calculateSyncScore(providerData.Sync) // smaller == better i.e less sync lag - // sync perturbation - syncScoreCurrent = pertrubWithNormalGaussian(syncScoreCurrent, perturbationPercentage) } utils.LavaFormatTrace("scores information", @@ -171,29 +181,51 @@ func (po *ProviderOptimizer) ChooseProvider(allAddresses []string, ignoredProvid utils.LogAttr("latencyScore", latencyScore), utils.LogAttr("syncScore", syncScore), ) - - // we want the minimum latency and sync diff - if po.isBetterProviderScore(latencyScore, latencyScoreCurrent, syncScore, syncScoreCurrent) || len(returnedProviders) == 0 { - if returnedProviders[0] != "" && po.shouldExplore(len(returnedProviders), numProviders) { - // we are about to overwrite position 0, and this provider needs a chance to be in exploration - returnedProviders = append(returnedProviders, returnedProviders[0]) - } - returnedProviders[0] = providerAddress // best provider is always on position 0 - latencyScore = latencyScoreCurrent - syncScore = syncScoreCurrent - continue - } - if po.shouldExplore(len(returnedProviders), numProviders) { - returnedProviders = append(returnedProviders, providerAddress) + providerScore := po.calcProviderScore(latencyScoreCurrent, syncScoreCurrent) + selectionTier.AddScore(providerAddress, providerScore) + + // check if candidate for exploration + updateTime := providerData.Latency.Time + if updateTime.Add(10*time.Second).Before(time.Now()) && updateTime.Before(explorationCandidate.time) { + // if the provider didn't update its data for 10 seconds, it is a candidate for exploration + explorationCandidate = Exploration{address: providerAddress, time: updateTime} } } + return selectionTier, explorationCandidate +} - utils.LavaFormatTrace("returned providers", +// returns a sub set of selected providers according to their scores, perturbation factor will be added to each score in order to randomly select providers that are not always on top +func (po *ProviderOptimizer) ChooseProvider(allAddresses []string, ignoredProviders map[string]struct{}, cu uint64, requestedBlock int64) (addresses []string, tier int) { + selectionTier, explorationCandidate := po.CalculateSelectionTiers(allAddresses, ignoredProviders, cu, requestedBlock) + if selectionTier.ScoresCount() == 0 { + // no providers to choose from + return []string{}, -1 + } + initialChances := map[int]float64{0: ATierChance} + if selectionTier.ScoresCount() < po.OptimizerNumTiers { + po.OptimizerNumTiers = selectionTier.ScoresCount() + } + if selectionTier.ScoresCount() >= MinimumEntries*2 { + // if we have more than 2*MinimumEntries we set the LastTierChance configured + initialChances[(po.OptimizerNumTiers - 1)] = LastTierChance + } + shiftedChances := selectionTier.ShiftTierChance(po.OptimizerNumTiers, initialChances) + tier = selectionTier.SelectTierRandomly(po.OptimizerNumTiers, shiftedChances) + tierProviders := selectionTier.GetTier(tier, po.OptimizerNumTiers, MinimumEntries) + // TODO: add penalty if a provider is chosen too much + selectedProvider := po.selectionWeighter.WeightedChoice(tierProviders) + returnedProviders := []string{selectedProvider} + if explorationCandidate.address != "" && po.shouldExplore(1, selectionTier.ScoresCount()) { + returnedProviders = append(returnedProviders, explorationCandidate.address) + } + utils.LavaFormatTrace("[Optimizer] returned providers", utils.LogAttr("providers", strings.Join(returnedProviders, ",")), utils.LogAttr("cu", cu), + utils.LogAttr("shiftedChances", shiftedChances), + utils.LogAttr("tier", tier), ) - return returnedProviders + return returnedProviders, tier } // calculate the expected average time until this provider catches up with the given latestSync block @@ -242,30 +274,35 @@ func (po *ProviderOptimizer) shouldExplore(currentNumProvders, numProviders int) case STRATEGY_PRIVACY: return false // only one at a time } - // Dividing the random threshold by the loop count ensures that the overall probability of success is the requirement for the entire loop not per iteration - return rand.Float64() < explorationChance/float64(numProviders) + return rand.Float64() < explorationChance } func (po *ProviderOptimizer) isBetterProviderScore(latencyScore, latencyScoreCurrent, syncScore, syncScoreCurrent float64) bool { - var latencyWeight float64 switch po.strategy { - case STRATEGY_LATENCY: - latencyWeight = 0.7 - case STRATEGY_SYNC_FRESHNESS: - latencyWeight = 0.2 case STRATEGY_PRIVACY: // pick at random regardless of score if rand.Intn(2) == 0 { return true } return false - default: - latencyWeight = 0.6 } if syncScoreCurrent == 0 { return latencyScore > latencyScoreCurrent } - return latencyScore*latencyWeight+syncScore*(1-latencyWeight) > latencyScoreCurrent*latencyWeight+syncScoreCurrent*(1-latencyWeight) + return po.calcProviderScore(latencyScore, syncScore) > po.calcProviderScore(latencyScoreCurrent, syncScoreCurrent) +} + +func (po *ProviderOptimizer) calcProviderScore(latencyScore, syncScore float64) float64 { + var latencyWeight float64 + switch po.strategy { + case STRATEGY_LATENCY: + latencyWeight = 0.7 + case STRATEGY_SYNC_FRESHNESS: + latencyWeight = 0.2 + default: + latencyWeight = 0.6 + } + return latencyScore*latencyWeight + syncScore*(1-latencyWeight) } func (po *ProviderOptimizer) calculateSyncScore(syncScore score.ScoreStore) float64 { @@ -469,7 +506,16 @@ func NewProviderOptimizer(strategy Strategy, averageBlockTIme, baseWorldLatency // overwrite wantedNumProvidersInConcurrency = 1 } - return &ProviderOptimizer{strategy: strategy, providersStorage: cache, averageBlockTime: averageBlockTIme, baseWorldLatency: baseWorldLatency, providerRelayStats: relayCache, wantedNumProvidersInConcurrency: wantedNumProvidersInConcurrency} + return &ProviderOptimizer{ + strategy: strategy, + providersStorage: cache, + averageBlockTime: averageBlockTIme, + baseWorldLatency: baseWorldLatency, + providerRelayStats: relayCache, + wantedNumProvidersInConcurrency: wantedNumProvidersInConcurrency, + selectionWeighter: NewSelectionWeighter(), + OptimizerNumTiers: OptimizerNumTiers, + } } // calculate the probability a random variable with a poisson distribution diff --git a/protocol/provideroptimizer/provider_optimizer_test.go b/protocol/provideroptimizer/provider_optimizer_test.go index 31a4df2f5..37b770e40 100644 --- a/protocol/provideroptimizer/provider_optimizer_test.go +++ b/protocol/provideroptimizer/provider_optimizer_test.go @@ -40,7 +40,7 @@ func (posc *providerOptimizerSyncCache) Set(key, value interface{}, cost int64) func setupProviderOptimizer(maxProvidersCount int) *ProviderOptimizer { averageBlockTIme := TEST_AVERAGE_BLOCK_TIME baseWorldLatency := TEST_BASE_WORLD_LATENCY - return NewProviderOptimizer(STRATEGY_BALANCED, averageBlockTIme, baseWorldLatency, 1) + return NewProviderOptimizer(STRATEGY_BALANCED, averageBlockTIme, baseWorldLatency, uint(maxProvidersCount)) } type providersGenerator struct { @@ -139,107 +139,148 @@ func TestProviderOptimizerBasic(t *testing.T) { requestCU := uint64(10) requestBlock := int64(1000) - pertrubationPercentage := 0.0 - returnedProviders := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) + returnedProviders, tier := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock) require.Equal(t, 1, len(returnedProviders)) - providerOptimizer.AppendProbeRelayData(providersGen.providersAddresses[1], TEST_BASE_WORLD_LATENCY*3, true) + require.NotEqual(t, 4, tier) + // damage their chance to be selected by placing them in the worst tier + providerOptimizer.AppendProbeRelayData(providersGen.providersAddresses[5], TEST_BASE_WORLD_LATENCY*3, true) + providerOptimizer.AppendProbeRelayData(providersGen.providersAddresses[6], TEST_BASE_WORLD_LATENCY*3, true) + providerOptimizer.AppendProbeRelayData(providersGen.providersAddresses[7], TEST_BASE_WORLD_LATENCY*3, true) time.Sleep(4 * time.Millisecond) - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) + returnedProviders, _ = providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock) require.Equal(t, 1, len(returnedProviders)) - require.NotEqual(t, returnedProviders[0], providersGen.providersAddresses[1]) // we shouldn't pick the wrong provider + require.NotEqual(t, 4, tier) + require.NotEqual(t, returnedProviders[0], providersGen.providersAddresses[5]) // we shouldn't pick the worst provider + require.NotEqual(t, returnedProviders[0], providersGen.providersAddresses[6]) // we shouldn't pick the worst provider + require.NotEqual(t, returnedProviders[0], providersGen.providersAddresses[7]) // we shouldn't pick the worst provider + // improve selection chance by placing them in the top tier providerOptimizer.AppendProbeRelayData(providersGen.providersAddresses[0], TEST_BASE_WORLD_LATENCY/2, true) + providerOptimizer.AppendProbeRelayData(providersGen.providersAddresses[1], TEST_BASE_WORLD_LATENCY/2, true) + providerOptimizer.AppendProbeRelayData(providersGen.providersAddresses[2], TEST_BASE_WORLD_LATENCY/2, true) time.Sleep(4 * time.Millisecond) - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[0], returnedProviders[0]) // we should pick the best provider + results, tierResults := runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, nil, requestCU, requestBlock, 1000) + require.Greater(t, tierResults[0], 650, tierResults) // we should pick the best tier most often + // out of 10 providers, and with 3 in the top tier we should pick 0 around a third of that + require.Greater(t, results[providersGen.providersAddresses[0]], 250, results) // we should pick the best tier most often +} + +func runChooseManyTimesAndReturnResults(t *testing.T, providerOptimizer *ProviderOptimizer, providers []string, ignoredProviders map[string]struct{}, requestCU uint64, requestBlock int64, times int) (map[string]int, map[int]int) { + tierResults := make(map[int]int) + results := make(map[string]int) + for i := 0; i < times; i++ { + returnedProviders, tier := providerOptimizer.ChooseProvider(providers, ignoredProviders, requestCU, requestBlock) + require.Equal(t, 1, len(returnedProviders)) + results[returnedProviders[0]]++ + tierResults[tier]++ + } + return results, tierResults } func TestProviderOptimizerBasicRelayData(t *testing.T) { providerOptimizer := setupProviderOptimizer(1) providersGen := (&providersGenerator{}).setupProvidersForTest(10) - + rand.InitRandomSeed() requestCU := uint64(1) requestBlock := int64(1000) - pertrubationPercentage := 0.0 + syncBlock := uint64(requestBlock) - providerOptimizer.AppendRelayData(providersGen.providersAddresses[1], TEST_BASE_WORLD_LATENCY*4, false, requestCU, syncBlock) - returnedProviders := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) + providerOptimizer.AppendRelayData(providersGen.providersAddresses[5], TEST_BASE_WORLD_LATENCY*4, false, requestCU, syncBlock) + providerOptimizer.AppendRelayData(providersGen.providersAddresses[6], TEST_BASE_WORLD_LATENCY*4, false, requestCU, syncBlock) + providerOptimizer.AppendRelayData(providersGen.providersAddresses[7], TEST_BASE_WORLD_LATENCY*4, false, requestCU, syncBlock) + time.Sleep(4 * time.Millisecond) + returnedProviders, tier := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock) require.Equal(t, 1, len(returnedProviders)) - require.NotEqual(t, returnedProviders[0], providersGen.providersAddresses[1]) // we shouldn't pick the wrong provider + // we shouldn't pick the low tier providers + require.NotEqual(t, tier, 3) + require.NotEqual(t, returnedProviders[0], providersGen.providersAddresses[5], tier) + require.NotEqual(t, returnedProviders[0], providersGen.providersAddresses[6], tier) + require.NotEqual(t, returnedProviders[0], providersGen.providersAddresses[7], tier) + providerOptimizer.AppendRelayData(providersGen.providersAddresses[0], TEST_BASE_WORLD_LATENCY/4, false, requestCU, syncBlock) - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[0], returnedProviders[0]) // we should pick the best provider + providerOptimizer.AppendRelayData(providersGen.providersAddresses[1], TEST_BASE_WORLD_LATENCY/4, false, requestCU, syncBlock) + providerOptimizer.AppendRelayData(providersGen.providersAddresses[2], TEST_BASE_WORLD_LATENCY/4, false, requestCU, syncBlock) + time.Sleep(4 * time.Millisecond) + results, tierResults := runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, nil, requestCU, requestBlock, 1000) + + require.Zero(t, results[providersGen.providersAddresses[5]]) + require.Zero(t, results[providersGen.providersAddresses[6]]) + require.Zero(t, results[providersGen.providersAddresses[7]]) + + require.Greater(t, tierResults[0], 650, tierResults) // we should pick the best tier most often + // out of 10 providers, and with 3 in the top tier we should pick 0 around a third of that + require.Greater(t, results[providersGen.providersAddresses[0]], 250, results) // we should pick the best tier most often } func TestProviderOptimizerAvailability(t *testing.T) { providerOptimizer := setupProviderOptimizer(1) providersCount := 100 providersGen := (&providersGenerator{}).setupProvidersForTest(providersCount) - + rand.InitRandomSeed() requestCU := uint64(10) requestBlock := int64(1000) - pertrubationPercentage := 0.0 - skipIndex := rand.Intn(providersCount) + + skipIndex := rand.Intn(providersCount - 3) + providerOptimizer.OptimizerNumTiers = 33 // set many tiers so good providers can stand out in the test for i := range providersGen.providersAddresses { - // give all providers a worse availability score - if i == skipIndex { + // give all providers a worse availability score except these 3 + if i == skipIndex || i == skipIndex+1 || i == skipIndex+2 { // skip 0 continue } providerOptimizer.AppendProbeRelayData(providersGen.providersAddresses[i], TEST_BASE_WORLD_LATENCY, false) } time.Sleep(4 * time.Millisecond) - returnedProviders := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[skipIndex], returnedProviders[0]) - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, map[string]struct{}{providersGen.providersAddresses[skipIndex]: {}}, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.NotEqual(t, providersGen.providersAddresses[skipIndex], returnedProviders[0]) + results, tierResults := runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, nil, requestCU, requestBlock, 1000) + require.Greater(t, tierResults[0], 300, tierResults) // 0.42 chance for top tier due to the algorithm to rebalance chances + require.Greater(t, results[providersGen.providersAddresses[skipIndex]]+results[providersGen.providersAddresses[skipIndex+1]]+results[providersGen.providersAddresses[skipIndex+2]], 300) + require.InDelta(t, results[providersGen.providersAddresses[skipIndex]], results[providersGen.providersAddresses[skipIndex+1]], 50) + results, _ = runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, map[string]struct{}{providersGen.providersAddresses[skipIndex]: {}}, requestCU, requestBlock, 1000) + require.Zero(t, results[providersGen.providersAddresses[skipIndex]]) } func TestProviderOptimizerAvailabilityRelayData(t *testing.T) { providerOptimizer := setupProviderOptimizer(1) providersCount := 100 providersGen := (&providersGenerator{}).setupProvidersForTest(providersCount) + rand.InitRandomSeed() requestCU := uint64(10) requestBlock := int64(1000) - pertrubationPercentage := 0.0 - skipIndex := rand.Intn(providersCount) + + skipIndex := rand.Intn(providersCount - 3) + providerOptimizer.OptimizerNumTiers = 33 // set many tiers so good providers can stand out in the test for i := range providersGen.providersAddresses { - // give all providers a worse availability score - if i == skipIndex { - // skip one provider + // give all providers a worse availability score except these 3 + if i == skipIndex || i == skipIndex+1 || i == skipIndex+2 { + // skip 0 continue } providerOptimizer.AppendRelayFailure(providersGen.providersAddresses[i]) } time.Sleep(4 * time.Millisecond) - returnedProviders := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[skipIndex], returnedProviders[0]) - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, map[string]struct{}{providersGen.providersAddresses[skipIndex]: {}}, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.NotEqual(t, providersGen.providersAddresses[skipIndex], returnedProviders[0]) + results, tierResults := runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, nil, requestCU, requestBlock, 1000) + require.Greater(t, tierResults[0], 300, tierResults) // 0.42 chance for top tier due to the algorithm to rebalance chances + require.Greater(t, results[providersGen.providersAddresses[skipIndex]]+results[providersGen.providersAddresses[skipIndex+1]]+results[providersGen.providersAddresses[skipIndex+2]], 270) + require.InDelta(t, results[providersGen.providersAddresses[skipIndex]], results[providersGen.providersAddresses[skipIndex+1]], 50) + results, _ = runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, map[string]struct{}{providersGen.providersAddresses[skipIndex]: {}}, requestCU, requestBlock, 1000) + require.Zero(t, results[providersGen.providersAddresses[skipIndex]]) } func TestProviderOptimizerAvailabilityBlockError(t *testing.T) { providerOptimizer := setupProviderOptimizer(1) providersCount := 10 providersGen := (&providersGenerator{}).setupProvidersForTest(providersCount) - + rand.InitRandomSeed() requestCU := uint64(10) requestBlock := int64(1000) - pertrubationPercentage := 0.0 + syncBlock := uint64(requestBlock) - chosenIndex := rand.Intn(providersCount) + chosenIndex := rand.Intn(providersCount - 2) for i := range providersGen.providersAddresses { time.Sleep(4 * time.Millisecond) - // give all providers a worse availability score - if i == chosenIndex { + if i == chosenIndex || i == chosenIndex+1 || i == chosenIndex+2 { // give better syncBlock, worse latency by a little providerOptimizer.AppendRelayData(providersGen.providersAddresses[i], TEST_BASE_WORLD_LATENCY+10*time.Millisecond, false, requestCU, syncBlock) continue @@ -247,13 +288,23 @@ func TestProviderOptimizerAvailabilityBlockError(t *testing.T) { providerOptimizer.AppendRelayData(providersGen.providersAddresses[i], TEST_BASE_WORLD_LATENCY, false, requestCU, syncBlock-1) // update that he doesn't have the latest requested block } time.Sleep(4 * time.Millisecond) - returnedProviders := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[chosenIndex], returnedProviders[0]) + selectionTier, _ := providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, nil, requestCU, requestBlock) + tierChances := selectionTier.ShiftTierChance(OptimizerNumTiers, map[int]float64{0: ATierChance, OptimizerNumTiers - 1: LastTierChance}) + require.Greater(t, tierChances[0], 0.7, tierChances) + results, tierResults := runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, nil, requestCU, requestBlock, 1000) + require.Greater(t, tierResults[0], 500, tierResults) // we should pick the best tier most often + // out of 10 providers, and with 3 in the top tier we should pick 0 around a third of that + require.Greater(t, results[providersGen.providersAddresses[chosenIndex]], 200, results) // we should pick the best tier most often + sumResults := results[providersGen.providersAddresses[chosenIndex]] + results[providersGen.providersAddresses[chosenIndex+1]] + results[providersGen.providersAddresses[chosenIndex+2]] + require.Greater(t, sumResults, 500, results) // we should pick the best tier most often // now try to get a previous block, our chosenIndex should be inferior in latency and blockError chance should be the same - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock-1, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.NotEqual(t, providersGen.providersAddresses[chosenIndex], returnedProviders[0]) + + results, tierResults = runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, nil, requestCU, requestBlock-1, 1000) + require.Greater(t, tierResults[0], 500, tierResults) // we should pick the best tier most often + // out of 10 providers, and with 3 in the top tier we should pick 0 around a third of that + require.Less(t, results[providersGen.providersAddresses[chosenIndex]], 50, results) // chosen indexes shoulnt be in the tier + sumResults = results[providersGen.providersAddresses[chosenIndex]] + results[providersGen.providersAddresses[chosenIndex+1]] + results[providersGen.providersAddresses[chosenIndex+2]] + require.Less(t, sumResults, 150, results) // we should pick the best tier most often } // TODO::PRT-1114 This needs to be fixed asap. currently commented out as it prevents pushing unrelated code @@ -289,57 +340,80 @@ func TestProviderOptimizerAvailabilityBlockError(t *testing.T) { // } // } -func TestProviderOptimizerStrategiesProviderCount(t *testing.T) { - providerOptimizer := setupProviderOptimizer(3) - providersCount := 5 +func TestProviderOptimizerExploration(t *testing.T) { + providerOptimizer := setupProviderOptimizer(2) + providersCount := 10 providersGen := (&providersGenerator{}).setupProvidersForTest(providersCount) requestCU := uint64(10) requestBlock := int64(1000) syncBlock := uint64(requestBlock) - pertrubationPercentage := 0.0 - // set a basic state for all of them - for i := 0; i < 10; i++ { - for _, address := range providersGen.providersAddresses { - providerOptimizer.AppendRelayData(address, TEST_BASE_WORLD_LATENCY*2, false, requestCU, syncBlock) - } - time.Sleep(4 * time.Millisecond) - } - testProvidersCount := func(iterations int) float64 { + + rand.InitRandomSeed() + // start with a disabled chosen index + chosenIndex := -1 + testProvidersExploration := func(iterations int) float64 { exploration := 0.0 for i := 0; i < iterations; i++ { - returnedProviders := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) + returnedProviders, _ := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock) if len(returnedProviders) > 1 { exploration++ + // check if we have a specific chosen index + if chosenIndex >= 0 { + // there's only one provider eligible for exploration it must be him + require.Equal(t, providersGen.providersAddresses[chosenIndex], returnedProviders[1]) + } } } return exploration } - // with a cost strategy we expect only one provider, two with a chance of 1/100 - providerOptimizer.strategy = STRATEGY_COST - providerOptimizer.wantedNumProvidersInConcurrency = 2 + // make sure exploration works when providers are defaulted (no data at all) + exploration := testProvidersExploration(1000) + require.Greater(t, exploration, float64(10)) + + chosenIndex = rand.Intn(providersCount - 2) + // set chosen index with a value in the past so it can be selected for exploration + providerOptimizer.appendRelayData(providersGen.providersAddresses[chosenIndex], TEST_BASE_WORLD_LATENCY*2, false, true, requestCU, syncBlock, time.Now().Add(-35*time.Second)) + // set a basic state for all other provider, with a recent time (so they can't be selected for exploration) + for i := 0; i < 10; i++ { + for index, address := range providersGen.providersAddresses { + if index == chosenIndex { + // we set chosenIndex with a past time so it can be selected for exploration + continue + } + // set samples in the future so they are never a candidate for exploration + providerOptimizer.appendRelayData(address, TEST_BASE_WORLD_LATENCY*2, false, true, requestCU, syncBlock, time.Now().Add(1*time.Second)) + } + time.Sleep(4 * time.Millisecond) + } + + // with a cost strategy we expect exploration at a 10% rate + providerOptimizer.strategy = STRATEGY_BALANCED // that's the default but to be explicit + providerOptimizer.wantedNumProvidersInConcurrency = 2 // that's in the constructor but to be explicit iterations := 10000 - exploration := testProvidersCount(iterations) - require.Less(t, exploration, float64(1.3)*float64(iterations*providersCount)*COST_EXPLORATION_CHANCE) // allow mistake buffer of 30% because of randomness + exploration = testProvidersExploration(iterations) + require.Less(t, exploration, float64(1.4)*float64(iterations)*DEFAULT_EXPLORATION_CHANCE) // allow mistake buffer of 40% because of randomness + require.Greater(t, exploration, float64(0.6)*float64(iterations)*DEFAULT_EXPLORATION_CHANCE) // allow mistake buffer of 40% because of randomness - // with a cost strategy we expect only one provider, two with a chance of 10/100 - providerOptimizer.strategy = STRATEGY_BALANCED - exploration = testProvidersCount(iterations) - require.Greater(t, exploration, float64(1.3)*float64(iterations*providersCount)/100.0) - require.Less(t, exploration, float64(1.3)*float64(iterations*providersCount)*DEFAULT_EXPLORATION_CHANCE) // allow mistake buffer of 30% because of randomness + // with a cost strategy we expect exploration to happen once in 100 samples + providerOptimizer.strategy = STRATEGY_COST + exploration = testProvidersExploration(iterations) + require.Less(t, exploration, float64(1.4)*float64(iterations)*COST_EXPLORATION_CHANCE) // allow mistake buffer of 40% because of randomness + require.Greater(t, exploration, float64(0.6)*float64(iterations)*COST_EXPLORATION_CHANCE) // allow mistake buffer of 40% because of randomness + // privacy disables exploration providerOptimizer.strategy = STRATEGY_PRIVACY - exploration = testProvidersCount(iterations) + exploration = testProvidersExploration(iterations) require.Equal(t, exploration, float64(0)) } func TestProviderOptimizerSyncScore(t *testing.T) { providerOptimizer := setupProviderOptimizer(1) providersGen := (&providersGenerator{}).setupProvidersForTest(10) - + rand.InitRandomSeed() requestCU := uint64(10) requestBlock := spectypes.LATEST_BLOCK - pertrubationPercentage := 0.0 + syncBlock := uint64(1000) chosenIndex := rand.Intn(len(providersGen.providersAddresses)) @@ -357,32 +431,38 @@ func TestProviderOptimizerSyncScore(t *testing.T) { sampleTime = sampleTime.Add(time.Millisecond * 5) } time.Sleep(4 * time.Millisecond) - returnedProviders := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[chosenIndex], returnedProviders[0]) // we should pick the best sync score - - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, int64(syncBlock), pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.NotEqual(t, providersGen.providersAddresses[chosenIndex], returnedProviders[0]) // sync score doesn't matter now + selectionTier, _ := providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, nil, requestCU, requestBlock) + tier0 := selectionTier.GetTier(0, 4, 3) + require.Greater(t, len(tier0), 0) // shouldn't be empty + // we have the best score on the top tier and it's sorted + require.Equal(t, providersGen.providersAddresses[chosenIndex], tier0[0].Address) + + // now choose with a specific block that all providers have + selectionTier, _ = providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, nil, requestCU, int64(syncBlock)) + tier0 = selectionTier.GetTier(0, 4, 3) + for idx := range tier0 { + // sync score doesn't matter now so the tier0 is recalculated and chosenIndex has worst latency + require.NotEqual(t, providersGen.providersAddresses[chosenIndex], tier0[idx].Address) + } } func TestProviderOptimizerStrategiesScoring(t *testing.T) { rand.InitRandomSeed() providerOptimizer := setupProviderOptimizer(1) - providersCount := 5 + providersCount := 10 providersGen := (&providersGenerator{}).setupProvidersForTest(providersCount) requestCU := uint64(10) requestBlock := spectypes.LATEST_BLOCK syncBlock := uint64(1000) - pertrubationPercentage := 0.0 - // set a basic state for all of them + + // set a basic state for all providers + sampleTime := time.Now() for i := 0; i < 10; i++ { for _, address := range providersGen.providersAddresses { - providerOptimizer.AppendRelayData(address, TEST_BASE_WORLD_LATENCY*2, false, requestCU, syncBlock) + providerOptimizer.appendRelayData(address, TEST_BASE_WORLD_LATENCY*2, false, true, requestCU, syncBlock, sampleTime) } time.Sleep(4 * time.Millisecond) } - time.Sleep(4 * time.Millisecond) // provider 2 doesn't get a probe availability hit, this is the most meaningful factor for idx, address := range providersGen.providersAddresses { if idx != 2 { @@ -399,7 +479,7 @@ func TestProviderOptimizerStrategiesScoring(t *testing.T) { time.Sleep(4 * time.Millisecond) } - sampleTime := time.Now() + sampleTime = time.Now() improvedLatency := 280 * time.Millisecond normalLatency := TEST_BASE_WORLD_LATENCY * 2 improvedBlock := syncBlock + 1 @@ -423,97 +503,39 @@ func TestProviderOptimizerStrategiesScoring(t *testing.T) { time.Sleep(4 * time.Millisecond) providerOptimizer.strategy = STRATEGY_BALANCED // a balanced strategy should pick provider 2 because of it's high availability - returnedProviders := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[2], returnedProviders[0]) + selectionTier, _ := providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, nil, requestCU, requestBlock) + tier0 := selectionTier.GetTier(0, 4, 3) + require.Greater(t, len(tier0), 0) // shouldn't be empty + // we have the best score on the top tier and it's sorted + require.Equal(t, providersGen.providersAddresses[2], tier0[0].Address) providerOptimizer.strategy = STRATEGY_COST // with a cost strategy we expect the same as balanced - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[2], returnedProviders[0]) + selectionTier, _ = providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, nil, requestCU, requestBlock) + tier0 = selectionTier.GetTier(0, 4, 3) + require.Greater(t, len(tier0), 0) // shouldn't be empty + // we have the best score on the top tier and it's sorted + require.Equal(t, providersGen.providersAddresses[2], tier0[0].Address) providerOptimizer.strategy = STRATEGY_LATENCY // latency strategy should pick the best latency - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, map[string]struct{}{providersGen.providersAddresses[2]: {}}, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[0], returnedProviders[0]) + selectionTier, _ = providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, map[string]struct{}{providersGen.providersAddresses[2]: {}}, requestCU, requestBlock) + tier0 = selectionTier.GetTier(0, 4, 3) + require.Greater(t, len(tier0), 0) // shouldn't be empty + require.Equal(t, providersGen.providersAddresses[0], tier0[0].Address) providerOptimizer.strategy = STRATEGY_SYNC_FRESHNESS // freshness strategy should pick the most advanced provider - returnedProviders = providerOptimizer.ChooseProvider(providersGen.providersAddresses, map[string]struct{}{providersGen.providersAddresses[2]: {}}, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - require.Equal(t, providersGen.providersAddresses[1], returnedProviders[0]) -} - -func TestPerturbation(t *testing.T) { - origValue1 := 1.0 - origValue2 := 0.5 - pertrubationPercentage := 0.03 // this is statistical and we don;t want this failing - runs := 100000 - success := 0 - for i := 0; i < runs; i++ { - res1 := pertrubWithNormalGaussian(origValue1, pertrubationPercentage) - res2 := pertrubWithNormalGaussian(origValue2, pertrubationPercentage) - if res1 > res2 { - success++ - } - } - require.GreaterOrEqual(t, float64(success), float64(runs)*0.9) -} - -// TODO: fix this test "22" is not less than "10" -func TestProviderOptimizerPerturbation(t *testing.T) { - providerOptimizer := setupProviderOptimizer(1) - providersCount := 100 - providersGen := (&providersGenerator{}).setupProvidersForTest(providersCount) - requestCU := uint64(10) - requestBlock := spectypes.LATEST_BLOCK - syncBlock := uint64(1000) - pertrubationPercentage := 0.03 // this is statistical and we don't want this failing - // set a basic state for all of them - sampleTime := time.Now() - for i := 0; i < 10; i++ { - for idx, address := range providersGen.providersAddresses { - if idx < len(providersGen.providersAddresses)/2 { - // first half are good - providerOptimizer.appendRelayData(address, TEST_BASE_WORLD_LATENCY, false, true, requestCU, syncBlock, sampleTime) - } else { - // second half are bad - providerOptimizer.appendRelayData(address, TEST_BASE_WORLD_LATENCY*10, false, true, requestCU, syncBlock, sampleTime) - } - } - sampleTime = sampleTime.Add(time.Millisecond * 5) - time.Sleep(4 * time.Millisecond) // let the cache add the entries - } - seed := time.Now().UnixNano() // constant seed. - // seed := int64(XXX) // constant seed. - for _, providerAddress := range providersGen.providersAddresses { - _, found := providerOptimizer.getProviderData(providerAddress) - require.True(t, found, providerAddress) - } - t.Logf("rand seed %d", seed) - same := 0 - pickFaults := 0 - chosenProvider := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, 0)[0] - runs := 1000 - // runs := 0 - for i := 0; i < runs; i++ { - returnedProviders := providerOptimizer.ChooseProvider(providersGen.providersAddresses, nil, requestCU, requestBlock, pertrubationPercentage) - require.Equal(t, 1, len(returnedProviders)) - if chosenProvider == returnedProviders[0] { - same++ - } - for idx, address := range providersGen.providersAddresses { - if address == returnedProviders[0] && idx >= len(providersGen.providersAddresses)/2 { - t.Logf("picked provider %s at index %d i: %d", returnedProviders[0], idx, i) - pickFaults++ - break - } - } - } - require.Less(t, float64(pickFaults), float64(runs)*0.01) - require.Less(t, same, runs/10) + selectionTier, _ = providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, map[string]struct{}{providersGen.providersAddresses[2]: {}}, requestCU, requestBlock) + tier0 = selectionTier.GetTier(0, 4, 3) + require.Greater(t, len(tier0), 0) // shouldn't be empty + require.Equal(t, providersGen.providersAddresses[1], tier0[0].Address) + + // but if we request a past block, then it doesnt matter and we choose by latency: + selectionTier, _ = providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, map[string]struct{}{providersGen.providersAddresses[2]: {}}, requestCU, int64(syncBlock)) + tier0 = selectionTier.GetTier(0, 4, 3) + require.Greater(t, len(tier0), 0) // shouldn't be empty + require.Equal(t, providersGen.providersAddresses[0], tier0[0].Address) } func TestExcellence(t *testing.T) { @@ -576,3 +598,165 @@ func TestPerturbationWithNormalGaussianOnConcurrentComputation(t *testing.T) { wg.Wait() fmt.Println("Test completed successfully") } + +// test low providers count 0-9 +func TestProviderOptimizerProvidersCount(t *testing.T) { + rand.InitRandomSeed() + providerOptimizer := setupProviderOptimizer(1) + providersCount := 10 + providersGen := (&providersGenerator{}).setupProvidersForTest(providersCount) + requestCU := uint64(10) + requestBlock := spectypes.LATEST_BLOCK + syncBlock := uint64(1000) + sampleTime := time.Now() + for i := 0; i < 10; i++ { + for _, address := range providersGen.providersAddresses { + providerOptimizer.appendRelayData(address, TEST_BASE_WORLD_LATENCY*2, false, true, requestCU, syncBlock, sampleTime) + } + time.Sleep(4 * time.Millisecond) + } + playbook := []struct { + name string + providers int + }{ + { + name: "one", + providers: 1, + }, + { + name: "two", + providers: 2, + }, + { + name: "three", + providers: 3, + }, + { + name: "four", + providers: 4, + }, + { + name: "five", + providers: 5, + }, + { + name: "six", + providers: 6, + }, + { + name: "seven", + providers: 7, + }, + { + name: "eight", + providers: 8, + }, + { + name: "nine", + providers: 9, + }, + } + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + for i := 0; i < 10; i++ { + returnedProviders, _ := providerOptimizer.ChooseProvider(providersGen.providersAddresses[:play.providers], nil, requestCU, requestBlock) + require.Greater(t, len(returnedProviders), 0) + } + }) + } +} + +func TestProviderOptimizerWeights(t *testing.T) { + rand.InitRandomSeed() + providerOptimizer := setupProviderOptimizer(1) + providersCount := 10 + providersGen := (&providersGenerator{}).setupProvidersForTest(providersCount) + requestCU := uint64(10) + requestBlock := spectypes.LATEST_BLOCK + syncBlock := uint64(1000) + sampleTime := time.Now() + weights := map[string]int64{ + providersGen.providersAddresses[0]: 10000000000000, // simulating 10m tokens + } + for i := 1; i < 10; i++ { + weights[providersGen.providersAddresses[i]] = 50000000000 + } + + normalLatency := TEST_BASE_WORLD_LATENCY * 2 + improvedLatency := normalLatency - 5*time.Millisecond + improvedBlock := syncBlock + 2 + + providerOptimizer.UpdateWeights(weights) + for i := 0; i < 10; i++ { + for idx, address := range providersGen.providersAddresses { + if idx == 0 { + providerOptimizer.appendRelayData(address, normalLatency, false, true, requestCU, improvedBlock, sampleTime) + } else { + providerOptimizer.appendRelayData(address, improvedLatency, false, true, requestCU, syncBlock, sampleTime) + } + sampleTime = sampleTime.Add(5 * time.Millisecond) + time.Sleep(4 * time.Millisecond) + } + } + + // verify 0 has the best score + selectionTier, _ := providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, nil, requestCU, requestBlock) + tier0 := selectionTier.GetTier(0, 4, 3) + require.Greater(t, len(tier0), 0) // shouldn't be empty + require.Equal(t, providersGen.providersAddresses[0], tier0[0].Address) + + // if we pick by sync, provider 0 is in the top tier and should be selected very often + results, tierResults := runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, nil, requestCU, requestBlock, 1000) + require.Greater(t, tierResults[0], 600, tierResults) // we should pick the best tier most often + // out of 10 providers, and with 3 in the top tier we should pick 0 around a third of that + require.Greater(t, results[providersGen.providersAddresses[0]], 550, results) // we should pick the top provider in tier 0 most times due to weight + + // if we pick by latency only, provider 0 is in the worst tier and can't be selected at all + results, tierResults = runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, nil, requestCU, int64(syncBlock), 1000) + require.Greater(t, tierResults[0], 500, tierResults) // we should pick the best tier most often + // out of 10 providers, and with 3 in the top tier we should pick 0 around a third of that + require.Zero(t, results[providersGen.providersAddresses[0]]) +} + +func TestProviderOptimizerTiers(t *testing.T) { + rand.InitRandomSeed() + + providersCountList := []int{9, 10} + for why, providersCount := range providersCountList { + providerOptimizer := setupProviderOptimizer(1) + providersGen := (&providersGenerator{}).setupProvidersForTest(providersCount) + requestCU := uint64(10) + requestBlock := spectypes.LATEST_BLOCK + syncBlock := uint64(1000) + sampleTime := time.Now() + normalLatency := TEST_BASE_WORLD_LATENCY * 2 + for i := 0; i < 10; i++ { + for _, address := range providersGen.providersAddresses { + modifierLatency := rand.Int63n(3) - 1 + modifierSync := rand.Int63n(3) - 1 + providerOptimizer.appendRelayData(address, normalLatency+time.Duration(modifierLatency)*time.Millisecond, false, true, requestCU, syncBlock+uint64(modifierSync), sampleTime) + sampleTime = sampleTime.Add(5 * time.Millisecond) + time.Sleep(4 * time.Millisecond) + } + } + selectionTier, _ := providerOptimizer.CalculateSelectionTiers(providersGen.providersAddresses, nil, requestCU, requestBlock) + shiftedChances := selectionTier.ShiftTierChance(4, map[int]float64{0: 0.75}) + require.NotZero(t, shiftedChances[3]) + // if we pick by sync, provider 0 is in the top tier and should be selected very often + _, tierResults := runChooseManyTimesAndReturnResults(t, providerOptimizer, providersGen.providersAddresses, nil, requestCU, requestBlock, 1000) + for index := 0; index < OptimizerNumTiers; index++ { + if providersCount >= 2*MinimumEntries && index == OptimizerNumTiers-1 { + // skip last tier if there's insufficient providers + continue + } + require.NotZero(t, tierResults[index], "tierResults %v providersCount %s index %d why: %d", tierResults, providersCount, index, why) + } + } +} + +// TODO: new tests we need: +// check 3 providers, one with great stake one with great score +// retries: groups getting smaller +// no possible selections full +// do a simulation with better and worse providers, make sure it's good +// TODO: Oren - check optimizer selection with defaults (no scores for some of the providers) diff --git a/protocol/provideroptimizer/selection_tier.go b/protocol/provideroptimizer/selection_tier.go new file mode 100644 index 000000000..1a239fdf8 --- /dev/null +++ b/protocol/provideroptimizer/selection_tier.go @@ -0,0 +1,204 @@ +package provideroptimizer + +import ( + "math" + + "github.com/lavanet/lava/v3/utils" + "github.com/lavanet/lava/v3/utils/lavaslices" + "github.com/lavanet/lava/v3/utils/rand" +) + +type Entry struct { + Address string + Score float64 + Part float64 +} + +// selectionTier is a utility to get a tier of addresses based on their scores +type SelectionTier interface { + AddScore(entry string, score float64) + GetTier(tier int, numTiers int, minimumEntries int) []Entry + SelectTierRandomly(numTiers int, tierChances map[int]float64) int + ShiftTierChance(numTiers int, initialYierChances map[int]float64) map[int]float64 + ScoresCount() int +} + +type SelectionTierInst struct { + scores []Entry +} + +func NewSelectionTier() SelectionTier { + return &SelectionTierInst{scores: []Entry{}} +} + +func (st *SelectionTierInst) ScoresCount() int { + return len(st.scores) +} + +func (st *SelectionTierInst) AddScore(entry string, score float64) { + // add the score to the scores list for the entry while keeping it sorted in ascending order + // this means that the highest score will be at the front of the list, tier 0 is highest scores + newEntry := Entry{Address: entry, Score: score, Part: 1} + // find the correct position to insert the new entry + + for i, existingEntry := range st.scores { + if existingEntry.Address == entry { + // overwrite the existing entry + st.scores[i].Score = score + return + } + if score <= existingEntry.Score { + st.scores = append(st.scores[:i], append([]Entry{newEntry}, st.scores[i:]...)...) + return + } + } + // it's not smaller than any existing entry, so add it to the end + st.scores = append(st.scores, newEntry) +} + +func (st *SelectionTierInst) SelectTierRandomly(numTiers int, tierChances map[int]float64) int { + // select a tier randomly based on the chances given + // if the chances are not given, select a tier randomly based on the number of tiers + if len(tierChances) == 0 || len(tierChances) > numTiers { + utils.LavaFormatError("Invalid tier chances usage", nil, utils.LogAttr("tierChances", tierChances), utils.LogAttr("numTiers", numTiers)) + return rand.Intn(numTiers) + } + // calculate the total chance + chanceForDefaultTiers := st.calcChanceForDefaultTiers(tierChances, numTiers) + // select a random number between 0 and 1 + randChance := rand.Float64() + // find the tier that the random chance falls into + currentChance := 0.0 + for i := 0; i < numTiers; i++ { + if chance, ok := tierChances[i]; ok { + currentChance += chance + } else { + currentChance += chanceForDefaultTiers + } + if randChance < currentChance { + return i + } + } + // default, should never happen + return 0 +} + +func (*SelectionTierInst) calcChanceForDefaultTiers(tierChances map[int]float64, numTiers int) float64 { + if numTiers <= len(tierChances) { + return 0 + } + totalChance := 0.0 + for _, chance := range tierChances { + totalChance += chance + } + // rounding errors can happen + if totalChance > 1 { + totalChance = 1 + } + chanceForDefaultTiers := (1 - totalChance) / float64(numTiers-len(tierChances)) + return chanceForDefaultTiers +} + +func (st *SelectionTierInst) averageScoreForTier(tier int, numTiers int) float64 { + // calculate the average score for the given tier and number of tiers + start, end, _, _ := getPositionsForTier(tier, numTiers, len(st.scores)) + sum := 0.0 + parts := 0.0 + for i := start; i < end; i++ { + sum += st.scores[i].Score * st.scores[i].Part + parts += st.scores[i].Part + } + return sum / parts +} + +func (st *SelectionTierInst) ShiftTierChance(numTiers int, initialTierChances map[int]float64) map[int]float64 { + if len(st.scores) == 0 { + return initialTierChances + } + chanceForDefaultTiers := st.calcChanceForDefaultTiers(initialTierChances, numTiers) + + // shift the chances + shiftedTierChances := make(map[int]float64) + // shift tier chances based on the difference in the average score of each tier + scores := make([]float64, numTiers) + for i := 0; i < numTiers; i++ { + // scores[i] = 1 / (st.averageScoreForTier(i, numTiers) + 0.0001) // add epsilon to avoid 0 + scores[i] = st.averageScoreForTier(i, numTiers) + } + medianScore := lavaslices.Median(scores) + medianScoreReversed := 1 / (medianScore + 0.0001) + percentile25Score := lavaslices.Percentile(scores, 0.25) + percentile25ScoreReversed := 1 / (percentile25Score + 0.0001) + + averageChance := 1 / float64(numTiers) + for i := 0; i < numTiers; i++ { + // reverse the score so that higher scores get higher chances + reversedScore := 1 / (scores[i] + 0.0001) + // offset the score based on the median and 75th percentile scores, the better they are compared to them the higher the chance + offsetFactor := 0.5*math.Pow(reversedScore/medianScoreReversed, 2) + 0.5*math.Pow(reversedScore/percentile25ScoreReversed, 2) + if _, ok := initialTierChances[i]; !ok { + if chanceForDefaultTiers > 0 { + shiftedTierChances[i] = chanceForDefaultTiers + averageChance*offsetFactor + } + } else { + if initialTierChances[i] > 0 { + shiftedTierChances[i] = initialTierChances[i] + averageChance*offsetFactor + } + } + } + // normalize the chances + totalChance := 0.0 + for _, chance := range shiftedTierChances { + totalChance += chance + } + for i := 0; i < numTiers; i++ { + shiftedTierChances[i] /= totalChance + } + return shiftedTierChances +} + +func (st *SelectionTierInst) GetTier(tier int, numTiers int, minimumEntries int) []Entry { + // get the tier of scores for the given tier and number of tiers + entriesLen := len(st.scores) + if entriesLen < minimumEntries || numTiers == 0 || tier >= numTiers { + return st.scores + } + + start, end, fracStart, fracEnd := getPositionsForTier(tier, numTiers, entriesLen) + if end < minimumEntries { + // only allow better tiers if there are not enough entries + return st.scores[:end] + } + ret := st.scores[start:end] + if len(ret) >= minimumEntries { + // apply the relative parts to the first and last entries + ret[0].Part = 1 - fracStart + ret[len(ret)-1].Part = fracEnd + return ret + } + // bring in entries from better tiers if insufficient, give them a handicap to weight + // end is > minimumEntries, and end - start < minimumEntries + entriesToTake := minimumEntries - len(ret) + entriesToTakeStart := start - entriesToTake + copiedEntries := st.scores[entriesToTakeStart:start] + entriesToAdd := make([]Entry, len(copiedEntries)) + copy(entriesToAdd, copiedEntries) + for i := range entriesToAdd { + entriesToAdd[i].Part = 0.5 + } + ret = append(entriesToAdd, ret...) + return ret +} + +func getPositionsForTier(tier int, numTiers int, entriesLen int) (start int, end int, fracStart float64, fracEnd float64) { + rankStart := float64(tier) / float64(numTiers) + rankEnd := float64(tier+1) / float64(numTiers) + // Calculate the position based on the rank + startPositionF := (float64(entriesLen-1) * rankStart) + endPositionF := (float64(entriesLen-1) * rankEnd) + + positionStart := int(startPositionF) + positionEnd := int(endPositionF) + 1 + + return positionStart, positionEnd, startPositionF - float64(positionStart), float64(positionEnd) - endPositionF +} diff --git a/protocol/provideroptimizer/selection_tier_test.go b/protocol/provideroptimizer/selection_tier_test.go new file mode 100644 index 000000000..5c4720c44 --- /dev/null +++ b/protocol/provideroptimizer/selection_tier_test.go @@ -0,0 +1,313 @@ +package provideroptimizer + +import ( + "strconv" + "testing" + + "github.com/lavanet/lava/v3/utils/rand" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSelectionTierInst_AddScore(t *testing.T) { + st := &SelectionTierInst{scores: []Entry{}} + + st.AddScore("entry1", 5.0) + st.AddScore("entry2", 3.0) + st.AddScore("entry3", 7.0) + st.AddScore("entry4", 1.0) + st.AddScore("entry5", 8.0) + st.AddScore("entry6", 4.0) + st.AddScore("entry7", 0.5) + + expectedScores := []Entry{ + {Address: "entry7", Score: 0.5, Part: 1}, + {Address: "entry4", Score: 1.0, Part: 1}, + {Address: "entry2", Score: 3.0, Part: 1}, + {Address: "entry6", Score: 4.0, Part: 1}, + {Address: "entry1", Score: 5.0, Part: 1}, + {Address: "entry3", Score: 7.0, Part: 1}, + {Address: "entry5", Score: 8.0, Part: 1}, + } + + assert.Equal(t, expectedScores, st.scores) +} + +func TestSelectionTierInst_GetTier(t *testing.T) { + st := NewSelectionTier() + + st.AddScore("entry1", 0.1) + st.AddScore("entry2", 0.5) + st.AddScore("entry3", 0.7) + st.AddScore("entry4", 0.3) + st.AddScore("entry5", 0.2) + st.AddScore("entry6", 0.9) + + numTiers := 3 + playbook := []struct { + tier int + minimumEntries int + expectedTier []string + name string + }{ + { + tier: 0, + minimumEntries: 2, + expectedTier: []string{"entry1", "entry5"}, + name: "tier 0, 2 entries", + }, + { + tier: 1, + minimumEntries: 2, + expectedTier: []string{"entry5", "entry4", "entry2"}, + name: "tier 1, 2 entries", + }, + { + tier: 2, + minimumEntries: 2, + expectedTier: []string{"entry2", "entry3", "entry6"}, + name: "tier 2, 2 entries", + }, + { + tier: 0, + minimumEntries: 3, + expectedTier: []string{"entry1", "entry5"}, // we can only bring better entries + name: "tier 0, 3 entries", + }, + { + tier: 1, + minimumEntries: 4, + expectedTier: []string{"entry1", "entry5", "entry4", "entry2"}, + name: "tier 1, 4 entries", + }, + { + tier: 1, + minimumEntries: 5, + expectedTier: []string{"entry1", "entry5", "entry4", "entry2"}, // we can only bring better entries + name: "tier 1, 5 entries", + }, + { + tier: 2, + minimumEntries: 4, + expectedTier: []string{"entry4", "entry2", "entry3", "entry6"}, + name: "tier 2, 4 entries", + }, + } + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + result := st.GetTier(play.tier, numTiers, play.minimumEntries) + require.Equal(t, len(play.expectedTier), len(result), result) + for i, entry := range play.expectedTier { + assert.Equal(t, entry, result[i].Address, "result %v, expected: %v", result, play.expectedTier) + } + for i := 1; i < len(result); i++ { + assert.LessOrEqual(t, result[i-1].Score, result[i].Score) + } + }) + } +} + +func TestSelectionTierInstGetTierBig(t *testing.T) { + st := NewSelectionTier() + + for i := 0; i < 25; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1+0.0001*float64(i)) + } + for i := 25; i < 50; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.2+0.0001*float64(i)) + } + for i := 50; i < 75; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.3+0.0001*float64(i)) + } + for i := 75; i < 100; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.4+0.0001*float64(i)) + } + + numTiers := 4 + playbook := []struct { + tier int + minimumEntries int + expectedTierLen int + name string + }{ + { + tier: 0, + minimumEntries: 5, + expectedTierLen: 25, + name: "tier 0, 25 entries", + }, + { + tier: 1, + minimumEntries: 5, + expectedTierLen: 26, + name: "tier 1, 26 entries", + }, + { + tier: 2, + minimumEntries: 5, + expectedTierLen: 26, + name: "tier 2, 26 entries", + }, + { + tier: 3, + minimumEntries: 5, + expectedTierLen: 26, + name: "tier 3, 26 entries", + }, + { + tier: 0, + minimumEntries: 26, + expectedTierLen: 25, // we can't bring entries from lower tiers + name: "tier 0, 26 entries", + }, + } + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + result := st.GetTier(play.tier, numTiers, play.minimumEntries) + require.Equal(t, play.expectedTierLen, len(result), result) + for i := 1; i < len(result); i++ { + assert.LessOrEqual(t, result[i-1].Score, result[i].Score) + } + }) + } +} + +func TestSelectionTierInstShiftTierChance(t *testing.T) { + st := NewSelectionTier() + numTiers := 4 + for i := 0; i < 25; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + for i := 25; i < 50; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + for i := 50; i < 75; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + for i := 75; i < 100; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + selectionTierChances := st.ShiftTierChance(numTiers, nil) + require.Equal(t, numTiers, len(selectionTierChances)) + require.Equal(t, selectionTierChances[0], selectionTierChances[1]) + + selectionTierChances = st.ShiftTierChance(numTiers, map[int]float64{0: 0.5, 1: 0.5}) + require.Equal(t, 0.0, selectionTierChances[len(selectionTierChances)-1]) + require.Equal(t, 0.5, selectionTierChances[0]) + + selectionTierChances = st.ShiftTierChance(numTiers, map[int]float64{0: 0.5, len(selectionTierChances) - 1: 0.1}) + require.Less(t, selectionTierChances[0], 0.5) + require.Greater(t, selectionTierChances[0], 0.25) + require.Greater(t, selectionTierChances[len(selectionTierChances)-1], 0.1) + + st = NewSelectionTier() + for i := 0; i < 25; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + for i := 25; i < 50; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.2) + } + for i := 50; i < 75; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.3) + } + for i := 75; i < 100; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.4) + } + selectionTierChances = st.ShiftTierChance(numTiers, nil) + require.Equal(t, numTiers, len(selectionTierChances)) + require.Greater(t, selectionTierChances[0], selectionTierChances[1]) + require.Greater(t, selectionTierChances[1]*3, selectionTierChances[0]) // make sure the adjustment is not that strong + require.Greater(t, selectionTierChances[1], selectionTierChances[2]) + + st = NewSelectionTier() + for i := 0; i < 25; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.01) + } + for i := 25; i < 50; i++ { + st.AddScore("entry"+strconv.Itoa(i), 1.2) + } + for i := 50; i < 75; i++ { + st.AddScore("entry"+strconv.Itoa(i), 1.3) + } + for i := 75; i < 100; i++ { + st.AddScore("entry"+strconv.Itoa(i), 1.4) + } + selectionTierChances = st.ShiftTierChance(numTiers, nil) + require.Equal(t, numTiers, len(selectionTierChances)) + require.Greater(t, selectionTierChances[0], 0.9) + + st = NewSelectionTier() + + for i := 25; i < 50; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.5) + } + for i := 0; i < 25; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.1) + } + for i := 50; i < 75; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.5) + } + for i := 75; i < 100; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.5) + } + selectionTierChances = st.ShiftTierChance(numTiers, nil) + require.Equal(t, numTiers, len(selectionTierChances)) + require.Greater(t, selectionTierChances[0], selectionTierChances[1]*2.5) // make sure the adjustment is strong enough + require.Greater(t, selectionTierChances[1]*10, selectionTierChances[0]) // but not too much + + selectionTierChances = st.ShiftTierChance(numTiers, map[int]float64{0: 0.5}) + require.Equal(t, numTiers, len(selectionTierChances)) + require.Greater(t, selectionTierChances[0], 0.5) // make sure the adjustment increases the base chance + require.Less(t, selectionTierChances[1], (1-0.5)/float64(numTiers-1), selectionTierChances) // and reduces it for lesser tiers +} + +func TestSelectionTierInstShiftTierChance_MaintainTopTierAdvantage(t *testing.T) { + st := NewSelectionTier() + numTiers := 4 + for i := 0; i < 3; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.195) + } + for i := 3; i < 10; i++ { + st.AddScore("entry"+strconv.Itoa(i), 0.399) + } + + selectionTierChances := st.ShiftTierChance(numTiers, map[int]float64{0: 0.75, numTiers - 1: 0}) + require.InDelta(t, 0.75, selectionTierChances[0], 0.1) +} + +func TestSelectionTierInst_SelectTierRandomly(t *testing.T) { + st := NewSelectionTier() + rand.InitRandomSeed() + numTiers := 5 + counter := map[int]int{} + for i := 0; i < 10000; i++ { + tier := st.SelectTierRandomly(numTiers, map[int]float64{0: 0.8, 4: 0}) + counter[tier]++ + assert.GreaterOrEqual(t, tier, 0) + assert.Less(t, tier, numTiers) + } + + require.Zero(t, counter[4]) + for i := 1; i < 4; i++ { + require.Greater(t, counter[i], 100) + } + require.Greater(t, counter[0], 7000) +} + +func TestSelectionTierInst_SelectTierRandomly_Default(t *testing.T) { + st := NewSelectionTier() + rand.InitRandomSeed() + numTiers := 5 + counter := map[int]int{} + for i := 0; i < 10000; i++ { + tier := st.SelectTierRandomly(numTiers, st.ShiftTierChance(numTiers, nil)) + counter[tier]++ + assert.GreaterOrEqual(t, tier, 0) + assert.Less(t, tier, numTiers) + } + + expectedDistribution := 10000 / numTiers + for _, count := range counter { + assert.InDelta(t, expectedDistribution, count, 300) + } +} diff --git a/protocol/provideroptimizer/selection_weight.go b/protocol/provideroptimizer/selection_weight.go new file mode 100644 index 000000000..25e9e0c19 --- /dev/null +++ b/protocol/provideroptimizer/selection_weight.go @@ -0,0 +1,69 @@ +package provideroptimizer + +import ( + "sync" + + "github.com/lavanet/lava/v3/utils" + "github.com/lavanet/lava/v3/utils/rand" +) + +// SelectionWeighter is a utility to select an address based on a weight. +type SelectionWeighter interface { + Weight(address string) int64 + SetWeights(weights map[string]int64) + WeightedChoice(possibilities []Entry) string +} + +type selectionWeighterInst struct { + lock sync.RWMutex + weights map[string]int64 +} + +func NewSelectionWeighter() SelectionWeighter { + return &selectionWeighterInst{ + weights: make(map[string]int64), + } +} + +func (sw *selectionWeighterInst) Weight(address string) int64 { + sw.lock.RLock() + defer sw.lock.RUnlock() + weight, ok := sw.weights[address] + if !ok { + // default weight is 1 + return 1 + } + return weight +} + +func (sw *selectionWeighterInst) SetWeights(weights map[string]int64) { + sw.lock.Lock() + defer sw.lock.Unlock() + for address, weight := range weights { + sw.weights[address] = weight + } +} + +func (sw *selectionWeighterInst) WeightedChoice(entries []Entry) string { + if len(entries) == 0 { + return "" + } + sw.lock.RLock() + defer sw.lock.RUnlock() + totalWeight := int64(0) + for _, entry := range entries { + totalWeight += int64(float64(sw.Weight(entry.Address)) * entry.Part) + } + randWeight := rand.Int63n(totalWeight) + currentWeight := int64(0) + for _, entry := range entries { + currentWeight += int64(float64(sw.Weight(entry.Address)) * entry.Part) + if currentWeight > randWeight { + return entry.Address + } + } + utils.LavaFormatError("invalid weighted choice, no address chosen, fallback to last one", nil, utils.LogAttr("addresses", entries), + utils.LogAttr("totalWeight", totalWeight)) + // Fallback to the last address if no address is selected + return entries[len(entries)-1].Address +} diff --git a/protocol/provideroptimizer/selection_weight_test.go b/protocol/provideroptimizer/selection_weight_test.go new file mode 100644 index 000000000..919692050 --- /dev/null +++ b/protocol/provideroptimizer/selection_weight_test.go @@ -0,0 +1,98 @@ +package provideroptimizer + +import ( + "testing" + + "github.com/lavanet/lava/v3/utils/rand" + "github.com/stretchr/testify/assert" +) + +func TestNewSelectionWeighter(t *testing.T) { + sw := NewSelectionWeighter() + assert.NotNil(t, sw) +} + +func TestWeight(t *testing.T) { + sw := NewSelectionWeighter() + weights := map[string]int64{ + "address1": 10, + "address2": 20, + } + sw.SetWeights(weights) + + assert.Equal(t, int64(10), sw.Weight("address1")) + assert.Equal(t, int64(20), sw.Weight("address2")) + assert.Equal(t, int64(1), sw.Weight("address3")) // address not set + + weights = map[string]int64{ + "address1": 25, + "address3": 30, + } + sw.SetWeights(weights) + + assert.Equal(t, int64(25), sw.Weight("address1")) + assert.Equal(t, int64(20), sw.Weight("address2")) + assert.Equal(t, int64(30), sw.Weight("address3")) // address not set +} + +func TestWeightedChoice(t *testing.T) { + sw := NewSelectionWeighter() + rand.InitRandomSeed() + weights := map[string]int64{ + "address1": 10, + "address2": 20, + "address3": 30, + } + sw.SetWeights(weights) + + // Create entries based on weights + entries := []Entry{ + {Address: "address1", Part: 1}, + {Address: "address2", Part: 1}, + {Address: "address3", Part: 1}, + } + + // Run the weighted choice multiple times to check distribution + results := make(map[string]int) + for i := 0; i < 1000; i++ { + choice := sw.WeightedChoice(entries) + results[choice]++ + } + + // Check that each address was chosen at least once + assert.Greater(t, results["address1"], 0) + assert.Greater(t, results["address2"], 0) + assert.Greater(t, results["address3"], 0) + + weights = map[string]int64{ + "address1": 800, + "address2": 100, + "address3": 100, + } + sw.SetWeights(weights) + results = make(map[string]int) + for i := 0; i < 10000; i++ { + choice := sw.WeightedChoice(entries) + results[choice]++ + } + // Check that address1 is chosen most of the time + assert.Greater(t, results["address1"], 7000) + assert.InDelta(t, 1000, results["address2"], 400) + assert.InDelta(t, 1000, results["address3"], 400) + + weights = map[string]int64{ + "address1": 100, + "address2": 800, + "address3": 100, + } + sw.SetWeights(weights) + results = make(map[string]int) + for i := 0; i < 10000; i++ { + choice := sw.WeightedChoice(entries) + results[choice]++ + } + // Check that address1 is chosen most of the time + assert.Greater(t, results["address2"], 7000) + assert.InDelta(t, 1000, results["address1"], 400) + assert.InDelta(t, 1000, results["address3"], 400) +} diff --git a/protocol/rpcconsumer/consumer_relay_state_machine.go b/protocol/rpcconsumer/consumer_relay_state_machine.go index 348757245..b7fd41f68 100644 --- a/protocol/rpcconsumer/consumer_relay_state_machine.go +++ b/protocol/rpcconsumer/consumer_relay_state_machine.go @@ -114,7 +114,6 @@ func (rssi *RelayStateSendInstructions) IsDone() bool { func (crsm *ConsumerRelayStateMachine) GetRelayTaskChannel() chan RelayStateSendInstructions { relayTaskChannel := make(chan RelayStateSendInstructions) go func() { - batchNumber := 0 // Set batch number // A channel to be notified processing was done, true means we have results and can return gotResults := make(chan bool, 1) processingTimeout, relayTimeout := crsm.relaySender.getProcessingTimeout(crsm.GetProtocolMessage()) @@ -127,7 +126,7 @@ func (crsm *ConsumerRelayStateMachine) GetRelayTaskChannel() chan RelayStateSend readResultsFromProcessor := func() { // ProcessResults is reading responses while blocking until the conditions are met - utils.LavaFormatTrace("[StateMachine] Waiting for results", utils.LogAttr("batch", batchNumber)) + utils.LavaFormatTrace("[StateMachine] Waiting for results", utils.LogAttr("batch", crsm.usedProviders.BatchNumber())) crsm.parentRelayProcessor.WaitForResults(processingCtx) // Decide if we need to resend or not if crsm.parentRelayProcessor.HasRequiredNodeResults() { @@ -140,11 +139,11 @@ func (crsm *ConsumerRelayStateMachine) GetRelayTaskChannel() chan RelayStateSend returnCondition := make(chan error, 1) // Used for checking whether to return an error to the user or to allow other channels return their result first see detailed description on the switch case below validateReturnCondition := func(err error) { - batchOnStart := batchNumber + batchOnStart := crsm.usedProviders.BatchNumber() time.Sleep(15 * time.Millisecond) - utils.LavaFormatTrace("[StateMachine] validating return condition", utils.LogAttr("batch", batchNumber)) + utils.LavaFormatTrace("[StateMachine] validating return condition", utils.LogAttr("batch", crsm.usedProviders.BatchNumber())) if batchOnStart == crsm.usedProviders.BatchNumber() && crsm.usedProviders.CurrentlyUsed() == 0 { - utils.LavaFormatTrace("[StateMachine] return condition triggered", utils.LogAttr("batch", batchNumber), utils.LogAttr("err", err)) + utils.LavaFormatTrace("[StateMachine] return condition triggered", utils.LogAttr("batch", crsm.usedProviders.BatchNumber()), utils.LogAttr("err", err)) returnCondition <- err } } @@ -165,16 +164,16 @@ func (crsm *ConsumerRelayStateMachine) GetRelayTaskChannel() chan RelayStateSend // Getting batch update for either errors sending message or successful batches case err := <-crsm.batchUpdate: if err != nil { // Error handling - utils.LavaFormatTrace("[StateMachine] err := <-crsm.batchUpdate", utils.LogAttr("err", err), utils.LogAttr("batch", batchNumber), utils.LogAttr("consecutiveBatchErrors", consecutiveBatchErrors)) + utils.LavaFormatTrace("[StateMachine] err := <-crsm.batchUpdate", utils.LogAttr("err", err), utils.LogAttr("batch", crsm.usedProviders.BatchNumber()), utils.LogAttr("consecutiveBatchErrors", consecutiveBatchErrors)) // Sending a new batch failed (consumer's protocol side), handling the state machine consecutiveBatchErrors++ // Increase consecutive error counter if consecutiveBatchErrors > SendRelayAttempts { // If we failed sending a message more than "SendRelayAttempts" time in a row. - if batchNumber == 0 && consecutiveBatchErrors == SendRelayAttempts+1 { // First relay attempt. print on first failure only. + if crsm.usedProviders.BatchNumber() == 0 && consecutiveBatchErrors == SendRelayAttempts+1 { // First relay attempt. print on first failure only. utils.LavaFormatWarning("Failed Sending First Message", err, utils.LogAttr("consecutive errors", consecutiveBatchErrors)) } go validateReturnCondition(err) // Check if we have ongoing messages pending return. } else { - utils.LavaFormatTrace("[StateMachine] batchUpdate - err != nil - batch fail retry attempt", utils.LogAttr("batch", batchNumber), utils.LogAttr("consecutiveBatchErrors", consecutiveBatchErrors)) + utils.LavaFormatTrace("[StateMachine] batchUpdate - err != nil - batch fail retry attempt", utils.LogAttr("batch", crsm.usedProviders.BatchNumber()), utils.LogAttr("consecutiveBatchErrors", consecutiveBatchErrors)) // Failed sending message, but we still want to attempt sending more. relayTaskChannel <- RelayStateSendInstructions{ protocolMessage: crsm.GetProtocolMessage(), @@ -183,20 +182,10 @@ func (crsm *ConsumerRelayStateMachine) GetRelayTaskChannel() chan RelayStateSend continue } // Successfully sent message. - batchNumber++ // Reset consecutiveBatchErrors consecutiveBatchErrors = 0 - // Batch number validation, should never happen. - if batchNumber != crsm.usedProviders.BatchNumber() { - // Mismatch, return error - relayTaskChannel <- RelayStateSendInstructions{ - err: utils.LavaFormatError("Batch Number mismatch between state machine and used providers", nil, utils.LogAttr("batchNumber", batchNumber), utils.LogAttr("crsm.parentRelayProcessor.usedProviders.BatchNumber()", crsm.usedProviders.BatchNumber())), - done: true, - } - return - } case success := <-gotResults: - utils.LavaFormatTrace("[StateMachine] success := <-gotResults", utils.LogAttr("batch", batchNumber)) + utils.LavaFormatTrace("[StateMachine] success := <-gotResults", utils.LogAttr("batch", crsm.usedProviders.BatchNumber())) // If we had a successful result return what we currently have // Or we are done sending relays, and we have no other relays pending results. if success { // Check wether we can return the valid results or we need to send another relay @@ -204,8 +193,8 @@ func (crsm *ConsumerRelayStateMachine) GetRelayTaskChannel() chan RelayStateSend return } // If should retry == true, send a new batch. (success == false) - if crsm.ShouldRetry(batchNumber) { - utils.LavaFormatTrace("[StateMachine] success := <-gotResults - crsm.ShouldRetry(batchNumber)", utils.LogAttr("batch", batchNumber)) + if crsm.ShouldRetry(crsm.usedProviders.BatchNumber()) { + utils.LavaFormatTrace("[StateMachine] success := <-gotResults - crsm.ShouldRetry(batchNumber)", utils.LogAttr("batch", crsm.usedProviders.BatchNumber())) relayTaskChannel <- RelayStateSendInstructions{protocolMessage: crsm.GetProtocolMessage()} } else { go validateReturnCondition(nil) @@ -213,14 +202,14 @@ func (crsm *ConsumerRelayStateMachine) GetRelayTaskChannel() chan RelayStateSend go readResultsFromProcessor() case <-startNewBatchTicker.C: // Only trigger another batch for non BestResult relays or if we didn't pass the retry limit. - if crsm.ShouldRetry(batchNumber) { - utils.LavaFormatTrace("[StateMachine] ticker triggered", utils.LogAttr("batch", batchNumber)) + if crsm.ShouldRetry(crsm.usedProviders.BatchNumber()) { + utils.LavaFormatTrace("[StateMachine] ticker triggered", utils.LogAttr("batch", crsm.usedProviders.BatchNumber())) relayTaskChannel <- RelayStateSendInstructions{protocolMessage: crsm.GetProtocolMessage()} // Add ticker launch metrics go crsm.tickerMetricSetter.SetRelaySentByNewBatchTickerMetric(crsm.relaySender.GetChainIdAndApiInterface()) } case returnErr := <-returnCondition: - utils.LavaFormatTrace("[StateMachine] returnErr := <-returnCondition", utils.LogAttr("batch", batchNumber)) + utils.LavaFormatTrace("[StateMachine] returnErr := <-returnCondition", utils.LogAttr("batch", crsm.usedProviders.BatchNumber())) // we use this channel because there could be a race condition between us releasing the provider and about to send the return // to an error happening on another relay processor's routine. this can cause an error that returns to the user // if we don't release the case, it will cause the success case condition to not be executed @@ -239,7 +228,7 @@ func (crsm *ConsumerRelayStateMachine) GetRelayTaskChannel() chan RelayStateSend utils.LogAttr("consumerIp", userData.ConsumerIp), utils.LogAttr("protocolMessage.GetApi().Name", crsm.GetProtocolMessage().GetApi().Name), utils.LogAttr("GUID", crsm.ctx), - utils.LogAttr("batchNumber", batchNumber), + utils.LogAttr("batchNumber", crsm.usedProviders.BatchNumber()), utils.LogAttr("consecutiveBatchErrors", consecutiveBatchErrors), ) // returning the context error diff --git a/protocol/rpcconsumer/policies_map.go b/protocol/rpcconsumer/policies_map.go deleted file mode 100644 index d70d2de3d..000000000 --- a/protocol/rpcconsumer/policies_map.go +++ /dev/null @@ -1,47 +0,0 @@ -package rpcconsumer - -import ( - "sync" - - "github.com/lavanet/lava/v3/protocol/statetracker/updaters" - "github.com/lavanet/lava/v3/utils" -) - -type syncMapPolicyUpdaters struct { - localMap sync.Map -} - -func (sm *syncMapPolicyUpdaters) Store(key string, toSet *updaters.PolicyUpdater) { - sm.localMap.Store(key, toSet) -} - -func (sm *syncMapPolicyUpdaters) Load(key string) (ret *updaters.PolicyUpdater, ok bool) { - value, ok := sm.localMap.Load(key) - if !ok { - return nil, ok - } - ret, ok = value.(*updaters.PolicyUpdater) - if !ok { - utils.LavaFormatFatal("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil) - } - return ret, true -} - -// LoadOrStore returns the existing value for the key if present. -// Otherwise, it stores and returns the given value. -// The loaded result is true if the value was loaded, false if stored. -// The function returns the value that was loaded or stored. -func (sm *syncMapPolicyUpdaters) LoadOrStore(key string, value *updaters.PolicyUpdater) (ret *updaters.PolicyUpdater, loaded bool) { - actual, loaded := sm.localMap.LoadOrStore(key, value) - if loaded { - // loaded from map - ret, loaded = actual.(*updaters.PolicyUpdater) - if !loaded { - utils.LavaFormatFatal("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil) - } - return ret, loaded - } - - // stored in map - return value, false -} diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 67f4a2446..c91967873 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -174,12 +174,15 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt for _, endpoint := range options.rpcEndpoints { chainMutexes[endpoint.ChainID] = &sync.Mutex{} // create a mutex per chain for shared resources } - var optimizers sync.Map - var consumerConsistencies sync.Map - var finalizationConsensuses sync.Map + + optimizers := &common.SafeSyncMap[string, *provideroptimizer.ProviderOptimizer]{} + consumerConsistencies := &common.SafeSyncMap[string, *ConsumerConsistency]{} + finalizationConsensuses := &common.SafeSyncMap[string, *finalizationconsensus.FinalizationConsensus]{} + var wg sync.WaitGroup parallelJobs := len(options.rpcEndpoints) wg.Add(parallelJobs) + errCh := make(chan error) consumerStateTracker.RegisterForUpdates(ctx, updaters.NewMetricsUpdater(consumerMetricsManager)) @@ -193,7 +196,7 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt } consumerStateTracker.RegisterForVersionUpdates(ctx, version.Version, &upgrade.ProtocolVersion{}) relaysMonitorAggregator := metrics.NewRelaysMonitorAggregator(options.cmdFlags.RelaysHealthIntervalFlag, consumerMetricsManager) - policyUpdaters := syncMapPolicyUpdaters{} + policyUpdaters := &common.SafeSyncMap[string, *updaters.PolicyUpdater]{} for _, rpcEndpoint := range options.rpcEndpoints { go func(rpcEndpoint *lavasession.RPCEndpoint) error { defer wg.Done() @@ -206,7 +209,12 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt chainID := rpcEndpoint.ChainID // create policyUpdaters per chain newPolicyUpdater := updaters.NewPolicyUpdater(chainID, consumerStateTracker, consumerAddr.String(), chainParser, *rpcEndpoint) - if policyUpdater, ok := policyUpdaters.LoadOrStore(chainID, newPolicyUpdater); ok { + policyUpdater, ok, err := policyUpdaters.LoadOrStore(chainID, newPolicyUpdater) + if err != nil { + errCh <- err + return utils.LavaFormatError("failed loading or storing policy updater", err, utils.LogAttr("endpoint", rpcEndpoint)) + } + if ok { err := policyUpdater.AddPolicySetter(chainParser, *rpcEndpoint) if err != nil { errCh <- err @@ -229,46 +237,33 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt // this is locked so we don't race optimizers creation chainMutexes[chainID].Lock() defer chainMutexes[chainID].Unlock() - value, exists := optimizers.Load(chainID) - if !exists { - // doesn't exist for this chain create a new one - baseLatency := common.AverageWorldLatency / 2 // we want performance to be half our timeout or better - optimizer = provideroptimizer.NewProviderOptimizer(options.strategy, averageBlockTime, baseLatency, options.maxConcurrentProviders) - optimizers.Store(chainID, optimizer) - } else { - var ok bool - optimizer, ok = value.(*provideroptimizer.ProviderOptimizer) - if !ok { - err = utils.LavaFormatError("failed loading optimizer, value is of the wrong type", nil, utils.Attribute{Key: "endpoint", Value: rpcEndpoint.Key()}) - return err - } + var loaded bool + var err error + + baseLatency := common.AverageWorldLatency / 2 // we want performance to be half our timeout or better + + // Create / Use existing optimizer + newOptimizer := provideroptimizer.NewProviderOptimizer(options.strategy, averageBlockTime, baseLatency, options.maxConcurrentProviders) + optimizer, _, err = optimizers.LoadOrStore(chainID, newOptimizer) + if err != nil { + return utils.LavaFormatError("failed loading optimizer", err, utils.LogAttr("endpoint", rpcEndpoint.Key())) } - value, exists = consumerConsistencies.Load(chainID) - if !exists { // doesn't exist for this chain create a new one - consumerConsistency = NewConsumerConsistency(chainID) - consumerConsistencies.Store(chainID, consumerConsistency) - } else { - var ok bool - consumerConsistency, ok = value.(*ConsumerConsistency) - if !ok { - err = utils.LavaFormatError("failed loading consumer consistency, value is of the wrong type", err, utils.Attribute{Key: "endpoint", Value: rpcEndpoint.Key()}) - return err - } + + // Create / Use existing ConsumerConsistency + newConsumerConsistency := NewConsumerConsistency(chainID) + consumerConsistency, _, err = consumerConsistencies.LoadOrStore(chainID, newConsumerConsistency) + if err != nil { + return utils.LavaFormatError("failed loading consumer consistency", err, utils.LogAttr("endpoint", rpcEndpoint.Key())) } - value, exists = finalizationConsensuses.Load(chainID) - if !exists { - // doesn't exist for this chain create a new one - finalizationConsensus = finalizationconsensus.NewFinalizationConsensus(rpcEndpoint.ChainID) + // Create / Use existing FinalizationConsensus + newFinalizationConsensus := finalizationconsensus.NewFinalizationConsensus(rpcEndpoint.ChainID) + finalizationConsensus, loaded, err = finalizationConsensuses.LoadOrStore(chainID, newFinalizationConsensus) + if err != nil { + return utils.LavaFormatError("failed loading finalization consensus", err, utils.LogAttr("endpoint", rpcEndpoint.Key())) + } + if !loaded { // when creating new finalization consensus instance we need to register it to updates consumerStateTracker.RegisterFinalizationConsensusForUpdates(ctx, finalizationConsensus) - finalizationConsensuses.Store(chainID, finalizationConsensus) - } else { - var ok bool - finalizationConsensus, ok = value.(*finalizationconsensus.FinalizationConsensus) - if !ok { - err = utils.LavaFormatError("failed loading finalization consensus, value is of the wrong type", nil, utils.Attribute{Key: "endpoint", Value: rpcEndpoint.Key()}) - return err - } } return nil } @@ -278,7 +273,7 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt return err } - if finalizationConsensus == nil || optimizer == nil { + if finalizationConsensus == nil || optimizer == nil || consumerConsistency == nil { err = utils.LavaFormatError("failed getting assets, found a nil", nil, utils.Attribute{Key: "endpoint", Value: rpcEndpoint.Key()}) errCh <- err return err @@ -327,9 +322,9 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt utils.LavaFormatDebug("Starting Policy Updaters for all chains") for chainId := range chainMutexes { - policyUpdater, ok := policyUpdaters.Load(chainId) - if !ok { - utils.LavaFormatError("could not load policy Updater for chain", nil, utils.LogAttr("chain", chainId)) + policyUpdater, ok, err := policyUpdaters.Load(chainId) + if !ok || err != nil { + utils.LavaFormatError("could not load policy Updater for chain", err, utils.LogAttr("chain", chainId)) continue } consumerStateTracker.RegisterForPairingUpdates(ctx, policyUpdater, chainId) @@ -619,6 +614,10 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 cmdRPCConsumer.Flags().DurationVar(&updaters.TimeOutForFetchingLavaBlocks, common.TimeOutForFetchingLavaBlocksFlag, time.Second*5, "setting the timeout for fetching lava blocks") cmdRPCConsumer.Flags().String(common.UseStaticSpecFlag, "", "load offline spec provided path to spec file, used to test specs before they are proposed on chain") cmdRPCConsumer.Flags().IntVar(&relayCountOnNodeError, common.SetRelayCountOnNodeErrorFlag, 2, "set the number of retries attempt on node errors") + // optimizer metrics + cmdRPCConsumer.Flags().Float64Var(&provideroptimizer.ATierChance, common.SetProviderOptimizerBestTierPickChance, 0.75, "set the chances for picking a provider from the best group, default is 75% -> 0.75") + cmdRPCConsumer.Flags().Float64Var(&provideroptimizer.LastTierChance, common.SetProviderOptimizerWorstTierPickChance, 0.0, "set the chances for picking a provider from the worse group, default is 0% -> 0.0") + cmdRPCConsumer.Flags().IntVar(&provideroptimizer.OptimizerNumTiers, common.SetProviderOptimizerNumberOfTiersToCreate, 4, "set the number of groups to create, default is 4") common.AddRollingLogConfig(cmdRPCConsumer) return cmdRPCConsumer } diff --git a/protocol/rpcconsumer/testing.go b/protocol/rpcconsumer/testing.go index fffc60d21..6c3683186 100644 --- a/protocol/rpcconsumer/testing.go +++ b/protocol/rpcconsumer/testing.go @@ -79,6 +79,7 @@ func startTesting(ctx context.Context, clientCtx client.Context, rpcEndpoints [] if err != nil { return utils.LavaFormatError("panic severity critical error, aborting support for chain api due to node access, continuing with other endpoints", err, utils.Attribute{Key: "chainTrackerConfig", Value: chainTrackerConfig}, utils.Attribute{Key: "endpoint", Value: rpcProviderEndpoint}) } + chainTracker.StartAndServe(ctx) _ = chainTracker // let the chain tracker work and make queries return nil }(rpcProviderEndpoint) diff --git a/protocol/rpcprovider/chain_tackers.go b/protocol/rpcprovider/chain_tackers.go deleted file mode 100644 index 95a43a5ea..000000000 --- a/protocol/rpcprovider/chain_tackers.go +++ /dev/null @@ -1,38 +0,0 @@ -package rpcprovider - -import ( - "sync" - - "github.com/lavanet/lava/v3/protocol/chaintracker" - "github.com/lavanet/lava/v3/utils" -) - -type ChainTrackers struct { - stateTrackersPerChain sync.Map -} - -func (ct *ChainTrackers) GetTrackerPerChain(specID string) (chainTracker *chaintracker.ChainTracker, found bool) { - chainTrackerInf, found := ct.stateTrackersPerChain.Load(specID) - if !found { - return nil, found - } - var ok bool - chainTracker, ok = chainTrackerInf.(*chaintracker.ChainTracker) - if !ok { - utils.LavaFormatFatal("invalid usage of syncmap, could not cast result into a chaintracker", nil) - } - return chainTracker, true -} - -func (ct *ChainTrackers) SetTrackerForChain(specId string, chainTracker *chaintracker.ChainTracker) { - ct.stateTrackersPerChain.Store(specId, chainTracker) -} - -func (ct *ChainTrackers) GetLatestBlockNumForSpec(specID string) int64 { - chainTracker, found := ct.GetTrackerPerChain(specID) - if !found { - return 0 - } - latestBlock, _ := chainTracker.GetLatestBlockNum() - return latestBlock -} diff --git a/protocol/rpcprovider/provider_listener.go b/protocol/rpcprovider/provider_listener.go index 30eb0cda6..a0f083052 100644 --- a/protocol/rpcprovider/provider_listener.go +++ b/protocol/rpcprovider/provider_listener.go @@ -67,7 +67,6 @@ func NewProviderListener(ctx context.Context, networkAddress lavasession.Network grpc.MaxRecvMsgSize(1024 * 1024 * 32), // setting receive size to 32mb instead of 4mb default } grpcServer := grpc.NewServer(opts...) - wrappedServer := grpcweb.WrapServer(grpcServer) handler := func(resp http.ResponseWriter, req *http.Request) { // Set CORS headers diff --git a/protocol/rpcprovider/rpcprovider.go b/protocol/rpcprovider/rpcprovider.go index 554d7cfba..9b860b433 100644 --- a/protocol/rpcprovider/rpcprovider.go +++ b/protocol/rpcprovider/rpcprovider.go @@ -133,7 +133,7 @@ type RPCProvider struct { parallelConnections uint cache *performance.Cache shardID uint // shardID is a flag that allows setting up multiple provider databases of the same chain - chainTrackers *ChainTrackers + chainTrackers *common.SafeSyncMap[string, *chaintracker.ChainTracker] relaysMonitorAggregator *metrics.RelaysMonitorAggregator relaysHealthCheckEnabled bool relaysHealthCheckInterval time.Duration @@ -152,7 +152,7 @@ func (rpcp *RPCProvider) Start(options *rpcProviderStartOptions) (err error) { cancel() }() rpcp.providerUniqueId = strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10) - rpcp.chainTrackers = &ChainTrackers{} + rpcp.chainTrackers = &common.SafeSyncMap[string, *chaintracker.ChainTracker]{} rpcp.parallelConnections = options.parallelConnections rpcp.cache = options.cache rpcp.providerMetricsManager = metrics.NewProviderMetricsManager(options.metricsListenAddress) // start up prometheus metrics @@ -185,7 +185,7 @@ func (rpcp *RPCProvider) Start(options *rpcProviderStartOptions) (err error) { // single reward server if !options.staticProvider { rewardDB := rewardserver.NewRewardDBWithTTL(options.rewardTTL) - rpcp.rewardServer = rewardserver.NewRewardServer(providerStateTracker, rpcp.providerMetricsManager, rewardDB, options.rewardStoragePath, options.rewardsSnapshotThreshold, options.rewardsSnapshotTimeoutSec, rpcp.chainTrackers) + rpcp.rewardServer = rewardserver.NewRewardServer(providerStateTracker, rpcp.providerMetricsManager, rewardDB, options.rewardStoragePath, options.rewardsSnapshotThreshold, options.rewardsSnapshotTimeoutSec, rpcp) rpcp.providerStateTracker.RegisterForEpochUpdates(ctx, rpcp.rewardServer) rpcp.providerStateTracker.RegisterPaymentUpdatableForPayments(ctx, rpcp.rewardServer) } @@ -409,42 +409,45 @@ func (rpcp *RPCProvider) SetupEndpoint(ctx context.Context, rpcProviderEndpoint chainCommonSetup := func() error { rpcp.chainMutexes[chainID].Lock() defer rpcp.chainMutexes[chainID].Unlock() - var found bool - chainTracker, found = rpcp.chainTrackers.GetTrackerPerChain(chainID) - if !found { - consistencyErrorCallback := func(oldBlock, newBlock int64) { - utils.LavaFormatError("Consistency issue detected", nil, - utils.Attribute{Key: "oldBlock", Value: oldBlock}, - utils.Attribute{Key: "newBlock", Value: newBlock}, - utils.Attribute{Key: "Chain", Value: rpcProviderEndpoint.ChainID}, - utils.Attribute{Key: "apiInterface", Value: apiInterface}, - ) - } - blocksToSaveChainTracker := uint64(blocksToFinalization + blocksInFinalizationData) - chainTrackerConfig := chaintracker.ChainTrackerConfig{ - BlocksToSave: blocksToSaveChainTracker, - AverageBlockTime: averageBlockTime, - ServerBlockMemory: ChainTrackerDefaultMemory + blocksToSaveChainTracker, - NewLatestCallback: recordMetricsOnNewBlock, - ConsistencyCallback: consistencyErrorCallback, - Pmetrics: rpcp.providerMetricsManager, - } + var loaded bool + consistencyErrorCallback := func(oldBlock, newBlock int64) { + utils.LavaFormatError("Consistency issue detected", nil, + utils.Attribute{Key: "oldBlock", Value: oldBlock}, + utils.Attribute{Key: "newBlock", Value: newBlock}, + utils.Attribute{Key: "Chain", Value: rpcProviderEndpoint.ChainID}, + utils.Attribute{Key: "apiInterface", Value: apiInterface}, + ) + } + blocksToSaveChainTracker := uint64(blocksToFinalization + blocksInFinalizationData) + chainTrackerConfig := chaintracker.ChainTrackerConfig{ + BlocksToSave: blocksToSaveChainTracker, + AverageBlockTime: averageBlockTime, + ServerBlockMemory: ChainTrackerDefaultMemory + blocksToSaveChainTracker, + NewLatestCallback: recordMetricsOnNewBlock, + ConsistencyCallback: consistencyErrorCallback, + Pmetrics: rpcp.providerMetricsManager, + } - chainTracker, err = chaintracker.NewChainTracker(ctx, chainFetcher, chainTrackerConfig) - if err != nil { - return utils.LavaFormatError("panic severity critical error, aborting support for chain api due to node access, continuing with other endpoints", err, utils.Attribute{Key: "chainTrackerConfig", Value: chainTrackerConfig}, utils.Attribute{Key: "endpoint", Value: rpcProviderEndpoint}) - } + chainTracker, err = chaintracker.NewChainTracker(ctx, chainFetcher, chainTrackerConfig) + if err != nil { + return utils.LavaFormatError("panic severity critical error, aborting support for chain api due to node access, continuing with other endpoints", err, utils.Attribute{Key: "chainTrackerConfig", Value: chainTrackerConfig}, utils.Attribute{Key: "endpoint", Value: rpcProviderEndpoint}) + } + + chainTrackerLoaded, loaded, err := rpcp.chainTrackers.LoadOrStore(chainID, chainTracker) + if err != nil { + utils.LavaFormatFatal("failed to load or store chain tracker", err, utils.LogAttr("chainID", chainID)) + } + if !loaded { // this is the first time we are setting up the chain tracker, we need to register for spec verifications + chainTracker.StartAndServe(ctx) utils.LavaFormatDebug("Registering for spec verifications for endpoint", utils.LogAttr("rpcEndpoint", rpcEndpoint)) // we register for spec verifications only once, and this triggers all chainFetchers of that specId when it triggers err = rpcp.providerStateTracker.RegisterForSpecVerifications(ctx, specValidator, rpcEndpoint.ChainID) if err != nil { return utils.LavaFormatError("failed to RegisterForSpecUpdates, panic severity critical error, aborting support for chain api due to invalid chain parser, continuing with others", err, utils.Attribute{Key: "endpoint", Value: rpcProviderEndpoint.String()}) } - - // Any validation needs to be before we store chain tracker for given chain id - rpcp.chainTrackers.SetTrackerForChain(rpcProviderEndpoint.ChainID, chainTracker) - } else { + } else { // loaded an existing chain tracker. use the same one instead + chainTracker = chainTrackerLoaded utils.LavaFormatDebug("reusing chain tracker", utils.Attribute{Key: "chain", Value: rpcProviderEndpoint.ChainID}) } return nil @@ -516,6 +519,19 @@ func (rpcp *RPCProvider) SetupEndpoint(ctx context.Context, rpcProviderEndpoint return nil } +func (rpcp *RPCProvider) GetLatestBlockNumForSpec(specID string) int64 { + chainTracker, found, err := rpcp.chainTrackers.Load(specID) + if err != nil { + utils.LavaFormatFatal("failed to load chain tracker", err, utils.LogAttr("specID", specID)) + } + if !found { + return 0 + } + + block, _ := chainTracker.GetLatestBlockNum() + return block +} + func ParseEndpointsCustomName(viper_endpoints *viper.Viper, endpointsConfigName string, geolocation uint64) (endpoints []*lavasession.RPCProviderEndpoint, err error) { err = viper_endpoints.UnmarshalKey(endpointsConfigName, &endpoints) if err != nil { diff --git a/protocol/statetracker/events.go b/protocol/statetracker/events.go index ff4b8dccc..fa1383e9f 100644 --- a/protocol/statetracker/events.go +++ b/protocol/statetracker/events.go @@ -24,6 +24,7 @@ import ( "github.com/lavanet/lava/v3/app" "github.com/lavanet/lava/v3/protocol/chainlib" "github.com/lavanet/lava/v3/protocol/chaintracker" + "github.com/lavanet/lava/v3/protocol/common" "github.com/lavanet/lava/v3/protocol/rpcprovider/rewardserver" updaters "github.com/lavanet/lava/v3/protocol/statetracker/updaters" "github.com/lavanet/lava/v3/utils" @@ -122,6 +123,7 @@ func eventsLookup(ctx context.Context, clientCtx client.Context, blocks, fromBlo if err != nil { return utils.LavaFormatError("failed setting up chain tracker", err) } + chainTracker.StartAndServe(ctx) _ = chainTracker select { case <-ctx.Done(): @@ -666,7 +668,7 @@ func countTransactionsPerDay(ctx context.Context, clientCtx client.Context, bloc // j are blocks in that day // starting from current day and going backwards var wg sync.WaitGroup - totalTxPerDay := sync.Map{} + totalTxPerDay := &common.SafeSyncMap[int64, int]{} // Process each day from the earliest to the latest for i := int64(1); i <= numberOfDays; i++ { @@ -703,14 +705,13 @@ func countTransactionsPerDay(ctx context.Context, clientCtx client.Context, bloc transactionResults := blockResults.TxsResults utils.LavaFormatInfo("Number of tx for block", utils.LogAttr("_routine", end-k), utils.LogAttr("block_number", k), utils.LogAttr("number_of_tx", len(transactionResults))) // Update totalTxPerDay safely - actual, _ := totalTxPerDay.LoadOrStore(i, len(transactionResults)) - if actual != nil { - val, ok := actual.(int) - if !ok { - utils.LavaFormatError("Failed converting int", nil) - return - } - totalTxPerDay.Store(i, val+len(transactionResults)) + actual, loaded, err := totalTxPerDay.LoadOrStore(i, len(transactionResults)) + if err != nil { + utils.LavaFormatError("failed to load or store", err) + return + } + if loaded { + totalTxPerDay.Store(i, actual+len(transactionResults)) } }(k) } diff --git a/protocol/statetracker/state_tracker.go b/protocol/statetracker/state_tracker.go index 96c833781..5ff131270 100644 --- a/protocol/statetracker/state_tracker.go +++ b/protocol/statetracker/state_tracker.go @@ -130,6 +130,7 @@ func NewStateTracker(ctx context.Context, txFactory tx.Factory, clientCtx client } cst.AverageBlockTime = chainTrackerConfig.AverageBlockTime cst.chainTracker, err = chaintracker.NewChainTracker(ctx, chainFetcher, chainTrackerConfig) + cst.chainTracker.StartAndServe(ctx) cst.chainTracker.RegisterForBlockTimeUpdates(cst) // registering for block time updates. return cst, err } diff --git a/scripts/test/init_payment_e2e.sh b/scripts/test/init_payment_e2e.sh index b41535c1f..87b118a72 100755 --- a/scripts/test/init_payment_e2e.sh +++ b/scripts/test/init_payment_e2e.sh @@ -32,6 +32,7 @@ lavad tx pairing stake-provider "LAV1" $STAKE "127.0.0.1:2262,1" 1 $(operator_ad # subscribed clients lavad tx subscription buy "DefaultPlan" $(lavad keys show user1 -a) 10 -y --from user1 --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE +wait_next_block sleep_until_next_epoch # the end diff --git a/testutil/e2e/allowedErrorList.go b/testutil/e2e/allowedErrorList.go index bb95b79e5..5b083bae7 100644 --- a/testutil/e2e/allowedErrorList.go +++ b/testutil/e2e/allowedErrorList.go @@ -13,6 +13,7 @@ var allowedErrors = map[string]string{ var allowedErrorsDuringEmergencyMode = map[string]string{ "connection refused": "Connection to tendermint port sometimes can happen as we shut down the node and we try to fetch info during emergency mode", + "Connection refused": "Connection to tendermint port sometimes can happen as we shut down the node and we try to fetch info during emergency mode", "connection reset by peer": "Connection to tendermint port sometimes can happen as we shut down the node and we try to fetch info during emergency mode", "Failed Querying EpochDetails": "Connection to tendermint port sometimes can happen as we shut down the node and we try to fetch info during emergency mode", } diff --git a/x/protocol/types/params.go b/x/protocol/types/params.go index 9b9fd3dcb..20d3bfa88 100644 --- a/x/protocol/types/params.go +++ b/x/protocol/types/params.go @@ -12,7 +12,7 @@ import ( var _ paramtypes.ParamSet = (*Params)(nil) const ( - TARGET_VERSION = "3.1.5" + TARGET_VERSION = "3.1.7" MIN_VERSION = "2.2.2" )