diff --git a/go.mod b/go.mod index c085e3b624bc..ffac0ad364c2 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/google/go-containerregistry v0.15.2 github.com/google/go-github/v42 v42.0.0 github.com/googleapis/gax-go/v2 v2.12.0 + github.com/h2non/filetype v1.1.3 github.com/hashicorp/go-retryablehttp v0.7.4 github.com/hashicorp/golang-lru v0.5.1 github.com/jlaffaye/ftp v0.2.0 diff --git a/go.sum b/go.sum index cfbdb13a7866..b19e519c8c55 100644 --- a/go.sum +++ b/go.sum @@ -315,6 +315,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= +github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= diff --git a/pkg/handlers/archive.go b/pkg/handlers/archive.go index 8ad25cb4abeb..151b72830307 100644 --- a/pkg/handlers/archive.go +++ b/pkg/handlers/archive.go @@ -6,8 +6,13 @@ import ( "errors" "fmt" "io" + "os" + "os/exec" + "path/filepath" + "strings" "time" + "github.com/h2non/filetype" "github.com/mholt/archiver/v4" "github.com/trufflesecurity/trufflehog/v3/pkg/common" @@ -26,14 +31,18 @@ var ( maxTimeout = time.Duration(30) * time.Second ) +// Ensure the Archive satisfies the interfaces at compile time. +var _ SpecializedHandler = (*Archive)(nil) + // Archive is a handler for extracting and decompressing archives. type Archive struct { - size int + size int + currentDepth int } // New sets a default maximum size and current size counter. -func (d *Archive) New() { - d.size = 0 +func (a *Archive) New() { + a.size = 0 } // SetArchiveMaxSize sets the maximum size of the archive. @@ -52,14 +61,14 @@ func SetArchiveMaxTimeout(timeout time.Duration) { } // FromFile extracts the files from an archive. -func (d *Archive) FromFile(originalCtx context.Context, data io.Reader) chan ([]byte) { - archiveChan := make(chan ([]byte), 512) +func (a *Archive) FromFile(originalCtx context.Context, data io.Reader) chan []byte { + archiveChan := make(chan []byte, 512) go func() { ctx, cancel := context.WithTimeout(originalCtx, maxTimeout) logger := logContext.AddLogger(ctx).Logger() defer cancel() defer close(archiveChan) - err := d.openArchive(ctx, 0, data, archiveChan) + err := a.openArchive(ctx, 0, data, archiveChan) if err != nil { if errors.Is(err, archiver.ErrNoMatch) { return @@ -71,7 +80,7 @@ func (d *Archive) FromFile(originalCtx context.Context, data io.Reader) chan ([] } // openArchive takes a reader and extracts the contents up to the maximum depth. -func (d *Archive) openArchive(ctx context.Context, depth int, reader io.Reader, archiveChan chan ([]byte)) error { +func (a *Archive) openArchive(ctx context.Context, depth int, reader io.Reader, archiveChan chan []byte) error { if depth >= maxDepth { return fmt.Errorf("max archive depth reached") } @@ -97,14 +106,14 @@ func (d *Archive) openArchive(ctx context.Context, depth int, reader io.Reader, if err != nil { return err } - fileBytes, err := d.ReadToMax(ctx, compReader) + fileBytes, err := a.ReadToMax(ctx, compReader) if err != nil { return err } newReader := bytes.NewReader(fileBytes) - return d.openArchive(ctx, depth+1, newReader, archiveChan) + return a.openArchive(ctx, depth+1, newReader, archiveChan) case archiver.Extractor: - err := archive.Extract(context.WithValue(ctx, depthKey, depth+1), reader, nil, d.extractorHandler(archiveChan)) + err := archive.Extract(context.WithValue(ctx, depthKey, depth+1), reader, nil, a.extractorHandler(archiveChan)) if err != nil { return err } @@ -114,7 +123,7 @@ func (d *Archive) openArchive(ctx context.Context, depth int, reader io.Reader, } // IsFiletype returns true if the provided reader is an archive. -func (d *Archive) IsFiletype(ctx context.Context, reader io.Reader) (io.Reader, bool) { +func (a *Archive) IsFiletype(ctx context.Context, reader io.Reader) (io.Reader, bool) { format, readerB, err := archiver.Identify("", reader) if err != nil { return readerB, false @@ -129,7 +138,7 @@ func (d *Archive) IsFiletype(ctx context.Context, reader io.Reader) (io.Reader, } // extractorHandler is applied to each file in an archiver.Extractor file. -func (d *Archive) extractorHandler(archiveChan chan ([]byte)) func(context.Context, archiver.File) error { +func (a *Archive) extractorHandler(archiveChan chan []byte) func(context.Context, archiver.File) error { return func(ctx context.Context, f archiver.File) error { logger := logContext.AddLogger(ctx).Logger() logger.V(5).Info("Handling extracted file.", "filename", f.Name()) @@ -142,13 +151,13 @@ func (d *Archive) extractorHandler(archiveChan chan ([]byte)) func(context.Conte if err != nil { return err } - fileBytes, err := d.ReadToMax(ctx, fReader) + fileBytes, err := a.ReadToMax(ctx, fReader) if err != nil { return err } fileContent := bytes.NewReader(fileBytes) - err = d.openArchive(ctx, depth, fileContent, archiveChan) + err = a.openArchive(ctx, depth, fileContent, archiveChan) if err != nil { return err } @@ -157,7 +166,7 @@ func (d *Archive) extractorHandler(archiveChan chan ([]byte)) func(context.Conte } // ReadToMax reads up to the max size. -func (d *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte, err error) { +func (a *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte, err error) { // Archiver v4 is in alpha and using an experimental version of // rardecode. There is a bug somewhere with rar decoder format 29 // that can lead to a panic. An issue is open in rardecode repo @@ -175,7 +184,7 @@ func (d *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte, } }() fileContent := bytes.Buffer{} - logger.V(5).Info("Remaining buffer capacity", "bytes", maxSize-d.size) + logger.V(5).Info("Remaining buffer capacity", "bytes", maxSize-a.size) for i := 0; i <= maxSize/512; i++ { if common.IsDone(ctx) { return nil, ctx.Err() @@ -185,17 +194,284 @@ func (d *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte, if err != nil && !errors.Is(err, io.EOF) { return []byte{}, err } - d.size += bRead + a.size += bRead if len(fileChunk) > 0 { fileContent.Write(fileChunk[0:bRead]) } if bRead < 512 { return fileContent.Bytes(), nil } - if d.size >= maxSize && bRead == 512 { + if a.size >= maxSize && bRead == 512 { logger.V(2).Info("Max archive size reached.") return fileContent.Bytes(), nil } } return fileContent.Bytes(), nil } + +const ( + arMimeType = "application/x-unix-archive" + rpmMimeType = "application/x-rpm" +) + +// Define a map of mime types to corresponding command-line tools +var mimeTools = map[string][]string{ + arMimeType: {"ar"}, + rpmMimeType: {"rpm2cpio", "cpio"}, +} + +// Check if the command-line tool is installed. +func isToolInstalled(tool string) bool { + _, err := exec.LookPath(tool) + return err == nil +} + +// Ensure all tools are available for given mime type. +func ensureToolsForMimeType(mimeType string) error { + tools, exists := mimeTools[mimeType] + if !exists { + return fmt.Errorf("unsupported mime type") + } + + for _, tool := range tools { + if !isToolInstalled(tool) { + return fmt.Errorf("Required tool " + tool + " is not installed") + } + } + + return nil +} + +// HandleSpecialized takes a file path and an io.Reader representing the input file, +// and processes it based on its extension, such as handling Debian (.deb) and RPM (.rpm) packages. +// It returns an io.Reader that can be used to read the processed content of the file, +// and an error if any issues occurred during processing. +// The caller is responsible for closing the returned reader. +func (a *Archive) HandleSpecialized(ctx context.Context, reader io.Reader) (io.Reader, bool, error) { + mimeType, reader, err := determineMimeType(reader) + if err != nil { + return nil, false, err + } + + switch mimeType { + case arMimeType: // includes .deb files + if err := ensureToolsForMimeType(mimeType); err != nil { + return nil, false, err + } + reader, err = a.extractDebContent(ctx, reader) + case rpmMimeType: + if err := ensureToolsForMimeType(mimeType); err != nil { + return nil, false, err + } + reader, err = a.extractRpmContent(ctx, reader) + default: + return reader, false, nil + } + + if err != nil { + return nil, false, fmt.Errorf("unable to extract file with MIME type %s: %w", mimeType, err) + } + return reader, true, nil +} + +// extractDebContent takes a .deb file as an io.Reader, extracts its contents +// into a temporary directory, and returns a Reader for the extracted data archive. +// It handles the extraction process by using the 'ar' command and manages temporary +// files and directories for the operation. +// The caller is responsible for closing the returned reader. +func (a *Archive) extractDebContent(ctx context.Context, file io.Reader) (io.ReadCloser, error) { + if a.currentDepth >= maxDepth { + return nil, fmt.Errorf("max archive depth reached") + } + + tmpEnv, err := a.createTempEnv(ctx, file) + if err != nil { + return nil, err + } + defer os.Remove(tmpEnv.tempFileName) + defer os.RemoveAll(tmpEnv.extractPath) + + cmd := exec.Command("ar", "x", tmpEnv.tempFile.Name()) + cmd.Dir = tmpEnv.extractPath + if err := executeCommand(cmd); err != nil { + return nil, err + } + + handler := func(ctx context.Context, env tempEnv, file string) (string, error) { + if strings.HasPrefix(file, "data.tar.") { + return file, nil + } + return a.handleNestedFileMIME(ctx, env, file) + } + + dataArchiveName, err := a.handleExtractedFiles(ctx, tmpEnv, handler) + if err != nil { + return nil, err + } + + return openDataArchive(tmpEnv.extractPath, dataArchiveName) +} + +// extractRpmContent takes an .rpm file as an io.Reader, extracts its contents +// into a temporary directory, and returns a Reader for the extracted data archive. +// It handles the extraction process by using the 'rpm2cpio' and 'cpio' commands and manages temporary +// files and directories for the operation. +// The caller is responsible for closing the returned reader. +func (a *Archive) extractRpmContent(ctx context.Context, file io.Reader) (io.ReadCloser, error) { + if a.currentDepth >= maxDepth { + return nil, fmt.Errorf("max archive depth reached") + } + + tmpEnv, err := a.createTempEnv(ctx, file) + if err != nil { + return nil, err + } + defer os.Remove(tmpEnv.tempFileName) + defer os.RemoveAll(tmpEnv.extractPath) + + // Use rpm2cpio to convert the RPM file to a cpio archive and then extract it using cpio command. + cmd := exec.Command("sh", "-c", "rpm2cpio "+tmpEnv.tempFile.Name()+" | cpio -id") + cmd.Dir = tmpEnv.extractPath + if err := executeCommand(cmd); err != nil { + return nil, err + } + + handler := func(ctx context.Context, env tempEnv, file string) (string, error) { + if strings.HasSuffix(file, ".tar.gz") { + return file, nil + } + return a.handleNestedFileMIME(ctx, env, file) + } + + dataArchiveName, err := a.handleExtractedFiles(ctx, tmpEnv, handler) + if err != nil { + return nil, err + } + + return openDataArchive(tmpEnv.extractPath, dataArchiveName) +} + +func (a *Archive) handleNestedFileMIME(ctx context.Context, tempEnv tempEnv, fileName string) (string, error) { + nestedFile, err := os.Open(filepath.Join(tempEnv.extractPath, fileName)) + if err != nil { + return "", err + } + defer nestedFile.Close() + + mimeType, reader, err := determineMimeType(nestedFile) + if err != nil { + return "", err + } + + switch mimeType { + case arMimeType: + _, _, err = a.HandleSpecialized(ctx, reader) + case rpmMimeType: + _, _, err = a.HandleSpecialized(ctx, reader) + default: + return "", nil + } + + if err != nil { + return "", err + } + + return fileName, nil +} + +// determineMimeType reads from the provided reader to detect the MIME type. +// It returns the detected MIME type and a new reader that includes the read portion. +func determineMimeType(reader io.Reader) (string, io.Reader, error) { + buffer := make([]byte, 512) + n, err := reader.Read(buffer) + if err != nil { + return "", nil, fmt.Errorf("unable to read file for MIME type detection: %w", err) + } + + // Create a new reader that starts with the buffer we just read + // and continues with the rest of the original reader. + reader = io.MultiReader(bytes.NewReader(buffer[:n]), reader) + + kind, err := filetype.Match(buffer) + if err != nil { + return "", nil, fmt.Errorf("unable to determine file type: %w", err) + } + + return kind.MIME.Value, reader, nil +} + +// handleExtractedFiles processes each file in the extracted directory using a provided handler function. +// The function iterates through the files, applying the handleFile function to each, and returns the name +// of the data archive it finds. This centralizes the logic for handling specialized files such as .deb and .rpm +// by using the appropriate handling function passed as an argument. This design allows for flexibility and reuse +// of this function across various extraction processes in the package. +func (a *Archive) handleExtractedFiles(ctx context.Context, env tempEnv, handleFile func(context.Context, tempEnv, string) (string, error)) (string, error) { + extractedFiles, err := os.ReadDir(env.extractPath) + if err != nil { + return "", fmt.Errorf("unable to read extracted directory: %w", err) + } + + var dataArchiveName string + for _, file := range extractedFiles { + name, err := handleFile(ctx, env, file.Name()) + if err != nil { + return "", err + } + if name != "" { + dataArchiveName = name + break + } + } + + return dataArchiveName, nil +} + +type tempEnv struct { + tempFile *os.File + tempFileName string + extractPath string +} + +// createTempEnv creates a temporary file and a temporary directory for extracting archives. +// The caller is responsible for removing these temporary resources +// (both the file and directory) when they are no longer needed. +func (a *Archive) createTempEnv(ctx context.Context, file io.Reader) (tempEnv, error) { + tempFile, err := os.CreateTemp("", "tmp") + if err != nil { + return tempEnv{}, fmt.Errorf("unable to create temporary file: %w", err) + } + + extractPath, err := os.MkdirTemp("", "tmp_archive") + if err != nil { + return tempEnv{}, fmt.Errorf("unable to create temporary directory: %w", err) + } + + b, err := a.ReadToMax(ctx, file) + if err != nil { + return tempEnv{}, err + } + + if _, err = tempFile.Write(b); err != nil { + return tempEnv{}, fmt.Errorf("unable to write to temporary file: %w", err) + } + + return tempEnv{tempFile: tempFile, tempFileName: tempFile.Name(), extractPath: extractPath}, nil +} + +func executeCommand(cmd *exec.Cmd) error { + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("unable to execute command: %w; error: %s", err, stderr.String()) + } + return nil +} + +func openDataArchive(extractPath string, dataArchiveName string) (io.ReadCloser, error) { + dataArchivePath := filepath.Join(extractPath, dataArchiveName) + dataFile, err := os.Open(dataArchivePath) + if err != nil { + return nil, fmt.Errorf("unable to open file: %w", err) + } + return dataFile, nil +} diff --git a/pkg/handlers/archive_test.go b/pkg/handlers/archive_test.go index a7dca30e4218..9608ad43e9c8 100644 --- a/pkg/handlers/archive_test.go +++ b/pkg/handlers/archive_test.go @@ -2,7 +2,9 @@ package handlers import ( "context" + "io" "net/http" + "os" "regexp" "strings" "testing" @@ -122,3 +124,39 @@ func TestHandleFile(t *testing.T) { assert.True(t, HandleFile(context.Background(), reader, &sources.Chunk{}, ch)) assert.Equal(t, 1, len(ch)) } + +func TestExtractDebContent(t *testing.T) { + // Open the sample .deb file from the testdata folder. + file, err := os.Open("testdata/test.deb") + assert.Nil(t, err) + defer file.Close() + + ctx := context.Background() + a := &Archive{} + + reader, err := a.extractDebContent(ctx, file) + assert.Nil(t, err) + + content, err := io.ReadAll(reader) + assert.Nil(t, err) + expectedLength := 1015582 + assert.Equal(t, expectedLength, len(string(content))) +} + +func TestExtractRPMContent(t *testing.T) { + // Open the sample .rpm file from the testdata folder. + file, err := os.Open("testdata/test.rpm") + assert.Nil(t, err) + defer file.Close() + + ctx := context.Background() + a := &Archive{} + + reader, err := a.extractRpmContent(ctx, file) + assert.Nil(t, err) + + content, err := io.ReadAll(reader) + assert.Nil(t, err) + expectedLength := 1822720 + assert.Equal(t, expectedLength, len(string(content))) +} diff --git a/pkg/handlers/handlers.go b/pkg/handlers/handlers.go index 84eeadcd0331..31696483e2bf 100644 --- a/pkg/handlers/handlers.go +++ b/pkg/handlers/handlers.go @@ -13,39 +13,59 @@ func DefaultHandlers() []Handler { } } +// SpecializedHandler defines the interface for handlers that can process specialized archives. +// It includes a method to handle specialized archives and determine if the file is of a special type. +type SpecializedHandler interface { + // HandleSpecialized examines the provided file reader within the context and determines if it is a specialized archive. + // It returns a reader with any necessary modifications, a boolean indicating if the file was specialized, + // and an error if something went wrong during processing. + HandleSpecialized(context.Context, io.Reader) (io.Reader, bool, error) +} + type Handler interface { FromFile(context.Context, io.Reader) chan ([]byte) IsFiletype(context.Context, io.Reader) (io.Reader, bool) New() } -func HandleFile(ctx context.Context, file io.Reader, chunkSkel *sources.Chunk, chunksChan chan (*sources.Chunk)) bool { - // Find a handler for this file. - var handler Handler +// HandleFile processes a given file by selecting an appropriate handler from DefaultHandlers. +// It first checks if the handler implements SpecializedHandler for any special processing, +// then falls back to regular file type handling. If successful, it reads the file in chunks, +// packages them in the provided chunk skeleton, and sends them to chunksChan. +// The function returns true if processing was successful and false otherwise. +// Context is used for cancellation, and the caller is responsible for canceling it if needed. +func HandleFile(ctx context.Context, file io.Reader, chunkSkel *sources.Chunk, chunksChan chan *sources.Chunk) bool { for _, h := range DefaultHandlers() { h.New() + var ( + isSpecial bool + err error + ) + + // Check if the handler implements SpecializedHandler and process accordingly. + if specialHandler, ok := h.(SpecializedHandler); ok { + if file, isSpecial, err = specialHandler.HandleSpecialized(ctx, file); isSpecial && err == nil { + return handleChunks(ctx, h.FromFile(ctx, file), chunkSkel, chunksChan) + } + } + var isType bool if file, isType = h.IsFiletype(ctx, file); isType { - handler = h - break + return handleChunks(ctx, h.FromFile(ctx, file), chunkSkel, chunksChan) } } - if handler == nil { - return false - } + return false +} - // Process the file and read all []byte chunks from handlerChan. - handlerChan := handler.FromFile(ctx, file) +func handleChunks(ctx context.Context, handlerChan chan []byte, chunkSkel *sources.Chunk, chunksChan chan *sources.Chunk) bool { for { select { case data, open := <-handlerChan: if !open { - // We finished reading everything from handlerChan. return true } chunk := *chunkSkel chunk.Data = data - // Send data on chunksChan. select { case chunksChan <- &chunk: case <-ctx.Done(): diff --git a/pkg/handlers/testdata/test.deb b/pkg/handlers/testdata/test.deb new file mode 100644 index 000000000000..1b078b955d0c Binary files /dev/null and b/pkg/handlers/testdata/test.deb differ diff --git a/pkg/handlers/testdata/test.rpm b/pkg/handlers/testdata/test.rpm new file mode 100644 index 000000000000..40373635e02f Binary files /dev/null and b/pkg/handlers/testdata/test.rpm differ