Skip to content

Commit

Permalink
Fix panic when execution tracing cheatcode contracts (#411)
Browse files Browse the repository at this point in the history
* fix panic when execution tracing cheatcode contracts.

* Update chain/test_chain.go

Co-authored-by: alpharush <[email protected]>

---------

Co-authored-by: alpharush <[email protected]>
  • Loading branch information
anishnaik and 0xalpharush authored Jul 24, 2024
1 parent b4d4c56 commit 924247c
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 34 deletions.
29 changes: 14 additions & 15 deletions chain/test_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/crytic/medusa/chain/config"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/tracing"
"github.com/ethereum/go-ethereum/eth/tracers"
"github.com/ethereum/go-ethereum/triedb"
"github.com/ethereum/go-ethereum/triedb/hashdb"
"github.com/holiman/uint256"
Expand Down Expand Up @@ -256,7 +255,7 @@ func (t *TestChain) Clone(onCreateFunc func(chain *TestChain) error) (*TestChain
// Now add each transaction/message to it.
messages := t.blocks[i].Messages
for j := 0; j < len(messages); j++ {
err = targetChain.PendingBlockAddTx(messages[j], nil)
err = targetChain.PendingBlockAddTx(messages[j])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -728,15 +727,7 @@ func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTi
// PendingBlockAddTx takes a message (internal txs) and adds it to the current pending block, updating the header
// with relevant execution information. If a pending block was not created, an error is returned.
// Returns an error if one occurred.
func (t *TestChain) PendingBlockAddTx(message *core.Message, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) error {
// Caller can specify any tracer they wish to use for this transaction
if getTracerFn == nil {
// TODO: Figure out whether it is possible to identify _which_ transaction you want to trace (versus all)
getTracerFn = func(txIndex int, txHash common.Hash) *tracers.Tracer {
return t.transactionTracerRouter.NativeTracer().Tracer
}
}

func (t *TestChain) PendingBlockAddTx(message *core.Message, additionalTracers ...*TestChainTracer) error {
// If we don't have a pending block, return an error
if t.pendingBlock == nil {
return errors.New("could not add tx to the chain's pending block because no pending block was created")
Expand All @@ -757,12 +748,20 @@ func (t *TestChain) PendingBlockAddTx(message *core.Message, getTracerFn func(tx
ConfigExtensions: t.vmConfigExtensions,
}

// Set the tracer to be used in the vm config
tracer := getTracerFn(len(t.pendingBlock.Messages), tx.Hash())
if tracer != nil {
vmConfig.Tracer = tracer.Hooks
// Figure out whether we need to attach any more tracers
var extendedTracerRouter *TestChainTracerRouter
if len(additionalTracers) > 0 {
// If we have more tracers, extend the transaction tracer router's tracers with additional ones
extendedTracerRouter = NewTestChainTracerRouter()
extendedTracerRouter.AddTracer(t.transactionTracerRouter.NativeTracer())
extendedTracerRouter.AddTracers(additionalTracers...)
} else {
extendedTracerRouter = t.transactionTracerRouter
}

// Update the VM's tracer
vmConfig.Tracer = extendedTracerRouter.NativeTracer().Tracer.Hooks

// Set tx context
t.state.SetTxContext(tx.Hash(), len(t.pendingBlock.Messages))

Expand Down
10 changes: 5 additions & 5 deletions chain/test_chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func TestChainDynamicDeployments(t *testing.T) {
assert.NoError(t, err)

// Add our transaction to the block
err = chain.PendingBlockAddTx(&msg, nil)
err = chain.PendingBlockAddTx(&msg)
assert.NoError(t, err)

// Commit the pending block to the chain, so it becomes the new head.
Expand Down Expand Up @@ -385,7 +385,7 @@ func TestChainDeploymentWithArgs(t *testing.T) {
assert.NoError(t, err)

// Add our transaction to the block
err = chain.PendingBlockAddTx(&msg, nil)
err = chain.PendingBlockAddTx(&msg)
assert.NoError(t, err)

// Commit the pending block to the chain, so it becomes the new head.
Expand Down Expand Up @@ -494,7 +494,7 @@ func TestChainCloning(t *testing.T) {
assert.NoError(t, err)

// Add our transaction to the block
err = chain.PendingBlockAddTx(&msg, nil)
err = chain.PendingBlockAddTx(&msg)
assert.NoError(t, err)

// Commit the pending block to the chain, so it becomes the new head.
Expand Down Expand Up @@ -588,7 +588,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) {
assert.NoError(t, err)

// Add our transaction to the block
err = chain.PendingBlockAddTx(&msg, nil)
err = chain.PendingBlockAddTx(&msg)
assert.NoError(t, err)

// Commit the pending block to the chain, so it becomes the new head.
Expand Down Expand Up @@ -627,7 +627,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) {
_, err := recreatedChain.PendingBlockCreate()
assert.NoError(t, err)
for _, message := range chain.blocks[i].Messages {
err = recreatedChain.PendingBlockAddTx(message, nil)
err = recreatedChain.PendingBlockAddTx(message)
assert.NoError(t, err)
}
err = recreatedChain.PendingBlockCommit()
Expand Down
15 changes: 5 additions & 10 deletions fuzzing/calls/call_sequence_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"github.com/crytic/medusa/fuzzing/contracts"
"github.com/crytic/medusa/fuzzing/executiontracer"
"github.com/crytic/medusa/utils"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/eth/tracers"
)

// ExecuteCallSequenceFetchElementFunc describes a function that is called to obtain the next call sequence element to
Expand All @@ -28,7 +26,7 @@ type ExecuteCallSequenceExecutionCheckFunc func(currentExecutedSequence CallSequ
// A "post element executed check" function is provided to check whether execution should stop after each element is
// executed.
// Returns the call sequence which was executed and an error if one occurs.
func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc ExecuteCallSequenceFetchElementFunc, executionCheckFunc ExecuteCallSequenceExecutionCheckFunc, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) (CallSequence, error) {
func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc ExecuteCallSequenceFetchElementFunc, executionCheckFunc ExecuteCallSequenceExecutionCheckFunc, additionalTracers ...*chain.TestChainTracer) (CallSequence, error) {
// If there is no fetch element function provided, throw an error
if fetchElementFunc == nil {
return nil, fmt.Errorf("could not execute call sequence on chain as the 'fetch element function' provided was nil")
Expand Down Expand Up @@ -90,7 +88,7 @@ func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc Exe
}

// Try to add our transaction to this block.
err = chain.PendingBlockAddTx(callSequenceElement.Call.ToCoreMessage(), getTracerFn)
err = chain.PendingBlockAddTx(callSequenceElement.Call.ToCoreMessage(), additionalTracers...)

if err != nil {
// If we encountered a block gas limit error, this tx is too expensive to fit in this block.
Expand Down Expand Up @@ -168,17 +166,14 @@ func ExecuteCallSequence(chain *chain.TestChain, callSequence CallSequence) (Cal
return nil, nil
}

return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil, nil)
return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil)
}

// ExecuteCallSequenceWithTracer attaches an executiontracer.ExecutionTracer to ExecuteCallSequenceIteratively and attaches execution traces to the call sequence elements.
func ExecuteCallSequenceWithExecutionTracer(testChain *chain.TestChain, contractDefinitions contracts.Contracts, callSequence CallSequence, verboseTracing bool) (CallSequence, error) {
// Create a new execution tracer
executionTracer := executiontracer.NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts())
defer executionTracer.Close()
getTracerFunc := func(txIndex int, txHash common.Hash) *tracers.Tracer {
return executionTracer.NativeTracer().Tracer
}

// Execute our sequence with a simple fetch operation provided to obtain each element.
fetchElementFunc := func(currentIndex int) (*CallSequenceElement, error) {
Expand All @@ -188,8 +183,8 @@ func ExecuteCallSequenceWithExecutionTracer(testChain *chain.TestChain, contract
return nil, nil
}

// Execute the call sequence
executedCallSeq, err := ExecuteCallSequenceIteratively(testChain, fetchElementFunc, nil, getTracerFunc)
// Execute the call sequence and attach the execution tracer
executedCallSeq, err := ExecuteCallSequenceIteratively(testChain, fetchElementFunc, nil, executionTracer.NativeTracer())

// By default, we only trace the last element in the call sequence.
traceFrom := len(callSequence) - 1
Expand Down
2 changes: 1 addition & 1 deletion fuzzing/corpus/corpus.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (c *Corpus) initializeSequences(sequenceFiles *corpusDirectory[calls.CallSe
}

// Execute each call sequence, populating runtime data and collecting coverage data along the way.
_, err = calls.ExecuteCallSequenceIteratively(testChain, fetchElementFunc, executionCheckFunc, nil)
_, err = calls.ExecuteCallSequenceIteratively(testChain, fetchElementFunc, executionCheckFunc)

// If we failed to replay a sequence and measure coverage due to an unexpected error, report it.
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions fuzzing/executiontracer/execution_tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ func (t *ExecutionTracer) OnTxStart(vm *tracing.VMContext, tx *coretypes.Transac
t.trace = newExecutionTrace(t.contractDefinitions)
t.currentCallFrame = nil
t.onNextCaptureState = nil
t.traceMap = make(map[common.Hash]*ExecutionTrace)

// Store our evm reference
t.evmContext = vm
}
Expand Down
2 changes: 1 addition & 1 deletion fuzzing/fuzzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ func chainSetupFromCompilations(fuzzer *Fuzzer, testChain *chain.TestChain) (*ex
}

// Add our transaction to the block
err = testChain.PendingBlockAddTx(msg.ToCoreMessage(), nil)
err = testChain.PendingBlockAddTx(msg.ToCoreMessage())
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions fuzzing/fuzzer_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func (fw *FuzzerWorker) testNextCallSequence() (calls.CallSequence, []ShrinkCall
}

// Execute our call sequence.
testedCallSequence, err := calls.ExecuteCallSequenceIteratively(fw.chain, fetchElementFunc, executionCheckFunc, nil)
testedCallSequence, err := calls.ExecuteCallSequenceIteratively(fw.chain, fetchElementFunc, executionCheckFunc)

// If we encountered an error, report it.
if err != nil {
Expand Down Expand Up @@ -383,7 +383,7 @@ func (fw *FuzzerWorker) testShrunkenCallSequence(possibleShrunkSequence calls.Ca
}

// Execute our call sequence.
_, err = calls.ExecuteCallSequenceIteratively(fw.chain, fetchElementFunc, executionCheckFunc, nil)
_, err = calls.ExecuteCallSequenceIteratively(fw.chain, fetchElementFunc, executionCheckFunc)
if err != nil {
return false, err
}
Expand Down

0 comments on commit 924247c

Please sign in to comment.