From 37dc11b320f696160a161997719e12a49bf8bfc4 Mon Sep 17 00:00:00 2001 From: AntiD2ta Date: Thu, 12 Sep 2024 16:33:31 +0000 Subject: [PATCH] refac: Update RPC client creation Updates: - Strengthen client validation - Allow providing a list of RPC endpoints --- .../{contract_setup.go => client.go} | 36 +++++-- internal/lido/contracts/client_test.go | 102 ++++++++++++++++++ 2 files changed, 132 insertions(+), 6 deletions(-) rename internal/lido/contracts/{contract_setup.go => client.go} (51%) create mode 100644 internal/lido/contracts/client_test.go diff --git a/internal/lido/contracts/contract_setup.go b/internal/lido/contracts/client.go similarity index 51% rename from internal/lido/contracts/contract_setup.go rename to internal/lido/contracts/client.go index 051956e4..de0f61cf 100644 --- a/internal/lido/contracts/contract_setup.go +++ b/internal/lido/contracts/client.go @@ -16,7 +16,9 @@ limitations under the License. package contracts import ( + "context" "fmt" + "math/big" "github.com/NethermindEth/sedge/configs" "github.com/ethereum/go-ethereum/ethclient" @@ -36,17 +38,39 @@ func connectToRPCETH(RPCs []string) (*ethclient.Client, error) { return nil, fmt.Errorf("failed to connect to any RPC URL") } -func ConnectClient(network string) (*ethclient.Client, error) { - rpcs, err := configs.GetPublicRPCs(network) - if err != nil { - return nil, fmt.Errorf("failed to get public RPC: %w", err) +func ConnectClient(network string, RPCs ...string) (*ethclient.Client, error) { + var rpcs []string + var err error + + if len(RPCs) == 0 { + rpcs, err = configs.GetPublicRPCs(network) + if err != nil { + return nil, fmt.Errorf("failed to get public RPC: %w", err) + } + } else { + rpcs = RPCs } - // Connect to the RPC endpoint client, err := connectToRPCETH(rpcs) if err != nil { return nil, fmt.Errorf("failed to connect to RPC: %w", err) } - return client, nil + // Verify that the client is indeed an Ethereum RPC client + if client != nil { + // Try to get the chain ID, which is a basic operation that should work for any Ethereum client + chainID, err := client.ChainID(context.Background()) + if err == nil { + expectedChainID := configs.NetworksConfigs()[network].ChainID + if chainID.Cmp(new(big.Int).SetUint64(expectedChainID)) == 0 { + // If we successfully got the chain ID and it matches the expected one, + // we can be reasonably sure this is the correct Ethereum client + return client, nil + } + } + // If there was an error or chain ID mismatch, close the client and continue to the next URL + client.Close() + } + + return nil, fmt.Errorf("failed to connect to RPC: %w", err) } diff --git a/internal/lido/contracts/client_test.go b/internal/lido/contracts/client_test.go new file mode 100644 index 00000000..ced442c2 --- /dev/null +++ b/internal/lido/contracts/client_test.go @@ -0,0 +1,102 @@ +/* +Copyright 2022 Nethermind + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package contracts + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConnectClient(t *testing.T) { + tcs := []struct { + name string + network string + wantErr bool + }{ + { + name: "ConnectClient, Holesky", + network: "holesky", + wantErr: false, + }, + { + name: "ConnectClient, invalid Network", + network: "invalid", + wantErr: true, + }, + { + name: "ConnectClient, Mainnet", + network: "mainnet", + wantErr: false, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + client, err := ConnectClient(tc.network) + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, client) + } + }) + } +} + +func TestConnectClientWithRPCs(t *testing.T) { + tcs := []struct { + name string + network string + RPCs []string + wantErr bool + }{ + { + name: "ConnectClientWithRPCs, Holesky", + network: "holesky", + RPCs: []string{"https://endpoints.omniatech.io/v1/eth/holesky/public", "https://ethereum-holesky.blockpi.network/v1/rpc/public"}, + wantErr: false, + }, + { + name: "ConnectClientWithRPCs, Holesky, invalid RPC", + network: "holesky", + RPCs: []string{"https://www.google.com"}, + wantErr: true, + }, + { + name: "ConnectClientWithRPCs, invalid Network RPCs", + network: "holesky", + RPCs: []string{"https://eth.llamarpc.com"}, // Mainnet RPC + wantErr: true, + }, + { + name: "ConnectClient, Mainnet", + network: "mainnet", + RPCs: []string{"https://eth.llamarpc.com"}, + wantErr: false, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + client, err := ConnectClient(tc.network, tc.RPCs...) + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, client) + } + }) + } +}