Skip to content

Commit

Permalink
Implement Stream Decompression for tar
Browse files Browse the repository at this point in the history
Implement stream decompression for tar files. This mean that -x becomes
smart enough to handle lzw, bzip2, gzip, and xz (common tar compression)
formats automatically. This will remove a sharp edge on pget and handle
cases of compressed tar files elegantly.
  • Loading branch information
tempusfrangit committed Mar 20, 2024
1 parent 6c4f66b commit cda1292
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 1 deletion.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ require (
github.com/hashicorp/go-retryablehttp v0.7.5
github.com/jarcoal/httpmock v1.3.1
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/pierrec/lz4 v2.6.1+incompatible
github.com/rs/zerolog v1.32.0
github.com/spf13/cobra v1.8.0
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.9.0
github.com/ulikunitz/xz v0.5.11
golang.org/x/sync v0.6.0
golang.org/x/tools v0.19.0
gotest.tools/gotestsum v1.11.0
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT9
github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM=
github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
Expand Down Expand Up @@ -583,6 +585,8 @@ github.com/tomarrell/wrapcheck/v2 v2.8.1 h1:HxSqDSN0sAt0yJYsrcYVoEeyM4aI9yAm3KQp
github.com/tomarrell/wrapcheck/v2 v2.8.1/go.mod h1:/n2Q3NZ4XFT50ho6Hbxg+RV1uyo2Uow/Vdm9NQcl5SE=
github.com/tommy-muehle/go-mnd/v2 v2.5.1 h1:NowYhSdyE/1zwK9QCLeRb6USWdoif80Ie+v+yU8u1Zw=
github.com/tommy-muehle/go-mnd/v2 v2.5.1/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw=
github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8=
github.com/ulikunitz/xz v0.5.11/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/ultraware/funlen v0.1.0 h1:BuqclbkY6pO+cvxoq7OsktIXZpgBSkYTQtmwhAK81vI=
github.com/ultraware/funlen v0.1.0/go.mod h1:XJqmOQja6DpxarLj6Jj1U7JuoS8PvL4nEqDaQhy22p4=
github.com/ultraware/whitespace v0.1.0 h1:O1HKYoh0kIeqE8sFqZf1o0qbORXUCOQFrlaQyZsczZw=
Expand Down
137 changes: 137 additions & 0 deletions pkg/extract/compression.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package extract

import (
"compress/bzip2"
"compress/gzip"
"compress/lzw"
"encoding/binary"
"io"

"github.com/pierrec/lz4"
"github.com/ulikunitz/xz"

"github.com/replicate/pget/pkg/logging"
)

const (
peekSize = 8

gzipMagic = 0x1F8B
bzipMagic = 0x425A
xzMagic = 0xFD377A585A00
lzwMagic = 0x1F9D
lz4Magic = 0x184D2204
)

var _ decompressor = gzipDecompressor{}
var _ decompressor = bzip2Decompressor{}
var _ decompressor = xzDecompressor{}
var _ decompressor = lzwDecompressor{}
var _ decompressor = lz4Decompressor{}
var _ decompressor = noOpDecompressor{}

// decompressor represents different compression formats.
type decompressor interface {
decompress(r io.Reader) (io.Reader, error)
}

// detectFormat returns the appropriate extractor according to the magic number.
func detectFormat(input []byte) decompressor {
log := logging.GetLogger()
inputSize := len(input)

if inputSize < 2 {
return noOpDecompressor{}
}
// pad to 8 bytes
if inputSize < 8 {
input = append(input, make([]byte, peekSize-inputSize)...)
}

magic16 := binary.BigEndian.Uint16(input)
magic32 := binary.BigEndian.Uint32(input)
// We need to pre-pend the padding since we're reading into something bigendian and exceeding the
// 48bits size of the magic number bytes. The 16 and 32 bit magic numbers are complete bytes and
// therefore do not need any padding.
magic48 := binary.BigEndian.Uint64(append(make([]byte, 2), input[0:6]...))

switch true {
case magic16 == gzipMagic:
log.Debug().
Str("type", "gzip").
Msg("Compression Format")
return gzipDecompressor{}
case magic16 == bzipMagic:
log.Debug().
Str("type", "bzip2").
Msg("Compression Format")
return bzip2Decompressor{}
case magic16 == lzwMagic:
compressionByte := input[2]
// litWidth is guaranteed to be at least 9 per specification, the high order 3 bits of byte[2] are the litWidth
// the low order 5 bits are only used by non-unix implementations, we are going to ignore them.
litWidth := int(compressionByte>>5) + 9
log.Debug().
Str("type", "lzw").
Int("litWidth", litWidth).
Msg("Compression Format")
return lzwDecompressor{
order: lzw.MSB,
litWidth: litWidth,
}
case magic32 == lz4Magic:
log.Debug().
Str("type", "lz4").
Msg("Compression Format")
return lz4Decompressor{}
case magic48 == xzMagic:
log.Debug().
Str("type", "xz").
Msg("Compression Format")
return xzDecompressor{}
default:
log.Debug().
Str("type", "none").
Msg("Compression Format")
return noOpDecompressor{}
}
}

type gzipDecompressor struct{}

func (d gzipDecompressor) decompress(r io.Reader) (io.Reader, error) {
return gzip.NewReader(r)
}

type bzip2Decompressor struct{}

func (d bzip2Decompressor) decompress(r io.Reader) (io.Reader, error) {
return bzip2.NewReader(r), nil
}

type xzDecompressor struct{}

func (d xzDecompressor) decompress(r io.Reader) (io.Reader, error) {
return xz.NewReader(r)
}

type lzwDecompressor struct {
litWidth int
order lzw.Order
}

func (d lzwDecompressor) decompress(r io.Reader) (io.Reader, error) {
return lzw.NewReader(r, d.order, d.litWidth), nil
}

type lz4Decompressor struct{}

func (d lz4Decompressor) decompress(r io.Reader) (io.Reader, error) {
return lz4.NewReader(r), nil
}

type noOpDecompressor struct{}

func (d noOpDecompressor) decompress(r io.Reader) (io.Reader, error) {
return r, nil
}
56 changes: 56 additions & 0 deletions pkg/extract/compression_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package extract

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestDetectFormat(t *testing.T) {
tests := []struct {
name string
input []byte
expectType string
}{
{
name: "GZIP",
input: []byte{0x1f, 0x8b},
expectType: "extract.gzipDecompressor",
},
{
name: "BZIP2",
input: []byte{0x42, 0x5a},
expectType: "extract.bzip2Decompressor",
},
{
name: "XZ",
input: []byte{0xfd, 0x37, 0x7a, 0x58, 0x5a, 0x00},
expectType: "extract.xzDecompressor",
},
{
name: "Less than 2 bytes",
input: []byte{0x1f},
expectType: "extract.noOpDecompressor",
},
{
name: "UNKNOWN",
input: []byte{0xde, 0xad},
expectType: "extract.noOpDecompressor",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := detectFormat(tt.input)
assert.Equal(t, tt.expectType, stringFromInterface(result))
})
}
}

func stringFromInterface(i interface{}) string {
if i == nil {
return ""
}
return fmt.Sprintf("%T", i)
}
49 changes: 49 additions & 0 deletions pkg/extract/peekreader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package extract

import (
"bytes"
"errors"
"io"
)

type readPeeker interface {
io.Reader
Peek(int) ([]byte, error)
}

var _ io.Reader = &peekReader{}
var _ readPeeker = &peekReader{}

type peekReader struct {
reader io.Reader
buffer *bytes.Buffer
}

func (p *peekReader) Read(b []byte) (int, error) {
if p.buffer != nil {
if p.buffer.Len() > 0 {
n, err := p.buffer.Read(b)
if errors.Is(err, io.EOF) {
err = nil
}
return n, err
}
}
return p.reader.Read(b)
}

func (p *peekReader) Peek(n int) ([]byte, error) {
return p.peek(n)
}

func (p *peekReader) peek(n int) ([]byte, error) {
if p.buffer == nil {
p.buffer = bytes.NewBuffer(make([]byte, 0, n))
}
// Read the next n bytes from the reader
_, err := io.CopyN(p.buffer, p.reader, int64(n))
if err != nil {
return p.buffer.Bytes(), err
}
return p.buffer.Bytes(), nil
}
72 changes: 72 additions & 0 deletions pkg/extract/peekreader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package extract

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestPeekReader_Read(t *testing.T) {
tests := []struct {
name string
readerContents string
wantBytesPeek int
wantBytesRead int
wantErr bool
}{
{
name: "read from buffer only",
readerContents: "abc123",
wantBytesPeek: 6,
wantBytesRead: 6,
wantErr: false,
},
{
name: "read from reader only",
readerContents: "abc123",
wantBytesRead: 3,
wantErr: false,
},
{
name: "read from both buffer and reader",
readerContents: "abc123",
wantBytesPeek: 3,
wantBytesRead: 6,
wantErr: false,
},
{
name: "read empty reader and buffer",
readerContents: "",
wantBytesRead: 0,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := strings.NewReader(tt.readerContents)
p := &peekReader{reader: reader}
if tt.wantBytesPeek > 0 {
peekBytes, err := p.Peek(tt.wantBytesPeek)
assert.NoError(t, err)
assert.Equal(t, tt.readerContents[0:tt.wantBytesPeek], string(peekBytes))
}

var totalBytesRead int
var err error
readBytes := make([]byte, tt.wantBytesRead)
for totalBytesRead < tt.wantBytesRead && err == nil {
bytesRead, err := p.Read(readBytes[totalBytesRead:])
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
totalBytesRead += bytesRead
}
assert.Equal(t, tt.wantBytesRead, totalBytesRead)
assert.Equal(t, tt.readerContents[0:tt.wantBytesRead], string(readBytes))
})
}
}
15 changes: 14 additions & 1 deletion pkg/extract/tar.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,23 @@ type link struct {
newName string
}

func TarFile(reader io.Reader, destDir string, overwrite bool) error {
func TarFile(r io.Reader, destDir string, overwrite bool) error {
var links []*link

startTime := time.Now()
peekReader := &peekReader{reader: r}
peekData, err := peekReader.peek(peekSize)
if err != nil {
return fmt.Errorf("error reading peek data: %w", err)
}
decompressor := detectFormat(peekData)
if err != nil {
return fmt.Errorf("error detecting format: %w", err)
}
reader, err := decompressor.decompress(peekReader)
if err != nil {
return fmt.Errorf("error creating decompressed stream: %w", err)
}
tarReader := tar.NewReader(reader)
logger := logging.GetLogger()

Expand Down

0 comments on commit cda1292

Please sign in to comment.