diff --git a/CHANGELOG.md b/CHANGELOG.md index 839523749..734497bbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Allow to set preferred node id to execute query + ## v3.89.2 * Returned log.XXX methods for create fields, removed from public at v3.85.0 diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 41dc4cfad..3c1029315 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -20,6 +20,7 @@ type ( Item interface { IsAlive() bool Close(ctx context.Context) error + NodeID() uint32 } ItemConstraint[T any] interface { *T @@ -30,7 +31,7 @@ type ( clock clockwork.Clock limit int createTimeout time.Duration - createItem func(ctx context.Context) (PT, error) + createItem func(ctx context.Context, preferredNodeId uint32) (PT, error) closeTimeout time.Duration closeItem func(ctx context.Context, item PT) idleTimeToLive time.Duration @@ -48,7 +49,7 @@ type ( Pool[PT ItemConstraint[T], T any] struct { config Config[PT, T] - createItem func(ctx context.Context) (PT, error) + createItem func(ctx context.Context, preferredNodeId uint32) (PT, error) closeItem func(ctx context.Context, item PT) mu xsync.RWMutex @@ -61,9 +62,12 @@ type ( done chan struct{} } Option[PT ItemConstraint[T], T any] func(c *Config[PT, T]) + SessionCallOption struct { + preferredNodeIdOption uint32 + } ) -func WithCreateItemFunc[PT ItemConstraint[T], T any](f func(ctx context.Context) (PT, error)) Option[PT, T] { +func WithCreateItemFunc[PT ItemConstraint[T], T any](f func(ctx context.Context, preferredNodeId uint32) (PT, error)) Option[PT, T] { return func(c *Config[PT, T]) { c.createItem = f } @@ -173,7 +177,7 @@ func New[PT ItemConstraint[T], T any]( } // defaultCreateItem returns a new item -func defaultCreateItem[T any, PT ItemConstraint[T]](context.Context) (PT, error) { +func defaultCreateItem[T any, PT ItemConstraint[T]](context.Context, uint32) (PT, error) { var item T return &item, nil @@ -182,8 +186,8 @@ func defaultCreateItem[T any, PT ItemConstraint[T]](context.Context) (PT, error) // makeAsyncCreateItemFunc wraps the createItem function with timeout handling func makeAsyncCreateItemFunc[PT ItemConstraint[T], T any]( //nolint:funlen p *Pool[PT, T], -) func(ctx context.Context) (PT, error) { - return func(ctx context.Context) (PT, error) { +) func(ctx context.Context, preferredNodeId uint32) (PT, error) { + return func(ctx context.Context, preferredNodeId uint32) (PT, error) { if !xsync.WithLock(&p.mu, func() bool { if len(p.index)+p.createInProgress < p.config.limit { p.createInProgress++ @@ -222,7 +226,7 @@ func makeAsyncCreateItemFunc[PT ItemConstraint[T], T any]( //nolint:funlen defer cancelCreate() } - newItem, err := p.config.createItem(createCtx) + newItem, err := p.config.createItem(createCtx, preferredNodeId) if newItem != nil { p.mu.WithLock(func() { var useCounter uint64 @@ -314,7 +318,7 @@ func (p *Pool[PT, T]) changeState(changeState func() Stats) { } } -func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item PT) error) (finalErr error) { +func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item PT) error, preferredNodeId uint32) (finalErr error) { if onTry := p.config.trace.OnTry; onTry != nil { onDone := onTry(&ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/pool.(*Pool).try"), @@ -334,7 +338,7 @@ func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item default: } - item, err := p.getItem(ctx) + item, err := p.getItem(ctx, preferredNodeId) if err != nil { if xerrors.IsYdb(err) { return xerrors.WithStackTrace(xerrors.Retryable(err)) @@ -358,6 +362,7 @@ func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item func (p *Pool[PT, T]) With( ctx context.Context, f func(ctx context.Context, item PT) error, + preferredNodeId uint32, opts ...retry.Option, ) (finalErr error) { var attempts int @@ -375,7 +380,7 @@ func (p *Pool[PT, T]) With( err := retry.Retry(ctx, func(ctx context.Context) error { attempts++ - err := p.try(ctx, f) + err := p.try(ctx, f, preferredNodeId) if err != nil { return xerrors.WithStackTrace(err) } @@ -460,8 +465,13 @@ func (p *Pool[PT, T]) putWaitCh(ch *chan PT) { //nolint:gocritic } // p.mu must be held. -func (p *Pool[PT, T]) peekFirstIdle() (item PT, touched time.Time) { +func (p *Pool[PT, T]) peekFirstIdle(preferredNodeId uint32) (item PT, touched time.Time) { el := p.idle.Front() + if preferredNodeId != 0 { + for el != nil && el.Value.NodeID() != preferredNodeId { + el = el.Next() + } + } if el == nil { return } @@ -478,8 +488,8 @@ func (p *Pool[PT, T]) peekFirstIdle() (item PT, touched time.Time) { // to prevent session from dying in the internalPoolGC after it was returned // to be used only in outgoing functions that make session busy. // p.mu must be held. -func (p *Pool[PT, T]) removeFirstIdle() PT { - idle, _ := p.peekFirstIdle() +func (p *Pool[PT, T]) removeFirstIdle(preferredNodeId uint32) PT { + idle, _ := p.peekFirstIdle(preferredNodeId) if idle != nil { info := p.removeIdle(idle) p.index[idle] = info @@ -567,7 +577,7 @@ func (p *Pool[PT, T]) pushIdle(item PT, now time.Time) { const maxAttempts = 100 -func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) { //nolint:funlen +func (p *Pool[PT, T]) getItem(ctx context.Context, preferredNodeId uint32) (item PT, finalErr error) { //nolint:funlen var ( start = p.config.clock.Now() attempt int @@ -593,8 +603,9 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) { / } if item := xsync.WithLock(&p.mu, func() PT { //nolint:nestif - return p.removeFirstIdle() + return p.removeFirstIdle(preferredNodeId) }); item != nil { + if item.IsAlive() { info := xsync.WithLock(&p.mu, func() itemInfo[PT, T] { info, has := p.index[item] @@ -625,7 +636,7 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) { / } } - item, err := p.createItem(ctx) + item, err := p.createItem(ctx, preferredNodeId) if item != nil { return item, nil } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 881fd7932..785e555a1 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -38,6 +38,7 @@ type ( onClose func() error onIsAlive func() bool + onNodeID func() uint32 } testWaitChPool struct { xsync.Pool[chan *testItem] @@ -113,6 +114,13 @@ func (t *testItem) ID() string { return "" } +func (t *testItem) NodeID() uint32 { + if t.onNodeID != nil { + return t.onNodeID() + } + return 0 +} + func (t *testItem) Close(context.Context) error { if t.closed.Len() > 0 { debug.PrintStack() @@ -135,8 +143,8 @@ func caller() string { return fmt.Sprintf("%s:%d", path.Base(file), line) } -func mustGetItem[PT ItemConstraint[T], T any](t testing.TB, p *Pool[PT, T]) PT { - s, err := p.getItem(context.Background()) +func mustGetItem[PT ItemConstraint[T], T any](t testing.TB, p *Pool[PT, T], nodeId uint32) PT { + s, err := p.getItem(context.Background(), nodeId) if err != nil { t.Helper() t.Fatalf("%s: %v", caller(), err) @@ -167,10 +175,113 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithTrace[*testItem, testItem](defaultTrace), ) err := p.With(rootCtx, func(ctx context.Context, testItem *testItem) error { + require.EqualValues(t, 0, testItem.NodeID()) return nil - }) + }, 0) require.NoError(t, err) }) + t.Run("RequireNodeIdFromPool", func(t *testing.T) { + var nextNodeId uint32 + nextNodeId = 0 + var newSessionCalled uint32 + p := New[*testItem, testItem](rootCtx, + WithTrace[*testItem, testItem](defaultTrace), + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { + newSessionCalled++ + var ( + nodeId = nextNodeId + v = testItem{ + v: 0, + onNodeID: func() uint32 { + return nodeId + }, + } + ) + return &v, nil + }), + ) + + item := mustGetItem(t, p, 0) + require.EqualValues(t, 0, item.NodeID()) + require.EqualValues(t, true, item.IsAlive()) + mustPutItem(t, p, item) + + nextNodeId = 32 + + item = mustGetItem(t, p, 32) + require.EqualValues(t, 32, item.NodeID()) + mustPutItem(t, p, item) + + nextNodeId = 33 + + item = mustGetItem(t, p, 33) + require.EqualValues(t, 33, item.NodeID()) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 32) + require.EqualValues(t, 32, item.NodeID()) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 33) + require.EqualValues(t, 33, item.NodeID()) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 32) + item2 := mustGetItem(t, p, 33) + require.EqualValues(t, 32, item.NodeID()) + require.EqualValues(t, 33, item2.NodeID()) + mustPutItem(t, p, item2) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 32) + item2 = mustGetItem(t, p, 33) + require.EqualValues(t, 32, item.NodeID()) + require.EqualValues(t, 33, item2.NodeID()) + mustPutItem(t, p, item) + mustPutItem(t, p, item2) + + item = mustGetItem(t, p, 32) + item2 = mustGetItem(t, p, 33) + item3 := mustGetItem(t, p, 0) + require.EqualValues(t, 32, item.NodeID()) + require.EqualValues(t, 33, item2.NodeID()) + require.EqualValues(t, 0, item3.NodeID()) + mustPutItem(t, p, item) + mustPutItem(t, p, item2) + mustPutItem(t, p, item3) + + require.EqualValues(t, 3, newSessionCalled) + }) + t.Run("CreateSessionOnGivenNode", func(t *testing.T) { + var newSessionCalled uint32 + p := New[*testItem, testItem](rootCtx, + WithTrace[*testItem, testItem](defaultTrace), + WithCreateItemFunc(func(ctx context.Context, nodeId uint32) (*testItem, error) { + _ = ctx + newSessionCalled++ + var ( + v = testItem{ + v: 0, + onNodeID: func() uint32 { + return nodeId + }, + } + ) + return &v, nil + }), + ) + + item := mustGetItem(t, p, 32) + require.EqualValues(t, 32, item.NodeID()) + require.EqualValues(t, true, item.IsAlive()) + mustPutItem(t, p, item) + + item = mustGetItem(t, p, 32) + require.EqualValues(t, 32, item.NodeID()) + mustPutItem(t, p, item) + + require.EqualValues(t, 1, newSessionCalled) + }) t.Run("WithLimit", func(t *testing.T) { p := New[*testItem, testItem](rootCtx, WithLimit[*testItem, testItem](1), WithTrace[*testItem, testItem](defaultTrace), @@ -184,7 +295,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithItemUsageLimit[*testItem, testItem](5), WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&newCounter, 1) var v testItem @@ -200,7 +311,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo } return nil - }) + }, 0) require.NoError(t, err) require.EqualValues(t, 2, newCounter) }) @@ -208,7 +319,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo var newCounter int64 p := New(rootCtx, WithLimit[*testItem, testItem](1), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&newCounter, 1) var v testItem @@ -218,7 +329,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo ) err := p.With(rootCtx, func(ctx context.Context, item *testItem) error { return nil - }) + }, 0) require.NoError(t, err) require.EqualValues(t, p.config.limit, atomic.LoadInt64(&newCounter)) }) @@ -242,7 +353,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithLimit[*testItem, testItem](3), WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { var ( idx = created.Add(1) - 1 v = testItem{ @@ -270,9 +381,9 @@ func TestPool(t *testing.T) { //nolint:gocyclo require.Zero(t, p.idle.Len()) var ( - s1 = mustGetItem(t, p) - s2 = mustGetItem(t, p) - s3 = mustGetItem(t, p) + s1 = mustGetItem(t, p, 0) + s2 = mustGetItem(t, p, 0) + s3 = mustGetItem(t, p, 0) ) require.Len(t, p.index, 3) @@ -347,7 +458,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo // second call getItem from pool with limit === 1 will skip // create item step (because pool have not enough space for // creating new items) and will freeze until wait free item from pool - mustGetItem(t, p) + mustGetItem(t, p, 0) go func() { p.config.trace.OnGet = func(ctx *context.Context, call stack.Caller) func(item any, attempts int, err error) { @@ -356,7 +467,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo return nil } - _, err := p.getItem(context.Background()) + _, err := p.getItem(context.Background(), 0) got <- err }() @@ -405,7 +516,8 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New[*testItem, testItem](rootCtx, WithLimit[*testItem, testItem](2), WithCreateItemTimeout[*testItem, testItem](0), - WithCreateItemFunc[*testItem, testItem](func(ctx context.Context) (*testItem, error) { + WithCreateItemFunc[*testItem, testItem](func(ctx context.Context, preferredNodeId uint32) (*testItem, error) { + _ = preferredNodeId v := testItem{ v: 0, onClose: func() error { @@ -425,8 +537,8 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithTrace[*testItem, testItem](defaultTrace), ) - s1 := mustGetItem(t, p) - s2 := mustGetItem(t, p) + s1 := mustGetItem(t, p, 0) + s2 := mustGetItem(t, p, 0) // Put both items at the absolutely same time. // That is, both items must be updated their lastUsage timestamp. @@ -442,7 +554,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo // on get item from idle list the pool must check the item idle timestamp // both existing items must be closed // getItem must create a new item and return it from getItem - s3 := mustGetItem(t, p) + s3 := mustGetItem(t, p, 0) require.Len(t, p.index, 1) @@ -464,7 +576,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo require.Len(t, p.index, 1) require.Equal(t, 1, p.idle.Len()) - s4 := mustGetItem(t, p) + s4 := mustGetItem(t, p, 0) require.Equal(t, s3, s4) require.Len(t, p.index, 1) require.Equal(t, 0, p.idle.Len()) @@ -485,7 +597,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New(rootCtx, WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&counter, 1) if atomic.LoadInt64(&counter) < 10 { @@ -499,7 +611,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo ) err := p.With(rootCtx, func(ctx context.Context, item *testItem) error { return nil - }) + }, 0) require.NoError(t, err) require.GreaterOrEqual(t, atomic.LoadInt64(&counter), int64(10)) }) @@ -508,7 +620,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New(rootCtx, WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&counter, 1) if atomic.LoadInt64(&counter) < 10 { @@ -522,7 +634,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo ) err := p.With(rootCtx, func(ctx context.Context, item *testItem) error { return nil - }) + }, 0) require.NoError(t, err) require.GreaterOrEqual(t, atomic.LoadInt64(&counter), int64(10)) }) @@ -532,7 +644,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New(rootCtx, WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&counter, 1) if atomic.LoadInt64(&counter) < 10 { @@ -546,7 +658,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo ) err := p.With(rootCtx, func(ctx context.Context, item *testItem) error { return nil - }) + }, 0) require.NoError(t, err) require.GreaterOrEqual(t, atomic.LoadInt64(&counter), int64(10)) }) @@ -555,7 +667,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New(rootCtx, WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&counter, 1) if atomic.LoadInt64(&counter) < 10 { @@ -569,7 +681,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo ) err := p.With(rootCtx, func(ctx context.Context, item *testItem) error { return nil - }) + }, 0) require.NoError(t, err) require.GreaterOrEqual(t, atomic.LoadInt64(&counter), int64(10)) }) @@ -597,7 +709,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo time.Duration(r.Int64(int64(time.Second))), ) defer childCancel() - s, err := p.createItem(childCtx) + s, err := p.createItem(childCtx, 0) if s == nil && err == nil { errCh <- fmt.Errorf("unexpected result: <%v, %w>", s, err) } @@ -625,7 +737,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New[*testItem, testItem](ctx, WithLimit[*testItem, testItem](1)) err := p.With(ctx, func(ctx context.Context, testItem *testItem) error { return nil - }) + }, 0) require.ErrorIs(t, err, context.Canceled) }) t.Run("DeadlineExceeded", func(t *testing.T) { @@ -634,7 +746,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo p := New[*testItem, testItem](ctx, WithLimit[*testItem, testItem](1)) err := p.With(ctx, func(ctx context.Context, testItem *testItem) error { return nil - }) + }, 0) require.ErrorIs(t, err, context.DeadlineExceeded) }) }) @@ -661,7 +773,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo err := p.With(ctx, func(ctx context.Context, item *testItem) error { return testErr - }, + }, 0, retry.WithFastBackoff( testutil.BackoffFunc(func(n int) <-chan time.Time { ch := make(chan time.Time) @@ -703,7 +815,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo ) p := New(rootCtx, WithLimit[*testItem, testItem](1), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { atomic.AddInt64(&createCounter, 1) v := &testItem{ @@ -721,7 +833,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo ) err := p.With(rootCtx, func(ctx context.Context, testItem *testItem) error { return nil - }) + }, 0) require.NoError(t, err) require.GreaterOrEqual(t, atomic.LoadInt64(&createCounter), atomic.LoadInt64(&closeCounter)) err = p.Close(rootCtx) @@ -740,7 +852,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithLimit[*testItem, testItem](1), WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { newItems.Add(1) v := &testItem{ @@ -765,7 +877,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo } return nil - }) + }, 0) require.NoError(t, err) require.GreaterOrEqual(t, newItems.Load(), int64(9)) require.GreaterOrEqual(t, newItems.Load(), deleteItems.Load()) @@ -801,7 +913,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo WithLimit[*testItem, testItem](1), WithCreateItemTimeout[*testItem, testItem](50*time.Millisecond), WithCloseItemTimeout[*testItem, testItem](50*time.Millisecond), - WithCreateItemFunc(func(context.Context) (*testItem, error) { + WithCreateItemFunc(func(context.Context, uint32) (*testItem, error) { created.Add(1) v := testItem{ v: 0, @@ -821,20 +933,20 @@ func TestPool(t *testing.T) { //nolint:gocyclo _ = p.Close(context.Background()) }() - s := mustGetItem(t, p) + s := mustGetItem(t, p, 0) assertCreated(1) mustPutItem(t, p, s) assertClosed(0) - mustGetItem(t, p) + mustGetItem(t, p, 0) assertCreated(1) p.closeItem(context.Background(), s) delete(p.index, s) assertClosed(1) - mustGetItem(t, p) + mustGetItem(t, p, 0) assertCreated(2) }) t.Run("Racy", func(t *testing.T) { @@ -862,7 +974,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo defer childCancel() err := p.With(childCtx, func(ctx context.Context, testItem *testItem) error { return nil - }) + }, 0) if err != nil && !xerrors.Is(err, errClosedPool, context.Canceled) { t.Failed() } @@ -897,7 +1009,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo defer wg.Done() err := p.With(rootCtx, func(ctx context.Context, testItem *testItem) error { return nil - }) + }, 0) if err != nil && !xerrors.Is(err, errClosedPool, context.Canceled) { t.Failed() } @@ -917,7 +1029,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo // replace default async closer for sync testing WithSyncCloseItem[*testItem, testItem](), ) - item := mustGetItem(t, p) + item := mustGetItem(t, p, 0) if err := p.putItem(context.Background(), item); err != nil { t.Fatalf("unexpected error on put session into non-full client: %v, wand: %v", err, nil) } @@ -934,7 +1046,7 @@ func TestPool(t *testing.T) { //nolint:gocyclo // replace default async closer for sync testing WithSyncCloseItem[*testItem, testItem](), ) - item := mustGetItem(t, p) + item := mustGetItem(t, p, 0) mustPutItem(t, p, item) require.Panics(t, func() { diff --git a/internal/query/client.go b/internal/query/client.go index e360dcb7c..93c3af032 100644 --- a/internal/query/client.go +++ b/internal/query/client.go @@ -2,6 +2,8 @@ package query import ( "context" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" + balancerContext "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "time" "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" @@ -39,7 +41,7 @@ type ( closer.Closer Stats() pool.Stats - With(ctx context.Context, f func(ctx context.Context, s *Session) error, opts ...retry.Option) error + With(ctx context.Context, f func(ctx context.Context, s *Session) error, preferredNodeId uint32, opts ...retry.Option) error } Client struct { config *config.Config @@ -198,6 +200,7 @@ func do( ctx context.Context, pool sessionPool, op func(ctx context.Context, s *Session) error, + preferredNodeId uint32, opts ...retry.Option, ) (finalErr error) { err := pool.With(ctx, func(ctx context.Context, s *Session) error { @@ -213,7 +216,7 @@ func do( s.SetStatus(session.StatusIdle) return nil - }, opts...) + }, preferredNodeId, opts...) if err != nil { return xerrors.WithStackTrace(err) } @@ -239,7 +242,7 @@ func (c *Client) Do(ctx context.Context, op query.Operation, opts ...options.DoO err := do(ctx, c.pool, func(ctx context.Context, s *Session) error { return op(ctx, s) - }, + }, 0, append([]retry.Option{ retry.WithTrace(&trace.Retry{ OnRetry: func(info trace.RetryLoopStartInfo) func(trace.RetryLoopDoneInfo) { @@ -286,7 +289,7 @@ func doTx( } return nil - }, opts...) + }, 0, opts...) if err != nil { return xerrors.WithStackTrace(err) } @@ -304,7 +307,7 @@ func clientQueryRow( } return nil - }, settings.RetryOpts()...) + }, 0, settings.RetryOpts()...) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -347,7 +350,7 @@ func clientExec(ctx context.Context, pool sessionPool, q string, opts ...options } return nil - }, settings.RetryOpts()...) + }, 0, settings.RetryOpts()...) if err != nil { return xerrors.WithStackTrace(err) } @@ -400,7 +403,7 @@ func clientQuery(ctx context.Context, pool sessionPool, q string, opts ...option } return nil - }, settings.RetryOpts()...) + }, 0, settings.RetryOpts()...) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -443,7 +446,7 @@ func clientQueryResultSet( } return nil - }, settings.RetryOpts()...) + }, 0, settings.RetryOpts()...) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -530,7 +533,7 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, cfg *config.Config) * pool.WithCreateItemTimeout[*Session, Session](cfg.SessionCreateTimeout()), pool.WithCloseItemTimeout[*Session, Session](cfg.SessionDeleteTimeout()), pool.WithIdleTimeToLive[*Session, Session](cfg.SessionIdleTimeToLive()), - pool.WithCreateItemFunc(func(ctx context.Context) (_ *Session, err error) { + pool.WithCreateItemFunc(func(ctx context.Context, nodeId uint32) (_ *Session, err error) { var ( createCtx context.Context cancelCreate context.CancelFunc @@ -542,6 +545,12 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, cfg *config.Config) * } defer cancelCreate() + if nodeId != 0 { + cc = conn.WithContextModifier(cc, func(ctx context.Context) context.Context { + return balancerContext.WithNodeID(ctx, nodeId) + }) + } + s, err := createSession(createCtx, client, session.WithConn(cc), session.WithDeleteTimeout(cfg.SessionDeleteTimeout()), diff --git a/internal/table/client.go b/internal/table/client.go index e463eba16..2da3a3cb1 100644 --- a/internal/table/client.go +++ b/internal/table/client.go @@ -2,6 +2,8 @@ package table import ( "context" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" + balancerContext "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/jonboulle/clockwork" "github.com/ydb-platform/ydb-go-genproto/Ydb_Table_V1" @@ -40,7 +42,12 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, config *config.Config pool.WithCreateItemTimeout[*session, session](config.CreateSessionTimeout()), pool.WithCloseItemTimeout[*session, session](config.DeleteTimeout()), pool.WithClock[*session, session](config.Clock()), - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, nodeId uint32) (*session, error) { + if nodeId != 0 { + cc = conn.WithContextModifier(cc, func(ctx context.Context) context.Context { + return balancerContext.WithNodeID(ctx, nodeId) + }) + } return newSession(ctx, cc, config) }), pool.WithTrace[*session, session](&pool.Trace{ @@ -212,7 +219,7 @@ func (c *Client) Do(ctx context.Context, op table.Operation, opts ...table.Optio err := do(ctx, c.pool, c.config, op, func(err error) { attempts++ - }, config.RetryOptions...) + }, config.PreferredNodeId, config.RetryOptions...) if err != nil { return xerrors.WithStackTrace(err) } @@ -263,7 +270,7 @@ func (c *Client) DoTx(ctx context.Context, op table.TxOperation, opts ...table.O } return nil - }, config.RetryOptions...) + }, config.PreferredNodeId, config.RetryOptions...) } func (c *Client) BulkUpsert( diff --git a/internal/table/retry.go b/internal/table/retry.go index 9e193fa5a..1992baee2 100644 --- a/internal/table/retry.go +++ b/internal/table/retry.go @@ -18,7 +18,7 @@ type sessionPool interface { closer.Closer Stats() pool.Stats - With(ctx context.Context, f func(ctx context.Context, s *session) error, opts ...retry.Option) error + With(ctx context.Context, f func(ctx context.Context, s *session) error, preferredNodeId uint32, opts ...retry.Option) error } func do( @@ -27,6 +27,7 @@ func do( config *config.Config, op table.Operation, onAttempt func(err error), + preferredNodeId uint32, opts ...retry.Option, ) (err error) { return retryBackoff(ctx, pool, @@ -54,6 +55,7 @@ func do( return nil }, + preferredNodeId, opts..., ) } @@ -62,6 +64,7 @@ func retryBackoff( ctx context.Context, pool sessionPool, op table.Operation, + preferredNodeId uint32, opts ...retry.Option, ) error { return pool.With(ctx, func(ctx context.Context, s *session) error { @@ -72,7 +75,7 @@ func retryBackoff( } return nil - }, opts...) + }, preferredNodeId, opts...) } func (c *Client) retryOptions(opts ...table.Option) *table.Options { diff --git a/internal/table/retry_test.go b/internal/table/retry_test.go index 2086d4fbb..1d9e19bdb 100644 --- a/internal/table/retry_test.go +++ b/internal/table/retry_test.go @@ -48,7 +48,7 @@ func TestDoBackoffRetryCancelation(t *testing.T) { func(ctx context.Context, _ table.Session) error { return testErr }, - nil, + nil, 0, retry.WithFastBackoff( testutil.BackoffFunc(func(n int) <-chan time.Time { ch := make(chan time.Time) @@ -86,7 +86,7 @@ func TestDoBadSession(t *testing.T) { xtest.TestManyTimes(t, func(t testing.TB) { closed := make(map[table.Session]bool) p := pool.New[*session, session](ctx, - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, _ uint32) (*session, error) { s := simpleSession(t) s.onClose = append(s.onClose, func(s *session) { closed[s] = true @@ -112,7 +112,7 @@ func TestDoBadSession(t *testing.T) { return xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) }, - func(err error) {}, + func(err error) {}, 0, ) if !xerrors.Is(err, context.Canceled) { t.Errorf("unexpected error: %v", err) @@ -137,7 +137,7 @@ func TestDoCreateSessionError(t *testing.T) { ctx, cancel := xcontext.WithTimeout(rootCtx, 30*time.Millisecond) defer cancel() p := pool.New[*session, session](ctx, - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, _ uint32) (*session, error) { return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_UNAVAILABLE)) }), pool.WithSyncCloseItem[*session, session](), @@ -146,7 +146,7 @@ func TestDoCreateSessionError(t *testing.T) { func(ctx context.Context, s table.Session) error { return nil }, - nil, + nil, 0, ) if !xerrors.Is(err, context.DeadlineExceeded) { t.Errorf("unexpected error: %v", err) @@ -190,6 +190,7 @@ func TestDoImmediateReturn(t *testing.T) { return testErr }, nil, + 0, retry.WithFastBackoff( testutil.BackoffFunc(func(n int) <-chan time.Time { panic("this code will not be called") @@ -306,7 +307,7 @@ func TestDoContextDeadline(t *testing.T) { } ctx := xtest.Context(t) p := pool.New[*session, session](ctx, - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, _ uint32) (*session, error) { return newSession(ctx, client.cc, config.New()) }), pool.WithSyncCloseItem[*session, session](), @@ -331,7 +332,7 @@ func TestDoContextDeadline(t *testing.T) { return errs[r.Int(len(errs))] } }, - nil, + nil, 0, ) }) } @@ -355,7 +356,7 @@ func TestDoWithCustomErrors(t *testing.T) { limit = 10 ctx = context.Background() p = pool.New[*session, session](ctx, - pool.WithCreateItemFunc[*session, session](func(ctx context.Context) (*session, error) { + pool.WithCreateItemFunc[*session, session](func(ctx context.Context, _ uint32) (*session, error) { return simpleSession(t), nil }), pool.WithLimit[*session, session](limit), @@ -435,7 +436,7 @@ func TestDoWithCustomErrors(t *testing.T) { return nil }, - nil, + nil, 0, ) //nolint:nestif if test.retriable { @@ -485,7 +486,7 @@ func (s *singleSession) Stats() pool.Stats { } func (s *singleSession) With(ctx context.Context, - f func(ctx context.Context, s *session) error, opts ...retry.Option, + f func(ctx context.Context, s *session) error, _ uint32, opts ...retry.Option, ) error { return retry.Retry(ctx, func(ctx context.Context) error { return f(ctx, s.s) diff --git a/retry/retry.go b/retry/retry.go index 4394b6146..aef9b111c 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -19,14 +19,15 @@ import ( type retryOperation func(context.Context) (err error) type retryOptions struct { - label string - call call - trace *trace.Retry - idempotent bool - stackTrace bool - fastBackoff backoff.Backoff - slowBackoff backoff.Backoff - budget budget.Budget + label string + call call + trace *trace.Retry + idempotent bool + stackTrace bool + fastBackoff backoff.Backoff + slowBackoff backoff.Backoff + budget budget.Budget + preferredNodeId uint32 panicCallback func(e interface{}) } diff --git a/table/table.go b/table/table.go index 961a994c3..3de0ac37c 100644 --- a/table/table.go +++ b/table/table.go @@ -505,12 +505,21 @@ type Options struct { TxCommitOptions []options.CommitTransactionOption RetryOptions []retry.Option Trace *trace.Table + PreferredNodeId uint32 } type Option interface { ApplyTableOption(opts *Options) } +type SessionOption struct { + PreferredNodeId uint32 +} + +func (o SessionOption) ApplyTableOption(opts *Options) { + opts.PreferredNodeId = o.PreferredNodeId +} + var _ Option = labelOption("") type labelOption string @@ -548,6 +557,10 @@ func WithIdempotent() retryOptionsOption { return []retry.Option{retry.WithIdempotent(true)} } +func WithPreferredNodeid(nodeId uint32) SessionOption { + return SessionOption{PreferredNodeId: nodeId} +} + var _ Option = txSettingsOption{} type txSettingsOption struct {