diff --git a/activation/handler_test.go b/activation/handler_test.go index 29a0f6e3e78..76e8bd22788 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -224,79 +224,94 @@ func newTestHandler(tb testing.TB, goldenATXID types.ATXID, opts ...HandlerOptio } } -func testHandler_PostMalfeasanceProofs(t *testing.T, synced bool) { - goldenATXID := types.ATXID{2, 3, 4} - atxHdlr := newTestHandler(t, goldenATXID) - - sig, err := signing.NewEdSigner() - require.NoError(t, err) +func TestHandler_PostMalfeasanceProofs(t *testing.T) { + t.Run("produced but not published during sync", func(t *testing.T) { + t.Parallel() + goldenATXID := types.ATXID{2, 3, 4} + atxHdlr := newTestHandler(t, goldenATXID) + sig, err := signing.NewEdSigner() + require.NoError(t, err) - _, err = identities.GetMalfeasanceProof(atxHdlr.cdb, sig.NodeID()) - require.ErrorIs(t, err, sql.ErrNotFound) + malicious, err := identities.IsMalicious(atxHdlr.cdb, sig.NodeID()) + require.NoError(t, err) + require.False(t, malicious) - atx := newInitialATXv1(t, goldenATXID) - atx.Sign(sig) + atx := newInitialATXv1(t, goldenATXID) + atx.Sign(sig) - var got mwire.MalfeasanceGossip - atxHdlr.mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) - atxHdlr.mValidator.EXPECT().VRFNonce(atx.SmesherID, goldenATXID, *atx.VRFNonce, gomock.Any(), atx.NumUnits) - atxHdlr.mValidator.EXPECT(). - Post(gomock.Any(), gomock.Any(), *atx.CommitmentATXID, gomock.Any(), gomock.Any(), atx.NumUnits) - atxHdlr.mockFetch.EXPECT().RegisterPeerHashes(gomock.Any(), gomock.Any()) - atxHdlr.mockFetch.EXPECT().GetPoetProof(gomock.Any(), gomock.Any()) - atxHdlr.mValidator.EXPECT().InitialNIPostChallengeV1(gomock.Any(), gomock.Any(), goldenATXID) - atxHdlr.mValidator.EXPECT().PositioningAtx(atx.PositioningATXID, gomock.Any(), goldenATXID, atx.PublishEpoch) - atxHdlr.mValidator.EXPECT(). - NIPost(gomock.Any(), atx.SmesherID, goldenATXID, gomock.Any(), gomock.Any(), atx.NumUnits, gomock.Any()). - Return(0, &verifying.ErrInvalidIndex{Index: 2}) - atxHdlr.mtortoise.EXPECT().OnMalfeasance(gomock.Any()) - - msg := codec.MustEncode(atx) - if synced { + atxHdlr.mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) + atxHdlr.mValidator.EXPECT().VRFNonce(atx.SmesherID, goldenATXID, *atx.VRFNonce, gomock.Any(), atx.NumUnits) + atxHdlr.mValidator.EXPECT(). + Post(gomock.Any(), gomock.Any(), *atx.CommitmentATXID, gomock.Any(), gomock.Any(), atx.NumUnits) + atxHdlr.mockFetch.EXPECT().RegisterPeerHashes(gomock.Any(), gomock.Any()) + atxHdlr.mockFetch.EXPECT().GetPoetProof(gomock.Any(), gomock.Any()) + atxHdlr.mValidator.EXPECT().InitialNIPostChallengeV1(gomock.Any(), gomock.Any(), goldenATXID) + atxHdlr.mValidator.EXPECT().PositioningAtx(atx.PositioningATXID, gomock.Any(), goldenATXID, atx.PublishEpoch) + atxHdlr.mValidator.EXPECT(). + NIPost(gomock.Any(), atx.SmesherID, goldenATXID, gomock.Any(), gomock.Any(), atx.NumUnits, gomock.Any()). + Return(0, &verifying.ErrInvalidIndex{Index: 2}) + atxHdlr.mtortoise.EXPECT().OnMalfeasance(gomock.Any()) + + msg := codec.MustEncode(atx) require.NoError(t, atxHdlr.HandleSyncedAtx(context.Background(), types.Hash32{}, p2p.NoPeer, msg)) - } else { + + // identity is still marked as malicious + malicious, err = identities.IsMalicious(atxHdlr.cdb, sig.NodeID()) + require.NoError(t, err) + require.True(t, malicious) + }) + + t.Run("produced and published during gossip", func(t *testing.T) { + t.Parallel() + goldenATXID := types.ATXID{2, 3, 4} + atxHdlr := newTestHandler(t, goldenATXID) + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + malicious, err := identities.IsMalicious(atxHdlr.cdb, sig.NodeID()) + require.NoError(t, err) + require.False(t, malicious) + + atx := newInitialATXv1(t, goldenATXID) + atx.Sign(sig) + + atxHdlr.mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) + atxHdlr.mValidator.EXPECT().VRFNonce(atx.SmesherID, goldenATXID, *atx.VRFNonce, gomock.Any(), atx.NumUnits) + atxHdlr.mValidator.EXPECT(). + Post(gomock.Any(), gomock.Any(), *atx.CommitmentATXID, gomock.Any(), gomock.Any(), atx.NumUnits) + atxHdlr.mockFetch.EXPECT().RegisterPeerHashes(gomock.Any(), gomock.Any()) + atxHdlr.mockFetch.EXPECT().GetPoetProof(gomock.Any(), gomock.Any()) + atxHdlr.mValidator.EXPECT().InitialNIPostChallengeV1(gomock.Any(), gomock.Any(), goldenATXID) + atxHdlr.mValidator.EXPECT().PositioningAtx(atx.PositioningATXID, gomock.Any(), goldenATXID, atx.PublishEpoch) + atxHdlr.mValidator.EXPECT(). + NIPost(gomock.Any(), atx.SmesherID, goldenATXID, gomock.Any(), gomock.Any(), atx.NumUnits, gomock.Any()). + Return(0, &verifying.ErrInvalidIndex{Index: 2}) + atxHdlr.mtortoise.EXPECT().OnMalfeasance(gomock.Any()) + msg := codec.MustEncode(atx) + postVerifier := NewMockPostVerifier(gomock.NewController(t)) - mh := NewInvalidPostIndexHandler(atxHdlr.cdb, - atxHdlr.logger, - atxHdlr.edVerifier, - postVerifier, - ) - atxHdlr.mpub.EXPECT().Publish(gomock.Any(), pubsub.MalfeasanceProof, gomock.Any()).DoAndReturn( - func(_ context.Context, _ string, data []byte) error { + mh := NewInvalidPostIndexHandler(atxHdlr.cdb, atxHdlr.logger, atxHdlr.edVerifier, postVerifier) + atxHdlr.mpub.EXPECT().Publish(gomock.Any(), pubsub.MalfeasanceProof, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, data []byte) error { + var got mwire.MalfeasanceGossip require.NoError(t, codec.Decode(data, &got)) require.Equal(t, mwire.InvalidPostIndex, got.Proof.Type) - - postVerifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + postVerifier.EXPECT(). + Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(errors.New("invalid")) nodeID, err := mh.Validate(context.Background(), got.Proof.Data) require.NoError(t, err) require.Equal(t, sig.NodeID(), nodeID) - p, ok := got.Proof.Data.(*mwire.InvalidPostIndexProof) require.True(t, ok) require.EqualValues(t, 2, p.InvalidIdx) return nil }) require.ErrorIs(t, atxHdlr.HandleGossipAtx(context.Background(), p2p.NoPeer, msg), errMaliciousATX) - } - - proof, err := identities.GetMalfeasanceProof(atxHdlr.cdb, atx.SmesherID) - require.NoError(t, err) - require.NotNil(t, proof.Received()) - proof.SetReceived(time.Time{}) - if !synced { - require.Equal(t, got.MalfeasanceProof, *proof) - require.Equal(t, atx.PublishEpoch.FirstLayer(), got.MalfeasanceProof.Layer) - } -} - -func TestHandler_PostMalfeasanceProofs(t *testing.T) { - t.Run("produced but not published during sync", func(t *testing.T) { - testHandler_PostMalfeasanceProofs(t, true) - }) - t.Run("produced and published during gossip", func(t *testing.T) { - testHandler_PostMalfeasanceProofs(t, false) + malicious, err = identities.IsMalicious(atxHdlr.cdb, sig.NodeID()) + require.NoError(t, err) + require.True(t, malicious) }) } @@ -400,31 +415,67 @@ func TestHandler_HandleParallelGossipAtxV1(t *testing.T) { require.NoError(t, eg.Wait()) } -func testHandler_HandleDoublePublish(t *testing.T, synced bool) { - t.Parallel() - goldenATXID := types.ATXID{2, 3, 4} - sig, err := signing.NewEdSigner() - require.NoError(t, err) - hdlr := newTestHandler(t, goldenATXID) +func TestHandler_HandleMaliciousAtx(t *testing.T) { + t.Run("produced but not published during sync", func(t *testing.T) { + t.Parallel() + goldenATXID := types.ATXID{2, 3, 4} + sig, err := signing.NewEdSigner() + require.NoError(t, err) - atx1 := newInitialATXv1(t, goldenATXID) - atx1.Sign(sig) - hdlr.expectAtxV1(atx1, sig.NodeID()) - require.NoError(t, hdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(atx1))) + atxHdlr := newTestHandler(t, goldenATXID) + atx1 := newInitialATXv1(t, goldenATXID) + atx1.Sign(sig) + atxHdlr.expectAtxV1(atx1, sig.NodeID()) + require.NoError(t, atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(atx1))) - atx2 := newInitialATXv1(t, goldenATXID, func(a *wire.ActivationTxV1) { a.NumUnits = atx1.NumUnits + 1 }) - atx2.Sign(sig) - hdlr.expectAtxV1(atx2, sig.NodeID()) - hdlr.mtortoise.EXPECT().OnMalfeasance(sig.NodeID()) + malicious, err := identities.IsMalicious(atxHdlr.cdb, sig.NodeID()) + require.NoError(t, err) + require.False(t, malicious) - msg := codec.MustEncode(atx2) - var got mwire.MalfeasanceGossip - if synced { - require.NoError(t, hdlr.HandleSyncedAtx(context.Background(), types.Hash32{}, "", msg)) - } else { - mh := NewMalfeasanceHandler(hdlr.cdb, hdlr.logger, hdlr.edVerifier) - hdlr.mpub.EXPECT().Publish(gomock.Any(), pubsub.MalfeasanceProof, gomock.Any()).DoAndReturn( - func(_ context.Context, _ string, data []byte) error { + atx2 := newInitialATXv1(t, goldenATXID, func(a *wire.ActivationTxV1) { + a.NumUnits = atx1.NumUnits + 1 + }) + atx2.Sign(sig) + atxHdlr.expectAtxV1(atx2, sig.NodeID()) + + atxHdlr.mtortoise.EXPECT().OnMalfeasance(sig.NodeID()) + msg := codec.MustEncode(atx2) + require.NoError(t, atxHdlr.HandleSyncedAtx(context.Background(), types.Hash32{}, "", msg)) + + // identity is still marked as malicious + malicious, err = identities.IsMalicious(atxHdlr.cdb, sig.NodeID()) + require.NoError(t, err) + require.True(t, malicious) + }) + + t.Run("produced and published during gossip", func(t *testing.T) { + t.Parallel() + goldenATXID := types.ATXID{2, 3, 4} + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + atxHdlr := newTestHandler(t, goldenATXID) + atx1 := newInitialATXv1(t, goldenATXID) + atx1.Sign(sig) + atxHdlr.expectAtxV1(atx1, sig.NodeID()) + require.NoError(t, atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(atx1))) + + malicious, err := identities.IsMalicious(atxHdlr.cdb, sig.NodeID()) + require.NoError(t, err) + require.False(t, malicious) + + atx2 := newInitialATXv1(t, goldenATXID, func(a *wire.ActivationTxV1) { + a.NumUnits = atx1.NumUnits + 1 + }) + atx2.Sign(sig) + atxHdlr.expectAtxV1(atx2, sig.NodeID()) + atxHdlr.mtortoise.EXPECT().OnMalfeasance(sig.NodeID()) + msg := codec.MustEncode(atx2) + + mh := NewMalfeasanceHandler(atxHdlr.cdb, atxHdlr.logger, atxHdlr.edVerifier) + atxHdlr.mpub.EXPECT().Publish(gomock.Any(), pubsub.MalfeasanceProof, gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, data []byte) error { + var got mwire.MalfeasanceGossip require.NoError(t, codec.Decode(data, &got)) require.Equal(t, mwire.MultipleATXs, got.Proof.Type) nodeID, err := mh.Validate(context.Background(), got.Proof.Data) @@ -432,25 +483,11 @@ func testHandler_HandleDoublePublish(t *testing.T, synced bool) { require.Equal(t, sig.NodeID(), nodeID) return nil }) - require.ErrorIs(t, hdlr.HandleGossipAtx(context.Background(), p2p.NoPeer, msg), errMaliciousATX) - } - - proof, err := identities.GetMalfeasanceProof(hdlr.cdb, sig.NodeID()) - require.NoError(t, err) - require.NotNil(t, proof) - if !synced { - proof.SetReceived(time.Time{}) - require.Equal(t, got.MalfeasanceProof, *proof) - } -} - -func TestHandler_HandleMaliciousAtx(t *testing.T) { - t.Run("produced but not published during sync", func(t *testing.T) { - testHandler_HandleDoublePublish(t, true) - }) + require.ErrorIs(t, atxHdlr.HandleGossipAtx(context.Background(), p2p.NoPeer, msg), errMaliciousATX) - t.Run("produced and published during gossip", func(t *testing.T) { - testHandler_HandleDoublePublish(t, false) + malicious, err = identities.IsMalicious(atxHdlr.cdb, sig.NodeID()) + require.NoError(t, err) + require.True(t, malicious) }) } diff --git a/api/grpcserver/activation_service.go b/api/grpcserver/activation_service.go index 4ba005a9706..4b5bb0c95ec 100644 --- a/api/grpcserver/activation_service.go +++ b/api/grpcserver/activation_service.go @@ -72,7 +72,7 @@ func (s *activationService) Get(ctx context.Context, request *pb.GetRequest) (*p return nil, status.Error(codes.Internal, "couldn't get previous ATXs") } - proof, err := s.atxProvider.GetMalfeasanceProof(atx.SmesherID) + proof, err := s.atxProvider.MalfeasanceProof(atx.SmesherID) if err != nil && !errors.Is(err, sql.ErrNotFound) { ctxzap.Error(ctx, "failed to get malfeasance proof", zap.Stringer("smesher", atx.SmesherID), diff --git a/api/grpcserver/activation_service_test.go b/api/grpcserver/activation_service_test.go index 3339cdb4dba..756501ac856 100644 --- a/api/grpcserver/activation_service_test.go +++ b/api/grpcserver/activation_service_test.go @@ -133,7 +133,7 @@ func TestGet_HappyPath(t *testing.T) { } atx.SetID(id) atxProvider.EXPECT().GetAtx(id).Return(&atx, nil) - atxProvider.EXPECT().GetMalfeasanceProof(gomock.Any()).Return(nil, sql.ErrNotFound) + atxProvider.EXPECT().MalfeasanceProof(gomock.Any()).Return(nil, sql.ErrNotFound) atxProvider.EXPECT().Previous(id).Return(previous, nil) response, err := activationService.Get(context.Background(), &pb.GetRequest{Id: id.Bytes()}) @@ -169,7 +169,7 @@ func TestGet_IdentityCanceled(t *testing.T) { } atx.SetID(id) atxProvider.EXPECT().GetAtx(id).Return(&atx, nil) - atxProvider.EXPECT().GetMalfeasanceProof(smesher).Return(proof, nil) + atxProvider.EXPECT().MalfeasanceProof(smesher).Return(proof, nil) atxProvider.EXPECT().Previous(id).Return([]types.ATXID{previous}, nil) response, err := activationService.Get(context.Background(), &pb.GetRequest{Id: id.Bytes()}) diff --git a/api/grpcserver/interface.go b/api/grpcserver/interface.go index bfca33b78b9..d8a4b5f5bbb 100644 --- a/api/grpcserver/interface.go +++ b/api/grpcserver/interface.go @@ -58,7 +58,7 @@ type atxProvider interface { GetAtx(id types.ATXID) (*types.ActivationTx, error) Previous(id types.ATXID) ([]types.ATXID, error) MaxHeightAtx() (types.ATXID, error) - GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) + MalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) } type postState interface { diff --git a/api/grpcserver/mesh_service.go b/api/grpcserver/mesh_service.go index 762e7891de1..7d08ff40374 100644 --- a/api/grpcserver/mesh_service.go +++ b/api/grpcserver/mesh_service.go @@ -611,7 +611,7 @@ func (s *MeshService) MalfeasanceQuery( fmt.Sprintf("invalid smesher id length (%d), expected (%d)", l, types.NodeIDSize)) } id := types.BytesToNodeID(parsed) - proof, err := s.cdb.GetMalfeasanceProof(id) + proof, err := s.cdb.MalfeasanceProof(id) if err != nil && !errors.Is(err, sql.ErrNotFound) { return nil, status.Error(codes.Internal, err.Error()) } diff --git a/api/grpcserver/mocks.go b/api/grpcserver/mocks.go index 1dab284f2ba..a9eab2e4d6b 100644 --- a/api/grpcserver/mocks.go +++ b/api/grpcserver/mocks.go @@ -912,41 +912,41 @@ func (c *MockatxProviderGetAtxCall) DoAndReturn(f func(types.ATXID) (*types.Acti return c } -// GetMalfeasanceProof mocks base method. -func (m *MockatxProvider) GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) { +// MalfeasanceProof mocks base method. +func (m *MockatxProvider) MalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMalfeasanceProof", id) + ret := m.ctrl.Call(m, "MalfeasanceProof", id) ret0, _ := ret[0].(*wire.MalfeasanceProof) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetMalfeasanceProof indicates an expected call of GetMalfeasanceProof. -func (mr *MockatxProviderMockRecorder) GetMalfeasanceProof(id any) *MockatxProviderGetMalfeasanceProofCall { +// MalfeasanceProof indicates an expected call of MalfeasanceProof. +func (mr *MockatxProviderMockRecorder) MalfeasanceProof(id any) *MockatxProviderMalfeasanceProofCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMalfeasanceProof", reflect.TypeOf((*MockatxProvider)(nil).GetMalfeasanceProof), id) - return &MockatxProviderGetMalfeasanceProofCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MalfeasanceProof", reflect.TypeOf((*MockatxProvider)(nil).MalfeasanceProof), id) + return &MockatxProviderMalfeasanceProofCall{Call: call} } -// MockatxProviderGetMalfeasanceProofCall wrap *gomock.Call -type MockatxProviderGetMalfeasanceProofCall struct { +// MockatxProviderMalfeasanceProofCall wrap *gomock.Call +type MockatxProviderMalfeasanceProofCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockatxProviderGetMalfeasanceProofCall) Return(arg0 *wire.MalfeasanceProof, arg1 error) *MockatxProviderGetMalfeasanceProofCall { +func (c *MockatxProviderMalfeasanceProofCall) Return(arg0 *wire.MalfeasanceProof, arg1 error) *MockatxProviderMalfeasanceProofCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockatxProviderGetMalfeasanceProofCall) Do(f func(types.NodeID) (*wire.MalfeasanceProof, error)) *MockatxProviderGetMalfeasanceProofCall { +func (c *MockatxProviderMalfeasanceProofCall) Do(f func(types.NodeID) (*wire.MalfeasanceProof, error)) *MockatxProviderMalfeasanceProofCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockatxProviderGetMalfeasanceProofCall) DoAndReturn(f func(types.NodeID) (*wire.MalfeasanceProof, error)) *MockatxProviderGetMalfeasanceProofCall { +func (c *MockatxProviderMalfeasanceProofCall) DoAndReturn(f func(types.NodeID) (*wire.MalfeasanceProof, error)) *MockatxProviderMalfeasanceProofCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/datastore/store.go b/datastore/store.go index 073c7bacdc9..3eaa6db83cf 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -10,6 +10,7 @@ import ( "go.uber.org/zap" "github.com/spacemeshos/go-spacemesh/atxsdata" + "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/proposals/store" @@ -119,7 +120,7 @@ func (db *CachedDB) MalfeasanceCacheSize() int { } // GetMalfeasanceProof gets the malfeasance proof associated with the NodeID. -func (db *CachedDB) GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) { +func (db *CachedDB) MalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) { if id == types.EmptyNodeID { panic("invalid argument to GetMalfeasanceProof") } @@ -133,10 +134,13 @@ func (db *CachedDB) GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof return proof, nil } - proof, err := identities.GetMalfeasanceProof(db.Database, id) + var blob sql.Blob + err := identities.LoadMalfeasanceBlob(context.Background(), db.Database, id.Bytes(), &blob) if err != nil && err != sql.ErrNotFound { return nil, err } + proof := &wire.MalfeasanceProof{} + codec.MustDecode(blob.Bytes, proof) db.malfeasanceCache.Add(id, proof) return proof, err } @@ -203,7 +207,7 @@ func (db *CachedDB) IterateMalfeasanceProofs( return err } for _, id := range ids { - proof, err := db.GetMalfeasanceProof(id) + proof, err := db.MalfeasanceProof(id) if err != nil { return err } diff --git a/datastore/store_test.go b/datastore/store_test.go index 043bc3d0d4d..a5f474a3422 100644 --- a/datastore/store_test.go +++ b/datastore/store_test.go @@ -74,7 +74,7 @@ func TestMalfeasanceProof_Dishonest(t *testing.T) { cdb.CacheMalfeasanceProof(nodeID1, proof) require.Equal(t, 1, cdb.MalfeasanceCacheSize()) - got, err := cdb.GetMalfeasanceProof(nodeID1) + got, err := cdb.MalfeasanceProof(nodeID1) require.NoError(t, err) require.EqualValues(t, proof, got) } diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index 86532ce6771..0a9d09d40f3 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -146,9 +146,9 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) { err := h.HandleMalfeasanceProof(context.Background(), "peer", codec.MustEncode(gossip)) require.NoError(t, err) - malProof, err := identities.GetMalfeasanceProof(h.db, nodeID) - require.NoError(t, err) - require.NotEqual(t, gossip.MalfeasanceProof, *malProof) + var blob sql.Blob + require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob)) + require.Equal(t, codec.MustEncode(&gossip.MalfeasanceProof), blob.Bytes) }) t.Run("new proof is noop", func(t *testing.T) { @@ -187,10 +187,9 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) { err := h.HandleMalfeasanceProof(context.Background(), "peer", codec.MustEncode(gossip)) require.ErrorIs(t, ErrKnownProof, err) - malProof, err := identities.GetMalfeasanceProof(h.db, nodeID) - require.NoError(t, err) - malProof.SetReceived(time.Time{}) - require.Equal(t, proof, malProof) + var blob sql.Blob + require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob)) + require.Equal(t, codec.MustEncode(proof), blob.Bytes) }) } @@ -318,19 +317,15 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { Data: &wire.AtxProof{}, }, } + proofBytes := codec.MustEncode(proof) h.mockTrt.EXPECT().OnMalfeasance(nodeID) - err := h.HandleSyncedMalfeasanceProof( - context.Background(), - types.Hash32(nodeID), - "peer", - codec.MustEncode(proof), - ) + err := h.HandleSyncedMalfeasanceProof(context.Background(), types.Hash32(nodeID), "peer", proofBytes) require.NoError(t, err) - malProof, err := identities.GetMalfeasanceProof(h.db, nodeID) - require.NoError(t, err) - require.NotEqual(t, proof, *malProof) + var blob sql.Blob + require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob)) + require.Equal(t, proofBytes, blob.Bytes) }) t.Run("new proof is noop", func(t *testing.T) { @@ -344,7 +339,8 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { Data: &wire.BallotProof{}, }, } - identities.SetMalicious(h.db, nodeID, codec.MustEncode(proof), time.Now()) + proofBytes := codec.MustEncode(proof) + identities.SetMalicious(h.db, nodeID, proofBytes, time.Now()) ctrl := gomock.NewController(t) handler := NewMockHandlerV1(ctrl) @@ -363,18 +359,14 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { Data: &wire.AtxProof{}, }, } + newProofBytes := codec.MustEncode(newProof) + require.NotEqual(t, proofBytes, newProofBytes) - err := h.HandleSyncedMalfeasanceProof( - context.Background(), - types.Hash32(nodeID), - "peer", - codec.MustEncode(newProof), - ) + err := h.HandleSyncedMalfeasanceProof(context.Background(), types.Hash32(nodeID), "peer", newProofBytes) require.ErrorIs(t, ErrKnownProof, err) - malProof, err := identities.GetMalfeasanceProof(h.db, nodeID) - require.NoError(t, err) - malProof.SetReceived(time.Time{}) - require.Equal(t, proof, malProof) + var blob sql.Blob + require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob)) + require.Equal(t, proofBytes, blob.Bytes) }) } diff --git a/mesh/mesh_test.go b/mesh/mesh_test.go index f32362948ca..a355ff590d1 100644 --- a/mesh/mesh_test.go +++ b/mesh/mesh_test.go @@ -379,9 +379,10 @@ func TestMesh_MaliciousBallots(t *testing.T) { mal, err := identities.IsMalicious(tm.cdb, sig.NodeID()) require.NoError(t, err) require.False(t, mal) - saved, err := identities.GetMalfeasanceProof(tm.cdb, sig.NodeID()) - require.ErrorIs(t, err, sql.ErrNotFound) - require.Nil(t, saved) + + malicious, err := identities.IsMalicious(tm.cdb, sig.NodeID()) + require.NoError(t, err) + require.False(t, malicious) // second one will create a MalfeasanceProof tm.mockTortoise.EXPECT().OnMalfeasance(sig.NodeID()) diff --git a/sql/identities/identities.go b/sql/identities/identities.go index ff4d9d804e0..257bec03ee6 100644 --- a/sql/identities/identities.go +++ b/sql/identities/identities.go @@ -7,9 +7,7 @@ import ( sqlite "github.com/go-llsqlite/crawshaw" - "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/sql" ) @@ -49,35 +47,6 @@ func IsMalicious(db sql.Executor, nodeID types.NodeID) (bool, error) { return rows > 0, nil } -// GetMalfeasanceProof returns the malfeasance proof for the given identity. -func GetMalfeasanceProof(db sql.Executor, nodeID types.NodeID) (*wire.MalfeasanceProof, error) { - var ( - data []byte - received time.Time - ) - rows, err := db.Exec("select proof, received from identities where pubkey = ?1 AND proof IS NOT NULL;", - func(stmt *sql.Statement) { - stmt.BindBytes(1, nodeID.Bytes()) - }, func(stmt *sql.Statement) bool { - data = make([]byte, stmt.ColumnLen(0)) - stmt.ColumnBytes(0, data[:]) - received = time.Unix(0, stmt.ColumnInt64(1)).Local() - return true - }) - if err != nil { - return nil, fmt.Errorf("proof %v: %w", nodeID, err) - } - if rows == 0 { - return nil, sql.ErrNotFound - } - var proof wire.MalfeasanceProof - if err = codec.Decode(data, &proof); err != nil { - return nil, err - } - proof.SetReceived(received.Local()) - return &proof, nil -} - // GetBlobSizes returns the sizes of the blobs corresponding to malfeasance proofs for the // specified identities. For non-existent proofs, the corresponding items are set to -1. func GetBlobSizes(db sql.Executor, ids [][]byte) (sizes []int, err error) { @@ -85,8 +54,12 @@ func GetBlobSizes(db sql.Executor, ids [][]byte) (sizes []int, err error) { } // LoadMalfeasanceBlob returns the malfeasance proof in raw bytes for the given identity. -func LoadMalfeasanceBlob(ctx context.Context, db sql.Executor, nodeID []byte, blob *sql.Blob) error { - return sql.LoadBlob(db, "select proof from identities where pubkey = ?1;", nodeID, blob) +func LoadMalfeasanceBlob(_ context.Context, db sql.Executor, nodeID []byte, blob *sql.Blob) error { + err := sql.LoadBlob(db, "select proof from identities where pubkey = ?1;", nodeID, blob) + if err == nil && len(blob.Bytes) == 0 { + return sql.ErrNotFound + } + return err } // IterateMalicious invokes the specified callback for each malicious node ID. @@ -117,20 +90,20 @@ func IterateMalicious( } // GetMalicious retrieves malicious node IDs from the database. -func GetMalicious(db sql.Executor) (nids []types.NodeID, err error) { +func GetMalicious(db sql.Executor) (ids []types.NodeID, err error) { if err = IterateMalicious(db, func(total int, nid types.NodeID) error { - if nids == nil { - nids = make([]types.NodeID, 0, total) + if ids == nil { + ids = make([]types.NodeID, 0, total) } - nids = append(nids, nid) + ids = append(ids, nid) return nil }); err != nil { return nil, err } - if len(nids) != cap(nids) { + if len(ids) != cap(ids) { panic("BUG: bad malicious node ID count") } - return nids, nil + return ids, nil } // MarriageATX obtains the marriage ATX for given ID. @@ -186,7 +159,7 @@ func Marriage(db sql.Executor, id types.NodeID) (*MarriageData, error) { } // Set marriage inserts marriage data for given identity. -// If identitty doesn't exist - create it. +// If identity doesn't exist - create it. func SetMarriage(db sql.Executor, id types.NodeID, m *MarriageData) error { _, err := db.Exec(` INSERT INTO identities (pubkey, marriage_atx, marriage_idx, marriage_target, marriage_signature) diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index 0d18ef5b79b..555a9be270b 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -1,4 +1,4 @@ -package identities +package identities_test import ( "context" @@ -11,6 +11,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) @@ -18,7 +19,7 @@ func TestMalicious(t *testing.T) { db := statesql.InMemory() nodeID := types.NodeID{1, 1, 1, 1} - mal, err := IsMalicious(db, nodeID) + mal, err := identities.IsMalicious(db, nodeID) require.NoError(t, err) require.False(t, mal) @@ -40,29 +41,26 @@ func TestMalicious(t *testing.T) { Data: &ballotProof, }, } - now := time.Now() - data, err := codec.Encode(proof) - require.NoError(t, err) - require.NoError(t, SetMalicious(db, nodeID, data, now)) + require.NoError(t, identities.SetMalicious(db, nodeID, codec.MustEncode(proof), time.Now())) - mal, err = IsMalicious(db, nodeID) + mal, err = identities.IsMalicious(db, nodeID) require.NoError(t, err) require.True(t, mal) - mal, err = IsMalicious(db, types.RandomNodeID()) + mal, err = identities.IsMalicious(db, types.RandomNodeID()) require.NoError(t, err) require.False(t, mal) - got, err := GetMalfeasanceProof(db, nodeID) - require.NoError(t, err) - require.Equal(t, now.UTC(), got.Received().UTC()) - got.SetReceived(time.Time{}) - require.EqualValues(t, proof, got) + var blob sql.Blob + require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), db, nodeID.Bytes(), &blob)) + got := &wire.MalfeasanceProof{} + codec.MustDecode(blob.Bytes, got) + require.Equal(t, proof, got) } func Test_GetMalicious(t *testing.T) { db := statesql.InMemory() - got, err := GetMalicious(db) + got, err := identities.GetMalicious(db) require.NoError(t, err) require.Nil(t, got) @@ -71,9 +69,9 @@ func Test_GetMalicious(t *testing.T) { for i := 0; i < numBad; i++ { nid := types.NodeID{byte(i + 1)} bad = append(bad, nid) - require.NoError(t, SetMalicious(db, nid, types.RandomBytes(11), time.Now().Local())) + require.NoError(t, identities.SetMalicious(db, nid, types.RandomBytes(11), time.Now().Local())) } - got, err = GetMalicious(db) + got, err = identities.GetMalicious(db) require.NoError(t, err) require.Equal(t, bad, got) } @@ -84,24 +82,24 @@ func TestLoadMalfeasanceBlob(t *testing.T) { nid1 := types.RandomNodeID() proof1 := types.RandomBytes(11) - SetMalicious(db, nid1, proof1, time.Now().Local()) + identities.SetMalicious(db, nid1, proof1, time.Now().Local()) var blob1 sql.Blob - require.NoError(t, LoadMalfeasanceBlob(ctx, db, nid1.Bytes(), &blob1)) + require.NoError(t, identities.LoadMalfeasanceBlob(ctx, db, nid1.Bytes(), &blob1)) require.Equal(t, proof1, blob1.Bytes) - blobSizes, err := GetBlobSizes(db, [][]byte{nid1.Bytes()}) + blobSizes, err := identities.GetBlobSizes(db, [][]byte{nid1.Bytes()}) require.NoError(t, err) require.Equal(t, []int{len(blob1.Bytes)}, blobSizes) nid2 := types.RandomNodeID() proof2 := types.RandomBytes(12) - SetMalicious(db, nid2, proof2, time.Now().Local()) + identities.SetMalicious(db, nid2, proof2, time.Now().Local()) var blob2 sql.Blob - require.NoError(t, LoadMalfeasanceBlob(ctx, db, nid2.Bytes(), &blob2)) + require.NoError(t, identities.LoadMalfeasanceBlob(ctx, db, nid2.Bytes(), &blob2)) require.Equal(t, proof2, blob2.Bytes) - blobSizes, err = GetBlobSizes(db, [][]byte{ + blobSizes, err = identities.GetBlobSizes(db, [][]byte{ nid1.Bytes(), nid2.Bytes(), }) @@ -109,9 +107,9 @@ func TestLoadMalfeasanceBlob(t *testing.T) { require.Equal(t, []int{len(blob1.Bytes), len(blob2.Bytes)}, blobSizes) noSuchID := types.RandomATXID() - require.ErrorIs(t, LoadMalfeasanceBlob(ctx, db, noSuchID[:], &sql.Blob{}), sql.ErrNotFound) + require.ErrorIs(t, identities.LoadMalfeasanceBlob(ctx, db, noSuchID[:], &sql.Blob{}), sql.ErrNotFound) - blobSizes, err = GetBlobSizes(db, [][]byte{ + blobSizes, err = identities.GetBlobSizes(db, [][]byte{ nid1.Bytes(), noSuchID.Bytes(), nid2.Bytes(), @@ -127,7 +125,7 @@ func TestMarriageATX(t *testing.T) { db := statesql.InMemory() id := types.RandomNodeID() - _, err := MarriageATX(db, id) + _, err := identities.MarriageATX(db, id) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("married", func(t *testing.T) { @@ -135,14 +133,14 @@ func TestMarriageATX(t *testing.T) { db := statesql.InMemory() id := types.RandomNodeID() - marriage := MarriageData{ + marriage := identities.MarriageData{ ATX: types.RandomATXID(), Signature: types.RandomEdSignature(), Index: 2, Target: types.RandomNodeID(), } - require.NoError(t, SetMarriage(db, id, &marriage)) - got, err := MarriageATX(db, id) + require.NoError(t, identities.SetMarriage(db, id, &marriage)) + got, err := identities.MarriageATX(db, id) require.NoError(t, err) require.Equal(t, marriage.ATX, got) }) @@ -154,14 +152,14 @@ func TestMarriage(t *testing.T) { db := statesql.InMemory() id := types.RandomNodeID() - marriage := MarriageData{ + marriage := identities.MarriageData{ ATX: types.RandomATXID(), Signature: types.RandomEdSignature(), Index: 2, Target: types.RandomNodeID(), } - require.NoError(t, SetMarriage(db, id, &marriage)) - got, err := Marriage(db, id) + require.NoError(t, identities.SetMarriage(db, id, &marriage)) + got, err := identities.Marriage(db, id) require.NoError(t, err) require.Equal(t, marriage, *got) } @@ -179,7 +177,7 @@ func TestEquivocationSet(t *testing.T) { types.RandomNodeID(), } for i, id := range ids { - err := SetMarriage(db, id, &MarriageData{ + err := identities.SetMarriage(db, id, &identities.MarriageData{ ATX: atx, Index: i, }) @@ -187,10 +185,10 @@ func TestEquivocationSet(t *testing.T) { } for _, id := range ids { - mAtx, err := MarriageATX(db, id) + mAtx, err := identities.MarriageATX(db, id) require.NoError(t, err) require.Equal(t, atx, mAtx) - set, err := EquivocationSet(db, id) + set, err := identities.EquivocationSet(db, id) require.NoError(t, err) require.ElementsMatch(t, ids, set) } @@ -199,7 +197,7 @@ func TestEquivocationSet(t *testing.T) { t.Parallel() db := statesql.InMemory() id := types.RandomNodeID() - set, err := EquivocationSet(db, id) + set, err := identities.EquivocationSet(db, id) require.NoError(t, err) require.Equal(t, []types.NodeID{id}, set) }) @@ -212,7 +210,7 @@ func TestEquivocationSet(t *testing.T) { types.RandomNodeID(), } for i, id := range ids { - err := SetMarriage(db, id, &MarriageData{ + err := identities.SetMarriage(db, id, &identities.MarriageData{ ATX: atx, Index: i, }) @@ -220,19 +218,19 @@ func TestEquivocationSet(t *testing.T) { } for _, id := range ids { - set, err := EquivocationSet(db, id) + set, err := identities.EquivocationSet(db, id) require.NoError(t, err) require.ElementsMatch(t, ids, set) } // try to marry via another random ATX // the set should remain intact - err := SetMarriage(db, ids[0], &MarriageData{ + err := identities.SetMarriage(db, ids[0], &identities.MarriageData{ ATX: types.RandomATXID(), }) require.NoError(t, err) for _, id := range ids { - set, err := EquivocationSet(db, id) + set, err := identities.EquivocationSet(db, id) require.NoError(t, err) require.ElementsMatch(t, ids, set) } @@ -241,17 +239,18 @@ func TestEquivocationSet(t *testing.T) { db := statesql.InMemory() atx := types.RandomATXID() id := types.RandomNodeID() - require.NoError(t, SetMarriage(db, id, &MarriageData{ATX: atx})) + require.NoError(t, identities.SetMarriage(db, id, &identities.MarriageData{ATX: atx})) - malicious, err := IsMalicious(db, id) + malicious, err := identities.IsMalicious(db, id) require.NoError(t, err) require.False(t, malicious) - proof, err := GetMalfeasanceProof(db, id) + var blob sql.Blob + err = identities.LoadMalfeasanceBlob(context.Background(), db, id.Bytes(), &blob) require.ErrorIs(t, err, sql.ErrNotFound) - require.Nil(t, proof) + require.Nil(t, blob.Bytes) - ids, err := GetMalicious(db) + ids, err := identities.GetMalicious(db) require.NoError(t, err) require.Empty(t, ids) }) @@ -264,13 +263,13 @@ func TestEquivocationSet(t *testing.T) { types.RandomNodeID(), } for i, id := range ids { - require.NoError(t, SetMarriage(db, id, &MarriageData{ATX: atx, Index: i})) + require.NoError(t, identities.SetMarriage(db, id, &identities.MarriageData{ATX: atx, Index: i})) } - require.NoError(t, SetMalicious(db, ids[0], []byte("proof"), time.Now())) + require.NoError(t, identities.SetMalicious(db, ids[0], []byte("proof"), time.Now())) for _, id := range ids { - malicious, err := IsMalicious(db, id) + malicious, err := identities.IsMalicious(db, id) require.NoError(t, err) require.True(t, malicious) } @@ -290,15 +289,15 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { } atx := types.RandomATXID() for i, id := range ids { - require.NoError(t, SetMarriage(db, id, &MarriageData{ATX: atx, Index: i})) + require.NoError(t, identities.SetMarriage(db, id, &identities.MarriageData{ATX: atx, Index: i})) } - set, err := EquivocationSetByMarriageATX(db, atx) + set, err := identities.EquivocationSetByMarriageATX(db, atx) require.NoError(t, err) require.Equal(t, ids, set) }) t.Run("empty set", func(t *testing.T) { db := statesql.InMemory() - set, err := EquivocationSetByMarriageATX(db, types.RandomATXID()) + set, err := identities.EquivocationSetByMarriageATX(db, types.RandomATXID()) require.NoError(t, err) require.Empty(t, set) })