diff --git a/chain/test_chain.go b/chain/test_chain.go index c1400311..0137c6c3 100644 --- a/chain/test_chain.go +++ b/chain/test_chain.go @@ -9,7 +9,6 @@ import ( "github.com/crytic/medusa/chain/config" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/tracing" - "github.com/ethereum/go-ethereum/eth/tracers" "github.com/ethereum/go-ethereum/triedb" "github.com/ethereum/go-ethereum/triedb/hashdb" "github.com/holiman/uint256" @@ -256,7 +255,7 @@ func (t *TestChain) Clone(onCreateFunc func(chain *TestChain) error) (*TestChain // Now add each transaction/message to it. messages := t.blocks[i].Messages for j := 0; j < len(messages); j++ { - err = targetChain.PendingBlockAddTx(messages[j], nil) + err = targetChain.PendingBlockAddTx(messages[j]) if err != nil { return nil, err } @@ -728,15 +727,7 @@ func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTi // PendingBlockAddTx takes a message (internal txs) and adds it to the current pending block, updating the header // with relevant execution information. If a pending block was not created, an error is returned. // Returns an error if one occurred. -func (t *TestChain) PendingBlockAddTx(message *core.Message, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) error { - // Caller can specify any tracer they wish to use for this transaction - if getTracerFn == nil { - // TODO: Figure out whether it is possible to identify _which_ transaction you want to trace (versus all) - getTracerFn = func(txIndex int, txHash common.Hash) *tracers.Tracer { - return t.transactionTracerRouter.NativeTracer().Tracer - } - } - +func (t *TestChain) PendingBlockAddTx(message *core.Message, additionalTracers ...*TestChainTracer) error { // If we don't have a pending block, return an error if t.pendingBlock == nil { return errors.New("could not add tx to the chain's pending block because no pending block was created") @@ -757,12 +748,20 @@ func (t *TestChain) PendingBlockAddTx(message *core.Message, getTracerFn func(tx ConfigExtensions: t.vmConfigExtensions, } - // Set the tracer to be used in the vm config - tracer := getTracerFn(len(t.pendingBlock.Messages), tx.Hash()) - if tracer != nil { - vmConfig.Tracer = tracer.Hooks + // Figure out whether we need to attach any more tracers + var extendedTracerRouter *TestChainTracerRouter + if len(additionalTracers) > 0 { + // If we have more tracers, extend the transaction tracer router's tracers with additional ones + extendedTracerRouter = NewTestChainTracerRouter() + extendedTracerRouter.AddTracer(t.transactionTracerRouter.NativeTracer()) + extendedTracerRouter.AddTracers(additionalTracers...) + } else { + extendedTracerRouter = t.transactionTracerRouter } + // Update the VM's tracer + vmConfig.Tracer = extendedTracerRouter.NativeTracer().Tracer.Hooks + // Set tx context t.state.SetTxContext(tx.Hash(), len(t.pendingBlock.Messages)) diff --git a/chain/test_chain_test.go b/chain/test_chain_test.go index 693260fc..5d33ceea 100644 --- a/chain/test_chain_test.go +++ b/chain/test_chain_test.go @@ -260,7 +260,7 @@ func TestChainDynamicDeployments(t *testing.T) { assert.NoError(t, err) // Add our transaction to the block - err = chain.PendingBlockAddTx(&msg, nil) + err = chain.PendingBlockAddTx(&msg) assert.NoError(t, err) // Commit the pending block to the chain, so it becomes the new head. @@ -385,7 +385,7 @@ func TestChainDeploymentWithArgs(t *testing.T) { assert.NoError(t, err) // Add our transaction to the block - err = chain.PendingBlockAddTx(&msg, nil) + err = chain.PendingBlockAddTx(&msg) assert.NoError(t, err) // Commit the pending block to the chain, so it becomes the new head. @@ -494,7 +494,7 @@ func TestChainCloning(t *testing.T) { assert.NoError(t, err) // Add our transaction to the block - err = chain.PendingBlockAddTx(&msg, nil) + err = chain.PendingBlockAddTx(&msg) assert.NoError(t, err) // Commit the pending block to the chain, so it becomes the new head. @@ -588,7 +588,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) { assert.NoError(t, err) // Add our transaction to the block - err = chain.PendingBlockAddTx(&msg, nil) + err = chain.PendingBlockAddTx(&msg) assert.NoError(t, err) // Commit the pending block to the chain, so it becomes the new head. @@ -627,7 +627,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) { _, err := recreatedChain.PendingBlockCreate() assert.NoError(t, err) for _, message := range chain.blocks[i].Messages { - err = recreatedChain.PendingBlockAddTx(message, nil) + err = recreatedChain.PendingBlockAddTx(message) assert.NoError(t, err) } err = recreatedChain.PendingBlockCommit() diff --git a/fuzzing/calls/call_sequence_execution.go b/fuzzing/calls/call_sequence_execution.go index 14fc1ecc..ca983f0d 100644 --- a/fuzzing/calls/call_sequence_execution.go +++ b/fuzzing/calls/call_sequence_execution.go @@ -7,8 +7,6 @@ import ( "github.com/crytic/medusa/fuzzing/contracts" "github.com/crytic/medusa/fuzzing/executiontracer" "github.com/crytic/medusa/utils" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/eth/tracers" ) // ExecuteCallSequenceFetchElementFunc describes a function that is called to obtain the next call sequence element to @@ -28,7 +26,7 @@ type ExecuteCallSequenceExecutionCheckFunc func(currentExecutedSequence CallSequ // A "post element executed check" function is provided to check whether execution should stop after each element is // executed. // Returns the call sequence which was executed and an error if one occurs. -func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc ExecuteCallSequenceFetchElementFunc, executionCheckFunc ExecuteCallSequenceExecutionCheckFunc, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) (CallSequence, error) { +func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc ExecuteCallSequenceFetchElementFunc, executionCheckFunc ExecuteCallSequenceExecutionCheckFunc, additionalTracers ...*chain.TestChainTracer) (CallSequence, error) { // If there is no fetch element function provided, throw an error if fetchElementFunc == nil { return nil, fmt.Errorf("could not execute call sequence on chain as the 'fetch element function' provided was nil") @@ -90,7 +88,7 @@ func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc Exe } // Try to add our transaction to this block. - err = chain.PendingBlockAddTx(callSequenceElement.Call.ToCoreMessage(), getTracerFn) + err = chain.PendingBlockAddTx(callSequenceElement.Call.ToCoreMessage(), additionalTracers...) if err != nil { // If we encountered a block gas limit error, this tx is too expensive to fit in this block. @@ -168,7 +166,7 @@ func ExecuteCallSequence(chain *chain.TestChain, callSequence CallSequence) (Cal return nil, nil } - return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil, nil) + return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil) } // ExecuteCallSequenceWithTracer attaches an executiontracer.ExecutionTracer to ExecuteCallSequenceIteratively and attaches execution traces to the call sequence elements. @@ -176,9 +174,6 @@ func ExecuteCallSequenceWithExecutionTracer(testChain *chain.TestChain, contract // Create a new execution tracer executionTracer := executiontracer.NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts()) defer executionTracer.Close() - getTracerFunc := func(txIndex int, txHash common.Hash) *tracers.Tracer { - return executionTracer.NativeTracer().Tracer - } // Execute our sequence with a simple fetch operation provided to obtain each element. fetchElementFunc := func(currentIndex int) (*CallSequenceElement, error) { @@ -188,8 +183,8 @@ func ExecuteCallSequenceWithExecutionTracer(testChain *chain.TestChain, contract return nil, nil } - // Execute the call sequence - executedCallSeq, err := ExecuteCallSequenceIteratively(testChain, fetchElementFunc, nil, getTracerFunc) + // Execute the call sequence and attach the execution tracer + executedCallSeq, err := ExecuteCallSequenceIteratively(testChain, fetchElementFunc, nil, executionTracer.NativeTracer()) // By default, we only trace the last element in the call sequence. traceFrom := len(callSequence) - 1 diff --git a/fuzzing/corpus/corpus.go b/fuzzing/corpus/corpus.go index 011b414d..8426f156 100644 --- a/fuzzing/corpus/corpus.go +++ b/fuzzing/corpus/corpus.go @@ -210,7 +210,7 @@ func (c *Corpus) initializeSequences(sequenceFiles *corpusDirectory[calls.CallSe } // Execute each call sequence, populating runtime data and collecting coverage data along the way. - _, err = calls.ExecuteCallSequenceIteratively(testChain, fetchElementFunc, executionCheckFunc, nil) + _, err = calls.ExecuteCallSequenceIteratively(testChain, fetchElementFunc, executionCheckFunc) // If we failed to replay a sequence and measure coverage due to an unexpected error, report it. if err != nil { diff --git a/fuzzing/executiontracer/execution_tracer.go b/fuzzing/executiontracer/execution_tracer.go index 651f9379..46bfb01a 100644 --- a/fuzzing/executiontracer/execution_tracer.go +++ b/fuzzing/executiontracer/execution_tracer.go @@ -119,6 +119,8 @@ func (t *ExecutionTracer) OnTxStart(vm *tracing.VMContext, tx *coretypes.Transac t.trace = newExecutionTrace(t.contractDefinitions) t.currentCallFrame = nil t.onNextCaptureState = nil + t.traceMap = make(map[common.Hash]*ExecutionTrace) + // Store our evm reference t.evmContext = vm } diff --git a/fuzzing/fuzzer.go b/fuzzing/fuzzer.go index 6f1d6e8b..960ebfe2 100644 --- a/fuzzing/fuzzer.go +++ b/fuzzing/fuzzer.go @@ -463,7 +463,7 @@ func chainSetupFromCompilations(fuzzer *Fuzzer, testChain *chain.TestChain) (*ex } // Add our transaction to the block - err = testChain.PendingBlockAddTx(msg.ToCoreMessage(), nil) + err = testChain.PendingBlockAddTx(msg.ToCoreMessage()) if err != nil { return nil, err } diff --git a/fuzzing/fuzzer_worker.go b/fuzzing/fuzzer_worker.go index 5f651325..7ac958b2 100644 --- a/fuzzing/fuzzer_worker.go +++ b/fuzzing/fuzzer_worker.go @@ -317,7 +317,7 @@ func (fw *FuzzerWorker) testNextCallSequence() (calls.CallSequence, []ShrinkCall } // Execute our call sequence. - testedCallSequence, err := calls.ExecuteCallSequenceIteratively(fw.chain, fetchElementFunc, executionCheckFunc, nil) + testedCallSequence, err := calls.ExecuteCallSequenceIteratively(fw.chain, fetchElementFunc, executionCheckFunc) // If we encountered an error, report it. if err != nil { @@ -383,7 +383,7 @@ func (fw *FuzzerWorker) testShrunkenCallSequence(possibleShrunkSequence calls.Ca } // Execute our call sequence. - _, err = calls.ExecuteCallSequenceIteratively(fw.chain, fetchElementFunc, executionCheckFunc, nil) + _, err = calls.ExecuteCallSequenceIteratively(fw.chain, fetchElementFunc, executionCheckFunc) if err != nil { return false, err }