From c20235d910cf92576f0f8d3ed4f1d52688c27887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Toni=20Ram=C3=ADrez?= <58293609+ToniRamirezM@users.noreply.github.com> Date: Thu, 7 Nov 2024 12:50:32 +0100 Subject: [PATCH] fix: acc input hash calculation (#167) * fix: acc input hash calculation * feat: add unit test * feat: change comment * feat: improve coverage * feat: remove commented code --- aggregator/aggregator.go | 77 ++++++++++++++++++++--------------- aggregator/aggregator_test.go | 51 ++++++++++++++--------- scripts/local_config | 2 +- 3 files changed, 77 insertions(+), 53 deletions(-) diff --git a/aggregator/aggregator.go b/aggregator/aggregator.go index d3479fcf..64f667ad 100644 --- a/aggregator/aggregator.go +++ b/aggregator/aggregator.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "math/big" "net" "strings" @@ -174,7 +175,7 @@ func New( a.ctx, a.exit = context.WithCancel(a.ctx) } - // Set function to handle the batches from the data stream + // Set function to handle events on L1 if !cfg.SyncModeOnlyEnabled { a.l1Syncr.SetCallbackOnReorgDone(a.handleReorg) a.l1Syncr.SetCallbackOnRollbackBatches(a.handleRollbackBatches) @@ -183,6 +184,26 @@ func New( return a, nil } +func (a *Aggregator) getAccInputHash(batchNumber uint64) common.Hash { + a.accInputHashesMutex.Lock() + defer a.accInputHashesMutex.Unlock() + return a.accInputHashes[batchNumber] +} + +func (a *Aggregator) setAccInputHash(batchNumber uint64, accInputHash common.Hash) { + a.accInputHashesMutex.Lock() + defer a.accInputHashesMutex.Unlock() + a.accInputHashes[batchNumber] = accInputHash +} + +func (a *Aggregator) removeAccInputHashes(firstBatch, lastBatch uint64) { + a.accInputHashesMutex.Lock() + defer a.accInputHashesMutex.Unlock() + for i := firstBatch; i <= lastBatch; i++ { + delete(a.accInputHashes, i) + } +} + func (a *Aggregator) handleReorg(reorgData synchronizer.ReorgExecutionResult) { a.logger.Warnf("Reorg detected, reorgData: %+v", reorgData) @@ -246,9 +267,9 @@ func (a *Aggregator) handleRollbackBatches(rollbackData synchronizer.RollbackBat if err == nil { a.accInputHashesMutex.Lock() a.accInputHashes = make(map[uint64]common.Hash) - a.logger.Infof("Starting AccInputHash:%v", accInputHash.String()) - a.accInputHashes[lastVerifiedBatchNumber] = *accInputHash a.accInputHashesMutex.Unlock() + a.logger.Infof("Starting AccInputHash:%v", accInputHash.String()) + a.setAccInputHash(lastVerifiedBatchNumber, *accInputHash) } } @@ -334,9 +355,7 @@ func (a *Aggregator) Start() error { } a.logger.Infof("Starting AccInputHash:%v", accInputHash.String()) - a.accInputHashesMutex.Lock() - a.accInputHashes[lastVerifiedBatchNumber] = *accInputHash - a.accInputHashesMutex.Unlock() + a.setAccInputHash(lastVerifiedBatchNumber, *accInputHash) a.resetVerifyProofTime() @@ -1076,6 +1095,22 @@ func (a *Aggregator) getAndLockBatchToProve( return nil, nil, nil, err } + + if proofExists { + accInputHash := a.getAccInputHash(batchNumberToVerify - 1) + if accInputHash == (common.Hash{}) && batchNumberToVerify > 1 { + tmpLogger.Warnf("AccInputHash for batch %d is not in memory, "+ + "deleting proofs to regenerate acc input hash chain in memory", batchNumberToVerify) + + err := a.state.CleanupGeneratedProofs(ctx, math.MaxInt, nil) + if err != nil { + tmpLogger.Infof("Error cleaning up generated proofs for batch %d", batchNumberToVerify) + return nil, nil, nil, err + } + batchNumberToVerify-- + break + } + } } // Check if the batch has been sequenced @@ -1130,10 +1165,9 @@ func (a *Aggregator) getAndLockBatchToProve( } // Calculate acc input hash as the RPC is not returning the correct one at the moment - a.accInputHashesMutex.Lock() accInputHash := cdkcommon.CalculateAccInputHash( a.logger, - a.accInputHashes[batchNumberToVerify-1], + a.getAccInputHash(batchNumberToVerify-1), virtualBatch.BatchL2Data, *virtualBatch.L1InfoRoot, uint64(sequence.Timestamp.Unix()), @@ -1141,8 +1175,7 @@ func (a *Aggregator) getAndLockBatchToProve( rpcBatch.ForcedBlockHashL1(), ) // Store the acc input hash - a.accInputHashes[batchNumberToVerify] = accInputHash - a.accInputHashesMutex.Unlock() + a.setAccInputHash(batchNumberToVerify, accInputHash) // Log params to calculate acc input hash a.logger.Debugf("Calculated acc input hash for batch %d: %v", batchNumberToVerify, accInputHash) @@ -1473,21 +1506,10 @@ func (a *Aggregator) buildInputProver( } } - // Get Old Acc Input Hash - /* - rpcOldBatch, err := a.rpcClient.GetBatch(batchToVerify.BatchNumber - 1) - if err != nil { - return nil, err - } - */ - - a.accInputHashesMutex.Lock() inputProver := &prover.StatelessInputProver{ PublicInputs: &prover.StatelessPublicInputs{ - Witness: witness, - // Use calculated acc inputh hash as the RPC is not returning the correct one at the moment - // OldAccInputHash: rpcOldBatch.AccInputHash().Bytes(), - OldAccInputHash: a.accInputHashes[batchToVerify.BatchNumber-1].Bytes(), + Witness: witness, + OldAccInputHash: a.getAccInputHash(batchToVerify.BatchNumber - 1).Bytes(), OldBatchNum: batchToVerify.BatchNumber - 1, ChainId: batchToVerify.ChainID, ForkId: batchToVerify.ForkID, @@ -1500,7 +1522,6 @@ func (a *Aggregator) buildInputProver( ForcedBlockhashL1: forcedBlockhashL1.Bytes(), }, } - a.accInputHashesMutex.Unlock() printInputProver(a.logger, inputProver) return inputProver, nil @@ -1594,14 +1615,6 @@ func (a *Aggregator) handleMonitoredTxResult(result ethtxtypes.MonitoredTxResult a.removeAccInputHashes(firstBatch, lastBatch-1) } -func (a *Aggregator) removeAccInputHashes(firstBatch, lastBatch uint64) { - a.accInputHashesMutex.Lock() - for i := firstBatch; i <= lastBatch; i++ { - delete(a.accInputHashes, i) - } - a.accInputHashesMutex.Unlock() -} - func (a *Aggregator) cleanupLockedProofs() { for { select { diff --git a/aggregator/aggregator_test.go b/aggregator/aggregator_test.go index e7284972..506ce16c 100644 --- a/aggregator/aggregator_test.go +++ b/aggregator/aggregator_test.go @@ -1527,34 +1527,27 @@ func Test_tryGenerateBatchProof(t *testing.T) { batchL2Data, err := hex.DecodeString(codedL2Block1) require.NoError(err) l1InfoRoot := common.HexToHash("0x057e9950fbd39b002e323f37c2330d0c096e66919e24cc96fb4b2dfa8f4af782") - batch := state.Batch{ - BatchNumber: lastVerifiedBatchNum + 1, - BatchL2Data: batchL2Data, - L1InfoRoot: l1InfoRoot, - Timestamp: time.Now(), - Coinbase: common.Address{}, - ChainID: uint64(1), - ForkID: uint64(12), - } + virtualBatch := synchronizer.VirtualBatch{ BatchNumber: lastVerifiedBatchNum + 1, BatchL2Data: batchL2Data, L1InfoRoot: &l1InfoRoot, } - m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum+1).Return(&virtualBatch, nil).Once() + m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum).Return(&virtualBatch, nil).Once() m.etherman.On("GetLatestVerifiedBatchNum").Return(lastVerifiedBatchNum, nil).Once() - m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), mock.AnythingOfType("uint64"), nil).Return(false, nil).Once() + m.stateMock.On("CheckProofExistsForBatch", mock.MatchedBy(matchProverCtxFn), lastVerifiedBatchNum+1, nil).Return(true, nil).Once() + m.stateMock.On("CleanupGeneratedProofs", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() sequence := synchronizer.SequencedBatches{ FromBatchNumber: uint64(10), ToBatchNumber: uint64(20), } - m.synchronizerMock.On("GetSequenceByBatchNumber", mock.MatchedBy(matchProverCtxFn), lastVerifiedBatchNum+1).Return(&sequence, nil).Once() + m.synchronizerMock.On("GetSequenceByBatchNumber", mock.MatchedBy(matchProverCtxFn), lastVerifiedBatchNum).Return(&sequence, nil).Once() rpcBatch := rpctypes.NewRPCBatch(lastVerifiedBatchNum+1, common.Hash{}, []string{}, batchL2Data, common.Hash{}, common.BytesToHash([]byte("mock LocalExitRoot")), common.BytesToHash([]byte("mock StateRoot")), common.Address{}, false) rpcBatch.SetLastL2BLockTimestamp(uint64(time.Now().Unix())) - m.rpcMock.On("GetWitness", lastVerifiedBatchNum+1, false).Return([]byte("witness"), nil) - m.rpcMock.On("GetBatch", lastVerifiedBatchNum+1).Return(rpcBatch, nil) + m.rpcMock.On("GetWitness", lastVerifiedBatchNum, false).Return([]byte("witness"), nil) + m.rpcMock.On("GetBatch", lastVerifiedBatchNum).Return(rpcBatch, nil) m.stateMock.On("AddSequence", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Return(nil).Once() m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, nil).Run( func(args mock.Arguments) { @@ -1569,17 +1562,14 @@ func Test_tryGenerateBatchProof(t *testing.T) { assert.InDelta(time.Now().Unix(), proof.GeneratingSince.Unix(), float64(time.Second)) }, ).Return(nil).Once() - m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil).Twice() + m.synchronizerMock.On("GetLeafsByL1InfoRoot", mock.Anything, l1InfoRoot).Return(l1InfoTreeLeaf, nil) m.synchronizerMock.On("GetL1InfoTreeLeaves", mock.Anything, mock.Anything).Return(map[uint32]synchronizer.L1InfoTreeLeaf{ 1: { BlockNumber: uint64(35), }, - }, nil).Twice() - - expectedInputProver, err := a.buildInputProver(context.Background(), &batch, []byte("witness")) - require.NoError(err) + }, nil) - m.proverMock.On("BatchProof", expectedInputProver).Return(nil, errTest).Once() + m.proverMock.On("BatchProof", mock.Anything).Return(nil, errTest).Once() m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(nil).Once() }, asserts: func(result bool, a *Aggregator, err error) { @@ -1969,3 +1959,24 @@ func Test_tryGenerateBatchProof(t *testing.T) { }) } } + +func Test_accInputHashFunctions(t *testing.T) { + aggregator := Aggregator{ + accInputHashes: make(map[uint64]common.Hash), + accInputHashesMutex: &sync.Mutex{}, + } + + hash1 := common.BytesToHash([]byte("hash1")) + hash2 := common.BytesToHash([]byte("hash2")) + + aggregator.setAccInputHash(1, hash1) + aggregator.setAccInputHash(2, hash2) + + assert.Equal(t, 2, len(aggregator.accInputHashes)) + + hash3 := aggregator.getAccInputHash(1) + assert.Equal(t, hash1, hash3) + + aggregator.removeAccInputHashes(1, 2) + assert.Equal(t, 0, len(aggregator.accInputHashes)) +} diff --git a/scripts/local_config b/scripts/local_config index d1a47b2c..b65210ac 100755 --- a/scripts/local_config +++ b/scripts/local_config @@ -206,7 +206,7 @@ function export_portnum_from_kurtosis_or_fail(){ ############################################################################### function export_ports_from_kurtosis(){ export_portnum_from_kurtosis_or_fail l1_rpc_port el-1-geth-lighthouse rpc - export_portnum_from_kurtosis_or_fail zkevm_rpc_http_port cdk-erigon-node-001 http-rpc rpc + export_portnum_from_kurtosis_or_fail zkevm_rpc_http_port cdk-erigon-rpc-001 http-rpc rpc export_portnum_from_kurtosis_or_fail zkevm_data_streamer_port cdk-erigon-sequencer-001 data-streamer export_portnum_from_kurtosis_or_fail aggregator_db_port postgres-001 postgres export_portnum_from_kurtosis_or_fail agglayer_port agglayer agglayer