From 6c4f3ddbcb9bdfe5e4b5321114a6f6d4bce8042b Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Thu, 29 Feb 2024 09:27:53 -0800 Subject: [PATCH] Move Enabling Overwrite to .GetConsumer() Move the enabling of the force-flag (telling the consumer to overwrite files) to the function .GetConsumer() in config. Config is already doing the lifting to construct the consumer, add the query to Viper at that time. --- cmd/multifile/multifile.go | 4 ---- cmd/root/root.go | 4 ---- pkg/config/config.go | 5 ++-- pkg/consumer/consumer.go | 1 - pkg/consumer/null.go | 4 ---- pkg/consumer/tar_extractor.go | 8 ++----- pkg/consumer/write_file.go | 8 ++----- pkg/extract/tar_test.go | 44 +++++++++++++++++------------------ 8 files changed, 29 insertions(+), 49 deletions(-) diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 4099fca..0cb461e 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -131,10 +131,6 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { return fmt.Errorf("error getting consumer: %w", err) } - if viper.GetBool(config.OptForce) { - consumer.EnableOverwrite() - } - getter := &pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, diff --git a/cmd/root/root.go b/cmd/root/root.go index 30dad43..6aeff55 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -229,10 +229,6 @@ func rootExecute(ctx context.Context, urlString, dest string) error { return err } - if viper.GetBool(config.OptForce) { - consumer.EnableOverwrite() - } - getter := pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, diff --git a/pkg/config/config.go b/pkg/config/config.go index 8cef10a..57d3f53 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -155,11 +155,12 @@ func ResolveOverridesToMap(resolveOverrides []string) (map[string]string, error) // calls viper.GetString(OptExtract) internally. func GetConsumer() (consumer.Consumer, error) { consumerName := viper.GetString(OptOutputConsumer) + enableOverwrite := viper.GetBool(OptForce) switch consumerName { case ConsumerFile: - return &consumer.FileWriter{}, nil + return &consumer.FileWriter{Overwrite: enableOverwrite}, nil case ConsumerTarExtractor: - return &consumer.TarExtractor{}, nil + return &consumer.TarExtractor{Overwrite: enableOverwrite}, nil case ConsumerNull: return &consumer.NullWriter{}, nil default: diff --git a/pkg/consumer/consumer.go b/pkg/consumer/consumer.go index c406a20..39bedef 100644 --- a/pkg/consumer/consumer.go +++ b/pkg/consumer/consumer.go @@ -4,5 +4,4 @@ import "io" type Consumer interface { Consume(reader io.Reader, destPath string) error - EnableOverwrite() } diff --git a/pkg/consumer/null.go b/pkg/consumer/null.go index 4ac2e7e..6bb3836 100644 --- a/pkg/consumer/null.go +++ b/pkg/consumer/null.go @@ -13,7 +13,3 @@ func (f *NullWriter) Consume(reader io.Reader, destPath string) error { _, _ = io.Copy(io.Discard, reader) return nil } - -func (f *NullWriter) EnableOverwrite() { - // no-op -} diff --git a/pkg/consumer/tar_extractor.go b/pkg/consumer/tar_extractor.go index fcbb393..cd02071 100644 --- a/pkg/consumer/tar_extractor.go +++ b/pkg/consumer/tar_extractor.go @@ -8,19 +8,15 @@ import ( ) type TarExtractor struct { - overwrite bool + Overwrite bool } var _ Consumer = &TarExtractor{} func (f *TarExtractor) Consume(reader io.Reader, destPath string) error { - err := extract.TarFile(reader, destPath, f.overwrite) + err := extract.TarFile(reader, destPath, f.Overwrite) if err != nil { return fmt.Errorf("error extracting file: %w", err) } return nil } - -func (f *TarExtractor) EnableOverwrite() { - f.overwrite = true -} diff --git a/pkg/consumer/write_file.go b/pkg/consumer/write_file.go index 80dd423..44ecff4 100644 --- a/pkg/consumer/write_file.go +++ b/pkg/consumer/write_file.go @@ -7,14 +7,14 @@ import ( ) type FileWriter struct { - overwrite bool + Overwrite bool } var _ Consumer = &FileWriter{} func (f *FileWriter) Consume(reader io.Reader, destPath string) error { openFlags := os.O_WRONLY | os.O_CREATE - if f.overwrite { + if f.Overwrite { openFlags |= os.O_TRUNC } out, err := os.OpenFile(destPath, openFlags, 0644) @@ -29,7 +29,3 @@ func (f *FileWriter) Consume(reader io.Reader, destPath string) error { } return nil } - -func (f *FileWriter) EnableOverwrite() { - f.overwrite = true -} diff --git a/pkg/extract/tar_test.go b/pkg/extract/tar_test.go index 8daec62..3e53790 100644 --- a/pkg/extract/tar_test.go +++ b/pkg/extract/tar_test.go @@ -12,11 +12,11 @@ import ( func TestCreateLinks(t *testing.T) { tests := []struct { - name string - links []*link - expectedError bool - overwrite bool - createOverwritenFile bool + name string + links []*link + expectedError bool + overwrite bool + createFileToOverwrite bool }{ { name: "EmptyLink", @@ -43,16 +43,16 @@ func TestCreateLinks(t *testing.T) { }, }, { - name: "HardLink_OverwriteEnabled_File Exists", - links: []*link{{tar.TypeLink, "", "testLinkHard"}}, - overwrite: true, - createOverwritenFile: true, + name: "HardLink_OverwriteEnabled_File Exists", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + overwrite: true, + createFileToOverwrite: true, }, { - name: "HardLink_OverwriteDisabled_FileExists", - links: []*link{{tar.TypeLink, "", "testLinkHard"}}, - createOverwritenFile: true, - expectedError: true, + name: "HardLink_OverwriteDisabled_FileExists", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + createFileToOverwrite: true, + expectedError: true, }, { name: "HardLink_OverwriteEnabled_FileDoesNotExist", @@ -60,16 +60,16 @@ func TestCreateLinks(t *testing.T) { overwrite: true, }, { - name: "SymLink_OverwriteEnabled_FileExists", - links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, - overwrite: true, - createOverwritenFile: true, + name: "SymLink_OverwriteEnabled_FileExists", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + overwrite: true, + createFileToOverwrite: true, }, { - name: "SymLink_OverwriteDisabled_FileExists", - links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, - createOverwritenFile: true, - expectedError: true, + name: "SymLink_OverwriteDisabled_FileExists", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + createFileToOverwrite: true, + expectedError: true, }, { name: "SymLink_OverwriteEnabled_FileDoesNotExist", @@ -92,7 +92,7 @@ func TestCreateLinks(t *testing.T) { for _, link := range tt.links { if link.linkType == tar.TypeLink || link.linkType == tar.TypeSymlink { testFile, err := os.CreateTemp(destDir, "test-") - if tt.createOverwritenFile { + if tt.createFileToOverwrite { _, err = os.Create(filepath.Join(destDir, link.newName)) } if err != nil {