Skip to content

Commit

Permalink
Handle Overwriting in the Consumer
Browse files Browse the repository at this point in the history
Make the consumer handle overwriting explicitly. This addresses edge
cases with tar when extracting files.
  • Loading branch information
tempusfrangit committed Feb 23, 2024
1 parent ba016dd commit 16ca4d2
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 17 deletions.
4 changes: 4 additions & 0 deletions cmd/multifile/multifile.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error {
return fmt.Errorf("error getting consumer: %w", err)
}

if viper.GetBool(config.OptForce) {
consumer.EnableOverwrite()
}

getter := &pget.Getter{
Downloader: download.GetBufferMode(downloadOpts),
Consumer: consumer,
Expand Down
4 changes: 4 additions & 0 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ func rootExecute(ctx context.Context, urlString, dest string) error {
return err
}

if viper.GetBool(config.OptForce) {
consumer.EnableOverwrite()
}

getter := pget.Getter{
Downloader: download.GetBufferMode(downloadOpts),
Consumer: consumer,
Expand Down
1 change: 1 addition & 0 deletions pkg/consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ import "io"

type Consumer interface {
Consume(reader io.Reader, destPath string) error
EnableOverwrite()
}
4 changes: 4 additions & 0 deletions pkg/consumer/null.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ func (f *NullWriter) Consume(reader io.Reader, destPath string) error {
_, _ = io.Copy(io.Discard, reader)
return nil
}

func (f *NullWriter) EnableOverwrite() {
// no-op
}
10 changes: 8 additions & 2 deletions pkg/consumer/tar_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ import (
"github.com/replicate/pget/pkg/extract"
)

type TarExtractor struct{}
type TarExtractor struct {
overwrite bool
}

var _ Consumer = &TarExtractor{}

func (f *TarExtractor) Consume(reader io.Reader, destPath string) error {
err := extract.TarFile(reader, destPath)
err := extract.TarFile(reader, destPath, f.overwrite)
if err != nil {
return fmt.Errorf("error extracting file: %w", err)
}
return nil
}

func (f *TarExtractor) EnableOverwrite() {
f.overwrite = true
}
17 changes: 12 additions & 5 deletions pkg/consumer/write_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ import (
"os"
)

type FileWriter struct{}
type FileWriter struct {
overwrite bool
}

var _ Consumer = &FileWriter{}

func (f *FileWriter) Consume(reader io.Reader, destPath string) error {
// NOTE(morgan): We check if the file exists early on allowing a fast fail, it is safe
// to just apply os.O_TRUNC. Getting to this point without checking existence and
// the `--force` flag is a programming error further up the stack.
out, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
openFlags := os.O_WRONLY | os.O_CREATE
if f.overwrite {
openFlags |= os.O_TRUNC
}
out, err := os.OpenFile(destPath, openFlags, 0644)
if err != nil {
return fmt.Errorf("error writing file: %w", err)
}
Expand All @@ -26,3 +29,7 @@ func (f *FileWriter) Consume(reader io.Reader, destPath string) error {
}
return nil
}

func (f *FileWriter) EnableOverwrite() {
f.overwrite = true
}
36 changes: 30 additions & 6 deletions pkg/extract/tar.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type link struct {
newName string
}

func TarFile(reader io.Reader, destDir string) error {
func TarFile(reader io.Reader, destDir string, overwrite bool) error {
var links []*link

startTime := time.Now()
Expand Down Expand Up @@ -49,7 +49,11 @@ func TarFile(reader io.Reader, destDir string) error {
return err
}
case tar.TypeReg:
targetFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY, os.FileMode(header.Mode))
openFlags := os.O_CREATE | os.O_WRONLY
if overwrite {
openFlags |= os.O_TRUNC
}
targetFile, err := os.OpenFile(target, openFlags, os.FileMode(header.Mode))
if err != nil {
return err
}
Expand All @@ -68,7 +72,7 @@ func TarFile(reader io.Reader, destDir string) error {
}
}

if err := createLinks(links, destDir); err != nil {
if err := createLinks(links, destDir, overwrite); err != nil {
return fmt.Errorf("error creating links: %w", err)
}

Expand All @@ -81,7 +85,7 @@ func TarFile(reader io.Reader, destDir string) error {
return nil
}

func createLinks(links []*link, destDir string) error {
func createLinks(links []*link, destDir string, overwrite bool) error {
for _, link := range links {
targetDir := filepath.Dir(link.newName)
if err := os.MkdirAll(targetDir, 0755); err != nil {
Expand All @@ -90,11 +94,11 @@ func createLinks(links []*link, destDir string) error {
switch link.linkType {
case tar.TypeLink:
oldPath := filepath.Join(destDir, link.oldName)
if err := os.Link(oldPath, link.newName); err != nil {
if err := createHardLink(oldPath, link.newName, overwrite); err != nil {
return fmt.Errorf("error creating hard link from %s to %s: %w", oldPath, link.newName, err)
}
case tar.TypeSymlink:
if err := os.Symlink(link.oldName, link.newName); err != nil {
if err := createSymlink(link.oldName, link.newName, overwrite); err != nil {
return fmt.Errorf("error creating symlink from %s to %s: %w", link.oldName, link.newName, err)
}
default:
Expand All @@ -103,3 +107,23 @@ func createLinks(links []*link, destDir string) error {
}
return nil
}

func createHardLink(oldName, newName string, overwrite bool) error {
if overwrite {
err := os.Remove(newName)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("error removing existing file: %w", err)
}
}
return os.Link(oldName, newName)
}

func createSymlink(oldName, newName string, overwrite bool) error {
if overwrite {
err := os.Remove(newName)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("error removing existing symlink/file: %w", err)
}
}
return os.Symlink(oldName, newName)
}
47 changes: 43 additions & 4 deletions pkg/extract/tar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ import (

func TestCreateLinks(t *testing.T) {
tests := []struct {
name string
links []*link
expectedError bool
name string
links []*link
expectedError bool
overwrite bool
createOverwritenFile bool
}{
{
name: "EmptyLink",
Expand All @@ -40,6 +42,40 @@ func TestCreateLinks(t *testing.T) {
{tar.TypeSymlink, "", "testLinkSym"},
},
},
{
name: "HardLink_OverwriteEnabled_File Exists",
links: []*link{{tar.TypeLink, "", "testLinkHard"}},
overwrite: true,
createOverwritenFile: true,
},
{
name: "HardLink_OverwriteDisabled_FileExists",
links: []*link{{tar.TypeLink, "", "testLinkHard"}},
createOverwritenFile: true,
expectedError: true,
},
{
name: "HardLink_OverwriteEnabled_FileDoesNotExist",
links: []*link{{tar.TypeLink, "", "testLinkHard"}},
overwrite: true,
},
{
name: "SymLink_OverwriteEnabled_FileExists",
links: []*link{{tar.TypeSymlink, "", "testLinkSym"}},
overwrite: true,
createOverwritenFile: true,
},
{
name: "SymLink_OverwriteDisabled_FileExists",
links: []*link{{tar.TypeSymlink, "", "testLinkSym"}},
createOverwritenFile: true,
expectedError: true,
},
{
name: "SymLink_OverwriteEnabled_FileDoesNotExist",
links: []*link{{tar.TypeSymlink, "", "testLinkSym"}},
overwrite: true,
},
}

for _, tt := range tests {
Expand All @@ -56,6 +92,9 @@ func TestCreateLinks(t *testing.T) {
for _, link := range tt.links {
if link.linkType == tar.TypeLink || link.linkType == tar.TypeSymlink {
testFile, err := os.CreateTemp(destDir, "test-")
if tt.createOverwritenFile {
_, err = os.Create(filepath.Join(destDir, link.newName))
}
if err != nil {
t.Fatalf("Test failed, could not create test file: %v", err)
}
Expand All @@ -65,7 +104,7 @@ func TestCreateLinks(t *testing.T) {
}
}

err = createLinks(tt.links, destDir)
err = createLinks(tt.links, destDir, tt.overwrite)

// Validation
if tt.expectedError {
Expand Down

0 comments on commit 16ca4d2

Please sign in to comment.