From 368e5263facfcbc6ca929e18254e88c6cdd7ce51 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Sat, 27 Apr 2024 12:59:34 -0700 Subject: [PATCH] Improve Stream Decompression Instead of implementing the peek reader, wrap everything in bufio.Reader and lean on it's peek() capabilities. This also allows for a simpler bytes.HasPrefix use for the magic numbers instead of needing to deal with the Endian-ness of the magic bytes -- critically this eliminates the padding for the 48-bit magic bytes header for some compression types. --- pkg/extract/compression.go | 36 +++++++++-------- pkg/extract/peekreader.go | 49 ----------------------- pkg/extract/peekreader_test.go | 72 ---------------------------------- pkg/extract/tar.go | 10 ++--- 4 files changed, 23 insertions(+), 144 deletions(-) delete mode 100644 pkg/extract/peekreader.go delete mode 100644 pkg/extract/peekreader_test.go diff --git a/pkg/extract/compression.go b/pkg/extract/compression.go index 74a9377..cbbba56 100644 --- a/pkg/extract/compression.go +++ b/pkg/extract/compression.go @@ -1,10 +1,10 @@ package extract import ( + "bytes" "compress/bzip2" "compress/gzip" "compress/lzw" - "encoding/binary" "io" "github.com/pierrec/lz4" @@ -15,12 +15,14 @@ import ( const ( peekSize = 8 +) - gzipMagic = 0x1F8B - bzipMagic = 0x425A - xzMagic = 0xFD377A585A00 - lzwMagic = 0x1F9D - lz4Magic = 0x184D2204 +var ( + gzipMagic = []byte{0x1F, 0x8B} + bzipMagic = []byte{0x42, 0x5A} + xzMagic = []byte{0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00} + lzwMagic = []byte{0x1F, 0x9D} + lz4Magic = []byte{0x18, 0x4D, 0x22, 0x04} ) var _ decompressor = gzipDecompressor{} @@ -48,25 +50,25 @@ func detectFormat(input []byte) decompressor { input = append(input, make([]byte, peekSize-inputSize)...) } - magic16 := binary.BigEndian.Uint16(input) - magic32 := binary.BigEndian.Uint32(input) - // We need to pre-pend the padding since we're reading into something bigendian and exceeding the - // 48bits size of the magic number bytes. The 16 and 32 bit magic numbers are complete bytes and - // therefore do not need any padding. - magic48 := binary.BigEndian.Uint64(append(make([]byte, 2), input[0:6]...)) + // magic16 := binary.BigEndian.Uint16(input) + // magic32 := binary.BigEndian.Uint32(input) + // // We need to pre-pend the padding since we're reading into something bigendian and exceeding the + // // 48bits size of the magic number bytes. The 16 and 32 bit magic numbers are complete bytes and + // // therefore do not need any padding. + // magic48 := binary.BigEndian.Uint64(append(make([]byte, 2), input[0:6]...)) switch true { - case magic16 == gzipMagic: + case bytes.HasPrefix(input, gzipMagic): log.Debug(). Str("type", "gzip"). Msg("Compression Format") return gzipDecompressor{} - case magic16 == bzipMagic: + case bytes.HasPrefix(input, bzipMagic): log.Debug(). Str("type", "bzip2"). Msg("Compression Format") return bzip2Decompressor{} - case magic16 == lzwMagic: + case bytes.HasPrefix(input, lzwMagic): compressionByte := input[2] // litWidth is guaranteed to be at least 9 per specification, the high order 3 bits of byte[2] are the litWidth // the low order 5 bits are only used by non-unix implementations, we are going to ignore them. @@ -79,12 +81,12 @@ func detectFormat(input []byte) decompressor { order: lzw.MSB, litWidth: litWidth, } - case magic32 == lz4Magic: + case bytes.HasPrefix(input, lz4Magic): log.Debug(). Str("type", "lz4"). Msg("Compression Format") return lz4Decompressor{} - case magic48 == xzMagic: + case bytes.HasPrefix(input, xzMagic): log.Debug(). Str("type", "xz"). Msg("Compression Format") diff --git a/pkg/extract/peekreader.go b/pkg/extract/peekreader.go deleted file mode 100644 index 6cbf9b9..0000000 --- a/pkg/extract/peekreader.go +++ /dev/null @@ -1,49 +0,0 @@ -package extract - -import ( - "bytes" - "errors" - "io" -) - -type readPeeker interface { - io.Reader - Peek(int) ([]byte, error) -} - -var _ io.Reader = &peekReader{} -var _ readPeeker = &peekReader{} - -type peekReader struct { - reader io.Reader - buffer *bytes.Buffer -} - -func (p *peekReader) Read(b []byte) (int, error) { - if p.buffer != nil { - if p.buffer.Len() > 0 { - n, err := p.buffer.Read(b) - if errors.Is(err, io.EOF) { - err = nil - } - return n, err - } - } - return p.reader.Read(b) -} - -func (p *peekReader) Peek(n int) ([]byte, error) { - return p.peek(n) -} - -func (p *peekReader) peek(n int) ([]byte, error) { - if p.buffer == nil { - p.buffer = bytes.NewBuffer(make([]byte, 0, n)) - } - // Read the next n bytes from the reader - _, err := io.CopyN(p.buffer, p.reader, int64(n)) - if err != nil { - return p.buffer.Bytes(), err - } - return p.buffer.Bytes(), nil -} diff --git a/pkg/extract/peekreader_test.go b/pkg/extract/peekreader_test.go deleted file mode 100644 index 80e4177..0000000 --- a/pkg/extract/peekreader_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package extract - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestPeekReader_Read(t *testing.T) { - tests := []struct { - name string - readerContents string - wantBytesPeek int - wantBytesRead int - wantErr bool - }{ - { - name: "read from buffer only", - readerContents: "abc123", - wantBytesPeek: 6, - wantBytesRead: 6, - wantErr: false, - }, - { - name: "read from reader only", - readerContents: "abc123", - wantBytesRead: 3, - wantErr: false, - }, - { - name: "read from both buffer and reader", - readerContents: "abc123", - wantBytesPeek: 3, - wantBytesRead: 6, - wantErr: false, - }, - { - name: "read empty reader and buffer", - readerContents: "", - wantBytesRead: 0, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - reader := strings.NewReader(tt.readerContents) - p := &peekReader{reader: reader} - if tt.wantBytesPeek > 0 { - peekBytes, err := p.Peek(tt.wantBytesPeek) - assert.NoError(t, err) - assert.Equal(t, tt.readerContents[0:tt.wantBytesPeek], string(peekBytes)) - } - - var totalBytesRead int - var err error - readBytes := make([]byte, tt.wantBytesRead) - for totalBytesRead < tt.wantBytesRead && err == nil { - bytesRead, err := p.Read(readBytes[totalBytesRead:]) - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - totalBytesRead += bytesRead - } - assert.Equal(t, tt.wantBytesRead, totalBytesRead) - assert.Equal(t, tt.readerContents[0:tt.wantBytesRead], string(readBytes)) - }) - } -} diff --git a/pkg/extract/tar.go b/pkg/extract/tar.go index 92e3f0d..89cfaf7 100644 --- a/pkg/extract/tar.go +++ b/pkg/extract/tar.go @@ -2,6 +2,7 @@ package extract import ( "archive/tar" + "bufio" "errors" "fmt" "io" @@ -26,16 +27,13 @@ func TarFile(r io.Reader, destDir string, overwrite bool) error { var links []*link startTime := time.Now() - peekReader := &peekReader{reader: r} - peekData, err := peekReader.peek(peekSize) + peekableReader := bufio.NewReader(r) + peekData, err := peekableReader.Peek(peekSize) if err != nil { return fmt.Errorf("error reading peek data: %w", err) } decompressor := detectFormat(peekData) - if err != nil { - return fmt.Errorf("error detecting format: %w", err) - } - reader, err := decompressor.decompress(peekReader) + reader, err := decompressor.decompress(peekableReader) if err != nil { return fmt.Errorf("error creating decompressed stream: %w", err) }