Skip to content

Commit

Permalink
Disable grab download resume
Browse files Browse the repository at this point in the history
This is to mitigate cases like #4296

By default grab tries to resume the download if the file name determined from either the url or from content-type headers already exists. This makes things go side ways, if the existing file is smaller than the new one, the old content would still be there and only the "extra" new bytes would get written. I.e. the download would be "resumed". 🤦

This is probably not a fix for the root cause in #4296 as the only way I've been able to make grab fail with `bad content length` is by crafting a custom http server that maliciously borks `Content-Length` header.

This is a minimal possible fix that we can easily backport. @twz123 is already working on bigger refactoring of autopilot download functionality that gets rid of grab. Grab seems to bring more (bad) surprises than real benefits. In the end, we just download files and we should pretty much always just replace them. No need for full library dependecy for that.

Signed-off-by: Jussi Nummelin <[email protected]>
  • Loading branch information
jnummelin committed Sep 23, 2024
1 parent bb4dfd5 commit 4027037
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pkg/autopilot/controller/signal/common/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type DownloadManifest struct {
apdl.Config

SuccessState string
// A hook that will be called after the download itself is done succesfully
AfterTransferSuccess func() error
}

type DownloadManifestBuilder interface {
Expand Down Expand Up @@ -90,7 +92,12 @@ func (r *downloadController) Reconcile(ctx context.Context, req cr.Request) (cr.

} else {
logger.Infof("Download of '%s' successful", manifest.URL)

// When download is successful, run the post-download hook
if manifest.AfterTransferSuccess != nil {
if err := manifest.AfterTransferSuccess(); err != nil {
return cr.Result{}, fmt.Errorf("failed to run post-download hook: %w", err)
}
}
// When the download is complete move the status to the success state
signalData.Status = apsigv2.NewStatus(manifest.SuccessState)
}
Expand Down
9 changes: 9 additions & 0 deletions pkg/autopilot/controller/signal/k0s/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package k0s

import (
"crypto/sha256"
"os"
"path/filepath"

apcomm "github.com/k0sproject/k0s/pkg/autopilot/common"
apdel "github.com/k0sproject/k0s/pkg/autopilot/controller/delegate"
Expand Down Expand Up @@ -88,6 +90,13 @@ func (b downloadManifestBuilderK0s) Build(signalNode crcli.Object, signalData ap
ExpectedHash: signalData.Command.K0sUpdate.Sha256,
Hasher: sha256.New(),
DownloadDir: b.k0sBinaryDir,
Filename: filepath.Join(b.k0sBinaryDir, "k0s.tmp"),
},
// After the download is done, we need to rename the file to the correct name
AfterTransferSuccess: func() error {
src := filepath.Join(b.k0sBinaryDir, "k0s.tmp")
dst := filepath.Join(b.k0sBinaryDir, "k0s")
return os.Rename(src, dst)
},
SuccessState: Cordoning,
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/autopilot/download/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Config struct {
ExpectedHash string
Hasher hash.Hash
DownloadDir string
Filename string
}

type downloader struct {
Expand Down Expand Up @@ -72,6 +73,15 @@ func (d *downloader) Download(ctx context.Context) error {
dlreq.SetChecksum(d.config.Hasher, expectedHash, true)
}

// We're never really resuming downloads, so disable this feature.
// This also allows to re-download the file if it's already present.
dlreq.NoResume = true

if d.config.Filename != "" {
d.logger.Infof("Setting filename to %s", d.config.Filename)
dlreq.Filename = d.config.Filename
}

client := grab.NewClient()
// Set user agent to mitigate 403 errors from GitHub
// See https://github.com/cavaliergopher/grab/issues/104
Expand Down
233 changes: 233 additions & 0 deletions pkg/autopilot/download/downloader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright 2021 k0s authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package download

import (
"crypto/rand"
"io"
"net/http"
"os"
"path/filepath"
"testing"
"time"

"github.com/cavaliergopher/grab/v3"
"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)

type content struct {

Check failure on line 32 in pkg/autopilot/download/downloader_test.go

View workflow job for this annotation

GitHub Actions / Lint

type `content` is unused (unused)
content string
checksum string
}

func TestDownloadWithGrab(t *testing.T) {
// go runServer()

tmp := t.TempDir()
dlPath := filepath.Join(tmp, "index.html")

go func() {
// Watch the file for changes
w, err := fsnotify.NewWatcher()
assert.NoError(t, err)
defer w.Close()
err = w.Add(tmp)
assert.NoError(t, err)
for {
select {
case event := <-w.Events:
t.Logf("event: %s -- %s", event.Op.String(), event.Name)
case err := <-w.Errors:
panic(err)
}
}
}()

// Dump some data to the file to make sure it will be overwritten
err := os.WriteFile(dlPath, []byte("foobar"), 0644)
assert.NoError(t, err)

url := "http://localhost:8888/index.html"
// Download should succeed and get "abc" as content
dlreq, err := grab.NewRequest(tmp, url)
assert.NoError(t, err)

// dlreq.NoResume = true

client := grab.NewClient()
resp := client.Do(dlreq)
<-resp.Done
assert.NoError(t, resp.Err())

data, err := getFileContent(dlPath)
assert.NoError(t, err)
assert.NotContains(t, data, "foobar")

// Now create the file again with more data than the previous download
// this way we can test if the download will overwrite the file entirely
randomData := make([]byte, 2048) // nginx index.html is 615 bytes
_, err = rand.Read(randomData)
assert.NoError(t, err)
err = os.WriteFile(dlPath, []byte("foobar"), 0644)
assert.NoError(t, err)

dlreq, err = grab.NewRequest(tmp, url)
dlreq.NoResume = true
assert.NoError(t, err)
client = grab.NewClient()
resp = client.Do(dlreq)
<-resp.Done
assert.NoError(t, resp.Err())

s, err := os.Stat(dlPath)
assert.NoError(t, err)
assert.Equal(t, int64(615), s.Size())

time.Sleep(5 * time.Second)
}

func TestDownloadRaw(t *testing.T) {
go runServer()

tmp := t.TempDir()
dlPath := filepath.Join(tmp, "foo.txt")

// Dump some data to the file to make sure it will be overwritten
err := os.WriteFile(dlPath, []byte("foobar"), 0644)
assert.NoError(t, err)

url := "http://localhost:8080/foo.txt"
// Download using plain http.Get
resp, err := http.Get(url)
assert.NoError(t, err)
defer resp.Body.Close()
// Dump all the data to file
f, err := os.Create(dlPath)
assert.NoError(t, err)
defer f.Close()
_, err = io.Copy(f, resp.Body)
assert.NoError(t, err)
// Check the content
data, err := getFileContent(dlPath)
assert.NoError(t, err)
assert.Equal(t, "abc", data)
}

func runServer() {
http.HandleFunc("/foo.txt", func(w http.ResponseWriter, r *http.Request) {
body := []byte("a")
// w.Header().Set("Content-Length", strconv.Itoa(len(body)))
if r.Method == http.MethodGet {
_, err := w.Write(body)
if err != nil {
logrus.Errorf("Failed to write response: %v", err)
}
}
})

http.ListenAndServe(":8080", nil)

Check failure on line 142 in pkg/autopilot/download/downloader_test.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `http.ListenAndServe` is not checked (errcheck)
}

// func TestDownload(t *testing.T) {

// // content is a map of content and its checksum
// // counter is used to switch between content
// content := map[int]content{
// 0: {content: "a", checksum: "ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"},
// 1: {content: "bb", checksum: "961b6dd3ede3cb8ecbaacbd68de040cd78eb2ed5889130cceb4c49268ea4d506"},
// // The next one has wrong checksum so it needs to fail
// 2: {content: "aaa", checksum: "961b6dd3ede3cb8ecbaacbd68de040cd78eb2ed5889130cceb4c49268ea4d506"},
// }

// counter := 0

// http.HandleFunc("/foo.txt", func(w http.ResponseWriter, r *http.Request) {
// // body := []byte(content[counter].content)
// body := []byte("a")
// w.Header().Set("Content-Length", strconv.Itoa(len(body)))
// // w.Header().Set("Content-Length", "42")
// // w.WriteHeader(http.StatusOK)

// if r.Method == http.MethodGet {
// _, err := w.Write(body)
// if err != nil {
// t.Errorf("Failed to write response: %v", err)
// }

// // increment the counter so next GET request will have different content
// counter++
// }
// })

// go func() {
// assert.NoError(t, http.ListenAndServe(":8080", nil))
// }()

// tmp := t.TempDir()
// dlPath := filepath.Join(tmp, "foo.txt")
// url := "http://localhost:8080/foo.txt"
// // Download should succeed and get "a" as content
// err := dl(url, content[0].checksum, tmp)
// assert.NoError(t, err)
// data, err := getFileContent(dlPath)
// assert.NoError(t, err)
// assert.Equal(t, content[0].content, data)

// // Write the file with different content and test if the download will overwrite it
// os.WriteFile(dlPath, []byte("foobar"), 0644)

// // Download should succeed and get "a" as content
// err = dl(url, content[0].checksum, tmp)
// assert.NoError(t, err)
// data, err = getFileContent(dlPath)
// assert.NoError(t, err)
// assert.Equal(t, content[0].content, data)

// // // Download should succeed and get "bb" as content
// // err = dl(url, content[1].checksum, tmp)
// // assert.NoError(t, err)
// // data, err = getFileContent(dlPath)
// // assert.NoError(t, err)
// // assert.Equal(t, content[1].content, data)

// // err = dl(url, content[2].checksum, tmp)
// // assert.Error(t, err)

// }

// func dl(url, hash, dir string) error {
// logger := logrus.New().WithField("component", "downloader")
// config := Config{
// URL: url,
// ExpectedHash: "",
// DownloadDir: dir,
// Hasher: sha256.New(),
// }
// d := NewDownloader(config, logger)
// ctx, cancel := context.WithTimeout(context.Background(), 10000*time.Second)
// defer cancel()

// return d.Download(ctx)
// }

func getFileContent(path string) (string, error) {
content, err := os.ReadFile(path)
if err != nil {
return "", err
}
return string(content), nil
}

0 comments on commit 4027037

Please sign in to comment.