From 4027037e51ba32b76adeedc89c565da999f9e6c7 Mon Sep 17 00:00:00 2001 From: Jussi Nummelin Date: Thu, 19 Sep 2024 22:29:28 +0300 Subject: [PATCH] Disable grab download resume This is to mitigate cases like #4296 By default grab tries to resume the download if the file name determined from either the url or from content-type headers already exists. This makes things go side ways, if the existing file is smaller than the new one, the old content would still be there and only the "extra" new bytes would get written. I.e. the download would be "resumed". :facepalm: This is probably not a fix for the root cause in #4296 as the only way I've been able to make grab fail with `bad content length` is by crafting a custom http server that maliciously borks `Content-Length` header. This is a minimal possible fix that we can easily backport. @twz123 is already working on bigger refactoring of autopilot download functionality that gets rid of grab. Grab seems to bring more (bad) surprises than real benefits. In the end, we just download files and we should pretty much always just replace them. No need for full library dependecy for that. Signed-off-by: Jussi Nummelin --- .../controller/signal/common/download.go | 9 +- .../controller/signal/k0s/download.go | 9 + pkg/autopilot/download/downloader.go | 10 + pkg/autopilot/download/downloader_test.go | 233 ++++++++++++++++++ 4 files changed, 260 insertions(+), 1 deletion(-) create mode 100644 pkg/autopilot/download/downloader_test.go diff --git a/pkg/autopilot/controller/signal/common/download.go b/pkg/autopilot/controller/signal/common/download.go index 3d4e6952ef50..c0bd425585d4 100644 --- a/pkg/autopilot/controller/signal/common/download.go +++ b/pkg/autopilot/controller/signal/common/download.go @@ -32,6 +32,8 @@ type DownloadManifest struct { apdl.Config SuccessState string + // A hook that will be called after the download itself is done succesfully + AfterTransferSuccess func() error } type DownloadManifestBuilder interface { @@ -90,7 +92,12 @@ func (r *downloadController) Reconcile(ctx context.Context, req cr.Request) (cr. } else { logger.Infof("Download of '%s' successful", manifest.URL) - + // When download is successful, run the post-download hook + if manifest.AfterTransferSuccess != nil { + if err := manifest.AfterTransferSuccess(); err != nil { + return cr.Result{}, fmt.Errorf("failed to run post-download hook: %w", err) + } + } // When the download is complete move the status to the success state signalData.Status = apsigv2.NewStatus(manifest.SuccessState) } diff --git a/pkg/autopilot/controller/signal/k0s/download.go b/pkg/autopilot/controller/signal/k0s/download.go index 044b2d75ae1a..233e1fc5e5aa 100644 --- a/pkg/autopilot/controller/signal/k0s/download.go +++ b/pkg/autopilot/controller/signal/k0s/download.go @@ -16,6 +16,8 @@ package k0s import ( "crypto/sha256" + "os" + "path/filepath" apcomm "github.com/k0sproject/k0s/pkg/autopilot/common" apdel "github.com/k0sproject/k0s/pkg/autopilot/controller/delegate" @@ -88,6 +90,13 @@ func (b downloadManifestBuilderK0s) Build(signalNode crcli.Object, signalData ap ExpectedHash: signalData.Command.K0sUpdate.Sha256, Hasher: sha256.New(), DownloadDir: b.k0sBinaryDir, + Filename: filepath.Join(b.k0sBinaryDir, "k0s.tmp"), + }, + // After the download is done, we need to rename the file to the correct name + AfterTransferSuccess: func() error { + src := filepath.Join(b.k0sBinaryDir, "k0s.tmp") + dst := filepath.Join(b.k0sBinaryDir, "k0s") + return os.Rename(src, dst) }, SuccessState: Cordoning, } diff --git a/pkg/autopilot/download/downloader.go b/pkg/autopilot/download/downloader.go index 11cdecdcec95..d80ace5bcc77 100644 --- a/pkg/autopilot/download/downloader.go +++ b/pkg/autopilot/download/downloader.go @@ -34,6 +34,7 @@ type Config struct { ExpectedHash string Hasher hash.Hash DownloadDir string + Filename string } type downloader struct { @@ -72,6 +73,15 @@ func (d *downloader) Download(ctx context.Context) error { dlreq.SetChecksum(d.config.Hasher, expectedHash, true) } + // We're never really resuming downloads, so disable this feature. + // This also allows to re-download the file if it's already present. + dlreq.NoResume = true + + if d.config.Filename != "" { + d.logger.Infof("Setting filename to %s", d.config.Filename) + dlreq.Filename = d.config.Filename + } + client := grab.NewClient() // Set user agent to mitigate 403 errors from GitHub // See https://github.com/cavaliergopher/grab/issues/104 diff --git a/pkg/autopilot/download/downloader_test.go b/pkg/autopilot/download/downloader_test.go new file mode 100644 index 000000000000..2b62fa3f72d3 --- /dev/null +++ b/pkg/autopilot/download/downloader_test.go @@ -0,0 +1,233 @@ +// Copyright 2021 k0s authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package download + +import ( + "crypto/rand" + "io" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/cavaliergopher/grab/v3" + "github.com/fsnotify/fsnotify" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +type content struct { + content string + checksum string +} + +func TestDownloadWithGrab(t *testing.T) { + // go runServer() + + tmp := t.TempDir() + dlPath := filepath.Join(tmp, "index.html") + + go func() { + // Watch the file for changes + w, err := fsnotify.NewWatcher() + assert.NoError(t, err) + defer w.Close() + err = w.Add(tmp) + assert.NoError(t, err) + for { + select { + case event := <-w.Events: + t.Logf("event: %s -- %s", event.Op.String(), event.Name) + case err := <-w.Errors: + panic(err) + } + } + }() + + // Dump some data to the file to make sure it will be overwritten + err := os.WriteFile(dlPath, []byte("foobar"), 0644) + assert.NoError(t, err) + + url := "http://localhost:8888/index.html" + // Download should succeed and get "abc" as content + dlreq, err := grab.NewRequest(tmp, url) + assert.NoError(t, err) + + // dlreq.NoResume = true + + client := grab.NewClient() + resp := client.Do(dlreq) + <-resp.Done + assert.NoError(t, resp.Err()) + + data, err := getFileContent(dlPath) + assert.NoError(t, err) + assert.NotContains(t, data, "foobar") + + // Now create the file again with more data than the previous download + // this way we can test if the download will overwrite the file entirely + randomData := make([]byte, 2048) // nginx index.html is 615 bytes + _, err = rand.Read(randomData) + assert.NoError(t, err) + err = os.WriteFile(dlPath, []byte("foobar"), 0644) + assert.NoError(t, err) + + dlreq, err = grab.NewRequest(tmp, url) + dlreq.NoResume = true + assert.NoError(t, err) + client = grab.NewClient() + resp = client.Do(dlreq) + <-resp.Done + assert.NoError(t, resp.Err()) + + s, err := os.Stat(dlPath) + assert.NoError(t, err) + assert.Equal(t, int64(615), s.Size()) + + time.Sleep(5 * time.Second) +} + +func TestDownloadRaw(t *testing.T) { + go runServer() + + tmp := t.TempDir() + dlPath := filepath.Join(tmp, "foo.txt") + + // Dump some data to the file to make sure it will be overwritten + err := os.WriteFile(dlPath, []byte("foobar"), 0644) + assert.NoError(t, err) + + url := "http://localhost:8080/foo.txt" + // Download using plain http.Get + resp, err := http.Get(url) + assert.NoError(t, err) + defer resp.Body.Close() + // Dump all the data to file + f, err := os.Create(dlPath) + assert.NoError(t, err) + defer f.Close() + _, err = io.Copy(f, resp.Body) + assert.NoError(t, err) + // Check the content + data, err := getFileContent(dlPath) + assert.NoError(t, err) + assert.Equal(t, "abc", data) +} + +func runServer() { + http.HandleFunc("/foo.txt", func(w http.ResponseWriter, r *http.Request) { + body := []byte("a") + // w.Header().Set("Content-Length", strconv.Itoa(len(body))) + if r.Method == http.MethodGet { + _, err := w.Write(body) + if err != nil { + logrus.Errorf("Failed to write response: %v", err) + } + } + }) + + http.ListenAndServe(":8080", nil) +} + +// func TestDownload(t *testing.T) { + +// // content is a map of content and its checksum +// // counter is used to switch between content +// content := map[int]content{ +// 0: {content: "a", checksum: "ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"}, +// 1: {content: "bb", checksum: "961b6dd3ede3cb8ecbaacbd68de040cd78eb2ed5889130cceb4c49268ea4d506"}, +// // The next one has wrong checksum so it needs to fail +// 2: {content: "aaa", checksum: "961b6dd3ede3cb8ecbaacbd68de040cd78eb2ed5889130cceb4c49268ea4d506"}, +// } + +// counter := 0 + +// http.HandleFunc("/foo.txt", func(w http.ResponseWriter, r *http.Request) { +// // body := []byte(content[counter].content) +// body := []byte("a") +// w.Header().Set("Content-Length", strconv.Itoa(len(body))) +// // w.Header().Set("Content-Length", "42") +// // w.WriteHeader(http.StatusOK) + +// if r.Method == http.MethodGet { +// _, err := w.Write(body) +// if err != nil { +// t.Errorf("Failed to write response: %v", err) +// } + +// // increment the counter so next GET request will have different content +// counter++ +// } +// }) + +// go func() { +// assert.NoError(t, http.ListenAndServe(":8080", nil)) +// }() + +// tmp := t.TempDir() +// dlPath := filepath.Join(tmp, "foo.txt") +// url := "http://localhost:8080/foo.txt" +// // Download should succeed and get "a" as content +// err := dl(url, content[0].checksum, tmp) +// assert.NoError(t, err) +// data, err := getFileContent(dlPath) +// assert.NoError(t, err) +// assert.Equal(t, content[0].content, data) + +// // Write the file with different content and test if the download will overwrite it +// os.WriteFile(dlPath, []byte("foobar"), 0644) + +// // Download should succeed and get "a" as content +// err = dl(url, content[0].checksum, tmp) +// assert.NoError(t, err) +// data, err = getFileContent(dlPath) +// assert.NoError(t, err) +// assert.Equal(t, content[0].content, data) + +// // // Download should succeed and get "bb" as content +// // err = dl(url, content[1].checksum, tmp) +// // assert.NoError(t, err) +// // data, err = getFileContent(dlPath) +// // assert.NoError(t, err) +// // assert.Equal(t, content[1].content, data) + +// // err = dl(url, content[2].checksum, tmp) +// // assert.Error(t, err) + +// } + +// func dl(url, hash, dir string) error { +// logger := logrus.New().WithField("component", "downloader") +// config := Config{ +// URL: url, +// ExpectedHash: "", +// DownloadDir: dir, +// Hasher: sha256.New(), +// } +// d := NewDownloader(config, logger) +// ctx, cancel := context.WithTimeout(context.Background(), 10000*time.Second) +// defer cancel() + +// return d.Download(ctx) +// } + +func getFileContent(path string) (string, error) { + content, err := os.ReadFile(path) + if err != nil { + return "", err + } + return string(content), nil +}