Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Dialers inside Zk and Region #249

Merged
merged 5 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion admin_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func newAdminClient(zkquorum string, options ...Option) AdminClient {
for _, option := range options {
option(c)
}
c.zkClient = zk.NewClient(zkquorum, c.zkTimeout)
c.zkClient = zk.NewClient(zkquorum, c.zkTimeout, c.zkDialer)
return c
}

Expand Down
28 changes: 26 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"net"
"sync"
"time"

Expand Down Expand Up @@ -97,9 +98,15 @@ type client struct {
closeOnce sync.Once

newRegionClientFn func(string, region.ClientType, int, time.Duration,
string, time.Duration, compression.Codec) hrpc.RegionClient
string, time.Duration, compression.Codec, func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient

compressionCodec compression.Codec

// zkDialer is used in the zkClient to connect to the quorum
zkDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// regionDialer is passed into the region client to connect to hbase in a custom way,
// such as SOCKS proxy.
regionDialer func(ctx context.Context, network, addr string) (net.Conn, error)
}

// NewClient creates a new HBase client.
Expand Down Expand Up @@ -140,7 +147,7 @@ func newClient(zkquorum string, options ...Option) *client {

//Have to create the zkClient after the Options have been set
//since the zkTimeout could be changed as an option
c.zkClient = zk.NewClient(zkquorum, c.zkTimeout)
c.zkClient = zk.NewClient(zkquorum, c.zkTimeout, c.zkDialer)

return c
}
Expand Down Expand Up @@ -268,6 +275,23 @@ func CompressionCodec(codec string) Option {
}
}

// ZooKeeperDialer will return an option to pass the given dialer function
// into the ZooKeeper client Connect() call, which allows for customizing
// network connections.
func ZooKeeperDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(c *client) {
c.zkDialer = dialer
}
}

// RegionDialer will return an option that uses the specified Dialer for
// connecting to region servers. This allows for connecting through proxies.
func RegionDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(c *client) {
c.regionDialer = dialer
}
}

// Close closes connections to hbase master and regionservers
func (c *client) Close() {
c.closeOnce.Do(func() {
Expand Down
2 changes: 2 additions & 0 deletions debug_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"

"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/region"
)
Expand All @@ -25,6 +26,7 @@ func TestDebugStateSanity(t *testing.T) {
defaultEffectiveUser,
region.DefaultReadTimeout,
client.compressionCodec,
nil,
)
newClientFn := func() hrpc.RegionClient {
return regClient
Expand Down
3 changes: 2 additions & 1 deletion mockrc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bytes"
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -177,7 +178,7 @@ func init() {

func newMockRegionClient(addr string, ctype region.ClientType, queueSize int,
flushInterval time.Duration, effectiveUser string,
readTimeout time.Duration, codec compression.Codec) hrpc.RegionClient {
readTimeout time.Duration, codec compression.Codec, dialer func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient {
m.Lock()
clients[addr]++
m.Unlock()
Expand Down
3 changes: 3 additions & 0 deletions region/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ type client struct {

// compressor for cellblocks. if nil, then no compression
compressor *compressor

// dialer is used to connect to region servers in non-standard ways
dialer func(ctx context.Context, network, addr string) (net.Conn, error)
}

// QueueRPC will add an rpc call to the queue for processing by the writer goroutine
Expand Down
14 changes: 11 additions & 3 deletions region/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import (

// NewClient creates a new RegionClient.
func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time.Duration,
effectiveUser string, readTimeout time.Duration, codec compression.Codec) hrpc.RegionClient {
effectiveUser string, readTimeout time.Duration, codec compression.Codec,
dialer func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient {
c := &client{
addr: addr,
ctype: ctype,
Expand All @@ -36,14 +37,21 @@ func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time.
if codec != nil {
c.compressor = &compressor{Codec: codec}
}

if dialer != nil {
c.dialer = dialer
} else {
var d net.Dialer
c.dialer = d.DialContext
}

return c
}

func (c *client) Dial(ctx context.Context) error {
c.dialOnce.Do(func() {
var d net.Dialer
var err error
c.conn, err = d.DialContext(ctx, "tcp", c.addr)
c.conn, err = c.dialer(ctx, "tcp", c.addr)
if err != nil {
c.fail(fmt.Errorf("failed to dial RegionServer: %s", err))
return
Expand Down
4 changes: 2 additions & 2 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,11 +828,11 @@ func (c *client) establishRegion(reg hrpc.RegionInfo, addr string) {
// master that we don't add to the cache
// TODO: consider combining this case with the regular regionserver path
client = c.newRegionClientFn(addr, c.clientType, c.rpcQueueSize, c.flushInterval,
c.effectiveUser, c.regionReadTimeout, nil)
c.effectiveUser, c.regionReadTimeout, nil, c.regionDialer)
} else {
client = c.clients.put(addr, reg, func() hrpc.RegionClient {
return c.newRegionClientFn(addr, c.clientType, c.rpcQueueSize, c.flushInterval,
c.effectiveUser, c.regionReadTimeout, c.compressionCodec)
c.effectiveUser, c.regionReadTimeout, c.compressionCodec, c.regionDialer)
})
}

Expand Down
5 changes: 3 additions & 2 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"math/rand"
"net"
"reflect"
"strconv"
"strings"
Expand All @@ -37,7 +38,7 @@ import (
func newRegionClientFn(addr string) func() hrpc.RegionClient {
return func() hrpc.RegionClient {
return newMockRegionClient(addr, region.RegionClient,
0, 0, "root", region.DefaultReadTimeout, nil)
0, 0, "root", region.DefaultReadTimeout, nil, nil)
}
}

Expand Down Expand Up @@ -301,7 +302,7 @@ func TestEstablishRegionDialFail(t *testing.T) {

newRegionClientFnCallCount := 0
c.newRegionClientFn = func(_ string, _ region.ClientType, _ int, _ time.Duration,
_ string, _ time.Duration, _ compression.Codec) hrpc.RegionClient {
_ string, _ time.Duration, _ compression.Codec, _ func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient {
var rc hrpc.RegionClient
if newRegionClientFnCallCount == 0 {
rc = rcFailDial
Expand Down
21 changes: 19 additions & 2 deletions zk/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package zk

import (
"context"
"encoding/binary"
"fmt"
"net"
Expand Down Expand Up @@ -58,19 +59,27 @@ type Client interface {
type client struct {
zks []string
sessionTimeout time.Duration
dialer func(ctx context.Context, network, addr string) (net.Conn, error)
}

// NewClient establishes connection to zookeeper and returns the client
func NewClient(zkquorum string, st time.Duration) Client {
func NewClient(zkquorum string, st time.Duration, dialer func(ctx context.Context, network, addr string) (net.Conn, error)) Client {
return &client{
zks: strings.Split(zkquorum, ","),
sessionTimeout: st,
dialer: dialer,
}
}

// LocateResource returns address of the server for the specified resource.
func (c *client) LocateResource(resource ResourceName) (string, error) {
conn, _, err := zk.Connect(c.zks, c.sessionTimeout)
var conn *zk.Conn
var err error
if c.dialer != nil {
conn, _, err = zk.Connect(c.zks, c.sessionTimeout, zk.WithDialer(makeZKDialer(c.dialer)))
} else {
conn, _, err = zk.Connect(c.zks, c.sessionTimeout)
}
if err != nil {
return "", fmt.Errorf("error connecting to ZooKeeper at %v: %s", c.zks, err)
}
Expand Down Expand Up @@ -116,3 +125,11 @@ func (c *client) LocateResource(resource ResourceName) (string, error) {
}
return net.JoinHostPort(*server.HostName, fmt.Sprint(*server.Port)), nil
}

func makeZKDialer(ctxDialer func(ctx context.Context, network, addr string) (net.Conn, error)) zk.Dialer {
return func(network, addr string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return ctxDialer(ctx, network, addr)
}
}
Loading