diff --git a/go.mod b/go.mod index 3a0ddbe..edb1946 100644 --- a/go.mod +++ b/go.mod @@ -9,10 +9,12 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.5 github.com/jarcoal/httpmock v1.3.1 github.com/mitchellh/hashstructure/v2 v2.0.2 + github.com/pierrec/lz4 v2.6.1+incompatible github.com/rs/zerolog v1.32.0 github.com/spf13/cobra v1.8.0 github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.9.0 + github.com/ulikunitz/xz v0.5.11 golang.org/x/sync v0.6.0 golang.org/x/tools v0.19.0 gotest.tools/gotestsum v1.11.0 diff --git a/go.sum b/go.sum index 454ddb3..fd8dccf 100644 --- a/go.sum +++ b/go.sum @@ -441,6 +441,8 @@ github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT9 github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM= +github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -583,6 +585,8 @@ github.com/tomarrell/wrapcheck/v2 v2.8.1 h1:HxSqDSN0sAt0yJYsrcYVoEeyM4aI9yAm3KQp github.com/tomarrell/wrapcheck/v2 v2.8.1/go.mod h1:/n2Q3NZ4XFT50ho6Hbxg+RV1uyo2Uow/Vdm9NQcl5SE= github.com/tommy-muehle/go-mnd/v2 v2.5.1 h1:NowYhSdyE/1zwK9QCLeRb6USWdoif80Ie+v+yU8u1Zw= github.com/tommy-muehle/go-mnd/v2 v2.5.1/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw= +github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8= +github.com/ulikunitz/xz v0.5.11/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ultraware/funlen v0.1.0 h1:BuqclbkY6pO+cvxoq7OsktIXZpgBSkYTQtmwhAK81vI= github.com/ultraware/funlen v0.1.0/go.mod h1:XJqmOQja6DpxarLj6Jj1U7JuoS8PvL4nEqDaQhy22p4= github.com/ultraware/whitespace v0.1.0 h1:O1HKYoh0kIeqE8sFqZf1o0qbORXUCOQFrlaQyZsczZw= diff --git a/pkg/extract/compression.go b/pkg/extract/compression.go new file mode 100644 index 0000000..74a9377 --- /dev/null +++ b/pkg/extract/compression.go @@ -0,0 +1,137 @@ +package extract + +import ( + "compress/bzip2" + "compress/gzip" + "compress/lzw" + "encoding/binary" + "io" + + "github.com/pierrec/lz4" + "github.com/ulikunitz/xz" + + "github.com/replicate/pget/pkg/logging" +) + +const ( + peekSize = 8 + + gzipMagic = 0x1F8B + bzipMagic = 0x425A + xzMagic = 0xFD377A585A00 + lzwMagic = 0x1F9D + lz4Magic = 0x184D2204 +) + +var _ decompressor = gzipDecompressor{} +var _ decompressor = bzip2Decompressor{} +var _ decompressor = xzDecompressor{} +var _ decompressor = lzwDecompressor{} +var _ decompressor = lz4Decompressor{} +var _ decompressor = noOpDecompressor{} + +// decompressor represents different compression formats. +type decompressor interface { + decompress(r io.Reader) (io.Reader, error) +} + +// detectFormat returns the appropriate extractor according to the magic number. +func detectFormat(input []byte) decompressor { + log := logging.GetLogger() + inputSize := len(input) + + if inputSize < 2 { + return noOpDecompressor{} + } + // pad to 8 bytes + if inputSize < 8 { + input = append(input, make([]byte, peekSize-inputSize)...) + } + + magic16 := binary.BigEndian.Uint16(input) + magic32 := binary.BigEndian.Uint32(input) + // We need to pre-pend the padding since we're reading into something bigendian and exceeding the + // 48bits size of the magic number bytes. The 16 and 32 bit magic numbers are complete bytes and + // therefore do not need any padding. + magic48 := binary.BigEndian.Uint64(append(make([]byte, 2), input[0:6]...)) + + switch true { + case magic16 == gzipMagic: + log.Debug(). + Str("type", "gzip"). + Msg("Compression Format") + return gzipDecompressor{} + case magic16 == bzipMagic: + log.Debug(). + Str("type", "bzip2"). + Msg("Compression Format") + return bzip2Decompressor{} + case magic16 == lzwMagic: + compressionByte := input[2] + // litWidth is guaranteed to be at least 9 per specification, the high order 3 bits of byte[2] are the litWidth + // the low order 5 bits are only used by non-unix implementations, we are going to ignore them. + litWidth := int(compressionByte>>5) + 9 + log.Debug(). + Str("type", "lzw"). + Int("litWidth", litWidth). + Msg("Compression Format") + return lzwDecompressor{ + order: lzw.MSB, + litWidth: litWidth, + } + case magic32 == lz4Magic: + log.Debug(). + Str("type", "lz4"). + Msg("Compression Format") + return lz4Decompressor{} + case magic48 == xzMagic: + log.Debug(). + Str("type", "xz"). + Msg("Compression Format") + return xzDecompressor{} + default: + log.Debug(). + Str("type", "none"). + Msg("Compression Format") + return noOpDecompressor{} + } +} + +type gzipDecompressor struct{} + +func (d gzipDecompressor) decompress(r io.Reader) (io.Reader, error) { + return gzip.NewReader(r) +} + +type bzip2Decompressor struct{} + +func (d bzip2Decompressor) decompress(r io.Reader) (io.Reader, error) { + return bzip2.NewReader(r), nil +} + +type xzDecompressor struct{} + +func (d xzDecompressor) decompress(r io.Reader) (io.Reader, error) { + return xz.NewReader(r) +} + +type lzwDecompressor struct { + litWidth int + order lzw.Order +} + +func (d lzwDecompressor) decompress(r io.Reader) (io.Reader, error) { + return lzw.NewReader(r, d.order, d.litWidth), nil +} + +type lz4Decompressor struct{} + +func (d lz4Decompressor) decompress(r io.Reader) (io.Reader, error) { + return lz4.NewReader(r), nil +} + +type noOpDecompressor struct{} + +func (d noOpDecompressor) decompress(r io.Reader) (io.Reader, error) { + return r, nil +} diff --git a/pkg/extract/compression_test.go b/pkg/extract/compression_test.go new file mode 100644 index 0000000..07c9ed0 --- /dev/null +++ b/pkg/extract/compression_test.go @@ -0,0 +1,56 @@ +package extract + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDetectFormat(t *testing.T) { + tests := []struct { + name string + input []byte + expectType string + }{ + { + name: "GZIP", + input: []byte{0x1f, 0x8b}, + expectType: "extract.gzipDecompressor", + }, + { + name: "BZIP2", + input: []byte{0x42, 0x5a}, + expectType: "extract.bzip2Decompressor", + }, + { + name: "XZ", + input: []byte{0xfd, 0x37, 0x7a, 0x58, 0x5a, 0x00}, + expectType: "extract.xzDecompressor", + }, + { + name: "Less than 2 bytes", + input: []byte{0x1f}, + expectType: "extract.noOpDecompressor", + }, + { + name: "UNKNOWN", + input: []byte{0xde, 0xad}, + expectType: "extract.noOpDecompressor", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := detectFormat(tt.input) + assert.Equal(t, tt.expectType, stringFromInterface(result)) + }) + } +} + +func stringFromInterface(i interface{}) string { + if i == nil { + return "" + } + return fmt.Sprintf("%T", i) +} diff --git a/pkg/extract/peekreader.go b/pkg/extract/peekreader.go new file mode 100644 index 0000000..6cbf9b9 --- /dev/null +++ b/pkg/extract/peekreader.go @@ -0,0 +1,49 @@ +package extract + +import ( + "bytes" + "errors" + "io" +) + +type readPeeker interface { + io.Reader + Peek(int) ([]byte, error) +} + +var _ io.Reader = &peekReader{} +var _ readPeeker = &peekReader{} + +type peekReader struct { + reader io.Reader + buffer *bytes.Buffer +} + +func (p *peekReader) Read(b []byte) (int, error) { + if p.buffer != nil { + if p.buffer.Len() > 0 { + n, err := p.buffer.Read(b) + if errors.Is(err, io.EOF) { + err = nil + } + return n, err + } + } + return p.reader.Read(b) +} + +func (p *peekReader) Peek(n int) ([]byte, error) { + return p.peek(n) +} + +func (p *peekReader) peek(n int) ([]byte, error) { + if p.buffer == nil { + p.buffer = bytes.NewBuffer(make([]byte, 0, n)) + } + // Read the next n bytes from the reader + _, err := io.CopyN(p.buffer, p.reader, int64(n)) + if err != nil { + return p.buffer.Bytes(), err + } + return p.buffer.Bytes(), nil +} diff --git a/pkg/extract/peekreader_test.go b/pkg/extract/peekreader_test.go new file mode 100644 index 0000000..80e4177 --- /dev/null +++ b/pkg/extract/peekreader_test.go @@ -0,0 +1,72 @@ +package extract + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPeekReader_Read(t *testing.T) { + tests := []struct { + name string + readerContents string + wantBytesPeek int + wantBytesRead int + wantErr bool + }{ + { + name: "read from buffer only", + readerContents: "abc123", + wantBytesPeek: 6, + wantBytesRead: 6, + wantErr: false, + }, + { + name: "read from reader only", + readerContents: "abc123", + wantBytesRead: 3, + wantErr: false, + }, + { + name: "read from both buffer and reader", + readerContents: "abc123", + wantBytesPeek: 3, + wantBytesRead: 6, + wantErr: false, + }, + { + name: "read empty reader and buffer", + readerContents: "", + wantBytesRead: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := strings.NewReader(tt.readerContents) + p := &peekReader{reader: reader} + if tt.wantBytesPeek > 0 { + peekBytes, err := p.Peek(tt.wantBytesPeek) + assert.NoError(t, err) + assert.Equal(t, tt.readerContents[0:tt.wantBytesPeek], string(peekBytes)) + } + + var totalBytesRead int + var err error + readBytes := make([]byte, tt.wantBytesRead) + for totalBytesRead < tt.wantBytesRead && err == nil { + bytesRead, err := p.Read(readBytes[totalBytesRead:]) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + totalBytesRead += bytesRead + } + assert.Equal(t, tt.wantBytesRead, totalBytesRead) + assert.Equal(t, tt.readerContents[0:tt.wantBytesRead], string(readBytes)) + }) + } +} diff --git a/pkg/extract/tar.go b/pkg/extract/tar.go index 1cc1810..92e3f0d 100644 --- a/pkg/extract/tar.go +++ b/pkg/extract/tar.go @@ -22,10 +22,23 @@ type link struct { newName string } -func TarFile(reader io.Reader, destDir string, overwrite bool) error { +func TarFile(r io.Reader, destDir string, overwrite bool) error { var links []*link startTime := time.Now() + peekReader := &peekReader{reader: r} + peekData, err := peekReader.peek(peekSize) + if err != nil { + return fmt.Errorf("error reading peek data: %w", err) + } + decompressor := detectFormat(peekData) + if err != nil { + return fmt.Errorf("error detecting format: %w", err) + } + reader, err := decompressor.decompress(peekReader) + if err != nil { + return fmt.Errorf("error creating decompressed stream: %w", err) + } tarReader := tar.NewReader(reader) logger := logging.GetLogger()