From 79c866e841810f338dde47efd2630902da1a963a Mon Sep 17 00:00:00 2001 From: Xiaoxuan Wang <103478229+wangxiaoxuan273@users.noreply.github.com> Date: Mon, 8 Apr 2024 17:21:02 +0800 Subject: [PATCH] fix: cancel goroutine before the next one is created (#739) To fix #738 Signed-off-by: Xiaoxuan Wang --- copy_test.go | 26 ++++++++++++++++---------- go.mod | 2 +- go.sum | 4 ++-- internal/syncutil/limit.go | 14 ++++++++++++-- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/copy_test.go b/copy_test.go index 02421524c..c9e9bfcd3 100644 --- a/copy_test.go +++ b/copy_test.go @@ -1777,9 +1777,12 @@ func TestCopyGraph_WithOptions(t *testing.T) { }) t.Run("MountFrom error", func(t *testing.T) { - root = descs[6] + root = descs[3] dst := &countingStorage{storage: cas.NewMemory()} - opts = oras.CopyGraphOptions{} + opts = oras.CopyGraphOptions{ + // to make the run result deterministic, we limit concurrency to 1 + Concurrency: 1, + } var numMountFrom atomic.Int64 e := errors.New("mountFrom error") opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { @@ -1790,7 +1793,7 @@ func TestCopyGraph_WithOptions(t *testing.T) { t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e) } - if got, expected := dst.numExists.Load(), int64(7); got != expected { + if got, expected := dst.numExists.Load(), int64(2); got != expected { t.Errorf("count(Exists()) = %d, want %d", got, expected) } if got, expected := dst.numFetch.Load(), int64(0); got != expected { @@ -1799,13 +1802,13 @@ func TestCopyGraph_WithOptions(t *testing.T) { if got, expected := dst.numPush.Load(), int64(0); got != expected { t.Errorf("count(Push()) = %d, want %d", got, expected) } - if got, expected := numMountFrom.Load(), int64(4); got != expected { + if got, expected := numMountFrom.Load(), int64(1); got != expected { t.Errorf("count(MountFrom()) = %d, want %d", got, expected) } }) t.Run("MountFrom OnMounted error", func(t *testing.T) { - root = descs[6] + root = descs[3] dst := &countingStorage{storage: cas.NewMemory()} var numMount atomic.Int64 dst.mount = func(ctx context.Context, @@ -1828,7 +1831,10 @@ func TestCopyGraph_WithOptions(t *testing.T) { } return nil } - opts = oras.CopyGraphOptions{} + opts = oras.CopyGraphOptions{ + // to make the run result deterministic, we limit concurrency to 1 + Concurrency: 1, + } var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64 opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { numPreCopy.Add(1) @@ -1851,7 +1857,7 @@ func TestCopyGraph_WithOptions(t *testing.T) { t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e) } - if got, expected := dst.numExists.Load(), int64(7); got != expected { + if got, expected := dst.numExists.Load(), int64(2); got != expected { t.Errorf("count(Exists()) = %d, want %d", got, expected) } if got, expected := dst.numFetch.Load(), int64(0); got != expected { @@ -1860,13 +1866,13 @@ func TestCopyGraph_WithOptions(t *testing.T) { if got, expected := dst.numPush.Load(), int64(0); got != expected { t.Errorf("count(Push()) = %d, want %d", got, expected) } - if got, expected := numMount.Load(), int64(4); got != expected { + if got, expected := numMount.Load(), int64(1); got != expected { t.Errorf("count(Mount()) = %d, want %d", got, expected) } - if got, expected := numOnMounted.Load(), int64(4); got != expected { + if got, expected := numOnMounted.Load(), int64(1); got != expected { t.Errorf("count(OnMounted()) = %d, want %d", got, expected) } - if got, expected := numMountFrom.Load(), int64(4); got != expected { + if got, expected := numMountFrom.Load(), int64(1); got != expected { t.Errorf("count(MountFrom()) = %d, want %d", got, expected) } if got, expected := numPreCopy.Load(), int64(0); got != expected { diff --git a/go.mod b/go.mod index 85b83d90b..bd267939b 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,5 @@ go 1.21 require ( github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/image-spec v1.1.0 - golang.org/x/sync v0.6.0 + golang.org/x/sync v0.7.0 ) diff --git a/go.sum b/go.sum index 9b89e8aea..eec227b27 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,5 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= diff --git a/internal/syncutil/limit.go b/internal/syncutil/limit.go index 2a05d4ea2..f2caabf72 100644 --- a/internal/syncutil/limit.go +++ b/internal/syncutil/limit.go @@ -17,6 +17,7 @@ package syncutil import ( "context" + "sync/atomic" "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" @@ -68,15 +69,24 @@ type GoFunc[T any] func(ctx context.Context, region *LimitedRegion, t T) error // Go concurrently invokes fn on items. func Go[T any](ctx context.Context, limiter *semaphore.Weighted, fn GoFunc[T], items ...T) error { eg, egCtx := errgroup.WithContext(ctx) + var egErr atomic.Value for _, item := range items { - region := LimitRegion(ctx, limiter) + region := LimitRegion(egCtx, limiter) if err := region.Start(); err != nil { + if egErr, ok := egErr.Load().(error); ok && egErr != nil { + return egErr + } return err } eg.Go(func(t T) func() error { return func() error { defer region.End() - return fn(egCtx, region, t) + err := fn(egCtx, region, t) + if err != nil { + egErr.CompareAndSwap(nil, err) + return err + } + return nil } }(item)) }