diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 4099fca..0cb461e 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -131,10 +131,6 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { return fmt.Errorf("error getting consumer: %w", err) } - if viper.GetBool(config.OptForce) { - consumer.EnableOverwrite() - } - getter := &pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, diff --git a/cmd/root/root.go b/cmd/root/root.go index 30dad43..6aeff55 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -229,10 +229,6 @@ func rootExecute(ctx context.Context, urlString, dest string) error { return err } - if viper.GetBool(config.OptForce) { - consumer.EnableOverwrite() - } - getter := pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, diff --git a/pkg/config/config.go b/pkg/config/config.go index 8cef10a..57d3f53 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -155,11 +155,12 @@ func ResolveOverridesToMap(resolveOverrides []string) (map[string]string, error) // calls viper.GetString(OptExtract) internally. func GetConsumer() (consumer.Consumer, error) { consumerName := viper.GetString(OptOutputConsumer) + enableOverwrite := viper.GetBool(OptForce) switch consumerName { case ConsumerFile: - return &consumer.FileWriter{}, nil + return &consumer.FileWriter{Overwrite: enableOverwrite}, nil case ConsumerTarExtractor: - return &consumer.TarExtractor{}, nil + return &consumer.TarExtractor{Overwrite: enableOverwrite}, nil case ConsumerNull: return &consumer.NullWriter{}, nil default: diff --git a/pkg/consumer/consumer.go b/pkg/consumer/consumer.go index c406a20..39bedef 100644 --- a/pkg/consumer/consumer.go +++ b/pkg/consumer/consumer.go @@ -4,5 +4,4 @@ import "io" type Consumer interface { Consume(reader io.Reader, destPath string) error - EnableOverwrite() } diff --git a/pkg/consumer/null.go b/pkg/consumer/null.go index 4ac2e7e..6bb3836 100644 --- a/pkg/consumer/null.go +++ b/pkg/consumer/null.go @@ -13,7 +13,3 @@ func (f *NullWriter) Consume(reader io.Reader, destPath string) error { _, _ = io.Copy(io.Discard, reader) return nil } - -func (f *NullWriter) EnableOverwrite() { - // no-op -} diff --git a/pkg/consumer/tar_extractor.go b/pkg/consumer/tar_extractor.go index fcbb393..cd02071 100644 --- a/pkg/consumer/tar_extractor.go +++ b/pkg/consumer/tar_extractor.go @@ -8,19 +8,15 @@ import ( ) type TarExtractor struct { - overwrite bool + Overwrite bool } var _ Consumer = &TarExtractor{} func (f *TarExtractor) Consume(reader io.Reader, destPath string) error { - err := extract.TarFile(reader, destPath, f.overwrite) + err := extract.TarFile(reader, destPath, f.Overwrite) if err != nil { return fmt.Errorf("error extracting file: %w", err) } return nil } - -func (f *TarExtractor) EnableOverwrite() { - f.overwrite = true -} diff --git a/pkg/consumer/write_file.go b/pkg/consumer/write_file.go index 80dd423..44ecff4 100644 --- a/pkg/consumer/write_file.go +++ b/pkg/consumer/write_file.go @@ -7,14 +7,14 @@ import ( ) type FileWriter struct { - overwrite bool + Overwrite bool } var _ Consumer = &FileWriter{} func (f *FileWriter) Consume(reader io.Reader, destPath string) error { openFlags := os.O_WRONLY | os.O_CREATE - if f.overwrite { + if f.Overwrite { openFlags |= os.O_TRUNC } out, err := os.OpenFile(destPath, openFlags, 0644) @@ -29,7 +29,3 @@ func (f *FileWriter) Consume(reader io.Reader, destPath string) error { } return nil } - -func (f *FileWriter) EnableOverwrite() { - f.overwrite = true -} diff --git a/pkg/extract/tar_test.go b/pkg/extract/tar_test.go index 8daec62..3e53790 100644 --- a/pkg/extract/tar_test.go +++ b/pkg/extract/tar_test.go @@ -12,11 +12,11 @@ import ( func TestCreateLinks(t *testing.T) { tests := []struct { - name string - links []*link - expectedError bool - overwrite bool - createOverwritenFile bool + name string + links []*link + expectedError bool + overwrite bool + createFileToOverwrite bool }{ { name: "EmptyLink", @@ -43,16 +43,16 @@ func TestCreateLinks(t *testing.T) { }, }, { - name: "HardLink_OverwriteEnabled_File Exists", - links: []*link{{tar.TypeLink, "", "testLinkHard"}}, - overwrite: true, - createOverwritenFile: true, + name: "HardLink_OverwriteEnabled_File Exists", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + overwrite: true, + createFileToOverwrite: true, }, { - name: "HardLink_OverwriteDisabled_FileExists", - links: []*link{{tar.TypeLink, "", "testLinkHard"}}, - createOverwritenFile: true, - expectedError: true, + name: "HardLink_OverwriteDisabled_FileExists", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + createFileToOverwrite: true, + expectedError: true, }, { name: "HardLink_OverwriteEnabled_FileDoesNotExist", @@ -60,16 +60,16 @@ func TestCreateLinks(t *testing.T) { overwrite: true, }, { - name: "SymLink_OverwriteEnabled_FileExists", - links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, - overwrite: true, - createOverwritenFile: true, + name: "SymLink_OverwriteEnabled_FileExists", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + overwrite: true, + createFileToOverwrite: true, }, { - name: "SymLink_OverwriteDisabled_FileExists", - links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, - createOverwritenFile: true, - expectedError: true, + name: "SymLink_OverwriteDisabled_FileExists", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + createFileToOverwrite: true, + expectedError: true, }, { name: "SymLink_OverwriteEnabled_FileDoesNotExist", @@ -92,7 +92,7 @@ func TestCreateLinks(t *testing.T) { for _, link := range tt.links { if link.linkType == tar.TypeLink || link.linkType == tar.TypeSymlink { testFile, err := os.CreateTemp(destDir, "test-") - if tt.createOverwritenFile { + if tt.createFileToOverwrite { _, err = os.Create(filepath.Join(destDir, link.newName)) } if err != nil {