From 16ca4d211864f3cc790a25ec5d4b1041b61ff470 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Fri, 23 Feb 2024 14:35:00 -0800 Subject: [PATCH] Handle Overwriting in the Consumer Make the consumer handle overwriting explicitly. This addresses edge cases with tar when extracting files. --- cmd/multifile/multifile.go | 4 +++ cmd/root/root.go | 4 +++ pkg/consumer/consumer.go | 1 + pkg/consumer/null.go | 4 +++ pkg/consumer/tar_extractor.go | 10 ++++++-- pkg/consumer/write_file.go | 17 +++++++++---- pkg/extract/tar.go | 36 ++++++++++++++++++++++----- pkg/extract/tar_test.go | 47 ++++++++++++++++++++++++++++++++--- 8 files changed, 106 insertions(+), 17 deletions(-) diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 0cb461e..4099fca 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -131,6 +131,10 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { return fmt.Errorf("error getting consumer: %w", err) } + if viper.GetBool(config.OptForce) { + consumer.EnableOverwrite() + } + getter := &pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, diff --git a/cmd/root/root.go b/cmd/root/root.go index 6aeff55..30dad43 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -229,6 +229,10 @@ func rootExecute(ctx context.Context, urlString, dest string) error { return err } + if viper.GetBool(config.OptForce) { + consumer.EnableOverwrite() + } + getter := pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, diff --git a/pkg/consumer/consumer.go b/pkg/consumer/consumer.go index 39bedef..c406a20 100644 --- a/pkg/consumer/consumer.go +++ b/pkg/consumer/consumer.go @@ -4,4 +4,5 @@ import "io" type Consumer interface { Consume(reader io.Reader, destPath string) error + EnableOverwrite() } diff --git a/pkg/consumer/null.go b/pkg/consumer/null.go index 6bb3836..4ac2e7e 100644 --- a/pkg/consumer/null.go +++ b/pkg/consumer/null.go @@ -13,3 +13,7 @@ func (f *NullWriter) Consume(reader io.Reader, destPath string) error { _, _ = io.Copy(io.Discard, reader) return nil } + +func (f *NullWriter) EnableOverwrite() { + // no-op +} diff --git a/pkg/consumer/tar_extractor.go b/pkg/consumer/tar_extractor.go index 012a1e6..fcbb393 100644 --- a/pkg/consumer/tar_extractor.go +++ b/pkg/consumer/tar_extractor.go @@ -7,14 +7,20 @@ import ( "github.com/replicate/pget/pkg/extract" ) -type TarExtractor struct{} +type TarExtractor struct { + overwrite bool +} var _ Consumer = &TarExtractor{} func (f *TarExtractor) Consume(reader io.Reader, destPath string) error { - err := extract.TarFile(reader, destPath) + err := extract.TarFile(reader, destPath, f.overwrite) if err != nil { return fmt.Errorf("error extracting file: %w", err) } return nil } + +func (f *TarExtractor) EnableOverwrite() { + f.overwrite = true +} diff --git a/pkg/consumer/write_file.go b/pkg/consumer/write_file.go index 008b21f..80dd423 100644 --- a/pkg/consumer/write_file.go +++ b/pkg/consumer/write_file.go @@ -6,15 +6,18 @@ import ( "os" ) -type FileWriter struct{} +type FileWriter struct { + overwrite bool +} var _ Consumer = &FileWriter{} func (f *FileWriter) Consume(reader io.Reader, destPath string) error { - // NOTE(morgan): We check if the file exists early on allowing a fast fail, it is safe - // to just apply os.O_TRUNC. Getting to this point without checking existence and - // the `--force` flag is a programming error further up the stack. - out, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + openFlags := os.O_WRONLY | os.O_CREATE + if f.overwrite { + openFlags |= os.O_TRUNC + } + out, err := os.OpenFile(destPath, openFlags, 0644) if err != nil { return fmt.Errorf("error writing file: %w", err) } @@ -26,3 +29,7 @@ func (f *FileWriter) Consume(reader io.Reader, destPath string) error { } return nil } + +func (f *FileWriter) EnableOverwrite() { + f.overwrite = true +} diff --git a/pkg/extract/tar.go b/pkg/extract/tar.go index 23b836d..aa74b25 100644 --- a/pkg/extract/tar.go +++ b/pkg/extract/tar.go @@ -17,7 +17,7 @@ type link struct { newName string } -func TarFile(reader io.Reader, destDir string) error { +func TarFile(reader io.Reader, destDir string, overwrite bool) error { var links []*link startTime := time.Now() @@ -49,7 +49,11 @@ func TarFile(reader io.Reader, destDir string) error { return err } case tar.TypeReg: - targetFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY, os.FileMode(header.Mode)) + openFlags := os.O_CREATE | os.O_WRONLY + if overwrite { + openFlags |= os.O_TRUNC + } + targetFile, err := os.OpenFile(target, openFlags, os.FileMode(header.Mode)) if err != nil { return err } @@ -68,7 +72,7 @@ func TarFile(reader io.Reader, destDir string) error { } } - if err := createLinks(links, destDir); err != nil { + if err := createLinks(links, destDir, overwrite); err != nil { return fmt.Errorf("error creating links: %w", err) } @@ -81,7 +85,7 @@ func TarFile(reader io.Reader, destDir string) error { return nil } -func createLinks(links []*link, destDir string) error { +func createLinks(links []*link, destDir string, overwrite bool) error { for _, link := range links { targetDir := filepath.Dir(link.newName) if err := os.MkdirAll(targetDir, 0755); err != nil { @@ -90,11 +94,11 @@ func createLinks(links []*link, destDir string) error { switch link.linkType { case tar.TypeLink: oldPath := filepath.Join(destDir, link.oldName) - if err := os.Link(oldPath, link.newName); err != nil { + if err := createHardLink(oldPath, link.newName, overwrite); err != nil { return fmt.Errorf("error creating hard link from %s to %s: %w", oldPath, link.newName, err) } case tar.TypeSymlink: - if err := os.Symlink(link.oldName, link.newName); err != nil { + if err := createSymlink(link.oldName, link.newName, overwrite); err != nil { return fmt.Errorf("error creating symlink from %s to %s: %w", link.oldName, link.newName, err) } default: @@ -103,3 +107,23 @@ func createLinks(links []*link, destDir string) error { } return nil } + +func createHardLink(oldName, newName string, overwrite bool) error { + if overwrite { + err := os.Remove(newName) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("error removing existing file: %w", err) + } + } + return os.Link(oldName, newName) +} + +func createSymlink(oldName, newName string, overwrite bool) error { + if overwrite { + err := os.Remove(newName) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("error removing existing symlink/file: %w", err) + } + } + return os.Symlink(oldName, newName) +} diff --git a/pkg/extract/tar_test.go b/pkg/extract/tar_test.go index e21f5af..7e39820 100644 --- a/pkg/extract/tar_test.go +++ b/pkg/extract/tar_test.go @@ -12,9 +12,11 @@ import ( func TestCreateLinks(t *testing.T) { tests := []struct { - name string - links []*link - expectedError bool + name string + links []*link + expectedError bool + overwrite bool + createOverwritenFile bool }{ { name: "EmptyLink", @@ -40,6 +42,40 @@ func TestCreateLinks(t *testing.T) { {tar.TypeSymlink, "", "testLinkSym"}, }, }, + { + name: "HardLink_OverwriteEnabled_File Exists", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + overwrite: true, + createOverwritenFile: true, + }, + { + name: "HardLink_OverwriteDisabled_FileExists", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + createOverwritenFile: true, + expectedError: true, + }, + { + name: "HardLink_OverwriteEnabled_FileDoesNotExist", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + overwrite: true, + }, + { + name: "SymLink_OverwriteEnabled_FileExists", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + overwrite: true, + createOverwritenFile: true, + }, + { + name: "SymLink_OverwriteDisabled_FileExists", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + createOverwritenFile: true, + expectedError: true, + }, + { + name: "SymLink_OverwriteEnabled_FileDoesNotExist", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + overwrite: true, + }, } for _, tt := range tests { @@ -56,6 +92,9 @@ func TestCreateLinks(t *testing.T) { for _, link := range tt.links { if link.linkType == tar.TypeLink || link.linkType == tar.TypeSymlink { testFile, err := os.CreateTemp(destDir, "test-") + if tt.createOverwritenFile { + _, err = os.Create(filepath.Join(destDir, link.newName)) + } if err != nil { t.Fatalf("Test failed, could not create test file: %v", err) } @@ -65,7 +104,7 @@ func TestCreateLinks(t *testing.T) { } } - err = createLinks(tt.links, destDir) + err = createLinks(tt.links, destDir, tt.overwrite) // Validation if tt.expectedError {