Skip to content

Commit

Permalink
Prevent ZipSlip
Browse files Browse the repository at this point in the history
  • Loading branch information
tempusfrangit committed Feb 29, 2024
1 parent c2f3c5e commit 0762f82
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 2 deletions.
32 changes: 30 additions & 2 deletions pkg/extract/tar.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"os"
"path/filepath"
"strings"
"time"

"github.com/replicate/pget/pkg/logging"
Expand Down Expand Up @@ -43,17 +44,21 @@ func TarFile(reader io.Reader, destDir string, overwrite bool) error {
return err
}

if err := guardAgainstZipSlip(header, destDir); err != nil {
return err
}

switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil {
if err := os.MkdirAll(target, cleanFileMode(os.FileMode(header.Mode))); err != nil {
return err
}
case tar.TypeReg:
openFlags := os.O_CREATE | os.O_WRONLY
if overwrite {
openFlags |= os.O_TRUNC
}
targetFile, err := os.OpenFile(target, openFlags, os.FileMode(header.Mode))
targetFile, err := os.OpenFile(target, openFlags, cleanFileMode(os.FileMode(header.Mode)))
if err != nil {
return err
}
Expand Down Expand Up @@ -127,3 +132,26 @@ func createSymlink(oldName, newName string, overwrite bool) error {
}
return os.Symlink(oldName, newName)
}

func guardAgainstZipSlip(header *tar.Header, destDir string) error {
if header.Name == "" {
return fmt.Errorf("tar file contains entry with empty name")
}
target, err := filepath.Abs(filepath.Join(destDir, header.Name))
if err != nil {
return fmt.Errorf("error getting absolute path of destDir %s: %w", header.Name, err)
}
filePath, err := filepath.Abs(target)
if err != nil {
return fmt.Errorf("error getting absolute path of %s: %w", target, err)
}
if !strings.HasPrefix(filePath, destDir) {
return fmt.Errorf("archive (tar) file contains file (%s) outside of target directory: %s", filePath, target)
}
return nil
}

func cleanFileMode(mode os.FileMode) os.FileMode {
mask := os.ModeSticky | os.ModeSetuid | os.ModeSetgid
return mode &^ mask
}
107 changes: 107 additions & 0 deletions pkg/extract/tar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,110 @@ func assertSymlinkTarget(t *testing.T, oldName, newName string) {
assert.Equal(t, fileStat.Sys().(*syscall.Stat_t).Ino,
realTarget.Sys().(*syscall.Stat_t).Ino)
}

func TestGuardAgainstZipSlip(t *testing.T) {
tests := []struct {
description string
header *tar.Header
destDir string
expectedError string
}{
{
description: "valid file path within directory",
header: &tar.Header{
Name: "valid_file",
},
destDir: "/tmp/valid_dir",
expectedError: "",
},
{
description: "file path outside directory",
header: &tar.Header{
Name: "../invalid_file",
},
destDir: "/tmp/valid_dir",
expectedError: "archive (tar) file contains file (/tmp/invalid_file) outside of target directory: ",
},
{
description: "directory traversal with invalid file",
header: &tar.Header{
Name: "./../../tmp/invalid_dir/invalid_file",
},
destDir: "/tmp/valid_dir",
expectedError: "archive (tar) file contains file (/tmp/invalid_dir/invalid_file) outside of target directory: ",
},
{
description: "Empty header name",
header: &tar.Header{
Name: "",
},
destDir: "/tmp",
expectedError: "tar file contains entry with empty name",
},
}

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
err := guardAgainstZipSlip(test.header, test.destDir)
if test.expectedError != "" {
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.expectedError)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestCleanFileMode(t *testing.T) {
testCases := []struct {
name string
input os.FileMode
expected os.FileMode
}{
{
name: "TestWithoutStickyBit",
input: 0755,
expected: 0755,
},
{
name: "TestWithStickyBit",
input: os.ModeSticky | 0755,
expected: 0755,
},
{
name: "TestWithoutSetuidBit",
input: 0600,
expected: 0600,
},
{
name: "TestWithSetuidBit",
input: os.ModeSetuid | 0600,
expected: 0600,
},
{
name: "TestWithoutSetgidBit",
input: 0777,
expected: 0777,
},
{
name: "TestWithSetgidBit",
input: os.ModeSetgid | 0777,
expected: 0777,
},
{
name: "TestWithAllBits",
input: os.ModeSticky | os.ModeSetuid | os.ModeSetgid | 0777,
expected: 0777,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := cleanFileMode(tc.input)
if result != tc.expected {
t.Errorf("cleanFileMode() = %v, want %v", result, tc.expected)
}
})
}
}

0 comments on commit 0762f82

Please sign in to comment.