Skip to content

Commit

Permalink
refac: Update RPC client creation
Browse files Browse the repository at this point in the history
Updates:
- Strengthen client validation
- Allow providing a list of RPC endpoints
  • Loading branch information
AntiD2ta committed Sep 12, 2024
1 parent d734acd commit 37dc11b
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ limitations under the License.
package contracts

import (
"context"
"fmt"
"math/big"

"github.com/NethermindEth/sedge/configs"
"github.com/ethereum/go-ethereum/ethclient"
Expand All @@ -36,17 +38,39 @@ func connectToRPCETH(RPCs []string) (*ethclient.Client, error) {
return nil, fmt.Errorf("failed to connect to any RPC URL")
}

func ConnectClient(network string) (*ethclient.Client, error) {
rpcs, err := configs.GetPublicRPCs(network)
if err != nil {
return nil, fmt.Errorf("failed to get public RPC: %w", err)
func ConnectClient(network string, RPCs ...string) (*ethclient.Client, error) {
var rpcs []string
var err error

if len(RPCs) == 0 {
rpcs, err = configs.GetPublicRPCs(network)
if err != nil {
return nil, fmt.Errorf("failed to get public RPC: %w", err)
}
} else {
rpcs = RPCs
}

// Connect to the RPC endpoint
client, err := connectToRPCETH(rpcs)
if err != nil {
return nil, fmt.Errorf("failed to connect to RPC: %w", err)
}

return client, nil
// Verify that the client is indeed an Ethereum RPC client
if client != nil {
// Try to get the chain ID, which is a basic operation that should work for any Ethereum client
chainID, err := client.ChainID(context.Background())
if err == nil {
expectedChainID := configs.NetworksConfigs()[network].ChainID
if chainID.Cmp(new(big.Int).SetUint64(expectedChainID)) == 0 {
// If we successfully got the chain ID and it matches the expected one,
// we can be reasonably sure this is the correct Ethereum client
return client, nil
}
}
// If there was an error or chain ID mismatch, close the client and continue to the next URL
client.Close()
}

return nil, fmt.Errorf("failed to connect to RPC: %w", err)
}
102 changes: 102 additions & 0 deletions internal/lido/contracts/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
Copyright 2022 Nethermind
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package contracts

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestConnectClient(t *testing.T) {
tcs := []struct {
name string
network string
wantErr bool
}{
{
name: "ConnectClient, Holesky",
network: "holesky",
wantErr: false,
},
{
name: "ConnectClient, invalid Network",
network: "invalid",
wantErr: true,
},
{
name: "ConnectClient, Mainnet",
network: "mainnet",
wantErr: false,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
client, err := ConnectClient(tc.network)
if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.NotNil(t, client)
}
})
}
}

func TestConnectClientWithRPCs(t *testing.T) {
tcs := []struct {
name string
network string
RPCs []string
wantErr bool
}{
{
name: "ConnectClientWithRPCs, Holesky",
network: "holesky",
RPCs: []string{"https://endpoints.omniatech.io/v1/eth/holesky/public", "https://ethereum-holesky.blockpi.network/v1/rpc/public"},
wantErr: false,
},
{
name: "ConnectClientWithRPCs, Holesky, invalid RPC",
network: "holesky",
RPCs: []string{"https://www.google.com"},
wantErr: true,
},
{
name: "ConnectClientWithRPCs, invalid Network RPCs",
network: "holesky",
RPCs: []string{"https://eth.llamarpc.com"}, // Mainnet RPC
wantErr: true,
},
{
name: "ConnectClient, Mainnet",
network: "mainnet",
RPCs: []string{"https://eth.llamarpc.com"},
wantErr: false,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
client, err := ConnectClient(tc.network, tc.RPCs...)
if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.NotNil(t, client)
}
})
}
}

0 comments on commit 37dc11b

Please sign in to comment.