From fc2ecbcfabd26a084a7ae7b5f358dcc313d9474c Mon Sep 17 00:00:00 2001 From: Daniil Cherednik Date: Fri, 1 Nov 2024 14:45:12 +0100 Subject: [PATCH] Allow to set preferred node id to execute query --- CHANGELOG.md | 2 + context.go | 5 + internal/conn/middleware.go | 12 +++ internal/operation/context.go | 18 ++++ internal/pool/pool.go | 35 ++++--- internal/pool/pool_test.go | 168 ++++++++++++++++++++++++++++------ internal/query/client.go | 7 +- internal/query/client_test.go | 4 +- internal/table/client.go | 5 +- internal/table/retry_test.go | 8 +- 10 files changed, 214 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 622c7c3b5..690dcff23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Allow to set preferred node id to execute query + ## v3.90.1 * Small broken change: added method `ID()` into `spans.Span` interface (need to implement in adapter) * Fixed traceparent header for tracing grpc requests diff --git a/context.go b/context.go index 02df85dce..8a0c0fa94 100644 --- a/context.go +++ b/context.go @@ -19,3 +19,8 @@ func WithOperationTimeout(ctx context.Context, operationTimeout time.Duration) c func WithOperationCancelAfter(ctx context.Context, operationCancelAfter time.Duration) context.Context { return operation.WithCancelAfter(ctx, operationCancelAfter) } + +// WithPreferredNodeID allows to set preferred node to get session from +func WithPreferredNodeID(ctx context.Context, nodeID uint32) context.Context { + return operation.WithPreferredNodeID(ctx, nodeID) +} diff --git a/internal/conn/middleware.go b/internal/conn/middleware.go index 07ab761e4..889c32470 100644 --- a/internal/conn/middleware.go +++ b/internal/conn/middleware.go @@ -4,6 +4,8 @@ import ( "context" "google.golang.org/grpc" + + balancerContext "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" ) var _ grpc.ClientConnInterface = (*middleware)(nil) @@ -30,6 +32,16 @@ func (m *middleware) NewStream( return m.newStream(ctx, desc, method, opts...) } +func ModifyConn(cc grpc.ClientConnInterface, nodeID uint32) grpc.ClientConnInterface { + if nodeID != 0 { + return WithContextModifier(cc, func(ctx context.Context) context.Context { + return balancerContext.WithNodeID(ctx, nodeID) + }) + } + + return cc +} + func WithContextModifier( cc grpc.ClientConnInterface, modifyCtx func(ctx context.Context) context.Context, diff --git a/internal/operation/context.go b/internal/operation/context.go index 2f76340a5..5d0720a72 100644 --- a/internal/operation/context.go +++ b/internal/operation/context.go @@ -8,6 +8,7 @@ import ( type ( ctxOperationTimeoutKey struct{} ctxOperationCancelAfterKey struct{} + ctxWithPreferredNodeIDKey struct{} ) // WithTimeout returns a copy of parent context in which YDB operation timeout @@ -33,6 +34,10 @@ func WithCancelAfter(ctx context.Context, operationCancelAfter time.Duration) co return context.WithValue(ctx, ctxOperationCancelAfterKey{}, operationCancelAfter) } +func WithPreferredNodeID(ctx context.Context, nodeID uint32) context.Context { + return context.WithValue(ctx, ctxWithPreferredNodeIDKey{}, nodeID) +} + // ctxTimeout returns the timeout within given context after which // YDB should try to cancel operation and return result regardless of the cancelation. func ctxTimeout(ctx context.Context) (d time.Duration, ok bool) { @@ -57,3 +62,16 @@ func ctxUntilDeadline(ctx context.Context) (time.Duration, bool) { return 0, false } + +func CtxPreferredNodeID(ctx context.Context) uint32 { + x := ctx.Value(ctxWithPreferredNodeIDKey{}) + if x == nil { + return 0 + } + val, ok := x.(uint32) + if !ok { + return 0 + } + + return val +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 41dc4cfad..1cc7d8768 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -8,6 +8,7 @@ import ( "github.com/jonboulle/clockwork" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/operation" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" @@ -20,6 +21,7 @@ type ( Item interface { IsAlive() bool Close(ctx context.Context) error + NodeID() uint32 } ItemConstraint[T any] interface { *T @@ -30,7 +32,7 @@ type ( clock clockwork.Clock limit int createTimeout time.Duration - createItem func(ctx context.Context) (PT, error) + createItem func(ctx context.Context, preferredNodeID uint32) (PT, error) closeTimeout time.Duration closeItem func(ctx context.Context, item PT) idleTimeToLive time.Duration @@ -48,7 +50,7 @@ type ( Pool[PT ItemConstraint[T], T any] struct { config Config[PT, T] - createItem func(ctx context.Context) (PT, error) + createItem func(ctx context.Context, preferredNodeID uint32) (PT, error) closeItem func(ctx context.Context, item PT) mu xsync.RWMutex @@ -63,7 +65,7 @@ type ( Option[PT ItemConstraint[T], T any] func(c *Config[PT, T]) ) -func WithCreateItemFunc[PT ItemConstraint[T], T any](f func(ctx context.Context) (PT, error)) Option[PT, T] { +func WithCreateItemFunc[PT ItemConstraint[T], T any](f func(context.Context, uint32) (PT, error)) Option[PT, T] { return func(c *Config[PT, T]) { c.createItem = f } @@ -173,7 +175,7 @@ func New[PT ItemConstraint[T], T any]( } // defaultCreateItem returns a new item -func defaultCreateItem[T any, PT ItemConstraint[T]](context.Context) (PT, error) { +func defaultCreateItem[T any, PT ItemConstraint[T]](context.Context, uint32) (PT, error) { var item T return &item, nil @@ -182,8 +184,8 @@ func defaultCreateItem[T any, PT ItemConstraint[T]](context.Context) (PT, error) // makeAsyncCreateItemFunc wraps the createItem function with timeout handling func makeAsyncCreateItemFunc[PT ItemConstraint[T], T any]( //nolint:funlen p *Pool[PT, T], -) func(ctx context.Context) (PT, error) { - return func(ctx context.Context) (PT, error) { +) func(ctx context.Context, preferredNodeID uint32) (PT, error) { + return func(ctx context.Context, preferredNodeID uint32) (PT, error) { if !xsync.WithLock(&p.mu, func() bool { if len(p.index)+p.createInProgress < p.config.limit { p.createInProgress++ @@ -222,7 +224,7 @@ func makeAsyncCreateItemFunc[PT ItemConstraint[T], T any]( //nolint:funlen defer cancelCreate() } - newItem, err := p.config.createItem(createCtx) + newItem, err := p.config.createItem(createCtx, preferredNodeID) if newItem != nil { p.mu.WithLock(func() { var useCounter uint64 @@ -314,7 +316,7 @@ func (p *Pool[PT, T]) changeState(changeState func() Stats) { } } -func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item PT) error) (finalErr error) { +func (p *Pool[PT, T]) try(ctx context.Context, f func(context.Context, PT) error) (finalErr error) { if onTry := p.config.trace.OnTry; onTry != nil { onDone := onTry(&ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/pool.(*Pool).try"), @@ -460,8 +462,13 @@ func (p *Pool[PT, T]) putWaitCh(ch *chan PT) { //nolint:gocritic } // p.mu must be held. -func (p *Pool[PT, T]) peekFirstIdle() (item PT, touched time.Time) { +func (p *Pool[PT, T]) peekFirstIdle(preferredNodeID uint32) (item PT, touched time.Time) { el := p.idle.Front() + if preferredNodeID != 0 { + for el != nil && el.Value.NodeID() != preferredNodeID { + el = el.Next() + } + } if el == nil { return } @@ -478,8 +485,8 @@ func (p *Pool[PT, T]) peekFirstIdle() (item PT, touched time.Time) { // to prevent session from dying in the internalPoolGC after it was returned // to be used only in outgoing functions that make session busy. // p.mu must be held. -func (p *Pool[PT, T]) removeFirstIdle() PT { - idle, _ := p.peekFirstIdle() +func (p *Pool[PT, T]) removeFirstIdle(preferredNodeID uint32) PT { + idle, _ := p.peekFirstIdle(preferredNodeID) if idle != nil { info := p.removeIdle(idle) p.index[idle] = info @@ -585,6 +592,8 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) { / } } + preferredNodeID := operation.CtxPreferredNodeID(ctx) + for ; attempt < maxAttempts; attempt++ { select { case <-p.done: @@ -593,7 +602,7 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) { / } if item := xsync.WithLock(&p.mu, func() PT { //nolint:nestif - return p.removeFirstIdle() + return p.removeFirstIdle(preferredNodeID) }); item != nil { if item.IsAlive() { info := xsync.WithLock(&p.mu, func() itemInfo[PT, T] { @@ -625,7 +634,7 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) { / } } - item, err := p.createItem(ctx) + item, err := p.createItem(ctx, preferredNodeID) if item != nil { return item, nil } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 881fd7932..8d764bfa1 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -20,6 +20,7 @@ import ( grpcStatus "google.golang.org/grpc/status" "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/operation" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" @@ -38,6 +39,7 @@ type ( onClose func() error onIsAlive func() bool + onNodeID func() uint32 } testWaitChPool struct { xsync.Pool[chan *testItem] @@ -113,6 +115,14 @@ func (t *testItem) ID() string { return "" } +func (t *testItem) NodeID() uint32 { + if t.onNodeID != nil { + return t.onNodeID() + } + + return 0 +} + func (t *testItem) Close(context.Context) error { if t.closed.Len() > 0 { debug.PrintStack() @@ -135,8 +145,8 @@ func caller() string { return fmt.Sprintf("%s:%d", path.Base(file), line) } -func mustGetItem[PT ItemConstraint[T], T any](t testing.TB, p *Pool[PT, T]) PT { - s, err := p.getItem(context.Background()) +func mustGetItem[PT ItemConstraint[T], T any](t testing.TB, p *Pool[PT, T], nodeID uint32) PT { + s, err := p.getItem(operation.WithPreferredNodeID(context.Background(), nodeID)) if err != nil { t.Helper() t.Fatalf("%s: %v", caller(), err) @@ -167,10 +177,113 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithTrace[*testItem, testItem](defaultTrace), ) err := p.With(rootCtx, func(ctx context.Context, testItem *testItem) error { + require.EqualValues(t, 0, testItem.NodeID()) + return nil }) require.NoError(t, err) }) + t.Run("RequireNodeIdFromPool", func(t *testing.T) { + var nextNodeID uint32 + nextNodeID = 0 + var newSessionCalled uint32 + p := New[*testItem, testItem](rootCtx, + WithTrace[*testItem, testItem](defaultTrace), + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { + newSessionCalled++ + var ( + nodeID = nextNodeID + v = testItem{ + v: 0, + onNodeID: func() uint32 { + return nodeID + }, + } + ) + + return &v, nil + }), + ) + + item := mustGetItem(t, p, 0) + require.EqualValues(t, 0, item.NodeID()) + require.EqualValues(t, true, item.IsAlive()) + mustPutItem(t, p, item) + + nextNodeID = 32 + + item = mustGetItem(t, p, 32) + require.EqualValues(t, 32, item.NodeID()) + mustPutItem(t, p, item) + + nextNodeID = 33 + + item = mustGetItem(t, p, 33) + require.EqualValues(t, 33, item.NodeID()) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 32) + require.EqualValues(t, 32, item.NodeID()) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 33) + require.EqualValues(t, 33, item.NodeID()) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 32) + item2 := mustGetItem(t, p, 33) + require.EqualValues(t, 32, item.NodeID()) + require.EqualValues(t, 33, item2.NodeID()) + mustPutItem(t, p, item2) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 32) + item2 = mustGetItem(t, p, 33) + require.EqualValues(t, 32, item.NodeID()) + require.EqualValues(t, 33, item2.NodeID()) + mustPutItem(t, p, item) + mustPutItem(t, p, item2) + + item = mustGetItem(t, p, 32) + item2 = mustGetItem(t, p, 33) + item3 := mustGetItem(t, p, 0) + require.EqualValues(t, 32, item.NodeID()) + require.EqualValues(t, 33, item2.NodeID()) + require.EqualValues(t, 0, item3.NodeID()) + mustPutItem(t, p, item) + mustPutItem(t, p, item2) + mustPutItem(t, p, item3) + + require.EqualValues(t, 3, newSessionCalled) + }) + t.Run("CreateSessionOnGivenNode", func(t *testing.T) { + var newSessionCalled uint32 + p := New[*testItem, testItem](rootCtx, + WithTrace[*testItem, testItem](defaultTrace), + WithCreateItemFunc(func(_ context.Context, nodeID uint32) (*testItem, error) { + newSessionCalled++ + v := testItem{ + v: 0, + onNodeID: func() uint32 { + return nodeID + }, + } + + return &v, nil + }), + ) + + item := mustGetItem(t, p, 32) + require.EqualValues(t, 32, item.NodeID()) + require.EqualValues(t, true, item.IsAlive()) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 32) + require.EqualValues(t, 32, item.NodeID()) + mustPutItem(t, p, item) + + require.EqualValues(t, 1, newSessionCalled) + }) t.Run("WithLimit", func(t *testing.T) { p := New[*testItem, testItem](rootCtx, WithLimit[*testItem, testItem](1), WithTrace[*testItem, testItem](defaultTrace), @@ -184,7 +297,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithItemUsageLimit[*testItem, testItem](5), WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&newCounter, 1) var v testItem @@ -208,7 +321,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo var newCounter int64 p := New(rootCtx, WithLimit[*testItem, testItem](1), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&newCounter, 1) var v testItem @@ -242,7 +355,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithLimit[*testItem, testItem](3), WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { var ( idx = created.Add(1) - 1 v = testItem{ @@ -270,9 +383,9 @@ func TestPool(t *testing.T) { //nolint:gocyclo require.Zero(t, p.idle.Len()) var ( - s1 = mustGetItem(t, p) - s2 = mustGetItem(t, p) - s3 = mustGetItem(t, p) + s1 = mustGetItem(t, p, 0) + s2 = mustGetItem(t, p, 0) + s3 = mustGetItem(t, p, 0) ) require.Len(t, p.index, 3) @@ -347,7 +460,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo // second call getItem from pool with limit === 1 will skip // create item step (because pool have not enough space for // creating new items) and will freeze until wait free item from pool - mustGetItem(t, p) + mustGetItem(t, p, 0) go func() { p.config.trace.OnGet = func(ctx *context.Context, call stack.Caller) func(item any, attempts int, err error) { @@ -405,7 +518,8 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New[*testItem, testItem](rootCtx, WithLimit[*testItem, testItem](2), WithCreateItemTimeout[*testItem, testItem](0), - WithCreateItemFunc[*testItem, testItem](func(ctx context.Context) (*testItem, error) { + WithCreateItemFunc[*testItem, testItem](func(ctx context.Context, preferredNodeID uint32) (*testItem, error) { + require.Equal(t, 0, preferredNodeID) v := testItem{ v: 0, onClose: func() error { @@ -425,8 +539,8 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithTrace[*testItem, testItem](defaultTrace), ) - s1 := mustGetItem(t, p) - s2 := mustGetItem(t, p) + s1 := mustGetItem(t, p, 0) + s2 := mustGetItem(t, p, 0) // Put both items at the absolutely same time. // That is, both items must be updated their lastUsage timestamp. @@ -442,7 +556,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo // on get item from idle list the pool must check the item idle timestamp // both existing items must be closed // getItem must create a new item and return it from getItem - s3 := mustGetItem(t, p) + s3 := mustGetItem(t, p, 0) require.Len(t, p.index, 1) @@ -464,7 +578,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo require.Len(t, p.index, 1) require.Equal(t, 1, p.idle.Len()) - s4 := mustGetItem(t, p) + s4 := mustGetItem(t, p, 0) require.Equal(t, s3, s4) require.Len(t, p.index, 1) require.Equal(t, 0, p.idle.Len()) @@ -485,7 +599,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New(rootCtx, WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&counter, 1) if atomic.LoadInt64(&counter) < 10 { @@ -508,7 +622,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New(rootCtx, WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&counter, 1) if atomic.LoadInt64(&counter) < 10 { @@ -532,7 +646,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New(rootCtx, WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&counter, 1) if atomic.LoadInt64(&counter) < 10 { @@ -555,7 +669,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New(rootCtx, WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&counter, 1) if atomic.LoadInt64(&counter) < 10 { @@ -597,7 +711,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo time.Duration(r.Int64(int64(time.Second))), ) defer childCancel() - s, err := p.createItem(childCtx) + s, err := p.createItem(childCtx, 0) if s == nil && err == nil { errCh <- fmt.Errorf("unexpected result: <%v, %w>", s, err) } @@ -703,7 +817,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo ) p := New(rootCtx, WithLimit[*testItem, testItem](1), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&createCounter, 1) v := &testItem{ @@ -740,7 +854,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithLimit[*testItem, testItem](1), WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { newItems.Add(1) v := &testItem{ @@ -801,7 +915,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithLimit[*testItem, testItem](1), WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { created.Add(1) v := testItem{ v: 0, @@ -821,20 +935,20 @@ func TestPool(t *testing.T) { //nolint:gocyclo _ = p.Close(context.Background()) }() - s := mustGetItem(t, p) + s := mustGetItem(t, p, 0) assertCreated(1) mustPutItem(t, p, s) assertClosed(0) - mustGetItem(t, p) + mustGetItem(t, p, 0) assertCreated(1) p.closeItem(context.Background(), s) delete(p.index, s) assertClosed(1) - mustGetItem(t, p) + mustGetItem(t, p, 0) assertCreated(2) }) t.Run("Racy", func(t *testing.T) { @@ -917,7 +1031,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo // replace default async closer for sync testing WithSyncCloseItem[*testItem, testItem](), ) - item := mustGetItem(t, p) + item := mustGetItem(t, p, 0) if err := p.putItem(context.Background(), item); err != nil { t.Fatalf("unexpected error on put session into non-full client: %v, wand: %v", err, nil) } @@ -934,7 +1048,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo // replace default async closer for sync testing WithSyncCloseItem[*testItem, testItem](), ) - item := mustGetItem(t, p) + item := mustGetItem(t, p, 0) mustPutItem(t, p, item) require.Panics(t, func() { diff --git a/internal/query/client.go b/internal/query/client.go index aad2545de..e0a4fd443 100644 --- a/internal/query/client.go +++ b/internal/query/client.go @@ -11,6 +11,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator" "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" "github.com/ydb-platform/ydb-go-sdk/v3/internal/operation" "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/config" @@ -39,7 +40,7 @@ type ( closer.Closer Stats() pool.Stats - With(ctx context.Context, f func(ctx context.Context, s *Session) error, opts ...retry.Option) error + With(ctx context.Context, f func(context.Context, *Session) error, opts ...retry.Option) error } Client struct { config *config.Config @@ -562,7 +563,7 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, cfg *config.Config) * pool.WithCreateItemTimeout[*Session, Session](cfg.SessionCreateTimeout()), pool.WithCloseItemTimeout[*Session, Session](cfg.SessionDeleteTimeout()), pool.WithIdleTimeToLive[*Session, Session](cfg.SessionIdleTimeToLive()), - pool.WithCreateItemFunc(func(ctx context.Context) (_ *Session, err error) { + pool.WithCreateItemFunc(func(ctx context.Context, nodeID uint32) (_ *Session, err error) { var ( createCtx context.Context cancelCreate context.CancelFunc @@ -575,7 +576,7 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, cfg *config.Config) * defer cancelCreate() s, err := createSession(createCtx, client, - session.WithConn(cc), + session.WithConn(conn.ModifyConn(cc, nodeID)), session.WithDeleteTimeout(cfg.SessionDeleteTimeout()), session.WithTrace(cfg.Trace()), ) diff --git a/internal/query/client_test.go b/internal/query/client_test.go index 1a1933ab4..b08349c80 100644 --- a/internal/query/client_test.go +++ b/internal/query/client_test.go @@ -1590,7 +1590,9 @@ func testPool( ) *pool.Pool[*Session, Session] { return pool.New[*Session, Session](ctx, pool.WithLimit[*Session, Session](1), - pool.WithCreateItemFunc(createSession), + pool.WithCreateItemFunc(func(ctx context.Context, _ uint32) (*Session, error) { + return createSession(ctx) + }), pool.WithSyncCloseItem[*Session, Session](), ) } diff --git a/internal/table/client.go b/internal/table/client.go index e463eba16..a020ae6fc 100644 --- a/internal/table/client.go +++ b/internal/table/client.go @@ -8,6 +8,7 @@ import ( "google.golang.org/grpc" "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/config" @@ -40,8 +41,8 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, config *config.Config pool.WithCreateItemTimeout[*session, session](config.CreateSessionTimeout()), pool.WithCloseItemTimeout[*session, session](config.DeleteTimeout()), pool.WithClock[*session, session](config.Clock()), - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { - return newSession(ctx, cc, config) + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, nodeId uint32) (*session, error) { + return newSession(ctx, conn.ModifyConn(cc, nodeId), config) }), pool.WithTrace[*session, session](&pool.Trace{ OnNew: func(ctx *context.Context, call stack.Caller) func(limit int) { diff --git a/internal/table/retry_test.go b/internal/table/retry_test.go index 2086d4fbb..056aa67be 100644 --- a/internal/table/retry_test.go +++ b/internal/table/retry_test.go @@ -86,7 +86,7 @@ func TestDoBadSession(t *testing.T) { xtest.TestManyTimes(t, func(t testing.TB) { closed := make(map[table.Session]bool) p := pool.New[*session, session](ctx, - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, _ uint32) (*session, error) { s := simpleSession(t) s.onClose = append(s.onClose, func(s *session) { closed[s] = true @@ -137,7 +137,7 @@ func TestDoCreateSessionError(t *testing.T) { ctx, cancel := xcontext.WithTimeout(rootCtx, 30*time.Millisecond) defer cancel() p := pool.New[*session, session](ctx, - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, _ uint32) (*session, error) { return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_UNAVAILABLE)) }), pool.WithSyncCloseItem[*session, session](), @@ -306,7 +306,7 @@ func TestDoContextDeadline(t *testing.T) { } ctx := xtest.Context(t) p := pool.New[*session, session](ctx, - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, _ uint32) (*session, error) { return newSession(ctx, client.cc, config.New()) }), pool.WithSyncCloseItem[*session, session](), @@ -355,7 +355,7 @@ func TestDoWithCustomErrors(t *testing.T) { limit = 10 ctx = context.Background() p = pool.New[*session, session](ctx, - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, _ uint32) (*session, error) { return simpleSession(t), nil }), pool.WithLimit[*session, session](limit),