Skip to content

Commit

Permalink
Implement cancellation of stale downloads
Browse files Browse the repository at this point in the history
Cancel downloads whenever there's no data flowing for a certain amount
of time. Use k0scontext.WithInactivityTimeout for that.

Signed-off-by: Tom Wieczorek <[email protected]>
  • Loading branch information
twz123 committed Sep 23, 2024
1 parent 2db0d85 commit 7f83b8f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
33 changes: 29 additions & 4 deletions internal/http/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,21 @@ import (
"fmt"
"io"
"net/http"
"time"

internalio "github.com/k0sproject/k0s/internal/io"
"github.com/k0sproject/k0s/pkg/build"
"github.com/k0sproject/k0s/pkg/k0scontext"
)

type DownloadOption func(*downloadOptions)

// Downloads the contents of the given URL. Writes the HTTP response body to writer.
// Stalled downloads will be aborted if there's no data transfer for some time.
func Download(ctx context.Context, url string, target io.Writer, options ...DownloadOption) (err error) {
opts := downloadOptions{}
opts := downloadOptions{
stalenessTimeout: time.Minute,
}
for _, opt := range options {
opt(&opts)
}
Expand All @@ -48,8 +54,9 @@ func Download(ctx context.Context, url string, target io.Writer, options ...Down
}
req.Header.Set("User-Agent", "k0s/"+build.Version)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Create a context with an inactivity timeout to cancel the download if it stalls.
ctx, cancel, keepAlive := k0scontext.WithInactivityTimeout(ctx, opts.stalenessTimeout)
defer cancel(nil)

// Execute the request.
resp, err := client.Do(req.WithContext(ctx))
Expand All @@ -73,8 +80,17 @@ func Download(ctx context.Context, url string, target io.Writer, options ...Down
return err
}

// Monitor writes. Keep the download context alive as long as data is flowing.
writeMonitor := internalio.WriterFunc(func(p []byte) (int, error) {
len := len(p)
if len > 0 {
keepAlive()
}
return len, nil
})

// Run the actual data transfer.
if _, err := io.Copy(target, resp.Body); err != nil {
if _, err := io.Copy(io.MultiWriter(writeMonitor, target), resp.Body); err != nil {
if cause := context.Cause(ctx); cause != nil && !errors.Is(err, cause) {
err = fmt.Errorf("%w (%w)", cause, err)
}
Expand All @@ -85,6 +101,15 @@ func Download(ctx context.Context, url string, target io.Writer, options ...Down
return nil
}

// Sets the staleness timeout for a download.
// Defaults to one minute if omitted.
func WithStalenessTimeout(stalenessTimeout time.Duration) DownloadOption {
return func(opts *downloadOptions) {
opts.stalenessTimeout = stalenessTimeout
}
}

type downloadOptions struct {
stalenessTimeout time.Duration
downloadFileNameOptions
}
30 changes: 30 additions & 0 deletions internal/http/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import (
"strings"
"sync/atomic"
"testing"
"time"

internalhttp "github.com/k0sproject/k0s/internal/http"
internalio "github.com/k0sproject/k0s/internal/io"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -83,6 +85,34 @@ func TestDownload_ExcessContentLength(t *testing.T) {
assert.Equal(t, "yolo", downloaded.String())
}

func TestDownload_CancelDownload(t *testing.T) {
ctx, cancel := context.WithCancelCause(context.TODO())
t.Cleanup(func() { cancel(nil) })

baseURL := startFakeDownloadServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for {
if _, err := w.Write([]byte(t.Name())); !assert.NoError(t, err) {
return
}

select {
case <-r.Context().Done():
return
case <-time.After(time.Microsecond):
}
}
}))

err := internalhttp.Download(ctx, baseURL, internalio.WriterFunc(func(p []byte) (int, error) {
cancel(assert.AnError)
return len(p), nil
}))

assert.ErrorContains(t, err, "while downloading: ")
assert.ErrorIs(t, err, assert.AnError)
assert.ErrorIs(t, err, context.Canceled)
}

func TestDownload_RedirectLoop(t *testing.T) {
// The current implementation doesn't detect loops, but it stops after 10 redirects.

Expand Down

0 comments on commit 7f83b8f

Please sign in to comment.