diff --git a/consensus/spos/errors.go b/consensus/spos/errors.go index c8b5cede565..f5f069d3394 100644 --- a/consensus/spos/errors.go +++ b/consensus/spos/errors.go @@ -243,3 +243,6 @@ var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") // ErrNilFunctionHandler signals that a nil function handler was provided var ErrNilFunctionHandler = errors.New("nil function handler") + +// ErrWrongHashForHeader signals that the hash of the header is not the expected one +var ErrWrongHashForHeader = errors.New("wrong hash for header") diff --git a/consensus/spos/worker.go b/consensus/spos/worker.go index f7159454f2a..f11e40d3089 100644 --- a/consensus/spos/worker.go +++ b/consensus/spos/worker.go @@ -1,6 +1,7 @@ package spos import ( + "bytes" "context" "encoding/hex" "errors" @@ -17,6 +18,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" errorsErd "github.com/multiversx/mx-chain-go/errors" @@ -485,6 +487,11 @@ func (wrk *Worker) doJobOnMessageWithHeader(cnsMsg *consensus.Message) error { "nbTxs", header.GetTxCount(), "val stats root hash", valStatsRootHash) + if !wrk.verifyHeaderHash(headerHash, cnsMsg.Header) { + return fmt.Errorf("%w : received header from consensus with wrong hash", + ErrWrongHashForHeader) + } + err = wrk.headerIntegrityVerifier.Verify(header) if err != nil { return fmt.Errorf("%w : verify header integrity from consensus topic failed", err) @@ -509,6 +516,11 @@ func (wrk *Worker) doJobOnMessageWithHeader(cnsMsg *consensus.Message) error { return nil } +func (wrk *Worker) verifyHeaderHash(hash []byte, marshalledHeader []byte) bool { + computedHash := wrk.hasher.Compute(string(marshalledHeader)) + return bytes.Equal(hash, computedHash) +} + func (wrk *Worker) doJobOnMessageWithSignature(cnsMsg *consensus.Message, p2pMsg p2p.MessageP2P) { wrk.mutDisplayHashConsensusMessage.Lock() defer wrk.mutDisplayHashConsensusMessage.Unlock() diff --git a/consensus/spos/worker_test.go b/consensus/spos/worker_test.go index 59d155e2117..b179fdf0db8 100644 --- a/consensus/spos/worker_test.go +++ b/consensus/spos/worker_test.go @@ -16,6 +16,9 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/mock" @@ -27,8 +30,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const roundTimeDuration = 100 * time.Millisecond @@ -1252,6 +1253,64 @@ func TestWorker_ProcessReceivedMessageWithABadOriginatorShouldErr(t *testing.T) assert.True(t, errors.Is(err, spos.ErrOriginatorMismatch)) } +func TestWorker_ProcessReceivedMessageWithHeaderAndWrongHash(t *testing.T) { + t.Parallel() + + workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) + wrk, _ := spos.NewWorker(workerArgs) + + wrk.SetBlockProcessor( + &testscommon.BlockProcessorStub{ + DecodeBlockHeaderCalled: func(dta []byte) data.HeaderHandler { + return &testscommon.HeaderHandlerStub{ + CheckChainIDCalled: func(reference []byte) error { + return nil + }, + GetPrevHashCalled: func() []byte { + return make([]byte, 0) + }, + } + }, + RevertCurrentBlockCalled: func() { + }, + DecodeBlockBodyCalled: func(dta []byte) data.BodyHandler { + return nil + }, + }, + ) + + hdr := &block.Header{ChainID: chainID} + hdrHash := make([]byte, 32) // wrong hash + hdrStr, _ := mock.MarshalizerMock{}.Marshal(hdr) + cnsMsg := consensus.NewConsensusMessage( + hdrHash, + nil, + nil, + hdrStr, + []byte(wrk.ConsensusState().ConsensusGroup()[0]), + signature, + int(bls.MtBlockHeader), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + buff, _ := wrk.Marshalizer().Marshal(cnsMsg) + msg := &p2pmocks.P2PMessageMock{ + DataField: buff, + PeerField: currentPid, + SignatureField: []byte("signature"), + } + err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + time.Sleep(time.Second) + + assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockHeader])) + assert.ErrorIs(t, err, spos.ErrWrongHashForHeader) +} + func TestWorker_ProcessReceivedMessageOkValsShouldWork(t *testing.T) { t.Parallel()