diff --git a/pkg/handlers/archive.go b/pkg/handlers/archive.go index 151b72830307..f743c2c9f942 100644 --- a/pkg/handlers/archive.go +++ b/pkg/handlers/archive.go @@ -246,8 +246,9 @@ func ensureToolsForMimeType(mimeType string) error { // and processes it based on its extension, such as handling Debian (.deb) and RPM (.rpm) packages. // It returns an io.Reader that can be used to read the processed content of the file, // and an error if any issues occurred during processing. +// If the file is specialized, the returned boolean is true with no error. // The caller is responsible for closing the returned reader. -func (a *Archive) HandleSpecialized(ctx context.Context, reader io.Reader) (io.Reader, bool, error) { +func (a *Archive) HandleSpecialized(ctx logContext.Context, reader io.Reader) (io.Reader, bool, error) { mimeType, reader, err := determineMimeType(reader) if err != nil { return nil, false, err @@ -279,7 +280,7 @@ func (a *Archive) HandleSpecialized(ctx context.Context, reader io.Reader) (io.R // It handles the extraction process by using the 'ar' command and manages temporary // files and directories for the operation. // The caller is responsible for closing the returned reader. -func (a *Archive) extractDebContent(ctx context.Context, file io.Reader) (io.ReadCloser, error) { +func (a *Archive) extractDebContent(ctx logContext.Context, file io.Reader) (io.ReadCloser, error) { if a.currentDepth >= maxDepth { return nil, fmt.Errorf("max archive depth reached") } @@ -297,7 +298,7 @@ func (a *Archive) extractDebContent(ctx context.Context, file io.Reader) (io.Rea return nil, err } - handler := func(ctx context.Context, env tempEnv, file string) (string, error) { + handler := func(ctx logContext.Context, env tempEnv, file string) (string, error) { if strings.HasPrefix(file, "data.tar.") { return file, nil } @@ -317,7 +318,7 @@ func (a *Archive) extractDebContent(ctx context.Context, file io.Reader) (io.Rea // It handles the extraction process by using the 'rpm2cpio' and 'cpio' commands and manages temporary // files and directories for the operation. // The caller is responsible for closing the returned reader. -func (a *Archive) extractRpmContent(ctx context.Context, file io.Reader) (io.ReadCloser, error) { +func (a *Archive) extractRpmContent(ctx logContext.Context, file io.Reader) (io.ReadCloser, error) { if a.currentDepth >= maxDepth { return nil, fmt.Errorf("max archive depth reached") } @@ -336,7 +337,7 @@ func (a *Archive) extractRpmContent(ctx context.Context, file io.Reader) (io.Rea return nil, err } - handler := func(ctx context.Context, env tempEnv, file string) (string, error) { + handler := func(ctx logContext.Context, env tempEnv, file string) (string, error) { if strings.HasSuffix(file, ".tar.gz") { return file, nil } @@ -351,7 +352,7 @@ func (a *Archive) extractRpmContent(ctx context.Context, file io.Reader) (io.Rea return openDataArchive(tmpEnv.extractPath, dataArchiveName) } -func (a *Archive) handleNestedFileMIME(ctx context.Context, tempEnv tempEnv, fileName string) (string, error) { +func (a *Archive) handleNestedFileMIME(ctx logContext.Context, tempEnv tempEnv, fileName string) (string, error) { nestedFile, err := os.Open(filepath.Join(tempEnv.extractPath, fileName)) if err != nil { return "", err @@ -360,7 +361,7 @@ func (a *Archive) handleNestedFileMIME(ctx context.Context, tempEnv tempEnv, fil mimeType, reader, err := determineMimeType(nestedFile) if err != nil { - return "", err + return "", fmt.Errorf("unable to determine MIME type of nested filename: %s, %w", nestedFile.Name(), err) } switch mimeType { @@ -373,7 +374,7 @@ func (a *Archive) handleNestedFileMIME(ctx context.Context, tempEnv tempEnv, fil } if err != nil { - return "", err + return "", fmt.Errorf("unable to extract file with MIME type %s: %w", mimeType, err) } return fileName, nil @@ -405,7 +406,7 @@ func determineMimeType(reader io.Reader) (string, io.Reader, error) { // of the data archive it finds. This centralizes the logic for handling specialized files such as .deb and .rpm // by using the appropriate handling function passed as an argument. This design allows for flexibility and reuse // of this function across various extraction processes in the package. -func (a *Archive) handleExtractedFiles(ctx context.Context, env tempEnv, handleFile func(context.Context, tempEnv, string) (string, error)) (string, error) { +func (a *Archive) handleExtractedFiles(ctx logContext.Context, env tempEnv, handleFile func(logContext.Context, tempEnv, string) (string, error)) (string, error) { extractedFiles, err := os.ReadDir(env.extractPath) if err != nil { return "", fmt.Errorf("unable to read extracted directory: %w", err) @@ -462,7 +463,7 @@ func executeCommand(cmd *exec.Cmd) error { var stderr bytes.Buffer cmd.Stderr = &stderr if err := cmd.Run(); err != nil { - return fmt.Errorf("unable to execute command: %w; error: %s", err, stderr.String()) + return fmt.Errorf("unable to execute command %v: %w; error: %s", cmd.String(), err, stderr.String()) } return nil } diff --git a/pkg/handlers/archive_test.go b/pkg/handlers/archive_test.go index 9608ad43e9c8..9b43582983e6 100644 --- a/pkg/handlers/archive_test.go +++ b/pkg/handlers/archive_test.go @@ -12,6 +12,7 @@ import ( diskbufferreader "github.com/bill-rich/disk-buffer-reader" "github.com/stretchr/testify/assert" + logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) @@ -131,7 +132,7 @@ func TestExtractDebContent(t *testing.T) { assert.Nil(t, err) defer file.Close() - ctx := context.Background() + ctx := logContext.AddLogger(context.Background()) a := &Archive{} reader, err := a.extractDebContent(ctx, file) @@ -149,7 +150,7 @@ func TestExtractRPMContent(t *testing.T) { assert.Nil(t, err) defer file.Close() - ctx := context.Background() + ctx := logContext.AddLogger(context.Background()) a := &Archive{} reader, err := a.extractRpmContent(ctx, file) diff --git a/pkg/handlers/handlers.go b/pkg/handlers/handlers.go index 31696483e2bf..cb18f7dd8eb9 100644 --- a/pkg/handlers/handlers.go +++ b/pkg/handlers/handlers.go @@ -4,6 +4,9 @@ import ( "context" "io" + diskbufferreader "github.com/bill-rich/disk-buffer-reader" + + logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) @@ -19,7 +22,7 @@ type SpecializedHandler interface { // HandleSpecialized examines the provided file reader within the context and determines if it is a specialized archive. // It returns a reader with any necessary modifications, a boolean indicating if the file was specialized, // and an error if something went wrong during processing. - HandleSpecialized(context.Context, io.Reader) (io.Reader, bool, error) + HandleSpecialized(logContext.Context, io.Reader) (io.Reader, bool, error) } type Handler interface { @@ -35,23 +38,40 @@ type Handler interface { // The function returns true if processing was successful and false otherwise. // Context is used for cancellation, and the caller is responsible for canceling it if needed. func HandleFile(ctx context.Context, file io.Reader, chunkSkel *sources.Chunk, chunksChan chan *sources.Chunk) bool { + aCtx := logContext.AddLogger(ctx) for _, h := range DefaultHandlers() { h.New() - var ( - isSpecial bool - err error - ) + + // The re-reader is used to reset the file reader after checking if the handler implements SpecializedHandler. + // This is necessary because the archive pkg doesn't correctly determine the file type when using + // an io.MultiReader, which is used by the SpecializedHandler. + reReader, err := diskbufferreader.New(file) + if err != nil { + aCtx.Logger().Error(err, "error creating re-reader reader") + return false + } + defer reReader.Close() // Check if the handler implements SpecializedHandler and process accordingly. if specialHandler, ok := h.(SpecializedHandler); ok { - if file, isSpecial, err = specialHandler.HandleSpecialized(ctx, file); isSpecial && err == nil { - return handleChunks(ctx, h.FromFile(ctx, file), chunkSkel, chunksChan) + file, isSpecial, err := specialHandler.HandleSpecialized(aCtx, reReader) + if isSpecial { + return handleChunks(aCtx, h.FromFile(ctx, file), chunkSkel, chunksChan) + } + + if err != nil { + aCtx.Logger().Error(err, "error handling file") } } + if err := reReader.Reset(); err != nil { + aCtx.Logger().Error(err, "error resetting re-reader") + return false + } + reReader.Stop() var isType bool - if file, isType = h.IsFiletype(ctx, file); isType { - return handleChunks(ctx, h.FromFile(ctx, file), chunkSkel, chunksChan) + if file, isType = h.IsFiletype(aCtx, reReader); isType { + return handleChunks(aCtx, h.FromFile(ctx, file), chunkSkel, chunksChan) } } return false