diff --git a/pkg/storage/stores/shipper/bloomshipper/block_downloader.go b/pkg/storage/stores/shipper/bloomshipper/block_downloader.go index 81355f78e84e..8bad4fd45fcf 100644 --- a/pkg/storage/stores/shipper/bloomshipper/block_downloader.go +++ b/pkg/storage/stores/shipper/bloomshipper/block_downloader.go @@ -2,7 +2,6 @@ package bloomshipper import ( "context" - "errors" "fmt" "os" "path" @@ -14,6 +13,7 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/grafana/dskit/services" + "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "go.uber.org/atomic" "k8s.io/utils/keymutex" @@ -82,7 +82,7 @@ func newBlockDownloader(config config.Config, blockClient BlockClient, limits Li for i := 0; i < config.BlocksDownloadingQueue.WorkersCount; i++ { b.wg.Add(1) - go b.serveDownloadingTasks(fmt.Sprintf("worker-%d", i)) + go b.runDownloadWorker(fmt.Sprintf("worker-%d", i)) } return b, nil } @@ -92,20 +92,20 @@ type BlockDownloadingTask struct { block BlockRef // ErrCh is a send-only channel to write an error to ErrCh chan<- error - // ResultsCh is a send-only channel to return the block querier for the downloaded block - ResultsCh chan<- blockWithQuerier + // ResCh is a send-only channel to return the block querier for the downloaded block + ResCh chan<- blockWithQuerier } func NewBlockDownloadingTask(ctx context.Context, block BlockRef, resCh chan<- blockWithQuerier, errCh chan<- error) *BlockDownloadingTask { return &BlockDownloadingTask{ - ctx: ctx, - block: block, - ErrCh: errCh, - ResultsCh: resCh, + ctx: ctx, + block: block, + ErrCh: errCh, + ResCh: resCh, } } -func (d *blockDownloader) serveDownloadingTasks(workerID string) { +func (d *blockDownloader) runDownloadWorker(workerID string) { // defer first, so it gets executed as last of the deferred functions defer d.wg.Done() @@ -134,13 +134,13 @@ func (d *blockDownloader) serveDownloadingTasks(workerID string) { } idx = newIdx + result, err := d.strategy.downloadBlock(task, logger) if err != nil { task.ErrCh <- err - continue + } else { + task.ResCh <- result } - task.ResultsCh <- result - continue } } @@ -245,25 +245,23 @@ func downloadBlockToDirectory(logger log.Logger, task *BlockDownloadingTask, wor return directory, nil } -func (d *blockDownloader) downloadBlocks(ctx context.Context, tenantID string, references []BlockRef) (chan blockWithQuerier, chan error) { +func (d *blockDownloader) fetch(ctx context.Context, tenantID string, references []BlockRef) (<-chan blockWithQuerier, <-chan error, error) { d.activeUsersService.UpdateUserTimestamp(tenantID, time.Now()) - // we need to have errCh with size that can keep max count of errors to prevent the case when - // the queue worker reported the error to this channel before the current goroutine - // and this goroutine will go to the deadlock because it won't be able to report an error - // because nothing reads this channel at this moment. + resCh := make(chan blockWithQuerier, len(references)) errCh := make(chan error, len(references)) - blocksCh := make(chan blockWithQuerier, len(references)) + var err error for _, reference := range references { - task := NewBlockDownloadingTask(ctx, reference, blocksCh, errCh) + task := NewBlockDownloadingTask(ctx, reference, resCh, errCh) level.Debug(d.logger).Log("msg", "enqueuing task to download block", "block", reference.BlockPath) - err := d.queue.Enqueue(tenantID, nil, task, nil) + err = d.queue.Enqueue(tenantID, nil, task, nil) if err != nil { - errCh <- fmt.Errorf("error enquing downloading task for block %s : %w", reference.BlockPath, err) - return blocksCh, errCh + err = errors.Wrapf(err, "failed to enqueue download task for block %s", reference.BlockPath) + break } } - return blocksCh, errCh + + return resCh, errCh, err } type blockWithQuerier struct { diff --git a/pkg/storage/stores/shipper/bloomshipper/block_downloader_test.go b/pkg/storage/stores/shipper/bloomshipper/block_downloader_test.go index ffe715c857ec..dee6012dbd82 100644 --- a/pkg/storage/stores/shipper/bloomshipper/block_downloader_test.go +++ b/pkg/storage/stores/shipper/bloomshipper/block_downloader_test.go @@ -41,7 +41,7 @@ func Test_blockDownloader_downloadBlocks(t *testing.T) { }, }, blockClient, overrides, log.NewNopLogger(), prometheus.DefaultRegisterer) require.NoError(t, err) - blocksCh, errorsCh := downloader.downloadBlocks(context.Background(), "fake", blockReferences) + blocksCh, errorsCh, _ := downloader.fetch(context.Background(), "fake", blockReferences) downloadedBlocks := make(map[string]any, len(blockReferences)) done := make(chan bool) go func() { @@ -110,7 +110,7 @@ func Test_blockDownloader_downloadBlock(t *testing.T) { t.Cleanup(downloader.stop) require.NoError(t, err) - blocksCh, errorsCh := downloader.downloadBlocks(context.Background(), "fake", blockReferences) + blocksCh, errorsCh, _ := downloader.fetch(context.Background(), "fake", blockReferences) downloadedBlocks := make(map[string]any, len(blockReferences)) done := make(chan bool) go func() { @@ -131,7 +131,7 @@ func Test_blockDownloader_downloadBlock(t *testing.T) { require.Len(t, downloadedBlocks, 20, "all 20 block must be downloaded") require.Equal(t, int32(20), blockClient.getBlockCalls.Load()) - blocksCh, errorsCh = downloader.downloadBlocks(context.Background(), "fake", blockReferences) + blocksCh, errorsCh, _ = downloader.fetch(context.Background(), "fake", blockReferences) downloadedBlocks = make(map[string]any, len(blockReferences)) done = make(chan bool) go func() { @@ -203,7 +203,7 @@ func Test_blockDownloader_downloadBlock_deduplication(t *testing.T) { waitGroup.Add(1) go func() { defer waitGroup.Done() - blocksCh, errCh := downloader.downloadBlocks(context.Background(), "fake", blockReferences) + blocksCh, errCh, _ := downloader.fetch(context.Background(), "fake", blockReferences) var err error select { case <-blocksCh: diff --git a/pkg/storage/stores/shipper/bloomshipper/shipper.go b/pkg/storage/stores/shipper/bloomshipper/shipper.go index 36bfba913c98..05ffa8b815e6 100644 --- a/pkg/storage/stores/shipper/bloomshipper/shipper.go +++ b/pkg/storage/stores/shipper/bloomshipper/shipper.go @@ -76,7 +76,11 @@ func (s *Shipper) GetBlockRefs(ctx context.Context, tenantID string, from, throu func (s *Shipper) Fetch(ctx context.Context, tenantID string, blocks []BlockRef, callback ForEachBlockCallback) error { cancelContext, cancelFunc := context.WithCancel(ctx) defer cancelFunc() - blocksChannel, errorsChannel := s.blockDownloader.downloadBlocks(cancelContext, tenantID, blocks) + + resCh, errCh, err := s.blockDownloader.fetch(cancelContext, tenantID, blocks) + if err != nil { + return err + } // track how many blocks are still remaning to be downloaded remaining := len(blocks) @@ -85,7 +89,9 @@ func (s *Shipper) Fetch(ctx context.Context, tenantID string, blocks []BlockRef, select { case <-ctx.Done(): return fmt.Errorf("failed to fetch blocks: %w", ctx.Err()) - case result, sentBeforeClosed := <-blocksChannel: + case err := <-errCh: + return fmt.Errorf("failed to fetch blocks: %w", err) + case result, sentBeforeClosed := <-resCh: if !sentBeforeClosed { return nil } @@ -97,8 +103,6 @@ func (s *Shipper) Fetch(ctx context.Context, tenantID string, blocks []BlockRef, if remaining == 0 { return nil } - case err := <-errorsChannel: - return fmt.Errorf("error downloading blocks : %w", err) } } }