Skip to content

Commit

Permalink
Move Enabling Overwrite to .GetConsumer()
Browse files Browse the repository at this point in the history
Move the enabling of the force-flag (telling the consumer to overwrite
files) to the function .GetConsumer() in config. Config is already doing
the lifting to construct the consumer, add the query to Viper at that
time.
  • Loading branch information
tempusfrangit committed Feb 29, 2024
1 parent 0762f82 commit 6c4f3dd
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 49 deletions.
4 changes: 0 additions & 4 deletions cmd/multifile/multifile.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error {
return fmt.Errorf("error getting consumer: %w", err)
}

if viper.GetBool(config.OptForce) {
consumer.EnableOverwrite()
}

getter := &pget.Getter{
Downloader: download.GetBufferMode(downloadOpts),
Consumer: consumer,
Expand Down
4 changes: 0 additions & 4 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,6 @@ func rootExecute(ctx context.Context, urlString, dest string) error {
return err
}

if viper.GetBool(config.OptForce) {
consumer.EnableOverwrite()
}

getter := pget.Getter{
Downloader: download.GetBufferMode(downloadOpts),
Consumer: consumer,
Expand Down
5 changes: 3 additions & 2 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,12 @@ func ResolveOverridesToMap(resolveOverrides []string) (map[string]string, error)
// calls viper.GetString(OptExtract) internally.
func GetConsumer() (consumer.Consumer, error) {
consumerName := viper.GetString(OptOutputConsumer)
enableOverwrite := viper.GetBool(OptForce)
switch consumerName {
case ConsumerFile:
return &consumer.FileWriter{}, nil
return &consumer.FileWriter{Overwrite: enableOverwrite}, nil
case ConsumerTarExtractor:
return &consumer.TarExtractor{}, nil
return &consumer.TarExtractor{Overwrite: enableOverwrite}, nil
case ConsumerNull:
return &consumer.NullWriter{}, nil
default:
Expand Down
1 change: 0 additions & 1 deletion pkg/consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ import "io"

type Consumer interface {
Consume(reader io.Reader, destPath string) error
EnableOverwrite()
}
4 changes: 0 additions & 4 deletions pkg/consumer/null.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,3 @@ func (f *NullWriter) Consume(reader io.Reader, destPath string) error {
_, _ = io.Copy(io.Discard, reader)
return nil
}

func (f *NullWriter) EnableOverwrite() {
// no-op
}
8 changes: 2 additions & 6 deletions pkg/consumer/tar_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,15 @@ import (
)

type TarExtractor struct {
overwrite bool
Overwrite bool
}

var _ Consumer = &TarExtractor{}

func (f *TarExtractor) Consume(reader io.Reader, destPath string) error {
err := extract.TarFile(reader, destPath, f.overwrite)
err := extract.TarFile(reader, destPath, f.Overwrite)
if err != nil {
return fmt.Errorf("error extracting file: %w", err)
}
return nil
}

func (f *TarExtractor) EnableOverwrite() {
f.overwrite = true
}
8 changes: 2 additions & 6 deletions pkg/consumer/write_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import (
)

type FileWriter struct {
overwrite bool
Overwrite bool
}

var _ Consumer = &FileWriter{}

func (f *FileWriter) Consume(reader io.Reader, destPath string) error {
openFlags := os.O_WRONLY | os.O_CREATE
if f.overwrite {
if f.Overwrite {
openFlags |= os.O_TRUNC
}
out, err := os.OpenFile(destPath, openFlags, 0644)
Expand All @@ -29,7 +29,3 @@ func (f *FileWriter) Consume(reader io.Reader, destPath string) error {
}
return nil
}

func (f *FileWriter) EnableOverwrite() {
f.overwrite = true
}
44 changes: 22 additions & 22 deletions pkg/extract/tar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (

func TestCreateLinks(t *testing.T) {
tests := []struct {
name string
links []*link
expectedError bool
overwrite bool
createOverwritenFile bool
name string
links []*link
expectedError bool
overwrite bool
createFileToOverwrite bool
}{
{
name: "EmptyLink",
Expand All @@ -43,33 +43,33 @@ func TestCreateLinks(t *testing.T) {
},
},
{
name: "HardLink_OverwriteEnabled_File Exists",
links: []*link{{tar.TypeLink, "", "testLinkHard"}},
overwrite: true,
createOverwritenFile: true,
name: "HardLink_OverwriteEnabled_File Exists",
links: []*link{{tar.TypeLink, "", "testLinkHard"}},
overwrite: true,
createFileToOverwrite: true,
},
{
name: "HardLink_OverwriteDisabled_FileExists",
links: []*link{{tar.TypeLink, "", "testLinkHard"}},
createOverwritenFile: true,
expectedError: true,
name: "HardLink_OverwriteDisabled_FileExists",
links: []*link{{tar.TypeLink, "", "testLinkHard"}},
createFileToOverwrite: true,
expectedError: true,
},
{
name: "HardLink_OverwriteEnabled_FileDoesNotExist",
links: []*link{{tar.TypeLink, "", "testLinkHard"}},
overwrite: true,
},
{
name: "SymLink_OverwriteEnabled_FileExists",
links: []*link{{tar.TypeSymlink, "", "testLinkSym"}},
overwrite: true,
createOverwritenFile: true,
name: "SymLink_OverwriteEnabled_FileExists",
links: []*link{{tar.TypeSymlink, "", "testLinkSym"}},
overwrite: true,
createFileToOverwrite: true,
},
{
name: "SymLink_OverwriteDisabled_FileExists",
links: []*link{{tar.TypeSymlink, "", "testLinkSym"}},
createOverwritenFile: true,
expectedError: true,
name: "SymLink_OverwriteDisabled_FileExists",
links: []*link{{tar.TypeSymlink, "", "testLinkSym"}},
createFileToOverwrite: true,
expectedError: true,
},
{
name: "SymLink_OverwriteEnabled_FileDoesNotExist",
Expand All @@ -92,7 +92,7 @@ func TestCreateLinks(t *testing.T) {
for _, link := range tt.links {
if link.linkType == tar.TypeLink || link.linkType == tar.TypeSymlink {
testFile, err := os.CreateTemp(destDir, "test-")
if tt.createOverwritenFile {
if tt.createFileToOverwrite {
_, err = os.Create(filepath.Join(destDir, link.newName))
}
if err != nil {
Expand Down

0 comments on commit 6c4f3dd

Please sign in to comment.