From e0cefe603592f1ad1a600dbca33848b247c54826 Mon Sep 17 00:00:00 2001 From: Keran Yang Date: Fri, 25 Aug 2023 17:47:41 -0400 Subject: [PATCH] chore: add some small refactors mainly for source (#986) 1. Remove retry on Ack. For most of our existing sources, the Ack() functions return `return make([]error, len(offsets))`. The only exception is the RedisStream source, it uses `err := br.Client.XAck()` to batch acknowledge and populates the error array with the same err. Hence, for ALL sources, it's either ALL success or ALL failure and we don't need to retry acking. 2. Replace `err ==` with `errors.Is`. 3. Fix some typos and nit grammars as IntelliJ is giving me a lot of red lines :) 4. Remove some gRPC unit tests because I don't think they are necessary. The whole point of mocking a dependency is so that we don't worry about how the dependency constructs the returned value. Some of our gRPC unit tests are implementing the dependency and verifying the functionality of the implementation, which doesn't make sense to me. e.g. in `TestHGRPCBasedUDF_ApplyWithMockClient`, we implement `multiplyBy2` and verify the result is `multipliedBy2`. Signed-off-by: Keran Yang --- pkg/reduce/pbq/store/wal/bootstrap.go | 3 +- pkg/reduce/pnf/ordered.go | 3 +- pkg/shared/kvs/jetstream/kv_watch.go | 5 +- pkg/sinks/udsink/udsink_grpc_test.go | 7 +- pkg/sources/forward/data_forward.go | 59 ++--------- pkg/sources/generator/tickgen.go | 81 +++++++-------- pkg/sources/generator/tickgen_test.go | 2 +- pkg/sources/http/http.go | 21 ++-- pkg/sources/kafka/reader.go | 100 +++++++++--------- pkg/sources/kafka/reader_test.go | 6 +- pkg/sources/nats/nats.go | 15 ++- pkg/sources/redisstreams/redisstream.go | 19 ++-- pkg/sources/source.go | 5 +- pkg/sources/udsource/grpc_udsource_test.go | 5 + pkg/sources/udsource/user_defined_source.go | 10 +- pkg/udf/rpc/grpc_map_test.go | 93 ++--------------- pkg/udf/rpc/grpc_mapstream_test.go | 109 +------------------- pkg/udf/rpc/grpc_reduce_test.go | 91 +--------------- 18 files changed, 164 insertions(+), 470 deletions(-) diff --git a/pkg/reduce/pbq/store/wal/bootstrap.go b/pkg/reduce/pbq/store/wal/bootstrap.go index 3dd64b259a..10a1cf3e94 100644 --- a/pkg/reduce/pbq/store/wal/bootstrap.go +++ b/pkg/reduce/pbq/store/wal/bootstrap.go @@ -18,6 +18,7 @@ package wal import ( "encoding/binary" + "errors" "fmt" "io" "os" @@ -111,7 +112,7 @@ func (w *WAL) Read(size int64) ([]*isb.ReadMessage, bool, error) { for size > w.rOffset-start && !w.isEnd() { message, sizeRead, err := decodeReadMessage(w.fp) if err != nil { - if err == errChecksumMismatch { + if errors.Is(err, errChecksumMismatch) { w.corrupted = true } return nil, false, err diff --git a/pkg/reduce/pnf/ordered.go b/pkg/reduce/pnf/ordered.go index 8f15dba4a0..b5d4782a74 100644 --- a/pkg/reduce/pnf/ordered.go +++ b/pkg/reduce/pnf/ordered.go @@ -19,6 +19,7 @@ package pnf import ( "container/list" "context" + "errors" "strconv" "sync" "time" @@ -132,7 +133,7 @@ func (op *OrderedProcessor) reduceOp(ctx context.Context, t *ForwardTask) { start := time.Now() err := t.pf.Process(ctx) if err != nil { - if err == ctx.Err() { + if errors.Is(err, ctx.Err()) { udfError.With(map[string]string{ metrics.LabelVertex: op.vertexName, metrics.LabelPipeline: op.pipelineName, diff --git a/pkg/shared/kvs/jetstream/kv_watch.go b/pkg/shared/kvs/jetstream/kv_watch.go index c742fb64e3..15d6a6e6a2 100644 --- a/pkg/shared/kvs/jetstream/kv_watch.go +++ b/pkg/shared/kvs/jetstream/kv_watch.go @@ -18,14 +18,15 @@ package jetstream import ( "context" + "errors" "fmt" "time" "github.com/nats-io/nats.go" - "github.com/numaproj/numaflow/pkg/shared/kvs" "go.uber.org/zap" jsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats" + "github.com/numaproj/numaflow/pkg/shared/kvs" "github.com/numaproj/numaflow/pkg/shared/logging" ) @@ -229,7 +230,7 @@ retryLoop: } else { // if there are no keys in the store, return zero time because there are no updates // upstream will handle it - if err == nats.ErrNoKeysFound { + if errors.Is(err, nats.ErrNoKeysFound) { return time.Time{} } jsw.log.Errorw("Failed to get keys", zap.String("watcher", jsw.GetKVName()), zap.Error(err)) diff --git a/pkg/sinks/udsink/udsink_grpc_test.go b/pkg/sinks/udsink/udsink_grpc_test.go index 875fe2d757..ba0255b576 100644 --- a/pkg/sinks/udsink/udsink_grpc_test.go +++ b/pkg/sinks/udsink/udsink_grpc_test.go @@ -18,6 +18,7 @@ package udsink import ( "context" + "errors" "fmt" "testing" "time" @@ -47,7 +48,7 @@ func Test_gRPCBasedUDSink_WaitUntilReadyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -103,7 +104,7 @@ func Test_gRPCBasedUDSink_ApplyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -145,7 +146,7 @@ func Test_gRPCBasedUDSink_ApplyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() diff --git a/pkg/sources/forward/data_forward.go b/pkg/sources/forward/data_forward.go index a0a204aa22..793d7b764e 100644 --- a/pkg/sources/forward/data_forward.go +++ b/pkg/sources/forward/data_forward.go @@ -20,12 +20,10 @@ import ( "context" "errors" "fmt" - "math" "sync" "time" "go.uber.org/zap" - "k8s.io/apimachinery/pkg/util/wait" dfv1 "github.com/numaproj/numaflow/pkg/apis/numaflow/v1alpha1" "github.com/numaproj/numaflow/pkg/forward" @@ -166,7 +164,7 @@ func (isdf *DataForward) Start() <-chan struct{} { } } - // publisher was created by the forwarder, so it should be closed by the forwarder. + // the publisher was created by the forwarder, so it should be closed by the forwarder. for _, toVertexPublishers := range isdf.toVertexWMPublishers { for _, pub := range toVertexPublishers { if err := pub.Close(); err != nil { @@ -374,7 +372,7 @@ func (isdf *DataForward) forwardAChunk(ctx context.Context) { // when we apply transformer, we don't handle partial errors (it's either non or all, non will return early), // so we should be able to ack all the readOffsets including data messages and control messages - err = isdf.ackFromBuffer(ctx, readOffsets) + err = isdf.ackFromSource(ctx, readOffsets) // implicit return for posterity :-) if err != nil { isdf.opts.logger.Errorw("failed to ack from source", zap.Error(err)) @@ -387,54 +385,11 @@ func (isdf *DataForward) forwardAChunk(ctx context.Context) { forwardAChunkProcessingTime.With(map[string]string{metrics.LabelVertex: isdf.vertexName, metrics.LabelPipeline: isdf.pipelineName, metrics.LabelPartitionName: isdf.reader.GetName()}).Observe(float64(time.Since(start).Microseconds())) } -// ackFromBuffer acknowledges an array of offsets back to the reader -// and is a blocking call or until shutdown has been initiated. -func (isdf *DataForward) ackFromBuffer(ctx context.Context, offsets []isb.Offset) error { - var ackRetryBackOff = wait.Backoff{ - Factor: 1, - Jitter: 0.1, - Steps: math.MaxInt, - Duration: time.Millisecond * 10, - } - var ackOffsets = offsets - attempt := 0 - - ctxClosedErr := wait.ExponentialBackoff(ackRetryBackOff, func() (done bool, err error) { - errs := isdf.reader.Ack(ctx, ackOffsets) - attempt += 1 - summarizedErr := errorArrayToMap(errs) - var failedOffsets []isb.Offset - if len(summarizedErr) > 0 { - isdf.opts.logger.Errorw("Failed to ack from buffer, retrying", zap.Any("errors", summarizedErr), zap.Int("attempt", attempt)) - // no point retrying if ctx.Done has been invoked - select { - case <-ctx.Done(): - // no point in retrying after we have been asked to stop. - return false, ctx.Err() - default: - // retry only the failed offsets - for i, offset := range ackOffsets { - if errs[i] != nil { - failedOffsets = append(failedOffsets, offset) - } - } - ackOffsets = failedOffsets - if ok, _ := isdf.IsShuttingDown(); ok { - ackErr := fmt.Errorf("AckFromBuffer, Stop called while stuck on an internal error, %v", summarizedErr) - return false, ackErr - } - return false, nil - } - } else { - return true, nil - } - }) - - if ctxClosedErr != nil { - isdf.opts.logger.Errorw("Context closed while waiting to ack messages inside forward", zap.Error(ctxClosedErr)) - } - - return ctxClosedErr +func (isdf *DataForward) ackFromSource(ctx context.Context, offsets []isb.Offset) error { + // for all the sources, we either ack all offsets or none. + // when a batch ack fails, the source Ack() function populate the error array with the same error; + // hence we can just return the first error. + return isdf.reader.Ack(ctx, offsets)[0] } // writeToBuffers is a blocking call until all the messages have been forwarded to all the toBuffers, or a shutdown diff --git a/pkg/sources/generator/tickgen.go b/pkg/sources/generator/tickgen.go index 81d14d7ba6..d969753cd5 100644 --- a/pkg/sources/generator/tickgen.go +++ b/pkg/sources/generator/tickgen.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package generator contains an implementation of a in memory generator that generates +// Package generator contains an implementation of an in-memory generator that generates // payloads in json format. package generator @@ -76,7 +76,7 @@ var recordGenerator = func(size int32, value *uint64, createdTS int64) []byte { } size = size - 8 if size > 0 { - // padding to guarantee size of the message + // padding to guarantee the size of the message b := make([]byte, size) _, err := rand.Read(b) // we do not care about failures here. if err != nil { @@ -94,8 +94,8 @@ var recordGenerator = func(size int32, value *uint64, createdTS int64) []byte { } type memgen struct { - // srcchan provides a go channel that supplies generated data - srcchan chan record + // srcChan provides a go channel that supplies generated data + srcChan chan record // rpu - records per time unit rpu int // keyCount is the number of unique keys in the payload @@ -108,15 +108,14 @@ type memgen struct { // timeunit - ticker will fire once per timeunit and generates // a number of records equal to the number passed to rpu. timeunit time.Duration - // genfn function that generates a payload as a byte array - genfn func(int32, *uint64, int64) []byte - // name is the name of the source node - name string + // genFn function that generates a payload as a byte array + genFn func(int32, *uint64, int64) []byte + // name is the name of the source vertex + vertexName string // pipelineName is the name of the pipeline pipelineName string - // cancel function . - // once terminated the source will not generate any more records. - cancel context.CancelFunc + // cancelFn terminates the source will not generate any more records. + cancelFn context.CancelFunc // forwarder to read from the source and write to the inter step buffer. forwarder *sourceforward.DataForward // lifecycleCtx context is used to control the lifecycle of this instance. @@ -182,36 +181,36 @@ func NewMemGen( value = vertexInstance.Vertex.Spec.Source.Generator.Value } - gensrc := &memgen{ + genSrc := &memgen{ rpu: rpu, keyCount: keyCount, value: value, msgSize: msgSize, timeunit: timeunit, - name: vertexInstance.Vertex.Spec.Name, + vertexName: vertexInstance.Vertex.Spec.Name, pipelineName: vertexInstance.Vertex.Spec.PipelineName, - genfn: recordGenerator, + genFn: recordGenerator, vertexInstance: vertexInstance, - srcchan: make(chan record, rpu*int(keyCount)*5), + srcChan: make(chan record, rpu*int(keyCount)*5), readTimeout: 3 * time.Second, // default timeout } for _, o := range opts { - if err := o(gensrc); err != nil { + if err := o(genSrc); err != nil { return nil, err } } - if gensrc.logger == nil { - gensrc.logger = logging.NewLogger() + if genSrc.logger == nil { + genSrc.logger = logging.NewLogger() } // this context is to be used internally for controlling the lifecycle of generator cctx, cancel := context.WithCancel(context.Background()) - gensrc.lifecycleCtx = cctx - gensrc.cancel = cancel + genSrc.lifecycleCtx = cctx + genSrc.cancelFn = cancel - forwardOpts := []sourceforward.Option{sourceforward.WithLogger(gensrc.logger)} + forwardOpts := []sourceforward.Option{sourceforward.WithLogger(genSrc.logger)} if x := vertexInstance.Vertex.Spec.Limits; x != nil { if x.ReadBatchSize != nil { forwardOpts = append(forwardOpts, sourceforward.WithReadBatchSize(int64(*x.ReadBatchSize))) @@ -219,28 +218,28 @@ func NewMemGen( } // attach a source publisher so the source can assign the watermarks. - gensrc.sourcePublishWM = gensrc.buildSourceWatermarkPublisher(publishWMStores) + genSrc.sourcePublishWM = genSrc.buildSourceWatermarkPublisher(publishWMStores) - // we pass in the context to forwarder as well so that it can shut down when we cancel the context - forwarder, err := sourceforward.NewDataForward(vertexInstance.Vertex, gensrc, writers, fsd, transformerApplier, fetchWM, gensrc, toVertexPublisherStores, forwardOpts...) + // we pass in the context to forwarder as well so that it can shut down when we cancelFn the context + forwarder, err := sourceforward.NewDataForward(vertexInstance.Vertex, genSrc, writers, fsd, transformerApplier, fetchWM, genSrc, toVertexPublisherStores, forwardOpts...) if err != nil { return nil, err } - gensrc.forwarder = forwarder + genSrc.forwarder = forwarder - return gensrc, nil + return genSrc, nil } func (mg *memgen) buildSourceWatermarkPublisher(publishWMStores store.WatermarkStore) publish.Publisher { // for tickgen, it can be the name of the replica entityName := fmt.Sprintf("%s-%d", mg.vertexInstance.Vertex.Name, mg.vertexInstance.Replica) processorEntity := processor.NewProcessorEntity(entityName) - // source publisher toVertexPartitionCount will be 1, because we publish watermarks within source itself. + // source publisher toVertexPartitionCount will be 1, because we publish watermarks within the source itself. return publish.NewPublish(mg.lifecycleCtx, processorEntity, publishWMStores, 1, publish.IsSource(), publish.WithDelay(mg.vertexInstance.Vertex.Spec.Watermark.GetMaxDelay())) } func (mg *memgen) GetName() string { - return mg.name + return mg.vertexName } // GetPartitionIdx returns the partition number for the source vertex buffer @@ -250,20 +249,20 @@ func (mg *memgen) GetPartitionIdx() int32 { } func (mg *memgen) IsEmpty() bool { - return len(mg.srcchan) == 0 + return len(mg.srcChan) == 0 } -func (mg *memgen) Read(ctx context.Context, count int64) ([]*isb.ReadMessage, error) { +func (mg *memgen) Read(_ context.Context, count int64) ([]*isb.ReadMessage, error) { msgs := make([]*isb.ReadMessage, 0, count) // timeout should not be re-triggered for every run of the for loop. it is for the entire Read() call. timeout := time.After(mg.readTimeout) loop: for i := int64(0); i < count; i++ { - // since the Read call is blocking, and runs in an infinite loop + // since the Read call is blocking, and runs in an infinite loop, // we implement Read With Wait semantics select { - case r := <-mg.srcchan: - tickgenSourceReadCount.With(map[string]string{metrics.LabelVertex: mg.name, metrics.LabelPipeline: mg.pipelineName}).Inc() + case r := <-mg.srcChan: + tickgenSourceReadCount.With(map[string]string{metrics.LabelVertex: mg.vertexName, metrics.LabelPipeline: mg.pipelineName}).Inc() msgs = append(msgs, mg.newReadMessage(r.key, r.data, r.offset)) case <-timeout: mg.logger.Debugw("Timed out waiting for messages to read.", zap.Duration("waited", mg.readTimeout)) @@ -294,7 +293,7 @@ func (mg *memgen) Close() error { } func (mg *memgen) Stop() { - mg.cancel() + mg.cancelFn() mg.forwarder.Stop() } @@ -323,10 +322,10 @@ func (mg *memgen) NewWorker(ctx context.Context, rate int) func(chan time.Time, <-tickChan } close(done) - close(mg.srcchan) + close(mg.srcChan) return case ts := <-tickChan: - tickgenSourceCount.With(map[string]string{metrics.LabelVertex: mg.name, metrics.LabelPipeline: mg.pipelineName}) + tickgenSourceCount.With(map[string]string{metrics.LabelVertex: mg.vertexName, metrics.LabelPipeline: mg.pipelineName}) // we would generate all the keys in a round robin fashion // even if there are multiple pods, all the pods will generate same keys in the same order. // TODO: alternatively, we could also think about generating a subset of keys per pod. @@ -334,13 +333,13 @@ func (mg *memgen) NewWorker(ctx context.Context, rate int) func(chan time.Time, for i := 0; i < rate; i++ { for k := int32(0); k < mg.keyCount; k++ { key := fmt.Sprintf("key-%d-%d", mg.vertexInstance.Replica, k) - payload := mg.genfn(mg.msgSize, mg.value, t) + payload := mg.genFn(mg.msgSize, mg.value, t) r := record{data: payload, offset: time.Now().UTC().UnixNano(), key: key} select { case <-ctx.Done(): log.Info("Context.Done is called. returning from the inner function") return - case mg.srcchan <- r: + case mg.srcChan <- r: } } } @@ -366,7 +365,7 @@ func (mg *memgen) generator(ctx context.Context, rate int, timeunit time.Duratio // make sure that there is only one worker all the time. // even when there is back pressure, max number of go routines inflight should be 1. - // at the same time, we dont want to miss any ticks that cannot be processed. + // at the same time, we don't want to miss any ticks that cannot be processed. worker := mg.NewWorker(childCtx, rate) go worker(tickChan, doneChan) @@ -375,7 +374,7 @@ func (mg *memgen) generator(ctx context.Context, rate int, timeunit time.Duratio for { select { // we don't need to wait for ticker to fire to return - // when context closes + // when the context closes case <-ctx.Done(): log.Info("Context.Done is called. exiting generator loop.") childCancel() @@ -423,7 +422,7 @@ func parseTime(payload []byte) int64 { return 0 } - // for now let's pretend that the time unit is nanos and that the time attribute is known + // for now, let's pretend that the time unit is nanos and that the time attribute is known eventTime := anyJson[timeAttr] if i, ok := eventTime.(float64); ok { return int64(i) diff --git a/pkg/sources/generator/tickgen_test.go b/pkg/sources/generator/tickgen_test.go index 8546762d65..c0e016ba89 100644 --- a/pkg/sources/generator/tickgen_test.go +++ b/pkg/sources/generator/tickgen_test.go @@ -88,7 +88,7 @@ func TestRead(t *testing.T) { // wait for the context to be completely stopped. for { - _, ok := <-mgen.srcchan + _, ok := <-mgen.srcChan if !ok { break } diff --git a/pkg/sources/http/http.go b/pkg/sources/http/http.go index b89094a1d4..0a2a8b0cf0 100644 --- a/pkg/sources/http/http.go +++ b/pkg/sources/http/http.go @@ -19,6 +19,7 @@ package http import ( "context" "crypto/tls" + "errors" "fmt" "io" "net/http" @@ -46,7 +47,7 @@ import ( ) type httpSource struct { - name string + vertexName string pipelineName string ready bool readTimeout time.Duration @@ -97,16 +98,16 @@ func New( opts ...Option) (*httpSource, error) { h := &httpSource{ - name: vertexInstance.Vertex.Spec.Name, + vertexName: vertexInstance.Vertex.Spec.Name, pipelineName: vertexInstance.Vertex.Spec.PipelineName, ready: false, bufferSize: 1000, // default size readTimeout: 1 * time.Second, // default timeout } + for _, o := range opts { - operr := o(h) - if operr != nil { - return nil, operr + if err := o(h); err != nil { + return nil, err } } if h.logger == nil { @@ -185,7 +186,7 @@ func New( } go func() { h.logger.Info("Starting http source server") - if err := server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { + if err := server.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { h.logger.Fatalw("Failed to listen-and-server on http source server", zap.Error(err)) } h.logger.Info("Shutdown http source server") @@ -209,13 +210,13 @@ func New( h.cancelFunc = cancel entityName := fmt.Sprintf("%s-%d", vertexInstance.Vertex.Name, vertexInstance.Replica) processorEntity := processor.NewProcessorEntity(entityName) - // source publisher toVertexPartitionCount will be 1, because we publish watermarks within source itself. + // source publisher toVertexPartitionCount will be 1, because we publish watermarks within the source itself. h.sourcePublishWM = publish.NewPublish(ctx, processorEntity, publishWMStores, 1, publish.IsSource(), publish.WithDelay(vertexInstance.Vertex.Spec.Watermark.GetMaxDelay())) return h, nil } func (h *httpSource) GetName() string { - return h.name + return h.vertexName } // GetPartitionIdx returns the partition number for the source vertex buffer @@ -231,7 +232,7 @@ loop: for i := int64(0); i < count; i++ { select { case m := <-h.messages: - httpSourceReadCount.With(map[string]string{metrics.LabelVertex: h.name, metrics.LabelPipeline: h.pipelineName}).Inc() + httpSourceReadCount.With(map[string]string{metrics.LabelVertex: h.vertexName, metrics.LabelPipeline: h.pipelineName}).Inc() msgs = append(msgs, m) case <-timeout: h.logger.Debugw("Timed out waiting for messages to read.", zap.Duration("waited", h.readTimeout), zap.Int("read", len(msgs))) @@ -251,7 +252,7 @@ func (h *httpSource) PublishSourceWatermarks(msgs []*isb.ReadMessage) { } if len(msgs) > 0 && !oldest.IsZero() { h.logger.Debugf("Publishing watermark %v to source", oldest) - // toVertexPartitionIdx is 0, because we publish watermarks within source itself. + // toVertexPartitionIdx is 0, because we publish watermarks within the source itself. h.sourcePublishWM.PublishWatermark(wmb.Watermark(oldest), nil, 0) // Source publisher does not care about the offset } } diff --git a/pkg/sources/kafka/reader.go b/pkg/sources/kafka/reader.go index cd7e2e57c7..e25aade959 100644 --- a/pkg/sources/kafka/reader.go +++ b/pkg/sources/kafka/reader.go @@ -42,7 +42,7 @@ import ( type KafkaSource struct { // name of the source vertex - name string + vertexName string // name of the pipeline pipelineName string // group name for the source vertex @@ -54,19 +54,19 @@ type KafkaSource struct { // forwarder that writes the consumed data to destination forwarder *sourceforward.DataForward // context cancel function - cancelfn context.CancelFunc + cancelFn context.CancelFunc // lifecycle context - lifecyclectx context.Context - // consumergroup handler for kafka consumer group + lifecycleCtx context.Context + // handler for a kafka consumer group handler *consumerHandler // sarama config for kafka consumer group config *sarama.Config // logger logger *zap.SugaredLogger - // channel to inidcate that we are done - stopch chan struct{} + // channel to indicate that we are done + stopCh chan struct{} // size of the buffer that holds consumed but yet to be forwarded messages - handlerbuffer int + handlerBuffer int // read timeout for the from buffer readTimeout time.Duration // client used to calculate pending messages @@ -129,7 +129,7 @@ func WithLogger(l *zap.SugaredLogger) Option { // WithBufferSize is used to return size of message channel information func WithBufferSize(s int) Option { return func(o *KafkaSource) error { - o.handlerbuffer = s + o.handlerBuffer = s return nil } } @@ -151,7 +151,7 @@ func WithGroupName(gn string) Option { } func (r *KafkaSource) GetName() string { - return r.name + return r.vertexName } // GetPartitionIdx returns the partition number for the source vertex buffer @@ -167,7 +167,7 @@ loop: for i := int64(0); i < count; i++ { select { case m := <-r.handler.messages: - kafkaSourceReadCount.With(map[string]string{metrics.LabelVertex: r.name, metrics.LabelPipeline: r.pipelineName}).Inc() + kafkaSourceReadCount.With(map[string]string{metrics.LabelVertex: r.vertexName, metrics.LabelPipeline: r.pipelineName}).Inc() msgs = append(msgs, toReadMessage(m)) case <-timeout: // log that timeout has happened and don't return an error @@ -189,7 +189,7 @@ func (r *KafkaSource) PublishSourceWatermarks(msgs []*isb.ReadMessage) { } for p, t := range oldestTimestamps { publisher := r.loadSourceWatermarkPublisher(p) - // toVertexPartitionIdx is 0 because we publish watermarks within source itself. + // toVertexPartitionIdx is 0 because we publish watermarks within the source itself. publisher.PublishWatermark(wmb.Watermark(t), nil, 0) // Source publisher does not care about the offset } } @@ -201,10 +201,10 @@ func (r *KafkaSource) loadSourceWatermarkPublisher(partitionID int32) publish.Pu if p, ok := r.sourcePublishWMs[partitionID]; ok { return p } - entityName := fmt.Sprintf("%s-%s-%d", r.pipelineName, r.name, partitionID) + entityName := fmt.Sprintf("%s-%s-%d", r.pipelineName, r.vertexName, partitionID) processorEntity := processor.NewProcessorEntity(entityName) - // toVertexPartitionCount is 1 because we publish watermarks within source itself. - sourcePublishWM := publish.NewPublish(r.lifecyclectx, processorEntity, r.srcPublishWMStores, 1, publish.IsSource(), publish.WithDelay(r.watermarkMaxDelay)) + // toVertexPartitionCount is 1 because we publish watermarks within the source itself. + sourcePublishWM := publish.NewPublish(r.lifecycleCtx, processorEntity, r.srcPublishWMStores, 1, publish.IsSource(), publish.WithDelay(r.watermarkMaxDelay)) r.sourcePublishWMs[partitionID] = sourcePublishWM return sourcePublishWM } @@ -221,12 +221,12 @@ func (r *KafkaSource) Ack(_ context.Context, offsets []isb.Offset) []error { // we need to mark the offset of the next message to read pOffset, err := offset.Sequence() if err != nil { - kafkaSourceOffsetAckErrors.With(map[string]string{metrics.LabelVertex: r.name, metrics.LabelPipeline: r.pipelineName}).Inc() - r.logger.Errorw("Unable to extract partition offset of type int64 from the supplied offset. skipping and continuing", zap.String("suppliedoffset", offset.String()), zap.Error(err)) + kafkaSourceOffsetAckErrors.With(map[string]string{metrics.LabelVertex: r.vertexName, metrics.LabelPipeline: r.pipelineName}).Inc() + r.logger.Errorw("Unable to extract partition offset of type int64 from the supplied offset. skipping and continuing", zap.String("supplied-offset", offset.String()), zap.Error(err)) continue } r.handler.sess.MarkOffset(topic, offset.PartitionIdx(), pOffset, "") - kafkaSourceAckCount.With(map[string]string{metrics.LabelVertex: r.name, metrics.LabelPipeline: r.pipelineName}).Inc() + kafkaSourceAckCount.With(map[string]string{metrics.LabelVertex: r.vertexName, metrics.LabelPipeline: r.pipelineName}).Inc() } // How come it does not return errors at all? @@ -277,19 +277,19 @@ func (r *KafkaSource) ForceStop() { func (r *KafkaSource) Close() error { r.logger.Info("Closing kafka reader...") // finally, shut down the client - r.cancelfn() + r.cancelFn() if r.adminClient != nil { // closes the underlying sarama client as well. if err := r.adminClient.Close(); err != nil { r.logger.Errorw("Error in closing kafka admin client", zap.Error(err)) } } - <-r.stopch + <-r.stopCh r.logger.Info("Kafka reader closed") return nil } -func (r *KafkaSource) Pending(ctx context.Context) (int64, error) { +func (r *KafkaSource) Pending(_ context.Context) (int64, error) { if r.adminClient == nil || r.saramaClient == nil { return isb.PendingNotAvailable, nil } @@ -320,11 +320,11 @@ func (r *KafkaSource) Pending(ctx context.Context) (int64, error) { } totalPending += partitionOffset - block.Offset } - kafkaPending.WithLabelValues(r.name, r.pipelineName, r.topic, r.groupName).Set(float64(totalPending)) + kafkaPending.WithLabelValues(r.vertexName, r.pipelineName, r.topic, r.groupName).Set(float64(totalPending)) return totalPending, nil } -// NewKafkaSource returns a KafkaSource reader based on Kafka Consumer Group . +// NewKafkaSource returns a KafkaSource reader based on Kafka Consumer Group. func NewKafkaSource( vertexInstance *dfv1.VertexInstance, writers map[string][]isb.BufferWriter, @@ -336,24 +336,23 @@ func NewKafkaSource( opts ...Option) (*KafkaSource, error) { source := vertexInstance.Vertex.Spec.Source.Kafka - kafkasource := &KafkaSource{ - name: vertexInstance.Vertex.Spec.Name, + kafkaSource := &KafkaSource{ + vertexName: vertexInstance.Vertex.Spec.Name, pipelineName: vertexInstance.Vertex.Spec.PipelineName, topic: source.Topic, brokers: source.Brokers, readTimeout: 1 * time.Second, // default timeout - handlerbuffer: 100, // default buffer size for kafka reads + handlerBuffer: 100, // default buffer size for kafka reads srcPublishWMStores: publishWMStores, - sourcePublishWMs: make(map[int32]publish.Publisher, 0), + sourcePublishWMs: make(map[int32]publish.Publisher), watermarkMaxDelay: vertexInstance.Vertex.Spec.Watermark.GetMaxDelay(), lock: new(sync.RWMutex), logger: logging.NewLogger(), // default logger } for _, o := range opts { - operr := o(kafkasource) - if operr != nil { - return nil, operr + if err := o(kafkaSource); err != nil { + return nil, err } } @@ -380,34 +379,34 @@ func NewKafkaSource( } } - sarama.Logger = zap.NewStdLog(kafkasource.logger.Desugar()) + sarama.Logger = zap.NewStdLog(kafkaSource.logger.Desugar()) // return errors from the underlying kafka client using the Errors channel config.Consumer.Return.Errors = true - kafkasource.config = config + kafkaSource.config = config ctx, cancel := context.WithCancel(context.Background()) - kafkasource.cancelfn = cancel - kafkasource.lifecyclectx = ctx + kafkaSource.cancelFn = cancel + kafkaSource.lifecycleCtx = ctx - kafkasource.stopch = make(chan struct{}) + kafkaSource.stopCh = make(chan struct{}) - handler := newConsumerHandler(kafkasource.handlerbuffer) - kafkasource.handler = handler + handler := newConsumerHandler(kafkaSource.handlerBuffer) + kafkaSource.handler = handler - forwardOpts := []sourceforward.Option{sourceforward.WithLogger(kafkasource.logger)} + forwardOpts := []sourceforward.Option{sourceforward.WithLogger(kafkaSource.logger)} if x := vertexInstance.Vertex.Spec.Limits; x != nil { if x.ReadBatchSize != nil { forwardOpts = append(forwardOpts, sourceforward.WithReadBatchSize(int64(*x.ReadBatchSize))) } } - forwarder, err := sourceforward.NewDataForward(vertexInstance.Vertex, kafkasource, writers, fsd, transformerApplier, fetchWM, kafkasource, toVertexPublisherStores, forwardOpts...) + forwarder, err := sourceforward.NewDataForward(vertexInstance.Vertex, kafkaSource, writers, fsd, transformerApplier, fetchWM, kafkaSource, toVertexPublisherStores, forwardOpts...) if err != nil { - kafkasource.logger.Errorw("Error instantiating the forwarder", zap.Error(err)) + kafkaSource.logger.Errorw("Error instantiating the forwarder", zap.Error(err)) return nil, err } - kafkasource.forwarder = forwarder - return kafkasource, nil + kafkaSource.forwarder = forwarder + return kafkaSource, nil } // refreshAdminClient refreshes the admin client @@ -416,7 +415,8 @@ func (r *KafkaSource) refreshAdminClient() error { return fmt.Errorf("failed to refresh controller, %w", err) } // we are not closing the old admin client because it will close the underlying sarama client as well - // it is safe to not close the admin client, since we are using the same sarama client we will not leak any resources(tcp connections) + // it is safe to not close the admin client, + // since we are using the same sarama client, we will not leak any resources(tcp connections) admin, err := sarama.NewClusterAdminFromClient(r.saramaClient) if err != nil { return fmt.Errorf("failed to create new admin client, %w", err) @@ -425,8 +425,8 @@ func (r *KafkaSource) refreshAdminClient() error { return nil } -func configFromOpts(yamlconfig string) (*sarama.Config, error) { - config, err := sharedutil.GetSaramaConfigFromYAMLString(yamlconfig) +func configFromOpts(yamlConfig string) (*sarama.Config, error) { + config, err := sharedutil.GetSaramaConfigFromYAMLString(yamlConfig) if err != nil { return nil, err } @@ -447,7 +447,7 @@ func (r *KafkaSource) startConsumer() { defer wg.Done() for { select { - case <-r.lifecyclectx.Done(): + case <-r.lifecycleCtx.Done(): return case cErr := <-client.Errors(): r.logger.Errorw("Kafka consumer error", zap.Error(cErr)) @@ -459,22 +459,22 @@ func (r *KafkaSource) startConsumer() { go func() { defer wg.Done() for { - // `Consume` should be called inside an infinite loop, when a - // server-side rebalance happens, the consumer session will need to be + // `Consume` should be called inside an infinite loop; when a + // server-side re-balance happens, the consumer session will need to be // recreated to get the new claims - if conErr := client.Consume(r.lifecyclectx, []string{r.topic}, r.handler); conErr != nil { + if conErr := client.Consume(r.lifecycleCtx, []string{r.topic}, r.handler); conErr != nil { // Panic on errors to let it crash and restart the process r.logger.Panicw("Kafka consumer failed with error: ", zap.Error(conErr)) } // check if context was cancelled, signaling that the consumer should stop - if r.lifecyclectx.Err() != nil { + if r.lifecycleCtx.Err() != nil { return } } }() wg.Wait() _ = client.Close() - close(r.stopch) + close(r.stopCh) } func toReadMessage(m *sarama.ConsumerMessage) *isb.ReadMessage { diff --git a/pkg/sources/kafka/reader_test.go b/pkg/sources/kafka/reader_test.go index f0277a4ff9..631beff7c1 100644 --- a/pkg/sources/kafka/reader_test.go +++ b/pkg/sources/kafka/reader_test.go @@ -68,7 +68,7 @@ func TestNewKafkasource(t *testing.T) { // config is all set and initialized correctly assert.NotNil(t, ks.config) - assert.Equal(t, 100, ks.handlerbuffer) + assert.Equal(t, 100, ks.handlerBuffer) assert.Equal(t, 100*time.Millisecond, ks.readTimeout) assert.Equal(t, 100, cap(ks.handler.messages)) assert.NotNil(t, ks.forwarder) @@ -136,7 +136,7 @@ func TestDefaultBufferSize(t *testing.T) { } ks, _ := NewKafkaSource(vi, toBuffers, myForwardToAllTest{}, applier.Terminal, fetchWatermark, toVertexWmStores, publishWMStore, WithLogger(logging.NewLogger()), WithReadTimeOut(100*time.Millisecond), WithGroupName("default")) - assert.Equal(t, 100, ks.handlerbuffer) + assert.Equal(t, 100, ks.handlerBuffer) } @@ -169,6 +169,6 @@ func TestBufferSizeOverrides(t *testing.T) { } ks, _ := NewKafkaSource(vi, toBuffers, myForwardToAllTest{}, applier.Terminal, fetchWatermark, toVertexWmStores, publishWMStore, WithLogger(logging.NewLogger()), WithBufferSize(110), WithReadTimeOut(100*time.Millisecond), WithGroupName("default")) - assert.Equal(t, 110, ks.handlerbuffer) + assert.Equal(t, 110, ks.handlerBuffer) } diff --git a/pkg/sources/nats/nats.go b/pkg/sources/nats/nats.go index 8c790630e5..d61be5801c 100644 --- a/pkg/sources/nats/nats.go +++ b/pkg/sources/nats/nats.go @@ -53,7 +53,7 @@ type natsSource struct { messages chan *isb.ReadMessage readTimeout time.Duration - cancelfn context.CancelFunc + cancelFn context.CancelFunc forwarder *sourceforward.DataForward // source watermark publisher sourcePublishWM publish.Publisher @@ -76,9 +76,8 @@ func New( readTimeout: 1 * time.Second, // default timeout } for _, o := range opts { - operr := o(n) - if operr != nil { - return nil, operr + if err := o(n); err != nil { + return nil, err } } if n.logger == nil { @@ -99,10 +98,10 @@ func New( } n.forwarder = forwarder ctx, cancel := context.WithCancel(context.Background()) - n.cancelfn = cancel + n.cancelFn = cancel entityName := fmt.Sprintf("%s-%d", vertexInstance.Vertex.Name, vertexInstance.Replica) processorEntity := processor.NewProcessorEntity(entityName) - // toVertexPartitionCount is 1 because we publish watermarks within source itself. + // toVertexPartitionCount is 1 because we publish watermarks within the source itself. n.sourcePublishWM = publish.NewPublish(ctx, processorEntity, publishWMStores, 1, publish.IsSource(), publish.WithDelay(vertexInstance.Vertex.Spec.Watermark.GetMaxDelay())) source := vertexInstance.Vertex.Spec.Source.Nats @@ -249,7 +248,7 @@ func (ns *natsSource) PublishSourceWatermarks(msgs []*isb.ReadMessage) { } } if len(msgs) > 0 && !oldest.IsZero() { - // toVertexPartitionIdx is 0 because we publish watermarks within source itself. + // toVertexPartitionIdx is 0 because we publish watermarks within the source itself. ns.sourcePublishWM.PublishWatermark(wmb.Watermark(oldest), nil, 0) // Source publisher does not care about the offset } } @@ -262,7 +261,7 @@ func (ns *natsSource) NoAck(_ context.Context, _ []isb.Offset) {} func (ns *natsSource) Close() error { ns.logger.Info("Shutting down nats source server...") - ns.cancelfn() + ns.cancelFn() if err := ns.sub.Unsubscribe(); err != nil { ns.logger.Errorw("Failed to unsubscribe nats subscription", zap.Error(err)) } diff --git a/pkg/sources/redisstreams/redisstream.go b/pkg/sources/redisstreams/redisstream.go index ad8c3475af..b3730f6ace 100644 --- a/pkg/sources/redisstreams/redisstream.go +++ b/pkg/sources/redisstreams/redisstream.go @@ -54,7 +54,7 @@ type redisStreamsSource struct { // forwarder that writes the consumed data to destination forwarder *sourceforward.DataForward // context cancel function - cancelfn context.CancelFunc + cancelFn context.CancelFunc // source watermark publisher sourcePublishWM publish.Publisher } @@ -131,9 +131,8 @@ func New( } for _, o := range opts { - operr := o(redisStreamsSource) - if operr != nil { - return nil, operr + if err := o(redisStreamsSource); err != nil { + return nil, err } } if redisStreamsReader.Log == nil { @@ -156,10 +155,10 @@ func New( // Create Watermark Publisher ctx, cancel := context.WithCancel(context.Background()) - redisStreamsSource.cancelfn = cancel + redisStreamsSource.cancelFn = cancel entityName := fmt.Sprintf("%s-%d", vertexInstance.Vertex.Name, vertexInstance.Replica) processorEntity := processor.NewProcessorEntity(entityName) - // toVertexPartitionCount is 1 because we publish watermarks within source itself. + // toVertexPartitionCount is 1 because we publish watermarks within the source itself. redisStreamsSource.sourcePublishWM = publish.NewPublish(ctx, processorEntity, publishWMStores, 1, publish.IsSource(), publish.WithDelay(vertexInstance.Vertex.Spec.Watermark.GetMaxDelay())) // create the ConsumerGroup here if not already created @@ -231,7 +230,7 @@ func produceMsg(inMsg redis.XMessage, replica int32) (*isb.ReadMessage, error) { if err != nil { return nil, fmt.Errorf("failed to json serialize RedisStream values: %v; inMsg=%+v", err, inMsg) } - keys := []string{} + var keys []string for k := range inMsg.Values { keys = append(keys, k) } @@ -246,7 +245,7 @@ func produceMsg(inMsg redis.XMessage, replica int32) (*isb.ReadMessage, error) { ID: inMsg.ID, Keys: keys, }, - Body: isb.Body{Payload: []byte(jsonSerialized)}, + Body: isb.Body{Payload: jsonSerialized}, } return &isb.ReadMessage{ @@ -298,14 +297,14 @@ func (rsSource *redisStreamsSource) PublishSourceWatermarks(msgs []*isb.ReadMess } } if len(msgs) > 0 && !oldest.IsZero() { - // toVertexPartitionIdx is 0 because we publish watermarks within source itself. + // toVertexPartitionIdx is 0 because we publish watermarks within the source itself. rsSource.sourcePublishWM.PublishWatermark(wmb.Watermark(oldest), nil, 0) // Source publisher does not care about the offset } } func (rsSource *redisStreamsSource) Close() error { rsSource.Log.Info("Shutting down redis source server...") - rsSource.cancelfn() + rsSource.cancelFn() rsSource.Log.Info("Redis source server shutdown") return nil } diff --git a/pkg/sources/source.go b/pkg/sources/source.go index 26dfe73e3a..84a9101c95 100644 --- a/pkg/sources/source.go +++ b/pkg/sources/source.go @@ -285,8 +285,7 @@ func (sp *SourceProcessor) getSourcer( if l := sp.VertexInstance.Vertex.Spec.Limits; l != nil && l.ReadTimeout != nil { readOptions = append(readOptions, udsource.WithReadTimeout(l.ReadTimeout.Duration)) } - udsource, err := udsource.New(sp.VertexInstance, writers, fsd, transformerApplier, udsGRPCClient, fetchWM, toVertexPublisherStores, publishWMStores, readOptions...) - return udsource, err + return udsource.New(sp.VertexInstance, writers, fsd, transformerApplier, udsGRPCClient, fetchWM, toVertexPublisherStores, publishWMStores, readOptions...) } else if x := src.Generator; x != nil { readOptions := []generator.Option{ generator.WithLogger(logger), @@ -361,7 +360,7 @@ func (sp *SourceProcessor) getTransformerGoWhereDecider(shuffleFuncMap map[strin } for _, edge := range sp.VertexInstance.Vertex.Spec.ToEdges { - // If returned tags is not "DROP", and there's no conditions defined in the edge, treat it as "ALL"? + // If returned tags are not "DROP", and there are no conditions defined in the edge, treat it as "ALL". if edge.Conditions == nil || edge.Conditions.Tags == nil || len(edge.Conditions.Tags.Values) == 0 { if edge.ToVertexType == dfv1.VertexTypeReduceUDF && edge.GetToVertexPartitionCount() > 1 { // Need to shuffle toVertexPartition := shuffleFuncMap[fmt.Sprintf("%s:%s", edge.From, edge.To)].Shuffle(keys) diff --git a/pkg/sources/udsource/grpc_udsource_test.go b/pkg/sources/udsource/grpc_udsource_test.go index 7662a2a7c3..fcfacb1acb 100644 --- a/pkg/sources/udsource/grpc_udsource_test.go +++ b/pkg/sources/udsource/grpc_udsource_test.go @@ -25,6 +25,7 @@ import ( "time" sourcepb "github.com/numaproj/numaflow-go/pkg/apis/proto/source/v1" + "go.uber.org/goleak" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -40,6 +41,10 @@ import ( "github.com/stretchr/testify/assert" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + type rpcMsg struct { msg proto.Message } diff --git a/pkg/sources/udsource/user_defined_source.go b/pkg/sources/udsource/user_defined_source.go index dcfc61199e..b90c2ee42f 100644 --- a/pkg/sources/udsource/user_defined_source.go +++ b/pkg/sources/udsource/user_defined_source.go @@ -56,8 +56,8 @@ func WithReadTimeout(t time.Duration) Option { } type userDefinedSource struct { - // name of the user-defined source - name string + // name of the user-defined source vertex + vertexName string // name of the pipeline pipelineName string // sourceApplier applies the user-defined source functions @@ -92,7 +92,7 @@ func New( opts ...Option) (*userDefinedSource, error) { u := &userDefinedSource{ - name: vertexInstance.Vertex.Spec.Name, + vertexName: vertexInstance.Vertex.Spec.Name, pipelineName: vertexInstance.Vertex.Spec.PipelineName, sourceApplier: sourceApplier, srcPublishWMStores: publishWMStores, @@ -126,7 +126,7 @@ func New( // GetName returns the name of the user-defined source vertex func (u *userDefinedSource) GetName() string { - return u.name + return u.vertexName } // GetPartitionIdx returns the partition number for the user-defined source. @@ -200,7 +200,7 @@ func (u *userDefinedSource) loadSourceWatermarkPublisher(partitionID int32) publ if p, ok := u.srcWMPublishers[partitionID]; ok { return p } - entityName := fmt.Sprintf("%s-%s-%d", u.pipelineName, u.name, partitionID) + entityName := fmt.Sprintf("%s-%s-%d", u.pipelineName, u.vertexName, partitionID) processorEntity := processor.NewProcessorEntity(entityName) // toVertexPartitionCount is 1 because we publish watermarks within the source itself. sourcePublishWM := publish.NewPublish(u.lifecycleCtx, processorEntity, u.srcPublishWMStores, 1, publish.IsSource()) diff --git a/pkg/udf/rpc/grpc_map_test.go b/pkg/udf/rpc/grpc_map_test.go index d0d4dc3743..05fd5d5c2d 100644 --- a/pkg/udf/rpc/grpc_map_test.go +++ b/pkg/udf/rpc/grpc_map_test.go @@ -18,7 +18,7 @@ package rpc import ( "context" - "encoding/json" + "errors" "fmt" "testing" "time" @@ -28,14 +28,12 @@ import ( "github.com/numaproj/numaflow-go/pkg/apis/proto/map/v1/mapmock" "github.com/stretchr/testify/assert" "go.uber.org/goleak" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/numaproj/numaflow/pkg/isb" - "github.com/numaproj/numaflow/pkg/isb/testutils" "github.com/numaproj/numaflow/pkg/sdkclient/mapper" ) @@ -75,7 +73,7 @@ func TestGRPCBasedMap_WaitUntilReadyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -111,7 +109,7 @@ func TestGRPCBasedMap_BasicApplyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -163,7 +161,7 @@ func TestGRPCBasedMap_BasicApplyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -213,7 +211,7 @@ func TestGRPCBasedMap_BasicApplyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -270,7 +268,7 @@ func TestGRPCBasedMap_BasicApplyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -317,7 +315,7 @@ func TestGRPCBasedMap_BasicApplyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -349,80 +347,3 @@ func TestGRPCBasedMap_BasicApplyWithMockClient(t *testing.T) { }) }) } - -func TestHGRPCBasedUDF_ApplyWithMockClient(t *testing.T) { - multiplyBy2 := func(body []byte) interface{} { - var result testutils.PayloadForTest - _ = json.Unmarshal(body, &result) - result.Value = result.Value * 2 - return result - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mapmock.NewMockMapClient(ctrl) - mockClient.EXPECT().MapFn(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, datum *mappb.MapRequest, opts ...grpc.CallOption) (*mappb.MapResponse, error) { - var originalValue testutils.PayloadForTest - _ = json.Unmarshal(datum.GetValue(), &originalValue) - doubledValue, _ := json.Marshal(multiplyBy2(datum.GetValue()).(testutils.PayloadForTest)) - var results []*mappb.MapResponse_Result - if originalValue.Value%2 == 0 { - results = append(results, &mappb.MapResponse_Result{ - Keys: []string{"even"}, - Value: doubledValue, - }) - } else { - results = append(results, &mappb.MapResponse_Result{ - Keys: []string{"odd"}, - Value: doubledValue, - }) - } - datumList := &mappb.MapResponse{ - Results: results, - } - return datumList, nil - }, - ).AnyTimes() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockUDSGRPCBasedMap(mockClient) - - var count = int64(10) - readMessages := testutils.BuildTestReadMessages(count, time.Unix(1661169600, 0)) - - var results = make([][]byte, len(readMessages)) - var resultKeys = make([][]string, len(readMessages)) - for idx, readMessage := range readMessages { - apply, err := u.ApplyMap(ctx, &readMessage) - assert.NoError(t, err) - results[idx] = apply[0].Payload - resultKeys[idx] = apply[0].Header.Keys - } - - var expectedResults = make([][]byte, count) - var expectedKeys = make([][]string, count) - for idx, readMessage := range readMessages { - var readMessagePayload testutils.PayloadForTest - _ = json.Unmarshal(readMessage.Payload, &readMessagePayload) - if readMessagePayload.Value%2 == 0 { - expectedKeys[idx] = []string{"even"} - } else { - expectedKeys[idx] = []string{"odd"} - } - marshal, _ := json.Marshal(multiplyBy2(readMessage.Payload)) - expectedResults[idx] = marshal - } - - assert.Equal(t, expectedResults, results) - assert.Equal(t, expectedKeys, resultKeys) -} diff --git a/pkg/udf/rpc/grpc_mapstream_test.go b/pkg/udf/rpc/grpc_mapstream_test.go index 394afcb455..186cdfdd96 100644 --- a/pkg/udf/rpc/grpc_mapstream_test.go +++ b/pkg/udf/rpc/grpc_mapstream_test.go @@ -18,7 +18,6 @@ package rpc import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -32,7 +31,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/numaproj/numaflow/pkg/isb" - "github.com/numaproj/numaflow/pkg/isb/testutils" "github.com/numaproj/numaflow/pkg/sdkclient/mapstreamer" ) @@ -52,7 +50,7 @@ func TestGRPCBasedMapStream_WaitUntilReadyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -91,7 +89,7 @@ func TestGRPCBasedUDF_BasicApplyStreamWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -155,7 +153,7 @@ func TestGRPCBasedUDF_BasicApplyStreamWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -188,104 +186,3 @@ func TestGRPCBasedUDF_BasicApplyStreamWithMockClient(t *testing.T) { }) }) } - -func TestHGRPCBasedUDF_ApplyStreamWithMockClient(t *testing.T) { - multiplyBy2 := func(body []byte) interface{} { - var result testutils.PayloadForTest - _ = json.Unmarshal(body, &result) - result.Value = result.Value * 2 - return result - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - var count = int64(1) - readMessages := testutils.BuildTestReadMessages(count, time.Unix(1661169600, 0)) - - mockClient := mapstreammock.NewMockMapStreamClient(ctrl) - mockMapStreamClient := mapstreammock.NewMockMapStream_MapStreamFnClient(ctrl) - for _, message := range readMessages { - keys := message.Keys - payload := message.Body.Payload - parentMessageInfo := message.MessageInfo - var datum = &mapstreampb.MapStreamRequest{ - Keys: keys, - Value: payload, - EventTime: timestamppb.New(parentMessageInfo.EventTime), - Watermark: timestamppb.New(message.Watermark), - } - mockMapStreamClient.EXPECT().Recv().DoAndReturn( - func() (*mapstreampb.MapStreamResponse, error) { - var originalValue testutils.PayloadForTest - _ = json.Unmarshal(datum.GetValue(), &originalValue) - doubledValue, _ := json.Marshal(multiplyBy2(datum.GetValue()).(testutils.PayloadForTest)) - var element *mapstreampb.MapStreamResponse_Result - if originalValue.Value%2 == 0 { - element = &mapstreampb.MapStreamResponse_Result{ - Keys: []string{"even"}, - Value: doubledValue, - } - } else { - element = &mapstreampb.MapStreamResponse_Result{ - Keys: []string{"odd"}, - Value: doubledValue, - } - } - - response := &mapstreampb.MapStreamResponse{ - Result: element, - } - - return response, nil - }, - ).Times(1) - } - mockMapStreamClient.EXPECT().Recv().Return(nil, io.EOF).Times(1) - - mockClient.EXPECT().MapStreamFn(gomock.Any(), gomock.Any()).Return(mockMapStreamClient, nil).AnyTimes() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockUDSGRPCBasedMapStream(mockClient) - - var results = make([][]byte, len(readMessages)) - var resultKeys = make([][]string, len(readMessages)) - idx := 0 - for _, readMessage := range readMessages { - writeMessageCh := make(chan isb.WriteMessage) - go func() { - err := u.ApplyMapStream(ctx, &readMessage, writeMessageCh) - assert.NoError(t, err) - }() - for m := range writeMessageCh { - results[idx] = m.Payload - resultKeys[idx] = m.Header.Keys - idx++ - } - } - - var expectedResults = make([][]byte, count) - var expectedKeys = make([][]string, count) - for idx, readMessage := range readMessages { - var readMessagePayload testutils.PayloadForTest - _ = json.Unmarshal(readMessage.Payload, &readMessagePayload) - if readMessagePayload.Value%2 == 0 { - expectedKeys[idx] = []string{"even"} - } else { - expectedKeys[idx] = []string{"odd"} - } - marshal, _ := json.Marshal(multiplyBy2(readMessage.Payload)) - expectedResults[idx] = marshal - } - - assert.Equal(t, expectedResults, results) - assert.Equal(t, expectedKeys, resultKeys) -} diff --git a/pkg/udf/rpc/grpc_reduce_test.go b/pkg/udf/rpc/grpc_reduce_test.go index 4aa64b3fce..7247f8079b 100644 --- a/pkg/udf/rpc/grpc_reduce_test.go +++ b/pkg/udf/rpc/grpc_reduce_test.go @@ -18,7 +18,6 @@ package rpc import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -52,7 +51,7 @@ func TestGRPCBasedReduce_WaitUntilReadyWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -98,7 +97,7 @@ func TestGRPCBasedUDF_BasicReduceWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -148,7 +147,7 @@ func TestGRPCBasedUDF_BasicReduceWithMockClient(t *testing.T) { defer cancel() go func() { <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { t.Log(t.Name(), "test timeout") } }() @@ -228,87 +227,3 @@ func TestGRPCBasedUDF_BasicReduceWithMockClient(t *testing.T) { assert.Error(t, err, ctx.Err()) }) } - -func TestHGRPCBasedUDF_Reduce(t *testing.T) { - sumFunc := func(dataStreamCh <-chan *reducepb.ReduceRequest) interface{} { - var sum testutils.PayloadForTest - for datum := range dataStreamCh { - var payLoad testutils.PayloadForTest - _ = json.Unmarshal(datum.GetValue(), &payLoad) - sum.Value += payLoad.Value - } - return sum - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - messageCh := make(chan *isb.ReadMessage, 10) - datumStreamCh := make(chan *reducepb.ReduceRequest, 10) - messages := testutils.BuildTestReadMessages(10, time.Now()) - - go func() { - for index := range messages { - messageCh <- &messages[index] - datumStreamCh <- createDatum(&messages[index]) - } - close(messageCh) - close(datumStreamCh) - }() - - mockClient := reducemock.NewMockReduceClient(ctrl) - mockReduceClient := reducemock.NewMockReduce_ReduceFnClient(ctrl) - mockReduceClient.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() - mockReduceClient.EXPECT().CloseSend().Return(nil).AnyTimes() - mockReduceClient.EXPECT().Recv().DoAndReturn( - func() (*reducepb.ReduceResponse, error) { - result := sumFunc(datumStreamCh) - sumValue, _ := json.Marshal(result.(testutils.PayloadForTest)) - var Results []*reducepb.ReduceResponse_Result - Results = append(Results, &reducepb.ReduceResponse_Result{ - Keys: []string{"sum"}, - Value: sumValue, - }) - datumList := &reducepb.ReduceResponse{ - Results: Results, - } - return datumList, nil - }).Times(1) - mockReduceClient.EXPECT().Recv().Return(&reducepb.ReduceResponse{ - Results: []*reducepb.ReduceResponse_Result{ - { - Keys: []string{"reduced_result_key"}, - Value: []byte(`forward_message`), - }, - }, - }, io.EOF).Times(1) - - mockClient.EXPECT().ReduceFn(gomock.Any(), gomock.Any()).Return(mockReduceClient, nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockUDSGRPCBasedReduce(mockClient) - - partitionID := &partition.ID{ - Start: time.Unix(60, 0), - End: time.Unix(120, 0), - Slot: "test", - } - - result, err := u.ApplyReduce(ctx, partitionID, messageCh) - - var resultPayload testutils.PayloadForTest - _ = json.Unmarshal(result[0].Payload, &resultPayload) - - assert.NoError(t, err) - assert.Equal(t, []string{"sum"}, result[0].Keys) - assert.Equal(t, int64(45), resultPayload.Value) - assert.Equal(t, time.Unix(120, 0).Add(-1*time.Millisecond), result[0].EventTime) -}