Skip to content

Commit

Permalink
concurently scan the filesystem source
Browse files Browse the repository at this point in the history
Co-authored-by: Miccah Castorina <[email protected]>
  • Loading branch information
ahrav and mcastorina committed Feb 1, 2024
1 parent c2ae31d commit 40bc33e
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions pkg/sources/filesystem/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/go-errors/errors"
"github.com/go-logr/logr"
diskbufferreader "github.com/trufflesecurity/disk-buffer-reader"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"

Expand All @@ -26,13 +27,14 @@ import (
const SourceType = sourcespb.SourceType_SOURCE_TYPE_FILESYSTEM

type Source struct {
name string
sourceId sources.SourceID
jobId sources.JobID
verify bool
paths []string
log logr.Logger
filter *common.Filter
name string
sourceId sources.SourceID
jobId sources.JobID
concurrency int
verify bool
paths []string
log logr.Logger
filter *common.Filter
sources.Progress
sources.CommonSourceUnitUnmarshaller
}
Expand All @@ -57,9 +59,10 @@ func (s *Source) JobID() sources.JobID {
}

// Init returns an initialized Filesystem source.
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, _ int) error {
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
s.log = aCtx.Logger()

s.concurrency = concurrency
s.name = name
s.sourceId = sourceId
s.jobId = jobId
Expand All @@ -82,6 +85,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, so

// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {

for i, path := range s.paths {
logger := ctx.Logger().WithValues("path", path)
if common.IsDone(ctx) {
Expand All @@ -93,7 +97,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
fileInfo, err := os.Stat(cleanPath)
if err != nil {
logger.Error(err, "unable to get file info")
continue
return nil
}

if fileInfo.IsDir() {
Expand All @@ -102,16 +106,22 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
err = s.scanFile(ctx, cleanPath, chunksChan)
}

if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
logger.Info("error scanning filesystem", "error", err)
}
}

return nil
}

func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sources.Chunk) error {
workerPool := new(errgroup.Group)
workerPool.SetLimit(s.concurrency)
defer func() { _ = workerPool.Wait() }()

return fs.WalkDir(os.DirFS(path), ".", func(relativePath string, d fs.DirEntry, err error) error {
if err != nil {
ctx.Logger().Error(err, "error walking directory")
return nil
}
fullPath := filepath.Join(path, relativePath)
Expand All @@ -126,9 +136,13 @@ func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sour
return nil
}

if err = s.scanFile(ctx, fullPath, chunksChan); err != nil {
ctx.Logger().Info("error scanning file", "path", fullPath, "error", err)
}
workerPool.Go(func() error {
if err = s.scanFile(ctx, fullPath, chunksChan); err != nil {
ctx.Logger().Info("error scanning file", "path", fullPath, "error", err)
}
return nil
})

return nil
})
}
Expand Down

0 comments on commit 40bc33e

Please sign in to comment.