Skip to content

Commit

Permalink
Remove dependency on activation/wire from sql/identities
Browse files Browse the repository at this point in the history
  • Loading branch information
fasmat committed Sep 13, 2024
1 parent a28991f commit 7a38fc7
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 146 deletions.
16 changes: 9 additions & 7 deletions activation/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ func testHandler_PostMalfeasanceProofs(t *testing.T, synced bool) {
sig, err := signing.NewEdSigner()
require.NoError(t, err)

_, err = identities.GetMalfeasanceProof(atxHdlr.cdb, sig.NodeID())
var blob sql.Blob
err = identities.LoadMalfeasanceBlob(context.Background(), atxHdlr.cdb, sig.NodeID().Bytes(), &blob)
require.ErrorIs(t, err, sql.ErrNotFound)

atx := newInitialATXv1(t, goldenATXID)
Expand Down Expand Up @@ -280,10 +281,10 @@ func testHandler_PostMalfeasanceProofs(t *testing.T, synced bool) {
require.ErrorIs(t, atxHdlr.HandleGossipAtx(context.Background(), p2p.NoPeer, msg), errMaliciousATX)
}

proof, err := identities.GetMalfeasanceProof(atxHdlr.cdb, atx.SmesherID)
err = identities.LoadMalfeasanceBlob(context.Background(), atxHdlr.cdb, atx.SmesherID.Bytes(), &blob)
require.NoError(t, err)
require.NotNil(t, proof.Received())
proof.SetReceived(time.Time{})
proof := &mwire.MalfeasanceProof{}
codec.MustDecode(blob.Bytes, proof)
if !synced {
require.Equal(t, got.MalfeasanceProof, *proof)
require.Equal(t, atx.PublishEpoch.FirstLayer(), got.MalfeasanceProof.Layer)
Expand Down Expand Up @@ -435,11 +436,12 @@ func testHandler_HandleDoublePublish(t *testing.T, synced bool) {
require.ErrorIs(t, hdlr.HandleGossipAtx(context.Background(), p2p.NoPeer, msg), errMaliciousATX)
}

proof, err := identities.GetMalfeasanceProof(hdlr.cdb, sig.NodeID())
var blob sql.Blob
err = identities.LoadMalfeasanceBlob(context.Background(), hdlr.cdb, sig.NodeID().Bytes(), &blob)
require.NoError(t, err)
require.NotNil(t, proof)
proof := &mwire.MalfeasanceProof{}
codec.MustDecode(blob.Bytes, proof)
if !synced {
proof.SetReceived(time.Time{})
require.Equal(t, got.MalfeasanceProof, *proof)
}
}
Expand Down
2 changes: 1 addition & 1 deletion api/grpcserver/activation_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (s *activationService) Get(ctx context.Context, request *pb.GetRequest) (*p
return nil, status.Error(codes.Internal, "couldn't get previous ATXs")
}

proof, err := s.atxProvider.GetMalfeasanceProof(atx.SmesherID)
proof, err := s.atxProvider.MalfeasanceProof(atx.SmesherID)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
ctxzap.Error(ctx, "failed to get malfeasance proof",
zap.Stringer("smesher", atx.SmesherID),
Expand Down
4 changes: 2 additions & 2 deletions api/grpcserver/activation_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestGet_HappyPath(t *testing.T) {
}
atx.SetID(id)
atxProvider.EXPECT().GetAtx(id).Return(&atx, nil)
atxProvider.EXPECT().GetMalfeasanceProof(gomock.Any()).Return(nil, sql.ErrNotFound)
atxProvider.EXPECT().MalfeasanceProof(gomock.Any()).Return(nil, sql.ErrNotFound)
atxProvider.EXPECT().Previous(id).Return(previous, nil)

response, err := activationService.Get(context.Background(), &pb.GetRequest{Id: id.Bytes()})
Expand Down Expand Up @@ -169,7 +169,7 @@ func TestGet_IdentityCanceled(t *testing.T) {
}
atx.SetID(id)
atxProvider.EXPECT().GetAtx(id).Return(&atx, nil)
atxProvider.EXPECT().GetMalfeasanceProof(smesher).Return(proof, nil)
atxProvider.EXPECT().MalfeasanceProof(smesher).Return(proof, nil)
atxProvider.EXPECT().Previous(id).Return([]types.ATXID{previous}, nil)

response, err := activationService.Get(context.Background(), &pb.GetRequest{Id: id.Bytes()})
Expand Down
2 changes: 1 addition & 1 deletion api/grpcserver/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type atxProvider interface {
GetAtx(id types.ATXID) (*types.ActivationTx, error)
Previous(id types.ATXID) ([]types.ATXID, error)
MaxHeightAtx() (types.ATXID, error)
GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error)
MalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error)
}

type postState interface {
Expand Down
2 changes: 1 addition & 1 deletion api/grpcserver/mesh_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ func (s *MeshService) MalfeasanceQuery(
fmt.Sprintf("invalid smesher id length (%d), expected (%d)", l, types.NodeIDSize))
}
id := types.BytesToNodeID(parsed)
proof, err := s.cdb.GetMalfeasanceProof(id)
proof, err := s.cdb.MalfeasanceProof(id)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return nil, status.Error(codes.Internal, err.Error())
}
Expand Down
24 changes: 12 additions & 12 deletions api/grpcserver/mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 7 additions & 3 deletions datastore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"go.uber.org/zap"

"github.com/spacemeshos/go-spacemesh/atxsdata"
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/malfeasance/wire"
"github.com/spacemeshos/go-spacemesh/proposals/store"
Expand Down Expand Up @@ -119,7 +120,7 @@ func (db *CachedDB) MalfeasanceCacheSize() int {
}

// GetMalfeasanceProof gets the malfeasance proof associated with the NodeID.
func (db *CachedDB) GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) {
func (db *CachedDB) MalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) {
if id == types.EmptyNodeID {
panic("invalid argument to GetMalfeasanceProof")
}
Expand All @@ -133,10 +134,13 @@ func (db *CachedDB) GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof
return proof, nil
}

proof, err := identities.GetMalfeasanceProof(db.Database, id)
var blob sql.Blob
err := identities.LoadMalfeasanceBlob(context.Background(), db.Database, id.Bytes(), &blob)
if err != nil && err != sql.ErrNotFound {
return nil, err
}
proof := &wire.MalfeasanceProof{}
codec.MustDecode(blob.Bytes, proof)
db.malfeasanceCache.Add(id, proof)
return proof, err
}
Expand Down Expand Up @@ -203,7 +207,7 @@ func (db *CachedDB) IterateMalfeasanceProofs(
return err
}
for _, id := range ids {
proof, err := db.GetMalfeasanceProof(id)
proof, err := db.MalfeasanceProof(id)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion datastore/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestMalfeasanceProof_Dishonest(t *testing.T) {
cdb.CacheMalfeasanceProof(nodeID1, proof)
require.Equal(t, 1, cdb.MalfeasanceCacheSize())

got, err := cdb.GetMalfeasanceProof(nodeID1)
got, err := cdb.MalfeasanceProof(nodeID1)
require.NoError(t, err)
require.EqualValues(t, proof, got)
}
Expand Down
46 changes: 20 additions & 26 deletions malfeasance/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) {
err := h.HandleMalfeasanceProof(context.Background(), "peer", codec.MustEncode(gossip))
require.NoError(t, err)

malProof, err := identities.GetMalfeasanceProof(h.db, nodeID)
require.NoError(t, err)
require.NotEqual(t, gossip.MalfeasanceProof, *malProof)
var blob sql.Blob
require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob))
require.Equal(t, codec.MustEncode(&gossip.MalfeasanceProof), blob.Bytes)
})

t.Run("new proof is noop", func(t *testing.T) {
Expand Down Expand Up @@ -187,9 +187,10 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) {
err := h.HandleMalfeasanceProof(context.Background(), "peer", codec.MustEncode(gossip))
require.ErrorIs(t, ErrKnownProof, err)

malProof, err := identities.GetMalfeasanceProof(h.db, nodeID)
require.NoError(t, err)
malProof.SetReceived(time.Time{})
var blob sql.Blob
require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob))
malProof := &wire.MalfeasanceProof{}
codec.MustDecode(blob.Bytes, malProof)
require.Equal(t, proof, malProof)
})
}
Expand Down Expand Up @@ -318,19 +319,15 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) {
Data: &wire.AtxProof{},
},
}
proofBytes := codec.MustEncode(proof)

h.mockTrt.EXPECT().OnMalfeasance(nodeID)
err := h.HandleSyncedMalfeasanceProof(
context.Background(),
types.Hash32(nodeID),
"peer",
codec.MustEncode(proof),
)
err := h.HandleSyncedMalfeasanceProof(context.Background(), types.Hash32(nodeID), "peer", proofBytes)
require.NoError(t, err)

malProof, err := identities.GetMalfeasanceProof(h.db, nodeID)
require.NoError(t, err)
require.NotEqual(t, proof, *malProof)
var blob sql.Blob
require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob))
require.Equal(t, proofBytes, blob.Bytes)
})

t.Run("new proof is noop", func(t *testing.T) {
Expand All @@ -344,7 +341,8 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) {
Data: &wire.BallotProof{},
},
}
identities.SetMalicious(h.db, nodeID, codec.MustEncode(proof), time.Now())
proofBytes := codec.MustEncode(proof)
identities.SetMalicious(h.db, nodeID, proofBytes, time.Now())

ctrl := gomock.NewController(t)
handler := NewMockHandlerV1(ctrl)
Expand All @@ -363,18 +361,14 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) {
Data: &wire.AtxProof{},
},
}
newProofBytes := codec.MustEncode(newProof)
require.NotEqual(t, proofBytes, newProofBytes)

err := h.HandleSyncedMalfeasanceProof(
context.Background(),
types.Hash32(nodeID),
"peer",
codec.MustEncode(newProof),
)
err := h.HandleSyncedMalfeasanceProof(context.Background(), types.Hash32(nodeID), "peer", newProofBytes)
require.ErrorIs(t, ErrKnownProof, err)

malProof, err := identities.GetMalfeasanceProof(h.db, nodeID)
require.NoError(t, err)
malProof.SetReceived(time.Time{})
require.Equal(t, proof, malProof)
var blob sql.Blob
require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob))
require.Equal(t, proofBytes, blob.Bytes)
})
}
6 changes: 4 additions & 2 deletions mesh/mesh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,11 @@ func TestMesh_MaliciousBallots(t *testing.T) {
mal, err := identities.IsMalicious(tm.cdb, sig.NodeID())
require.NoError(t, err)
require.False(t, mal)
saved, err := identities.GetMalfeasanceProof(tm.cdb, sig.NodeID())

var blob sql.Blob
err = identities.LoadMalfeasanceBlob(context.Background(), tm.cdb, sig.NodeID().Bytes(), &blob)
require.ErrorIs(t, err, sql.ErrNotFound)
require.Nil(t, saved)
require.Nil(t, blob.Bytes)

// second one will create a MalfeasanceProof
tm.mockTortoise.EXPECT().OnMalfeasance(sig.NodeID())
Expand Down
53 changes: 13 additions & 40 deletions sql/identities/identities.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (

sqlite "github.com/go-llsqlite/crawshaw"

"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/malfeasance/wire"
"github.com/spacemeshos/go-spacemesh/sql"
)

Expand Down Expand Up @@ -49,44 +47,19 @@ func IsMalicious(db sql.Executor, nodeID types.NodeID) (bool, error) {
return rows > 0, nil
}

// GetMalfeasanceProof returns the malfeasance proof for the given identity.
func GetMalfeasanceProof(db sql.Executor, nodeID types.NodeID) (*wire.MalfeasanceProof, error) {
var (
data []byte
received time.Time
)
rows, err := db.Exec("select proof, received from identities where pubkey = ?1 AND proof IS NOT NULL;",
func(stmt *sql.Statement) {
stmt.BindBytes(1, nodeID.Bytes())
}, func(stmt *sql.Statement) bool {
data = make([]byte, stmt.ColumnLen(0))
stmt.ColumnBytes(0, data[:])
received = time.Unix(0, stmt.ColumnInt64(1)).Local()
return true
})
if err != nil {
return nil, fmt.Errorf("proof %v: %w", nodeID, err)
}
if rows == 0 {
return nil, sql.ErrNotFound
}
var proof wire.MalfeasanceProof
if err = codec.Decode(data, &proof); err != nil {
return nil, err
}
proof.SetReceived(received.Local())
return &proof, nil
}

// GetBlobSizes returns the sizes of the blobs corresponding to malfeasance proofs for the
// specified identities. For non-existent proofs, the corresponding items are set to -1.
func GetBlobSizes(db sql.Executor, ids [][]byte) (sizes []int, err error) {
return sql.GetBlobSizes(db, "select pubkey, length(proof) from identities where pubkey in", ids)
}

// LoadMalfeasanceBlob returns the malfeasance proof in raw bytes for the given identity.
func LoadMalfeasanceBlob(ctx context.Context, db sql.Executor, nodeID []byte, blob *sql.Blob) error {
return sql.LoadBlob(db, "select proof from identities where pubkey = ?1;", nodeID, blob)
func LoadMalfeasanceBlob(_ context.Context, db sql.Executor, nodeID []byte, blob *sql.Blob) error {
err := sql.LoadBlob(db, "select proof from identities where pubkey = ?1;", nodeID, blob)
if err == nil && len(blob.Bytes) == 0 {
return sql.ErrNotFound
}
return err
}

// IterateMalicious invokes the specified callback for each malicious node ID.
Expand Down Expand Up @@ -117,20 +90,20 @@ func IterateMalicious(
}

// GetMalicious retrieves malicious node IDs from the database.
func GetMalicious(db sql.Executor) (nids []types.NodeID, err error) {
func GetMalicious(db sql.Executor) (ids []types.NodeID, err error) {
if err = IterateMalicious(db, func(total int, nid types.NodeID) error {
if nids == nil {
nids = make([]types.NodeID, 0, total)
if ids == nil {
ids = make([]types.NodeID, 0, total)
}
nids = append(nids, nid)
ids = append(ids, nid)
return nil
}); err != nil {
return nil, err
}
if len(nids) != cap(nids) {
if len(ids) != cap(ids) {
panic("BUG: bad malicious node ID count")
}
return nids, nil
return ids, nil
}

// MarriageATX obtains the marriage ATX for given ID.
Expand Down Expand Up @@ -186,7 +159,7 @@ func Marriage(db sql.Executor, id types.NodeID) (*MarriageData, error) {
}

// Set marriage inserts marriage data for given identity.
// If identitty doesn't exist - create it.
// If identity doesn't exist - create it.
func SetMarriage(db sql.Executor, id types.NodeID, m *MarriageData) error {
_, err := db.Exec(`
INSERT INTO identities (pubkey, marriage_atx, marriage_idx, marriage_target, marriage_signature)
Expand Down
Loading

0 comments on commit 7a38fc7

Please sign in to comment.