From 0e1763f3c9d48b02a091eba8dc9f23043f9dd100 Mon Sep 17 00:00:00 2001 From: s4nsec Date: Tue, 23 Jul 2024 19:26:38 +0400 Subject: [PATCH] Adapt value generation tracer to interface changes --- fuzzing/executiontracer/execution_tracer.go | 1 - fuzzing/fuzzer_worker.go | 2 +- .../valuegeneration_tracer.go | 125 ++++++++++-------- 3 files changed, 72 insertions(+), 56 deletions(-) diff --git a/fuzzing/executiontracer/execution_tracer.go b/fuzzing/executiontracer/execution_tracer.go index db5632a4..d7b3157c 100644 --- a/fuzzing/executiontracer/execution_tracer.go +++ b/fuzzing/executiontracer/execution_tracer.go @@ -6,7 +6,6 @@ import ( "github.com/crytic/medusa/chain" "github.com/crytic/medusa/fuzzing/contracts" - "github.com/crytic/medusa/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" diff --git a/fuzzing/fuzzer_worker.go b/fuzzing/fuzzer_worker.go index c5ffb39c..a4b870a2 100644 --- a/fuzzing/fuzzer_worker.go +++ b/fuzzing/fuzzer_worker.go @@ -578,7 +578,7 @@ func (fw *FuzzerWorker) run(baseTestChain *chain.TestChain) (bool, error) { // execution and connect it to the chain if fw.fuzzer.config.Fuzzing.Testing.ExperimentalValueGenerationEnabled { fw.valueGenerationTracer = valuegenerationtracer.NewValueGenerationTracer(fw.fuzzer.contractDefinitions) - initializedChain.AddTracer(fw.valueGenerationTracer, true, false) + initializedChain.AddTracer(fw.valueGenerationTracer.NativeTracer(), true, false) } return nil }) diff --git a/fuzzing/valuegenerationtracer/valuegeneration_tracer.go b/fuzzing/valuegenerationtracer/valuegeneration_tracer.go index ba6dcf6d..0c1e95e0 100644 --- a/fuzzing/valuegenerationtracer/valuegeneration_tracer.go +++ b/fuzzing/valuegenerationtracer/valuegeneration_tracer.go @@ -1,6 +1,7 @@ package valuegenerationtracer import ( + "github.com/crytic/medusa/chain" "github.com/crytic/medusa/chain/types" "github.com/crytic/medusa/compilation/abiutils" "github.com/crytic/medusa/fuzzing/contracts" @@ -8,8 +9,10 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" - coreTypes "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/tracing" + coretypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/eth/tracers" "golang.org/x/exp/slices" "math/big" ) @@ -31,7 +34,7 @@ type ValueGenerationTrace struct { type ValueGenerationTracer struct { // evm refers to the EVM instance last captured. - evm *vm.EVM + evmContext *tracing.VMContext // trace represents the current execution trace captured by this tracer. trace *ValueGenerationTrace @@ -47,39 +50,32 @@ type ValueGenerationTracer struct { // after some state is captured, on the next state capture (e.g. detecting a log instruction, but // using this structure to execute code later once the log is committed). onNextCaptureState []func() + + // nativeTracer is the underlying tracer used to capture EVM execution. + nativeTracer *chain.TestChainTracer } +// NativeTracer returns the underlying TestChainTracer. +func (t *ValueGenerationTracer) NativeTracer() *chain.TestChainTracer { + return t.nativeTracer +} func NewValueGenerationTracer(contractDefinitions contracts.Contracts) *ValueGenerationTracer { tracer := &ValueGenerationTracer{ contractDefinitions: contractDefinitions, } - return tracer -} -func (v *ValueGenerationTracer) CaptureTxStart(gasLimit uint64) { - v.trace = newValueGenerationTrace(v.contractDefinitions) - v.currentCallFrame = nil - v.onNextCaptureState = nil -} - -func (v *ValueGenerationTracer) CaptureTxEnd(restGas uint64) { -} - -func (v *ValueGenerationTracer) CaptureStart(env *vm.EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { - v.evm = env - v.captureEnteredCallFrame(from, to, input, create, value) -} - -func (v *ValueGenerationTracer) CaptureEnd(output []byte, gasUsed uint64, err error) { - v.trace.transactionOutputValues = append(v.trace.transactionOutputValues, v.captureExitedCallFrame(output, err)) -} - -func (v *ValueGenerationTracer) CaptureEnter(typ vm.OpCode, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { - v.captureEnteredCallFrame(from, to, input, typ == vm.CREATE || typ == vm.CREATE2, value) -} - -func (v *ValueGenerationTracer) CaptureExit(output []byte, gasUsed uint64, err error) { - v.trace.transactionOutputValues = append(v.trace.transactionOutputValues, v.captureExitedCallFrame(output, err)) + innerTracer := &tracers.Tracer{ + Hooks: &tracing.Hooks{ + OnTxStart: tracer.OnTxStart, + OnEnter: tracer.OnEnter, + OnTxEnd: tracer.OnTxEnd, + OnExit: tracer.OnExit, + OnOpcode: tracer.OnOpcode, + OnLog: tracer.OnLog, + }, + } + tracer.nativeTracer = &chain.TestChainTracer{Tracer: innerTracer, CaptureTxEndSetAdditionalResults: nil} + return tracer } func newValueGenerationTrace(contracts contracts.Contracts) *ValueGenerationTrace { @@ -205,7 +201,7 @@ func (v *ValueGenerationTracer) captureExitedCallFrame(output []byte, err error) if v.currentCallFrame.ToRuntimeBytecode == nil { // As long as this isn't a failed contract creation, we should be able to fetch "to" byte code on exit. if !v.currentCallFrame.IsContractCreation() || err == nil { - v.currentCallFrame.ToRuntimeBytecode = v.evm.StateDB.GetCode(v.currentCallFrame.ToAddress) + v.currentCallFrame.ToRuntimeBytecode = v.evmContext.StateDB.GetCode(v.currentCallFrame.ToAddress) } } if v.currentCallFrame.CodeRuntimeBytecode == nil { @@ -214,7 +210,7 @@ func (v *ValueGenerationTracer) captureExitedCallFrame(output []byte, err error) if v.currentCallFrame.CodeAddress == v.currentCallFrame.ToAddress { v.currentCallFrame.CodeRuntimeBytecode = v.currentCallFrame.ToRuntimeBytecode } else { - v.currentCallFrame.CodeRuntimeBytecode = v.evm.StateDB.GetCode(v.currentCallFrame.CodeAddress) + v.currentCallFrame.CodeRuntimeBytecode = v.evmContext.StateDB.GetCode(v.currentCallFrame.CodeAddress) } } @@ -237,28 +233,6 @@ func (v *ValueGenerationTracer) captureExitedCallFrame(output []byte, err error) return returnValue } -func (v *ValueGenerationTracer) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) { - // TODO: look for RET opcode (for now try getting them from currentCallFrame.ReturnData) - // Execute all "on next capture state" events and clear them. - for _, eventHandler := range v.onNextCaptureState { - eventHandler() - } - v.onNextCaptureState = nil - - // If a log operation occurred, add a deferred operation to capture it. - if op == vm.LOG0 || op == vm.LOG1 || op == vm.LOG2 || op == vm.LOG3 || op == vm.LOG4 { - v.onNextCaptureState = append(v.onNextCaptureState, func() { - logs := v.evm.StateDB.(*state.StateDB).Logs() - if len(logs) > 0 { - v.currentCallFrame.Operations = append(v.currentCallFrame.Operations, logs[len(logs)-1]) - } - }) - } -} - -func (v *ValueGenerationTracer) CaptureFault(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, depth int, err error) { -} - // CaptureTxEndSetAdditionalResults can be used to set additional results captured from execution tracing. If this // tracer is used during transaction execution (block creation), the results can later be queried from the block. // This method will only be called on the added tracer if it implements the extended TestChainTracer interface. @@ -275,12 +249,55 @@ func (v *ValueGenerationTracer) CaptureTxEndSetAdditionalResults(results *types. } +// OnTxStart is called upon the start of transaction execution, as defined by tracers.Tracer. +func (t *ValueGenerationTracer) OnTxStart(vm *tracing.VMContext, tx *coretypes.Transaction, from common.Address) { + t.trace = newValueGenerationTrace(t.contractDefinitions) + t.currentCallFrame = nil + t.onNextCaptureState = nil + // Store our evm reference + t.evmContext = vm +} + +func (t *ValueGenerationTracer) OnEnter(depth int, typ byte, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { + t.captureEnteredCallFrame(from, to, input, typ == byte(vm.CREATE) || typ == byte(vm.CREATE2), value) +} + +func (t *ValueGenerationTracer) OnTxEnd(receipt *coretypes.Receipt, err error) { + +} + +func (t *ValueGenerationTracer) OnExit(depth int, output []byte, used uint64, err error, reverted bool) { + t.trace.transactionOutputValues = append(t.trace.transactionOutputValues, t.captureExitedCallFrame(output, err)) +} + +func (t *ValueGenerationTracer) OnOpcode(pc uint64, op byte, gas uint64, cost uint64, scope tracing.OpContext, data []byte, depth int, err error) { + + // TODO: look for RET opcode (for now try getting them from currentCallFrame.ReturnData) + // Execute all "on next capture state" events and clear them. + for _, eventHandler := range t.onNextCaptureState { + eventHandler() + } + t.onNextCaptureState = nil + +} + +func (t *ValueGenerationTracer) OnLog(log *coretypes.Log) { + + // If a log operation occurred, add a deferred operation to capture it. + t.onNextCaptureState = append(t.onNextCaptureState, func() { + logs := t.evmContext.StateDB.(*state.StateDB).Logs() + if len(logs) > 0 { + t.currentCallFrame.Operations = append(t.currentCallFrame.Operations, logs[len(logs)-1]) + } + }) +} + func (t *ValueGenerationTrace) generateEvents(currentCallFrame *utils.CallFrame, events []any) []any { for _, operation := range currentCallFrame.Operations { if childCallFrame, ok := operation.(*utils.CallFrame); ok { // If this is a call frame being entered, generate information recursively. t.generateEvents(childCallFrame, events) - } else if eventLog, ok := operation.(*coreTypes.Log); ok { + } else if eventLog, ok := operation.(*coretypes.Log); ok { // If an event log was emitted, add a message for it. events = append(events, t.getEventsGenerated(currentCallFrame, eventLog)...) //t.getEventsGenerated(currentCallFrame, eventLog) @@ -290,7 +307,7 @@ func (t *ValueGenerationTrace) generateEvents(currentCallFrame *utils.CallFrame, return events } -func (t *ValueGenerationTrace) getEventsGenerated(callFrame *utils.CallFrame, eventLog *coreTypes.Log) []any { +func (t *ValueGenerationTrace) getEventsGenerated(callFrame *utils.CallFrame, eventLog *coretypes.Log) []any { // Try to unpack our event data eventInputs := make([]any, 0) event, eventInputValues := abiutils.UnpackEventAndValues(callFrame.CodeContractAbi, eventLog)