From cd181e4c6dbf1a086e8c991d7ed3f88562a40c2b Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Thu, 14 Nov 2024 11:00:31 +0800 Subject: [PATCH] fix: make NewDataSyncService idempotent of dispatcher (#37576) issue: #37547 Signed-off-by: chyezh --- .../flushcommon/pipeline/data_sync_service.go | 30 ++++++------ .../pipeline/flow_graph_dd_node.go | 4 +- .../pipeline/flow_graph_dd_node_test.go | 3 +- .../flow_graph_dmstream_input_node.go | 48 ++++++++++--------- .../flow_graph_dmstream_input_node_test.go | 17 ++++--- .../pipeline/flow_graph_time_tick_node.go | 4 +- .../flusher/flusherimpl/channel_lifetime.go | 1 + pkg/mq/msgdispatcher/manager.go | 13 +++++ pkg/mq/msgdispatcher/manager_test.go | 25 ++++++++++ 9 files changed, 95 insertions(+), 50 deletions(-) diff --git a/internal/flushcommon/pipeline/data_sync_service.go b/internal/flushcommon/pipeline/data_sync_service.go index 5196742cbfd87..3e69ad016d433 100644 --- a/internal/flushcommon/pipeline/data_sync_service.go +++ b/internal/flushcommon/pipeline/data_sync_service.go @@ -224,7 +224,7 @@ func getServiceWithChannel(initCtx context.Context, params *util.PipelineParams, unflushed, flushed []*datapb.SegmentInfo, input <-chan *msgstream.MsgPack, wbTaskObserverCallback writebuffer.TaskObserverCallback, dropCallback func(), -) (*DataSyncService, error) { +) (dss *DataSyncService, err error) { var ( channelName = info.GetVchan().GetChannelName() collectionID = info.GetVchan().GetCollectionID() @@ -269,13 +269,10 @@ func getServiceWithChannel(initCtx context.Context, params *util.PipelineParams, fg := flowgraph.NewTimeTickedFlowGraph(params.Ctx) nodeList := []flowgraph.Node{} - dmStreamNode, err := newDmInputNode(initCtx, params.DispClient, info.GetVchan().GetSeekPosition(), config, input) - if err != nil { - return nil, err - } + dmStreamNode := newDmInputNode(config, input) nodeList = append(nodeList, dmStreamNode) - ddNode, err := newDDNode( + ddNode := newDDNode( params.Ctx, collectionID, channelName, @@ -285,9 +282,6 @@ func getServiceWithChannel(initCtx context.Context, params *util.PipelineParams, params.CompactionExecutor, params.MsgHandler, ) - if err != nil { - return nil, err - } nodeList = append(nodeList, ddNode) if len(info.GetSchema().GetFunctions()) > 0 { @@ -304,10 +298,7 @@ func getServiceWithChannel(initCtx context.Context, params *util.PipelineParams, } nodeList = append(nodeList, writeNode) - ttNode, err := newTTNode(config, params.WriteBufferManager, params.CheckpointUpdater) - if err != nil { - return nil, err - } + ttNode := newTTNode(config, params.WriteBufferManager, params.CheckpointUpdater) nodeList = append(nodeList, ttNode) if err := fg.AssembleNodes(nodeList...); err != nil { @@ -358,7 +349,18 @@ func NewDataSyncService(initCtx context.Context, pipelineParams *util.PipelinePa if metaCache, err = getMetaCacheWithTickler(initCtx, pipelineParams, info, tickler, unflushedSegmentInfos, flushedSegmentInfos); err != nil { return nil, err } - return getServiceWithChannel(initCtx, pipelineParams, info, metaCache, unflushedSegmentInfos, flushedSegmentInfos, nil, nil, nil) + + input, err := createNewInputFromDispatcher(initCtx, pipelineParams.DispClient, info.GetVchan().GetChannelName(), info.GetVchan().GetSeekPosition()) + if err != nil { + return nil, err + } + ds, err := getServiceWithChannel(initCtx, pipelineParams, info, metaCache, unflushedSegmentInfos, flushedSegmentInfos, input, nil, nil) + if err != nil { + // deregister channel if failed to init flowgraph to avoid resource leak. + pipelineParams.DispClient.Deregister(info.GetVchan().GetChannelName()) + return nil, err + } + return ds, nil } func NewStreamingNodeDataSyncService( diff --git a/internal/flushcommon/pipeline/flow_graph_dd_node.go b/internal/flushcommon/pipeline/flow_graph_dd_node.go index 286389baba100..ddfb5ca81fe0e 100644 --- a/internal/flushcommon/pipeline/flow_graph_dd_node.go +++ b/internal/flushcommon/pipeline/flow_graph_dd_node.go @@ -332,7 +332,7 @@ func (ddn *ddNode) Close() { func newDDNode(ctx context.Context, collID typeutil.UniqueID, vChannelName string, droppedSegmentIDs []typeutil.UniqueID, sealedSegments []*datapb.SegmentInfo, growingSegments []*datapb.SegmentInfo, executor compaction.Executor, handler flusher.MsgHandler, -) (*ddNode, error) { +) *ddNode { baseNode := BaseNode{} baseNode.SetMaxQueueLength(paramtable.Get().DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()) baseNode.SetMaxParallelism(paramtable.Get().DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32()) @@ -364,5 +364,5 @@ func newDDNode(ctx context.Context, collID typeutil.UniqueID, vChannelName strin zap.Int("No. growing segments", len(growingSegments)), ) - return dd, nil + return dd } diff --git a/internal/flushcommon/pipeline/flow_graph_dd_node_test.go b/internal/flushcommon/pipeline/flow_graph_dd_node_test.go index fe3c3f9f568e3..a7920eaa3acbc 100644 --- a/internal/flushcommon/pipeline/flow_graph_dd_node_test.go +++ b/internal/flushcommon/pipeline/flow_graph_dd_node_test.go @@ -76,7 +76,7 @@ func TestFlowGraph_DDNode_newDDNode(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - ddNode, err := newDDNode( + ddNode := newDDNode( context.Background(), collectionID, channelName, @@ -86,7 +86,6 @@ func TestFlowGraph_DDNode_newDDNode(t *testing.T) { compaction.NewExecutor(), nil, ) - require.NoError(t, err) require.NotNil(t, ddNode) assert.Equal(t, fmt.Sprintf("ddNode-%s", ddNode.vChannelName), ddNode.Name()) diff --git a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go index d035592495c7c..cb62f585f63a4 100644 --- a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go +++ b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go @@ -39,30 +39,10 @@ import ( // // messages between two timeticks to the following flowgraph node. In DataNode, the following flow graph node is // flowgraph ddNode. -func newDmInputNode(initCtx context.Context, dispatcherClient msgdispatcher.Client, seekPos *msgpb.MsgPosition, dmNodeConfig *nodeConfig, input <-chan *msgstream.MsgPack) (*flowgraph.InputNode, error) { - log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()), - zap.Int64("collectionID", dmNodeConfig.collectionID), - zap.String("vchannel", dmNodeConfig.vChannelName)) - var err error +func newDmInputNode(dmNodeConfig *nodeConfig, input <-chan *msgstream.MsgPack) *flowgraph.InputNode { if input == nil { - if seekPos != nil && len(seekPos.MsgID) != 0 { - input, err = dispatcherClient.Register(initCtx, dmNodeConfig.vChannelName, seekPos, common.SubscriptionPositionUnknown) - if err != nil { - return nil, err - } - log.Info("datanode seek successfully when register to msgDispatcher", - zap.ByteString("msgID", seekPos.GetMsgID()), - zap.Time("tsTime", tsoutil.PhysicalTime(seekPos.GetTimestamp())), - zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp())))) - } else { - input, err = dispatcherClient.Register(initCtx, dmNodeConfig.vChannelName, nil, common.SubscriptionPositionEarliest) - if err != nil { - return nil, err - } - log.Info("datanode consume successfully when register to msgDispatcher") - } + panic("unreachable: input channel is nil for input node") } - name := fmt.Sprintf("dmInputNode-data-%s", dmNodeConfig.vChannelName) node := flowgraph.NewInputNode( input, @@ -74,5 +54,27 @@ func newDmInputNode(initCtx context.Context, dispatcherClient msgdispatcher.Clie dmNodeConfig.collectionID, metrics.AllLabel, ) - return node, nil + return node +} + +func createNewInputFromDispatcher(initCtx context.Context, dispatcherClient msgdispatcher.Client, vchannel string, seekPos *msgpb.MsgPosition) (<-chan *msgstream.MsgPack, error) { + log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()), + zap.String("vchannel", vchannel)) + if seekPos != nil && len(seekPos.MsgID) != 0 { + input, err := dispatcherClient.Register(initCtx, vchannel, seekPos, common.SubscriptionPositionUnknown) + if err != nil { + return nil, err + } + log.Info("datanode seek successfully when register to msgDispatcher", + zap.ByteString("msgID", seekPos.GetMsgID()), + zap.Time("tsTime", tsoutil.PhysicalTime(seekPos.GetTimestamp())), + zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp())))) + return input, err + } + input, err := dispatcherClient.Register(initCtx, vchannel, nil, common.SubscriptionPositionEarliest) + if err != nil { + return nil, err + } + log.Info("datanode consume successfully when register to msgDispatcher") + return input, err } diff --git a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go index 0c0782fcb8785..2b213bdd3d2e7 100644 --- a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go +++ b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go @@ -23,14 +23,11 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/mq/common" - "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) type mockMsgStreamFactory struct { @@ -107,10 +104,16 @@ func (mtm *mockTtMsgStream) EnableProduce(can bool) { } func TestNewDmInputNode(t *testing.T) { - client := msgdispatcher.NewClient(&mockMsgStreamFactory{}, typeutil.DataNodeRole, paramtable.GetNodeID()) - _, err := newDmInputNode(context.Background(), client, new(msgpb.MsgPosition), &nodeConfig{ + assert.Panics(t, func() { + newDmInputNode(&nodeConfig{ + msFactory: &mockMsgStreamFactory{}, + vChannelName: "mock_vchannel_0", + }, nil) + }) + + node := newDmInputNode(&nodeConfig{ msFactory: &mockMsgStreamFactory{}, vChannelName: "mock_vchannel_0", - }, nil) - assert.NoError(t, err) + }, make(<-chan *msgstream.MsgPack)) + assert.NotNil(t, node) } diff --git a/internal/flushcommon/pipeline/flow_graph_time_tick_node.go b/internal/flushcommon/pipeline/flow_graph_time_tick_node.go index 9f720cbaa4e5a..e1cb4f64dbfc4 100644 --- a/internal/flushcommon/pipeline/flow_graph_time_tick_node.go +++ b/internal/flushcommon/pipeline/flow_graph_time_tick_node.go @@ -140,7 +140,7 @@ func (ttn *ttNode) updateChannelCP(channelPos *msgpb.MsgPosition, curTs time.Tim ttn.lastUpdateTime.Store(curTs) } -func newTTNode(config *nodeConfig, wbManager writebuffer.BufferManager, cpUpdater *util.ChannelCheckpointUpdater) (*ttNode, error) { +func newTTNode(config *nodeConfig, wbManager writebuffer.BufferManager, cpUpdater *util.ChannelCheckpointUpdater) *ttNode { baseNode := BaseNode{} baseNode.SetMaxQueueLength(paramtable.Get().DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()) baseNode.SetMaxParallelism(paramtable.Get().DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32()) @@ -156,5 +156,5 @@ func newTTNode(config *nodeConfig, wbManager writebuffer.BufferManager, cpUpdate dropCallback: config.dropCallback, } - return tt, nil + return tt } diff --git a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go index dbc8007777598..f8c8eaed5cca3 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go +++ b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go @@ -131,6 +131,7 @@ func (c *channelLifetime) Run() error { func() { go func() { c.Cancel() }() }, ) if err != nil { + handler.Close() return err } ds.Start() diff --git a/pkg/mq/msgdispatcher/manager.go b/pkg/mq/msgdispatcher/manager.go index 72ff293825fde..6fd9a22f20354 100644 --- a/pkg/mq/msgdispatcher/manager.go +++ b/pkg/mq/msgdispatcher/manager.go @@ -88,6 +88,19 @@ func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos, c.mu.Lock() defer c.mu.Unlock() + if _, ok := c.soloDispatchers[vchannel]; ok { + // current dispatcher didn't allow multiple subscriptions on same vchannel at same time + log.Warn("unreachable: solo vchannel dispatcher already exists") + return nil, fmt.Errorf("solo vchannel dispatcher already exists") + } + if c.mainDispatcher != nil { + if _, err := c.mainDispatcher.GetTarget(vchannel); err == nil { + // current dispatcher didn't allow multiple subscriptions on same vchannel at same time + log.Warn("unreachable: vchannel has been registered in main dispatcher, ") + return nil, fmt.Errorf("vchannel has been registered in main dispatcher") + } + } + isMain := c.mainDispatcher == nil d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, pos, c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets, false) if err != nil { diff --git a/pkg/mq/msgdispatcher/manager_test.go b/pkg/mq/msgdispatcher/manager_test.go index 7271b2edc268b..4f42392cb5150 100644 --- a/pkg/mq/msgdispatcher/manager_test.go +++ b/pkg/mq/msgdispatcher/manager_test.go @@ -146,6 +146,31 @@ func TestManager(t *testing.T) { c.Close() }) }) + + t.Run("test_repeated_vchannel", func(t *testing.T) { + prefix := fmt.Sprintf("mock%d", time.Now().UnixNano()) + c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) + go c.Run() + assert.NotNil(t, c) + ctx := context.Background() + _, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown) + assert.NoError(t, err) + _, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown) + assert.NoError(t, err) + _, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown) + assert.NoError(t, err) + + _, err = c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown) + assert.Error(t, err) + _, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown) + assert.Error(t, err) + _, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown) + assert.Error(t, err) + + assert.NotPanics(t, func() { + c.Close() + }) + }) } type vchannelHelper struct {