From 8c35ca2abdd763dd0a6f8803fd1e153882f84367 Mon Sep 17 00:00:00 2001 From: Jim Zhang Date: Tue, 17 Sep 2024 13:28:25 -0400 Subject: [PATCH] Make merkle tree leaf addition operations atomic Signed-off-by: Jim Zhang --- go-sdk/go.mod | 4 +- go-sdk/integration-test/db_test.go | 72 ++-- go-sdk/integration-test/e2e_test.go | 326 ++++++++---------- go-sdk/integration-test/smt_test.go | 96 ++++++ go-sdk/internal/log/log.go | 91 +++++ .../sparse-merkle-tree/smt/merkletree.go | 61 +++- .../sparse-merkle-tree/smt/merkletree_test.go | 9 + .../sparse-merkle-tree/smt/smt_test.go | 220 ++++++------ .../sparse-merkle-tree/storage/memory.go | 59 ---- .../sparse-merkle-tree/storage/memory_test.go | 79 ----- .../sparse-merkle-tree/storage/sql.go | 66 +++- go-sdk/pkg/sparse-merkle-tree/core/storage.go | 11 + .../pkg/sparse-merkle-tree/storage/storage.go | 4 - 13 files changed, 625 insertions(+), 473 deletions(-) create mode 100644 go-sdk/integration-test/smt_test.go create mode 100644 go-sdk/internal/log/log.go delete mode 100644 go-sdk/internal/sparse-merkle-tree/storage/memory.go delete mode 100644 go-sdk/internal/sparse-merkle-tree/storage/memory_test.go diff --git a/go-sdk/go.mod b/go-sdk/go.mod index e1a90ef..cd1ad81 100644 --- a/go-sdk/go.mod +++ b/go-sdk/go.mod @@ -4,7 +4,9 @@ go 1.22.0 require ( github.com/iden3/go-rapidsnark/witness/wasmer v0.0.0-20230524142950-0986cf057d4e + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 + github.com/x-cray/logrus-prefixed-formatter v0.5.2 ) require ( @@ -29,8 +31,6 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 // indirect - github.com/sirupsen/logrus v1.9.3 // indirect - github.com/x-cray/logrus-prefixed-formatter v0.5.2 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/term v0.24.0 // indirect golang.org/x/text v0.18.0 // indirect diff --git a/go-sdk/integration-test/db_test.go b/go-sdk/integration-test/db_test.go index 93b29b9..72e4ac1 100644 --- a/go-sdk/integration-test/db_test.go +++ b/go-sdk/integration-test/db_test.go @@ -17,7 +17,9 @@ package integration_test import ( + "fmt" "math/big" + "math/rand" "os" "testing" @@ -27,65 +29,89 @@ import ( "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/node" "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/smt" "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/storage" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" ) +type SqliteTestSuite struct { + suite.Suite + db core.Storage + dbfile *os.File + gormDB *gorm.DB + smtName string +} + type testSqlProvider struct { db *gorm.DB } -func (s *testSqlProvider) DB() *gorm.DB { - return s.db +func (p *testSqlProvider) DB() *gorm.DB { + return p.db } -func (s *testSqlProvider) Close() {} +func (p *testSqlProvider) Close() {} -func TestSqliteStorage(t *testing.T) { - dbfile, err := os.CreateTemp("", "gorm.db") +func newSqliteStorage(t *testing.T) (*os.File, core.Storage, *gorm.DB, string) { + seq := rand.Intn(1000) + testName := fmt.Sprintf("test_%d", seq) + dbfile, err := os.CreateTemp("", fmt.Sprintf("gorm-%s.db", testName)) assert.NoError(t, err) - defer func() { - err := os.Remove(dbfile.Name()) - assert.NoError(t, err) - }() db, err := gorm.Open(sqlite.Open(dbfile.Name()), &gorm.Config{}) assert.NoError(t, err) err = db.Table(core.TreeRootsTable).AutoMigrate(&core.SMTRoot{}) assert.NoError(t, err) - err = db.Table(core.NodesTablePrefix + "test_1").AutoMigrate(&core.SMTNode{}) + err = db.Table(core.NodesTablePrefix + testName).AutoMigrate(&core.SMTNode{}) assert.NoError(t, err) provider := &testSqlProvider{db: db} - s, err := storage.NewSqlStorage(provider, "test_1") + sqlStorage, err := storage.NewSqlStorage(provider, testName) assert.NoError(t, err) + return dbfile, sqlStorage, db, testName +} - mt, err := smt.NewMerkleTree(s, MAX_HEIGHT) - assert.NoError(t, err) +func (s *SqliteTestSuite) SetupTest() { + logrus.SetLevel(logrus.DebugLevel) + s.dbfile, s.db, s.gormDB, s.smtName = newSqliteStorage(s.T()) +} + +func (s *SqliteTestSuite) TearDownTest() { + os.Remove(s.dbfile.Name()) +} + +func (s *SqliteTestSuite) TestSqliteStorage() { + mt, err := smt.NewMerkleTree(s.db, MAX_HEIGHT) + assert.NoError(s.T(), err) tokenId := big.NewInt(1001) uriString := "https://example.com/token/1001" - assert.NoError(t, err) + assert.NoError(s.T(), err) sender := testutils.NewKeypair() salt1 := crypto.NewSalt() utxo1 := node.NewNonFungible(tokenId, uriString, sender.PublicKey, salt1) n1, err := node.NewLeafNode(utxo1) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n1) - assert.NoError(t, err) + assert.NoError(s.T(), err) root := mt.Root() - dbRoot := core.SMTRoot{Name: "test_1"} - err = db.Table(core.TreeRootsTable).First(&dbRoot).Error - assert.NoError(t, err) - assert.Equal(t, root.Hex(), dbRoot.RootIndex) + dbRoot := core.SMTRoot{Name: s.smtName} + err = s.gormDB.Table(core.TreeRootsTable).First(&dbRoot).Error + assert.NoError(s.T(), err) + assert.Equal(s.T(), root.Hex(), dbRoot.RootIndex) dbNode := core.SMTNode{RefKey: n1.Ref().Hex()} - err = db.Table(core.NodesTablePrefix + "test_1").First(&dbNode).Error - assert.NoError(t, err) - assert.Equal(t, n1.Ref().Hex(), dbNode.RefKey) + err = s.gormDB.Table(core.NodesTablePrefix + s.smtName).First(&dbNode).Error + assert.NoError(s.T(), err) + assert.Equal(s.T(), n1.Ref().Hex(), dbNode.RefKey) +} + +func TestSqliteStorage(t *testing.T) { + suite.Run(t, new(SqliteTestSuite)) } func TestPostgresStorage(t *testing.T) { diff --git a/go-sdk/integration-test/e2e_test.go b/go-sdk/integration-test/e2e_test.go index bfd25a9..f8fb626 100644 --- a/go-sdk/integration-test/e2e_test.go +++ b/go-sdk/integration-test/e2e_test.go @@ -19,7 +19,6 @@ package integration_test import ( "fmt" "math/big" - "math/rand" "os" "path" "testing" @@ -29,18 +28,20 @@ import ( "github.com/hyperledger-labs/zeto/go-sdk/pkg/crypto" keyscore "github.com/hyperledger-labs/zeto/go-sdk/pkg/key-manager/core" "github.com/hyperledger-labs/zeto/go-sdk/pkg/key-manager/key" + "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/core" "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/node" "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/smt" - "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/storage" "github.com/hyperledger-labs/zeto/go-sdk/pkg/utxo" "github.com/hyperledger/firefly-signer/pkg/keystorev3" "github.com/hyperledger/firefly-signer/pkg/secp256k1" - "github.com/iden3/go-iden3-crypto/babyjub" "github.com/iden3/go-iden3-crypto/poseidon" "github.com/iden3/go-rapidsnark/prover" "github.com/iden3/go-rapidsnark/witness/v2" "github.com/iden3/go-rapidsnark/witness/wasmer" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" ) const MAX_HEIGHT = 64 @@ -114,12 +115,28 @@ func testKeyFromKeyStorev3(t *testing.T) *keyscore.KeyEntry { return key.NewKeyEntryFromPrivateKeyBytes([32]byte(keypair.PrivateKeyBytes())) } -func TestZeto_1_SuccessfulProving(t *testing.T) { +type E2ETestSuite struct { + suite.Suite + db core.Storage + dbfile *os.File + gormDB *gorm.DB +} + +func (s *E2ETestSuite) SetupSuite() { + logrus.SetLevel(logrus.DebugLevel) + s.dbfile, s.db, s.gormDB, _ = newSqliteStorage(s.T()) +} + +func (s *E2ETestSuite) TearDownSuite() { + os.Remove(s.dbfile.Name()) +} + +func (s *E2ETestSuite) TestZeto_1_SuccessfulProving() { calc, provingKey, err := loadCircuit("anon") - assert.NoError(t, err) - assert.NotNil(t, calc) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), calc) - sender := testKeyFromKeyStorev3(t) + sender := testKeyFromKeyStorev3(s.T()) receiver := testutils.NewKeypair() inputValues := []*big.Int{big.NewInt(30), big.NewInt(40)} @@ -150,37 +167,37 @@ func TestZeto_1_SuccessfulProving(t *testing.T) { // calculate the witness object for checking correctness witness, err := calc.CalculateWitness(witnessInputs, true) - assert.NoError(t, err) - assert.NotNil(t, witness) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), witness) - assert.Equal(t, 0, witness[0].Cmp(big.NewInt(1))) - assert.Equal(t, 0, witness[1].Cmp(inputCommitments[0])) - assert.Equal(t, 0, witness[2].Cmp(inputCommitments[1])) - assert.Equal(t, 0, witness[3].Cmp(outputCommitments[0])) - assert.Equal(t, 0, witness[4].Cmp(outputCommitments[1])) + assert.Equal(s.T(), 0, witness[0].Cmp(big.NewInt(1))) + assert.Equal(s.T(), 0, witness[1].Cmp(inputCommitments[0])) + assert.Equal(s.T(), 0, witness[2].Cmp(inputCommitments[1])) + assert.Equal(s.T(), 0, witness[3].Cmp(outputCommitments[0])) + assert.Equal(s.T(), 0, witness[4].Cmp(outputCommitments[1])) // generate the witness binary to feed into the prover startTime := time.Now() witnessBin, err := calc.CalculateWTNSBin(witnessInputs, true) - assert.NoError(t, err) - assert.NotNil(t, witnessBin) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), witnessBin) proof, err := prover.Groth16Prover(provingKey, witnessBin) elapsedTime := time.Since(startTime) fmt.Printf("Proving time: %s\n", elapsedTime) - assert.NoError(t, err) - assert.Equal(t, 3, len(proof.Proof.A)) - assert.Equal(t, 3, len(proof.Proof.B)) - assert.Equal(t, 3, len(proof.Proof.C)) - assert.Equal(t, 4, len(proof.PubSignals)) + assert.NoError(s.T(), err) + assert.Equal(s.T(), 3, len(proof.Proof.A)) + assert.Equal(s.T(), 3, len(proof.Proof.B)) + assert.Equal(s.T(), 3, len(proof.Proof.C)) + assert.Equal(s.T(), 4, len(proof.PubSignals)) } -func TestZeto_2_SuccessfulProving(t *testing.T) { +func (s *E2ETestSuite) TestZeto_2_SuccessfulProving() { calc, provingKey, err := loadCircuit("anon_enc") - assert.NoError(t, err) - assert.NotNil(t, calc) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), calc) - sender := testKeyFromKeyStorev3(t) + sender := testKeyFromKeyStorev3(s.T()) receiver := testutils.NewKeypair() inputValues := []*big.Int{big.NewInt(30), big.NewInt(40)} @@ -214,24 +231,24 @@ func TestZeto_2_SuccessfulProving(t *testing.T) { startTime := time.Now() witnessBin, err := calc.CalculateWTNSBin(witnessInputs, true) - assert.NoError(t, err) - assert.NotNil(t, witnessBin) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), witnessBin) proof, err := prover.Groth16Prover(provingKey, witnessBin) elapsedTime := time.Since(startTime) fmt.Printf("Proving time: %s\n", elapsedTime) - assert.NoError(t, err) - assert.Equal(t, 3, len(proof.Proof.A)) - assert.Equal(t, 3, len(proof.Proof.B)) - assert.Equal(t, 3, len(proof.Proof.C)) - assert.Equal(t, 9, len(proof.PubSignals)) + assert.NoError(s.T(), err) + assert.Equal(s.T(), 3, len(proof.Proof.A)) + assert.Equal(s.T(), 3, len(proof.Proof.B)) + assert.Equal(s.T(), 3, len(proof.Proof.C)) + assert.Equal(s.T(), 9, len(proof.PubSignals)) // the receiver would be able to get the encrypted values and salts // from the transaction events encryptedValues := make([]*big.Int, 4) for i := 0; i < 4; i++ { v, ok := new(big.Int).SetString(proof.PubSignals[i], 10) - assert.True(t, ok) + assert.True(s.T(), ok) encryptedValues[i] = v } @@ -240,21 +257,21 @@ func TestZeto_2_SuccessfulProving(t *testing.T) { // the UTXO hash secret := crypto.GenerateECDHSharedSecret(receiver.PrivateKey, sender.PublicKey) decrypted, err := crypto.PoseidonDecrypt(encryptedValues, []*big.Int{secret.X, secret.Y}, encryptionNonce, 2) - assert.NoError(t, err) - assert.Equal(t, outputValues[0].String(), decrypted[0].String()) - assert.Equal(t, salt3.String(), decrypted[1].String()) + assert.NoError(s.T(), err) + assert.Equal(s.T(), outputValues[0].String(), decrypted[0].String()) + assert.Equal(s.T(), salt3.String(), decrypted[1].String()) // as the receiver, to check if the decryption was successful, we hash the decrypted // value and salt and compare with the output commitment calculatedHash, err := poseidon.Hash([]*big.Int{decrypted[0], decrypted[1], receiver.PublicKey.X, receiver.PublicKey.Y}) - assert.NoError(t, err) - assert.Equal(t, output1.String(), calculatedHash.String()) + assert.NoError(s.T(), err) + assert.Equal(s.T(), output1.String(), calculatedHash.String()) } -func TestZeto_3_SuccessfulProving(t *testing.T) { +func (s *E2ETestSuite) TestZeto_3_SuccessfulProving() { calc, provingKey, err := loadCircuit("anon_nullifier") - assert.NoError(t, err) - assert.NotNil(t, calc) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), calc) sender := testutils.NewKeypair() receiver := testutils.NewKeypair() @@ -272,26 +289,26 @@ func TestZeto_3_SuccessfulProving(t *testing.T) { nullifier2, _ := poseidon.Hash([]*big.Int{inputValues[1], salt2, sender.PrivateKeyBigInt}) nullifiers := []*big.Int{nullifier1, nullifier2} - mt, err := smt.NewMerkleTree(storage.NewMemoryStorage(), MAX_HEIGHT) - assert.NoError(t, err) + mt, err := smt.NewMerkleTree(s.db, MAX_HEIGHT) + assert.NoError(s.T(), err) utxo1 := node.NewFungible(inputValues[0], sender.PublicKey, salt1) n1, err := node.NewLeafNode(utxo1) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n1) - assert.NoError(t, err) + assert.NoError(s.T(), err) utxo2 := node.NewFungible(inputValues[1], sender.PublicKey, salt2) n2, err := node.NewLeafNode(utxo2) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n2) - assert.NoError(t, err) + assert.NoError(s.T(), err) proof1, _, err := mt.GenerateProof(input1, nil) - assert.NoError(t, err) + assert.NoError(s.T(), err) circomProof1, err := proof1.ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT) - assert.NoError(t, err) + assert.NoError(s.T(), err) proof2, _, err := mt.GenerateProof(input2, nil) - assert.NoError(t, err) + assert.NoError(s.T(), err) circomProof2, err := proof2.ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT) - assert.NoError(t, err) + assert.NoError(s.T(), err) salt3 := crypto.NewSalt() output1, _ := poseidon.Hash([]*big.Int{outputValues[0], salt3, receiver.PublicKey.X, receiver.PublicKey.Y}) @@ -324,23 +341,23 @@ func TestZeto_3_SuccessfulProving(t *testing.T) { startTime := time.Now() witnessBin, err := calc.CalculateWTNSBin(witnessInputs, true) - assert.NoError(t, err) - assert.NotNil(t, witnessBin) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), witnessBin) proof, err := prover.Groth16Prover(provingKey, witnessBin) elapsedTime := time.Since(startTime) fmt.Printf("Proving time: %s\n", elapsedTime) - assert.NoError(t, err) - assert.Equal(t, 3, len(proof.Proof.A)) - assert.Equal(t, 3, len(proof.Proof.B)) - assert.Equal(t, 3, len(proof.Proof.C)) - assert.Equal(t, 7, len(proof.PubSignals)) + assert.NoError(s.T(), err) + assert.Equal(s.T(), 3, len(proof.Proof.A)) + assert.Equal(s.T(), 3, len(proof.Proof.B)) + assert.Equal(s.T(), 3, len(proof.Proof.C)) + assert.Equal(s.T(), 7, len(proof.PubSignals)) } -func TestZeto_4_SuccessfulProving(t *testing.T) { +func (s *E2ETestSuite) TestZeto_4_SuccessfulProving() { calc, provingKey, err := loadCircuit("anon_enc_nullifier") - assert.NoError(t, err) - assert.NotNil(t, calc) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), calc) sender := testutils.NewKeypair() receiver := testutils.NewKeypair() @@ -358,26 +375,26 @@ func TestZeto_4_SuccessfulProving(t *testing.T) { nullifier2, _ := poseidon.Hash([]*big.Int{inputValues[1], salt2, sender.PrivateKeyBigInt}) nullifiers := []*big.Int{nullifier1, nullifier2} - mt, err := smt.NewMerkleTree(storage.NewMemoryStorage(), MAX_HEIGHT) - assert.NoError(t, err) + mt, err := smt.NewMerkleTree(s.db, MAX_HEIGHT) + assert.NoError(s.T(), err) utxo1 := node.NewFungible(inputValues[0], sender.PublicKey, salt1) n1, err := node.NewLeafNode(utxo1) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n1) - assert.NoError(t, err) + assert.NoError(s.T(), err) utxo2 := node.NewFungible(inputValues[1], sender.PublicKey, salt2) n2, err := node.NewLeafNode(utxo2) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n2) - assert.NoError(t, err) + assert.NoError(s.T(), err) proof1, _, err := mt.GenerateProof(input1, nil) - assert.NoError(t, err) + assert.NoError(s.T(), err) circomProof1, err := proof1.ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT) - assert.NoError(t, err) + assert.NoError(s.T(), err) proof2, _, err := mt.GenerateProof(input2, nil) - assert.NoError(t, err) + assert.NoError(s.T(), err) circomProof2, err := proof2.ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT) - assert.NoError(t, err) + assert.NoError(s.T(), err) salt3 := crypto.NewSalt() output1, _ := poseidon.Hash([]*big.Int{outputValues[0], salt3, receiver.PublicKey.X, receiver.PublicKey.Y}) @@ -413,38 +430,38 @@ func TestZeto_4_SuccessfulProving(t *testing.T) { startTime := time.Now() witnessBin, err := calc.CalculateWTNSBin(witnessInputs, true) - assert.NoError(t, err) - assert.NotNil(t, witnessBin) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), witnessBin) proof, err := prover.Groth16Prover(provingKey, witnessBin) elapsedTime := time.Since(startTime) fmt.Printf("Proving time: %s\n", elapsedTime) - assert.NoError(t, err) - assert.Equal(t, 3, len(proof.Proof.A)) - assert.Equal(t, 3, len(proof.Proof.B)) - assert.Equal(t, 3, len(proof.Proof.C)) - assert.Equal(t, 12, len(proof.PubSignals)) + assert.NoError(s.T(), err) + assert.Equal(s.T(), 3, len(proof.Proof.A)) + assert.Equal(s.T(), 3, len(proof.Proof.B)) + assert.Equal(s.T(), 3, len(proof.Proof.C)) + assert.Equal(s.T(), 12, len(proof.PubSignals)) } -func TestZeto_5_SuccessfulProving(t *testing.T) { +func (s *E2ETestSuite) TestZeto_5_SuccessfulProving() { calc, provingKey, err := loadCircuit("nf_anon") - assert.NoError(t, err) - assert.NotNil(t, calc) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), calc) sender := testutils.NewKeypair() receiver := testutils.NewKeypair() tokenId := big.NewInt(1001) tokenUri, err := utxo.HashTokenUri("https://example.com/token/1001") - assert.NoError(t, err) + assert.NoError(s.T(), err) salt1 := crypto.NewSalt() input1, err := poseidon.Hash([]*big.Int{tokenId, tokenUri, salt1, sender.PublicKey.X, sender.PublicKey.Y}) - assert.NoError(t, err) + assert.NoError(s.T(), err) salt3 := crypto.NewSalt() output1, err := poseidon.Hash([]*big.Int{tokenId, tokenUri, salt3, receiver.PublicKey.X, receiver.PublicKey.Y}) - assert.NoError(t, err) + assert.NoError(s.T(), err) witnessInputs := map[string]interface{}{ "tokenIds": []*big.Int{tokenId}, @@ -459,35 +476,35 @@ func TestZeto_5_SuccessfulProving(t *testing.T) { // calculate the witness object for checking correctness witness, err := calc.CalculateWitness(witnessInputs, true) - assert.NoError(t, err) - assert.NotNil(t, witness) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), witness) - assert.Equal(t, 0, witness[0].Cmp(big.NewInt(1))) - assert.Equal(t, 0, witness[1].Cmp(input1)) - assert.Equal(t, 0, witness[2].Cmp(output1)) - assert.Equal(t, 0, witness[3].Cmp(tokenId)) - assert.Equal(t, 0, witness[4].Cmp(tokenUri)) + assert.Equal(s.T(), 0, witness[0].Cmp(big.NewInt(1))) + assert.Equal(s.T(), 0, witness[1].Cmp(input1)) + assert.Equal(s.T(), 0, witness[2].Cmp(output1)) + assert.Equal(s.T(), 0, witness[3].Cmp(tokenId)) + assert.Equal(s.T(), 0, witness[4].Cmp(tokenUri)) // generate the witness binary to feed into the prover startTime := time.Now() witnessBin, err := calc.CalculateWTNSBin(witnessInputs, true) - assert.NoError(t, err) - assert.NotNil(t, witnessBin) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), witnessBin) proof, err := prover.Groth16Prover(provingKey, witnessBin) elapsedTime := time.Since(startTime) fmt.Printf("Proving time: %s\n", elapsedTime) - assert.NoError(t, err) - assert.Equal(t, 3, len(proof.Proof.A)) - assert.Equal(t, 3, len(proof.Proof.B)) - assert.Equal(t, 3, len(proof.Proof.C)) - assert.Equal(t, 2, len(proof.PubSignals)) + assert.NoError(s.T(), err) + assert.Equal(s.T(), 3, len(proof.Proof.A)) + assert.Equal(s.T(), 3, len(proof.Proof.B)) + assert.Equal(s.T(), 3, len(proof.Proof.C)) + assert.Equal(s.T(), 2, len(proof.PubSignals)) } -func TestZeto_6_SuccessfulProving(t *testing.T) { +func (s *E2ETestSuite) TestZeto_6_SuccessfulProving() { calc, provingKey, err := loadCircuit("nf_anon_nullifier") - assert.NoError(t, err) - assert.NotNil(t, calc) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), calc) sender := testutils.NewKeypair() receiver := testutils.NewKeypair() @@ -495,25 +512,25 @@ func TestZeto_6_SuccessfulProving(t *testing.T) { tokenId := big.NewInt(1001) uriString := "https://example.com/token/1001" tokenUri, err := utxo.HashTokenUri(uriString) - assert.NoError(t, err) + assert.NoError(s.T(), err) salt1 := crypto.NewSalt() input1, err := poseidon.Hash([]*big.Int{tokenId, tokenUri, salt1, sender.PublicKey.X, sender.PublicKey.Y}) - assert.NoError(t, err) + assert.NoError(s.T(), err) nullifier1, _ := poseidon.Hash([]*big.Int{tokenId, tokenUri, salt1, sender.PrivateKeyBigInt}) - mt, err := smt.NewMerkleTree(storage.NewMemoryStorage(), MAX_HEIGHT) - assert.NoError(t, err) + mt, err := smt.NewMerkleTree(s.db, MAX_HEIGHT) + assert.NoError(s.T(), err) utxo1 := node.NewNonFungible(tokenId, uriString, sender.PublicKey, salt1) n1, err := node.NewLeafNode(utxo1) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n1) - assert.NoError(t, err) + assert.NoError(s.T(), err) proof1, _, err := mt.GenerateProof(input1, nil) - assert.NoError(t, err) + assert.NoError(s.T(), err) circomProof1, err := proof1.ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT) - assert.NoError(t, err) + assert.NoError(s.T(), err) proof1Siblings := make([]*big.Int, len(circomProof1.Siblings)-1) for i, s := range circomProof1.Siblings[0 : len(circomProof1.Siblings)-1] { proof1Siblings[i] = s.BigInt() @@ -521,7 +538,7 @@ func TestZeto_6_SuccessfulProving(t *testing.T) { salt3 := crypto.NewSalt() output1, err := poseidon.Hash([]*big.Int{tokenId, tokenUri, salt3, receiver.PublicKey.X, receiver.PublicKey.Y}) - assert.NoError(t, err) + assert.NoError(s.T(), err) witnessInputs := map[string]interface{}{ "tokenId": tokenId, @@ -539,88 +556,39 @@ func TestZeto_6_SuccessfulProving(t *testing.T) { // calculate the witness object for checking correctness witness, err := calc.CalculateWitness(witnessInputs, true) - assert.NoError(t, err) - assert.NotNil(t, witness) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), witness) - assert.Equal(t, 0, witness[0].Cmp(big.NewInt(1))) - assert.Equal(t, 0, witness[1].Cmp(nullifier1)) + assert.Equal(s.T(), 0, witness[0].Cmp(big.NewInt(1))) + assert.Equal(s.T(), 0, witness[1].Cmp(nullifier1)) // generate the witness binary to feed into the prover startTime := time.Now() witnessBin, err := calc.CalculateWTNSBin(witnessInputs, true) - assert.NoError(t, err) - assert.NotNil(t, witnessBin) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), witnessBin) proof, err := prover.Groth16Prover(provingKey, witnessBin) elapsedTime := time.Since(startTime) fmt.Printf("Proving time: %s\n", elapsedTime) - assert.NoError(t, err) - assert.Equal(t, 3, len(proof.Proof.A)) - assert.Equal(t, 3, len(proof.Proof.B)) - assert.Equal(t, 3, len(proof.Proof.C)) - assert.Equal(t, 3, len(proof.PubSignals)) -} - -func TestConcurrentLeafnodesInsertion(t *testing.T) { - x, _ := new(big.Int).SetString("9198063289874244593808956064764348354864043212453245695133881114917754098693", 10) - y, _ := new(big.Int).SetString("3600411115173311692823743444460566395943576560299970643507632418781961416843", 10) - alice := &babyjub.PublicKey{ - X: x, - Y: y, - } - - values := []int{10, 20, 30, 40} - salts := []string{ - "43c49e8ba68a9b8a6bb5c230a734d8271a83d2f63722e7651272ebeef5446e", - "19b965f7629e4f0c4bd0b8f9c87f17580f18a32a31b4641550071ee4916bbbfc", - "9b0b93df975547e430eabff085a77831b8fcb6b5396e6bb815fda8d14125370", - "194ec10ec96a507c7c9b60df133d13679b874b0bd6ab89920135508f55b3f064", - } - - // run the test 10 times - for i := 0; i < 100; i++ { - // shuffle the utxos for this run - for i := range values { - j := rand.Intn(i + 1) - values[i], values[j] = values[j], values[i] - salts[i], salts[j] = salts[j], salts[i] - } - - testConcurrentInsertion(t, alice, values, salts) - } + assert.NoError(s.T(), err) + assert.Equal(s.T(), 3, len(proof.Proof.A)) + assert.Equal(s.T(), 3, len(proof.Proof.B)) + assert.Equal(s.T(), 3, len(proof.Proof.C)) + assert.Equal(s.T(), 3, len(proof.PubSignals)) } -func testConcurrentInsertion(t *testing.T, alice *babyjub.PublicKey, values []int, salts []string) { - mt, err := smt.NewMerkleTree(storage.NewMemoryStorage(), MAX_HEIGHT) - assert.NoError(t, err) - done := make(chan bool, len(values)) - - for i, v := range values { - go func(i, v int) { - salt, _ := new(big.Int).SetString(salts[i], 16) - utxo := node.NewFungible(big.NewInt(int64(v)), alice, salt) - n, err := node.NewLeafNode(utxo) - assert.NoError(t, err) - err = mt.AddLeaf(n) - assert.NoError(t, err) - done <- true - }(i, v) - } +func (s *E2ETestSuite) TestKeyManager() { + keypair := decryptKeyStorev3(s.T()) - for i := 0; i < len(values); i++ { - <-done - } + keyEntry := key.NewKeyEntryFromPrivateKeyBytes([32]byte(keypair.PrivateKeyBytes())) + assert.NotNil(s.T(), keyEntry) - assert.Equal(t, "abacf46f5217552ee28fe50b8fd7ca6aa46daeb9acf9f60928654c3b1a472f23", mt.Root().Hex()) + assert.NotNil(s.T(), keyEntry.PrivateKey) + assert.NotNil(s.T(), keyEntry.PublicKey) + assert.NotNil(s.T(), keyEntry.PrivateKeyForZkp) } -func TestKeyManager(t *testing.T) { - keypair := decryptKeyStorev3(t) - - keyEntry := key.NewKeyEntryFromPrivateKeyBytes([32]byte(keypair.PrivateKeyBytes())) - assert.NotNil(t, keyEntry) - - assert.NotNil(t, keyEntry.PrivateKey) - assert.NotNil(t, keyEntry.PublicKey) - assert.NotNil(t, keyEntry.PrivateKeyForZkp) +func TestE2ETestSuite(t *testing.T) { + suite.Run(t, new(E2ETestSuite)) } diff --git a/go-sdk/integration-test/smt_test.go b/go-sdk/integration-test/smt_test.go new file mode 100644 index 0000000..800bbe2 --- /dev/null +++ b/go-sdk/integration-test/smt_test.go @@ -0,0 +1,96 @@ +// Copyright © 2024 Kaleido, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "math/big" + "math/rand" + "os" + "testing" + + "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/node" + "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/smt" + "github.com/iden3/go-iden3-crypto/babyjub" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type SmtTestSuite struct { + suite.Suite +} + +func (s *SmtTestSuite) TestConcurrentLeafnodesInsertion() { + logrus.SetLevel(logrus.DebugLevel) + x, _ := new(big.Int).SetString("9198063289874244593808956064764348354864043212453245695133881114917754098693", 10) + y, _ := new(big.Int).SetString("3600411115173311692823743444460566395943576560299970643507632418781961416843", 10) + alice := &babyjub.PublicKey{ + X: x, + Y: y, + } + + values := []int{10, 20, 30, 40} + salts := []string{ + "43c49e8ba68a9b8a6bb5c230a734d8271a83d2f63722e7651272ebeef5446e", + "19b965f7629e4f0c4bd0b8f9c87f17580f18a32a31b4641550071ee4916bbbfc", + "9b0b93df975547e430eabff085a77831b8fcb6b5396e6bb815fda8d14125370", + "194ec10ec96a507c7c9b60df133d13679b874b0bd6ab89920135508f55b3f064", + } + + // run the test 10 times + for i := 0; i < 100; i++ { + // shuffle the utxos for this run + for i := range values { + j := rand.Intn(i + 1) + values[i], values[j] = values[j], values[i] + salts[i], salts[j] = salts[j], salts[i] + } + + testConcurrentInsertion(s.T(), alice, values, salts) + } +} + +func testConcurrentInsertion(t *testing.T, alice *babyjub.PublicKey, values []int, salts []string) { + dbfile, db, _, _ := newSqliteStorage(t) + defer os.Remove(dbfile.Name()) + + mt, err := smt.NewMerkleTree(db, MAX_HEIGHT) + assert.NoError(t, err) + done := make(chan bool, len(values)) + + for i, v := range values { + go func(i, v int) { + salt, _ := new(big.Int).SetString(salts[i], 16) + utxo := node.NewFungible(big.NewInt(int64(v)), alice, salt) + n, err := node.NewLeafNode(utxo) + assert.NoError(t, err) + err = mt.AddLeaf(n) + assert.NoError(t, err) + done <- true + }(i, v) + } + + for i := 0; i < len(values); i++ { + <-done + } + + assert.Equal(t, "abacf46f5217552ee28fe50b8fd7ca6aa46daeb9acf9f60928654c3b1a472f23", mt.Root().Hex()) +} + +func TestSmtTestSuite(t *testing.T) { + suite.Run(t, new(SmtTestSuite)) +} diff --git a/go-sdk/internal/log/log.go b/go-sdk/internal/log/log.go new file mode 100644 index 0000000..1addd20 --- /dev/null +++ b/go-sdk/internal/log/log.go @@ -0,0 +1,91 @@ +// Copyright © 2024 Kaleido, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +import ( + "os" + "sync/atomic" + + "github.com/sirupsen/logrus" + prefixed "github.com/x-cray/logrus-prefixed-formatter" +) + +var ( + rootLogger = logrus.NewEntry(logrus.StandardLogger()) + + // L accesses the current logger from the context + L = logger + + initAtLeastOnce atomic.Bool +) + +func InitConfig() { + initAtLeastOnce.Store(true) // must store before SetLevel + + logrus.SetLevel(logrus.InfoLevel) + logrus.SetOutput(os.Stdout) + var formatter logrus.Formatter + formatter = &prefixed.TextFormatter{ + DisableColors: false, + ForceColors: false, + TimestampFormat: "2006-01-02T15:04:05.000Z07:00", + DisableSorting: false, + FullTimestamp: true, + } + logrus.SetReportCaller(true) + formatter = &utcFormat{f: formatter} + logrus.SetFormatter(formatter) +} + +func IsDebugEnabled() bool { + return logrus.IsLevelEnabled(logrus.DebugLevel) +} + +func IsTraceEnabled() bool { + return logrus.IsLevelEnabled(logrus.TraceLevel) +} + +func ensureInit() { + // Called at a couple of strategic points to check we get log initialize in things like unit tests + // However NOT guaranteed to be called because we can't afford to do atomic load on every log line + if !initAtLeastOnce.Load() { + InitConfig() + } +} + +func logger() *logrus.Entry { + ensureInit() + return rootLogger +} + +// WithLogField adds the specified field to the logger in the context +func WithLogField(key, value string) *logrus.Entry { + ensureInit() + if len(value) > 61 { + value = value[0:61] + "..." + } + return rootLogger.WithField(key, value) +} + +type utcFormat struct { + f logrus.Formatter +} + +func (utc *utcFormat) Format(e *logrus.Entry) ([]byte, error) { + e.Time = e.Time.UTC() + return utc.f.Format(e) +} diff --git a/go-sdk/internal/sparse-merkle-tree/smt/merkletree.go b/go-sdk/internal/sparse-merkle-tree/smt/merkletree.go index e1f0a8b..73461cd 100644 --- a/go-sdk/internal/sparse-merkle-tree/smt/merkletree.go +++ b/go-sdk/internal/sparse-merkle-tree/smt/merkletree.go @@ -18,8 +18,10 @@ package smt import ( "math/big" + "strconv" "sync" + "github.com/hyperledger-labs/zeto/go-sdk/internal/log" "github.com/hyperledger-labs/zeto/go-sdk/internal/sparse-merkle-tree/node" "github.com/hyperledger-labs/zeto/go-sdk/internal/sparse-merkle-tree/storage" "github.com/hyperledger-labs/zeto/go-sdk/internal/sparse-merkle-tree/utils" @@ -78,18 +80,33 @@ func (mt *sparseMerkleTree) AddLeaf(node core.Node) error { // use up all the bits in the index's path. As soon as a unique path is found, // which may be only the first few bits of the index, the new leaf node is added. // One or more branch nodes may be created to accommodate the new leaf node. - newRootKey, err := mt.addLeaf(node, mt.rootKey, 0, path) + batch, err := mt.db.BeginBatch() if err != nil { return err } + newRootKey, err := mt.addLeaf(batch, node, mt.rootKey, 0, path) + if err != nil { + log.L().Errorf("Error adding leaf node %s: %v, rolling back", node.Ref().Hex(), err) + _ = batch.Rollback() + return err + } mt.rootKey = newRootKey // update the root node index in the storage - err = mt.db.UpsertRootNodeIndex(mt.rootKey) + log.L().Infof("Upserting root node index to %s", mt.rootKey.Hex()) + err = batch.UpsertRootNodeIndex(mt.rootKey) if err != nil { + log.L().Errorf("Error upserting root node %s: %v, rolling back", mt.rootKey.Hex(), err) + _ = batch.Rollback() + return err + } + log.L().Infof("Committing batch operations for adding leaf node %s", node.Ref().Hex()) + err = batch.Commit() + if err != nil { + log.L().Errorf("Error committing batch operations for adding leaf node %s: %v", node.Ref().Hex(), err) + _ = batch.Rollback() return err } - return nil } @@ -185,7 +202,8 @@ func (mt *sparseMerkleTree) getNode(key core.NodeIndex) (core.Node, error) { // as children of a new branch node. // - if the current node is a branch node, it will continue traversing the tree, using the // next bit of the new node's index to determine which child to go down to. -func (mt *sparseMerkleTree) addLeaf(newLeaf core.Node, currentNodeIndex core.NodeIndex, level int, path []bool) (core.NodeIndex, error) { +func (mt *sparseMerkleTree) addLeaf(batch core.Transaction, newLeaf core.Node, currentNodeIndex core.NodeIndex, level int, path []bool) (core.NodeIndex, error) { + log.WithLogField("level", strconv.Itoa(level)).Debugf("Adding leaf node %s", newLeaf.Ref().Hex()) if level > mt.maxLevels-1 { // we have exhausted all levels but could not find a unique path for the new leaf. // this happens when two leaf nodes have the same beginning bits of the index, of @@ -203,7 +221,8 @@ func (mt *sparseMerkleTree) addLeaf(newLeaf core.Node, currentNodeIndex core.Nod // We have searched to a level and have found a position in the // index's path where the tree is empty. This means we have found // the node that doesn't exist yet. We can add the new leaf node here - return mt.addNode(newLeaf) + log.WithLogField("level", strconv.Itoa(level)).Debugf("Found empty slot, inserting leaf node %s", newLeaf.Ref().Hex()) + return mt.addNode(batch, newLeaf) case core.NodeTypeLeaf: nIndex := currentNode.Index() // Check if leaf node found contains the leaf node we are @@ -216,20 +235,23 @@ func (mt *sparseMerkleTree) addLeaf(newLeaf core.Node, currentNodeIndex core.Nod // but we still have more bits in the index to use. We need to extend the // path of the existing leaf node and the new leaf node until they diverge. pathOldLeaf := nIndex.ToPath(mt.maxLevels) - return mt.extendPath(newLeaf, currentNode, level, path, pathOldLeaf) + log.WithLogField("level", strconv.Itoa(level)).Debug("Found occupied slot, extending path") + return mt.extendPath(batch, newLeaf, currentNode, level, path, pathOldLeaf) case core.NodeTypeBranch: // We need to go deeper, continue traversing the tree, left or // right depending on path var newBranchNode core.Node if path[level] { // go right - nextKey, err = mt.addLeaf(newLeaf, currentNode.RightChild(), level+1, path) + log.WithLogField("level", strconv.Itoa(level)).Debug("Found branch node, going right") + nextKey, err = mt.addLeaf(batch, newLeaf, currentNode.RightChild(), level+1, path) if err != nil { return nil, err } // replace the branch node with the new branch node, which now has a new right child newBranchNode, err = node.NewBranchNode(currentNode.LeftChild(), nextKey) } else { // go left - nextKey, err = mt.addLeaf(newLeaf, currentNode.LeftChild(), level+1, path) + log.WithLogField("level", strconv.Itoa(level)).Debug("Found branch node, going left") + nextKey, err = mt.addLeaf(batch, newLeaf, currentNode.LeftChild(), level+1, path) if err != nil { return nil, err } @@ -240,7 +262,8 @@ func (mt *sparseMerkleTree) addLeaf(newLeaf core.Node, currentNodeIndex core.Nod return nil, err } // persist the updated branch node - return mt.addNode(newBranchNode) + log.WithLogField("level", strconv.Itoa(level)).Debugf("Inserting new branch node %s (leftChild=%s, rightChild=%s)", newBranchNode.Ref().Hex(), newBranchNode.LeftChild().Hex(), newBranchNode.RightChild().Hex()) + return mt.addNode(batch, newBranchNode) default: return nil, ErrInvalidNodeFound } @@ -249,23 +272,23 @@ func (mt *sparseMerkleTree) addLeaf(newLeaf core.Node, currentNodeIndex core.Nod // must be called from inside a write lock // addNode adds a node into the MT. Empty nodes are not stored in the tree; // they are all the same and assumed to always exist. -func (mt *sparseMerkleTree) addNode(n core.Node) (core.NodeIndex, error) { +func (mt *sparseMerkleTree) addNode(batch core.Transaction, n core.Node) (core.NodeIndex, error) { if n.Type() == core.NodeTypeEmpty { return n.Ref(), nil } k := n.Ref() // Check that the node key doesn't already exist - if _, err := mt.db.GetNode(k); err == nil { + if _, err := batch.GetNode(k); err == nil { return nil, ErrNodeIndexAlreadyExists } - err := mt.db.InsertNode(n) + err := batch.InsertNode(n) return k, err } // must be called from inside a write lock // extendPath extends the path of two leaf nodes, which share the same beginnging part of // their indexes, until their paths diverge, creating ancestor branch nodes as needed. -func (mt *sparseMerkleTree) extendPath(newLeaf core.Node, oldLeaf core.Node, level int, pathNewLeaf []bool, pathOldLeaf []bool) (core.NodeIndex, error) { +func (mt *sparseMerkleTree) extendPath(batch core.Transaction, newLeaf core.Node, oldLeaf core.Node, level int, pathNewLeaf []bool, pathOldLeaf []bool) (core.NodeIndex, error) { if level > mt.maxLevels-2 { return nil, ErrReachedMaxLevel } @@ -274,7 +297,8 @@ func (mt *sparseMerkleTree) extendPath(newLeaf core.Node, oldLeaf core.Node, lev // If the next bit of the new leaf node's index is the same as the // next bit of the existing leaf node's index, we need to further extend // the path of both nodes. - nextKey, err := mt.extendPath(newLeaf, oldLeaf, level+1, pathNewLeaf, pathOldLeaf) + log.WithLogField("level", strconv.Itoa(level)).Debug("Found occupied slot, extending path") + nextKey, err := mt.extendPath(batch, newLeaf, oldLeaf, level+1, pathNewLeaf, pathOldLeaf) if err != nil { return nil, err } @@ -289,7 +313,8 @@ func (mt *sparseMerkleTree) extendPath(newLeaf core.Node, oldLeaf core.Node, lev return nil, err } // persist the new branch node. and return the key of the new branch node - return mt.addNode(newBranchNode) + log.WithLogField("level", strconv.Itoa(level)).Debugf("Inserting new branch node %s (leftChild=%s, rightChild=%s)", newBranchNode.Ref().Hex(), newBranchNode.LeftChild().Hex(), newBranchNode.RightChild().Hex()) + return mt.addNode(batch, newBranchNode) } // at the current level, the two nodes finally diverges. We can now create a @@ -310,12 +335,14 @@ func (mt *sparseMerkleTree) extendPath(newLeaf core.Node, oldLeaf core.Node, lev } // We can add newLeaf to the DB now. We don't need to add oldLeaf because it's // already in the DB. - _, err = mt.addNode(newLeaf) + log.WithLogField("level", strconv.Itoa(level)).Debugf("Inserting new leaf node %s", newLeaf.Ref().Hex()) + _, err = mt.addNode(batch, newLeaf) if err != nil { return nil, err } // finally don't forget to add the new branch node that is the parent of // the new leaf node to the DB. We also return this new branch node's key // to allow the caller to create branch nodes as needed. - return mt.addNode(newBranchNode) + log.WithLogField("level", strconv.Itoa(level)).Debugf("Inserting new branch node %s (leftChild=%s, rightChild=%s)", newBranchNode.Ref().Hex(), newBranchNode.LeftChild().Hex(), newBranchNode.RightChild().Hex()) + return mt.addNode(batch, newBranchNode) } diff --git a/go-sdk/internal/sparse-merkle-tree/smt/merkletree_test.go b/go-sdk/internal/sparse-merkle-tree/smt/merkletree_test.go index 5df24f3..738aa6d 100644 --- a/go-sdk/internal/sparse-merkle-tree/smt/merkletree_test.go +++ b/go-sdk/internal/sparse-merkle-tree/smt/merkletree_test.go @@ -44,6 +44,15 @@ func (ms *mockStorage) GetNode(core.NodeIndex) (core.Node, error) { func (ms *mockStorage) InsertNode(core.Node) error { return nil } +func (ms *mockStorage) BeginBatch() (core.Transaction, error) { + return ms, nil +} +func (ms *mockStorage) Commit() error { + return nil +} +func (ms *mockStorage) Rollback() error { + return nil +} func (ms *mockStorage) Close() {} func TestNewMerkleTreeFailures(t *testing.T) { diff --git a/go-sdk/internal/sparse-merkle-tree/smt/smt_test.go b/go-sdk/internal/sparse-merkle-tree/smt/smt_test.go index 11ab4d0..42acf3e 100644 --- a/go-sdk/internal/sparse-merkle-tree/smt/smt_test.go +++ b/go-sdk/internal/sparse-merkle-tree/smt/smt_test.go @@ -18,6 +18,7 @@ package smt import ( "fmt" + "log" "math/big" "math/rand" "os" @@ -29,22 +30,70 @@ import ( "github.com/hyperledger-labs/zeto/go-sdk/internal/testutils" "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/core" "github.com/iden3/go-iden3-crypto/babyjub" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "gorm.io/driver/sqlite" "gorm.io/gorm" + "gorm.io/gorm/logger" ) -func TestNewMerkleTree(t *testing.T) { - db := storage.NewMemoryStorage() - mt, err := NewMerkleTree(db, 64) - assert.NoError(t, err) - assert.Equal(t, 0, mt.Root().BigInt().Cmp(big.NewInt(0))) +type MerkleTreeTestSuite struct { + suite.Suite + db core.Storage + dbfile *os.File + gormDB *gorm.DB } -func TestAddNode(t *testing.T) { - db := storage.NewMemoryStorage() - mt, err := NewMerkleTree(db, 10) - assert.NoError(t, err) +type testSqlProvider struct { + db *gorm.DB +} + +func (p *testSqlProvider) DB() *gorm.DB { + return p.db +} + +func (p *testSqlProvider) Close() {} + +func (s *MerkleTreeTestSuite) SetupTest() { + logrus.SetLevel(logrus.DebugLevel) + dbfile, err := os.CreateTemp("", "gorm.db") + assert.NoError(s.T(), err) + s.dbfile = dbfile + newLogger := logger.New( + log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer + logger.Config{ + LogLevel: logger.Info, // Log level + IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger + ParameterizedQueries: false, // Don't include params in the SQL log + Colorful: true, // Disable color + }, + ) + db, err := gorm.Open(sqlite.Open(dbfile.Name()), &gorm.Config{Logger: newLogger}) + assert.NoError(s.T(), err) + err = db.Table(core.TreeRootsTable).AutoMigrate(&core.SMTRoot{}) + assert.NoError(s.T(), err) + err = db.Table(core.NodesTablePrefix + "test_1").AutoMigrate(&core.SMTNode{}) + assert.NoError(s.T(), err) + + provider := &testSqlProvider{db: db} + s.db = storage.NewSqlStorage(provider, "test_1") + s.gormDB = db +} + +func (s *MerkleTreeTestSuite) TearDownTest() { + os.Remove(s.dbfile.Name()) +} + +func (s *MerkleTreeTestSuite) TestNewMerkleTree() { + mt, err := NewMerkleTree(s.db, 64) + assert.NoError(s.T(), err) + assert.Equal(s.T(), 0, mt.Root().BigInt().Cmp(big.NewInt(0))) +} + +func (s *MerkleTreeTestSuite) TestAddNode() { + mt, err := NewMerkleTree(s.db, 10) + assert.NoError(s.T(), err) x, _ := new(big.Int).SetString("9198063289874244593808956064764348354864043212453245695133881114917754098693", 10) y, _ := new(big.Int).SetString("3600411115173311692823743444460566395943576560299970643507632418781961416843", 10) @@ -55,109 +104,107 @@ func TestAddNode(t *testing.T) { salt1, _ := new(big.Int).SetString("43c49e8ba68a9b8a6bb5c230a734d8271a83d2f63722e7651272ebeef5446e", 16) utxo1 := node.NewFungible(big.NewInt(10), alice, salt1) idx1, err := utxo1.CalculateIndex() - assert.NoError(t, err) - assert.Equal(t, "11a22e32f5010d3658d1da9c93f26b77afe7a84346f49eae3d1d4fc6cd0a36fd", idx1.BigInt().Text(16)) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "11a22e32f5010d3658d1da9c93f26b77afe7a84346f49eae3d1d4fc6cd0a36fd", idx1.BigInt().Text(16)) n1, err := node.NewLeafNode(utxo1) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n1) - assert.NoError(t, err) - assert.Equal(t, "525b60b382630ee7825bea84fb8808c13ede1fb827fe683cd5b14d76f6ac6d0b", mt.Root().Hex()) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "525b60b382630ee7825bea84fb8808c13ede1fb827fe683cd5b14d76f6ac6d0b", mt.Root().Hex()) // adding a 2nd node to test the tree update and branch nodes salt2, _ := new(big.Int).SetString("19b965f7629e4f0c4bd0b8f9c87f17580f18a32a31b4641550071ee4916bbbfc", 16) utxo2 := node.NewFungible(big.NewInt(20), alice, salt2) idx2, err := utxo2.CalculateIndex() - assert.NoError(t, err) - assert.Equal(t, "197b0dc3f167041e03d3eafacec1aa3ab12a0d7a606581af01447c269935e521", idx2.BigInt().Text(16)) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "197b0dc3f167041e03d3eafacec1aa3ab12a0d7a606581af01447c269935e521", idx2.BigInt().Text(16)) n2, err := node.NewLeafNode(utxo2) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n2) - assert.NoError(t, err) - assert.Equal(t, "c432caeb6448cb10bf8b449704f0fc79d84723b5aadeaf6f1b73cf00fe94c22f", mt.Root().Hex()) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "c432caeb6448cb10bf8b449704f0fc79d84723b5aadeaf6f1b73cf00fe94c22f", mt.Root().Hex()) // adding a 3rd node to test the tree update and branch nodes with a left/right child node salt3, _ := new(big.Int).SetString("9b0b93df975547e430eabff085a77831b8fcb6b5396e6bb815fda8d14125370", 16) utxo3 := node.NewFungible(big.NewInt(30), alice, salt3) idx3, err := utxo3.CalculateIndex() - assert.NoError(t, err) - assert.Equal(t, "2d46e23e813abf1fdabffe3ff22a38ebf6bb92d7c381463bee666eb010289fd5", idx3.BigInt().Text(16)) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "2d46e23e813abf1fdabffe3ff22a38ebf6bb92d7c381463bee666eb010289fd5", idx3.BigInt().Text(16)) n3, err := node.NewLeafNode(utxo3) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n3) - assert.NoError(t, err) - assert.Equal(t, "bf8409a4a6c7366bc64c154d3c2f40a8c3c5ddb0f1d47c41336d97ff27640502", mt.Root().Hex()) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "bf8409a4a6c7366bc64c154d3c2f40a8c3c5ddb0f1d47c41336d97ff27640502", mt.Root().Hex()) // adding a 4th node to test the tree update and branch nodes with the other left/right child node salt4, _ := new(big.Int).SetString("194ec10ec96a507c7c9b60df133d13679b874b0bd6ab89920135508f55b3f064", 16) utxo4 := node.NewFungible(big.NewInt(40), alice, salt4) idx4, err := utxo4.CalculateIndex() - assert.NoError(t, err) - assert.Equal(t, "887884c3421b72f8f1991c64808262da78732abf961118d02b0792bd421521f", idx4.BigInt().Text(16)) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "887884c3421b72f8f1991c64808262da78732abf961118d02b0792bd421521f", idx4.BigInt().Text(16)) n4, err := node.NewLeafNode(utxo4) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n4) - assert.NoError(t, err) - assert.Equal(t, "abacf46f5217552ee28fe50b8fd7ca6aa46daeb9acf9f60928654c3b1a472f23", mt.Root().Hex()) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "abacf46f5217552ee28fe50b8fd7ca6aa46daeb9acf9f60928654c3b1a472f23", mt.Root().Hex()) // test storage persistence rawDB := mt.(*sparseMerkleTree).db rootIdx, err := rawDB.GetRootNodeIndex() - assert.NoError(t, err) - assert.Equal(t, "abacf46f5217552ee28fe50b8fd7ca6aa46daeb9acf9f60928654c3b1a472f23", rootIdx.Hex()) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "abacf46f5217552ee28fe50b8fd7ca6aa46daeb9acf9f60928654c3b1a472f23", rootIdx.Hex()) // test storage persistence across tree creation - mt2, err := NewMerkleTree(db, 10) - assert.NoError(t, err) - assert.Equal(t, "abacf46f5217552ee28fe50b8fd7ca6aa46daeb9acf9f60928654c3b1a472f23", mt2.Root().Hex()) + mt2, err := NewMerkleTree(s.db, 10) + assert.NoError(s.T(), err) + assert.Equal(s.T(), "abacf46f5217552ee28fe50b8fd7ca6aa46daeb9acf9f60928654c3b1a472f23", mt2.Root().Hex()) } -func TestGenerateProof(t *testing.T) { +func (s *MerkleTreeTestSuite) TestGenerateProof() { const levels = 10 - db := storage.NewMemoryStorage() - mt, _ := NewMerkleTree(db, levels) + mt, _ := NewMerkleTree(s.db, levels) alice := testutils.NewKeypair() utxo1 := node.NewFungible(big.NewInt(10), alice.PublicKey, big.NewInt(12345)) node1, err := node.NewLeafNode(utxo1) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(node1) - assert.NoError(t, err) + assert.NoError(s.T(), err) utxo2 := node.NewFungible(big.NewInt(10), alice.PublicKey, big.NewInt(12346)) node2, err := node.NewLeafNode(utxo2) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(node2) - assert.NoError(t, err) + assert.NoError(s.T(), err) target1 := node1.Index().BigInt() proof1, foundValue1, err := mt.GenerateProof(target1, mt.Root()) - assert.NoError(t, err) - assert.Equal(t, target1, foundValue1) - assert.True(t, proof1.(*proof).existence) + assert.NoError(s.T(), err) + assert.Equal(s.T(), target1, foundValue1) + assert.True(s.T(), proof1.(*proof).existence) valid := VerifyProof(mt.Root(), proof1, node1) - assert.True(t, valid) + assert.True(s.T(), valid) utxo3 := node.NewFungible(big.NewInt(10), alice.PublicKey, big.NewInt(12347)) node3, err := node.NewLeafNode(utxo3) - assert.NoError(t, err) + assert.NoError(s.T(), err) target2 := node3.Index().BigInt() proof2, _, err := mt.GenerateProof(target2, mt.Root()) - assert.NoError(t, err) - assert.False(t, proof2.(*proof).existence) + assert.NoError(s.T(), err) + assert.False(s.T(), proof2.(*proof).existence) proof3, err := proof1.ToCircomVerifierProof(target1, foundValue1, mt.Root(), levels) - assert.NoError(t, err) - assert.False(t, proof3.IsOld0) + assert.NoError(s.T(), err) + assert.False(s.T(), proof3.IsOld0) } -func TestVerifyProof(t *testing.T) { +func (s *MerkleTreeTestSuite) TestVerifyProof() { const levels = 10 - db := storage.NewMemoryStorage() - mt, _ := NewMerkleTree(db, levels) + mt, _ := NewMerkleTree(s.db, levels) alice := testutils.NewKeypair() - values := []int{10, 20, 30, 40, 50} + values := []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100} done := make(chan bool, len(values)) startProving := make(chan core.Node, len(values)) for idx, value := range values { @@ -165,9 +212,9 @@ func TestVerifyProof(t *testing.T) { salt := rand.Intn(100000) utxo := node.NewFungible(big.NewInt(int64(v)), alice.PublicKey, big.NewInt(int64(salt))) node, err := node.NewLeafNode(utxo) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(node) - assert.NoError(t, err) + assert.NoError(s.T(), err) startProving <- node done <- true fmt.Printf("Added node %d\n", idx) @@ -182,11 +229,11 @@ func TestVerifyProof(t *testing.T) { target := n.Index().BigInt() root := mt.Root() p, _, err := mt.GenerateProof(target, root) - assert.NoError(t, err) - assert.True(t, p.(*proof).existence) + assert.NoError(s.T(), err) + assert.True(s.T(), p.(*proof).existence) valid := VerifyProof(root, p, n) - assert.True(t, valid) + assert.True(s.T(), valid) }() for i := 0; i < len(values); i++ { @@ -196,56 +243,29 @@ func TestVerifyProof(t *testing.T) { fmt.Println("All done") } -type testSqlProvider struct { - db *gorm.DB -} - -func (s *testSqlProvider) DB() *gorm.DB { - return s.db -} - -func (s *testSqlProvider) Close() {} - -func TestSqliteStorage(t *testing.T) { - dbfile, err := os.CreateTemp("", "gorm.db") - assert.NoError(t, err) - defer func() { - os.Remove(dbfile.Name()) - }() - db, err := gorm.Open(sqlite.Open(dbfile.Name()), &gorm.Config{}) - assert.NoError(t, err) - err = db.Table(core.TreeRootsTable).AutoMigrate(&core.SMTRoot{}) - assert.NoError(t, err) - err = db.Table(core.NodesTablePrefix + "test_1").AutoMigrate(&core.SMTNode{}) - assert.NoError(t, err) - - provider := &testSqlProvider{db: db} - s := storage.NewSqlStorage(provider, "test_1") - assert.NoError(t, err) - - mt, err := NewMerkleTree(s, 10) - assert.NoError(t, err) +func (s *MerkleTreeTestSuite) TestSqliteStorage() { + mt, err := NewMerkleTree(s.db, 10) + assert.NoError(s.T(), err) + assert.NotNil(s.T(), mt) tokenId := big.NewInt(1001) uriString := "https://example.com/token/1001" - assert.NoError(t, err) + assert.NoError(s.T(), err) sender := testutils.NewKeypair() salt1 := crypto.NewSalt() utxo1 := node.NewNonFungible(tokenId, uriString, sender.PublicKey, salt1) n1, err := node.NewLeafNode(utxo1) - assert.NoError(t, err) + assert.NoError(s.T(), err) err = mt.AddLeaf(n1) - assert.NoError(t, err) - - root := mt.Root() - dbRoot := core.SMTRoot{Name: "test_1"} - err = db.Table(core.TreeRootsTable).First(&dbRoot).Error - assert.NoError(t, err) - assert.Equal(t, root.Hex(), dbRoot.RootIndex) + assert.NoError(s.T(), err) dbNode := core.SMTNode{RefKey: n1.Ref().Hex()} - err = db.Table(core.NodesTablePrefix + "test_1").First(&dbNode).Error - assert.NoError(t, err) - assert.Equal(t, n1.Ref().Hex(), dbNode.RefKey) + err = s.gormDB.Table(core.NodesTablePrefix + "test_1").First(&dbNode).Error + assert.NoError(s.T(), err) + assert.Equal(s.T(), n1.Ref().Hex(), dbNode.RefKey) +} + +func TestMerkleTreeSuite(t *testing.T) { + suite.Run(t, new(MerkleTreeTestSuite)) } diff --git a/go-sdk/internal/sparse-merkle-tree/storage/memory.go b/go-sdk/internal/sparse-merkle-tree/storage/memory.go deleted file mode 100644 index fb45e5c..0000000 --- a/go-sdk/internal/sparse-merkle-tree/storage/memory.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright © 2024 Kaleido, Inc. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/core" - -type memoryStorage struct { - root core.NodeIndex - nodes map[core.NodeIndex]core.Node -} - -func NewMemoryStorage() *memoryStorage { - var nodes = make(map[core.NodeIndex]core.Node) - return &memoryStorage{ - nodes: nodes, - } -} - -func (m *memoryStorage) GetRootNodeIndex() (core.NodeIndex, error) { - if m.root == nil { - return nil, ErrNotFound - } - return m.root, nil -} - -func (m *memoryStorage) UpsertRootNodeIndex(root core.NodeIndex) error { - m.root = root - return nil -} - -func (m *memoryStorage) GetNode(idx core.NodeIndex) (core.Node, error) { - n, ok := m.nodes[idx] - if !ok { - return nil, ErrNotFound - } - return n, nil -} - -func (m *memoryStorage) InsertNode(node core.Node) error { - m.nodes[node.Ref()] = node - return nil -} - -func (m *memoryStorage) Close() { -} diff --git a/go-sdk/internal/sparse-merkle-tree/storage/memory_test.go b/go-sdk/internal/sparse-merkle-tree/storage/memory_test.go deleted file mode 100644 index ab7981b..0000000 --- a/go-sdk/internal/sparse-merkle-tree/storage/memory_test.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright © 2024 Kaleido, Inc. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "math/big" - "testing" - - "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/core" - "github.com/stretchr/testify/assert" -) - -type mockNodeIndex struct{} - -func (ni *mockNodeIndex) BigInt() *big.Int { return big.NewInt(0) } -func (ni *mockNodeIndex) Hex() string { return "0" } -func (ni *mockNodeIndex) IsZero() bool { return true } -func (ni *mockNodeIndex) Equal(n core.NodeIndex) bool { return true } -func (ni *mockNodeIndex) IsBitOn(uint) bool { return false } -func (ni *mockNodeIndex) ToPath(int) []bool { return []bool{true, false} } - -type mockNode struct { - idx core.NodeIndex - ref core.NodeIndex - lc core.NodeIndex - rc core.NodeIndex -} - -func (n *mockNode) Type() core.NodeType { return core.NodeTypeLeaf } -func (n *mockNode) Index() core.NodeIndex { return n.idx } -func (n *mockNode) Ref() core.NodeIndex { return n.ref } -func (n *mockNode) Value() core.Indexable { return nil } -func (n *mockNode) LeftChild() core.NodeIndex { return n.lc } -func (n *mockNode) RightChild() core.NodeIndex { return n.rc } - -func TestNewMemoryStorage(t *testing.T) { - s := NewMemoryStorage() - assert.Nil(t, s.root) - assert.NotNil(t, s.nodes) - assert.Empty(t, s.nodes) - - _, err := s.GetRootNodeIndex() - assert.Equal(t, ErrNotFound, err) - - err = s.UpsertRootNodeIndex(&mockNodeIndex{}) - assert.NoError(t, err) - - ni, err := s.GetRootNodeIndex() - assert.NoError(t, err) - assert.NotNil(t, ni) - - idx1 := &mockNodeIndex{} - _, err = s.GetNode(idx1) - assert.Equal(t, ErrNotFound, err) - - idx2 := &mockNodeIndex{} - n1 := &mockNode{idx: idx1, ref: idx2} - err = s.InsertNode(n1) - assert.NoError(t, err) - - found, err := s.GetNode(idx1) - assert.NoError(t, err) - assert.NotNil(t, found) - assert.Equal(t, n1, found) -} diff --git a/go-sdk/internal/sparse-merkle-tree/storage/sql.go b/go-sdk/internal/sparse-merkle-tree/storage/sql.go index 80e64d3..dc061e2 100644 --- a/go-sdk/internal/sparse-merkle-tree/storage/sql.go +++ b/go-sdk/internal/sparse-merkle-tree/storage/sql.go @@ -55,20 +55,70 @@ func (s *sqlStorage) GetRootNodeIndex() (core.NodeIndex, error) { } func (s *sqlStorage) UpsertRootNodeIndex(root core.NodeIndex) error { - err := s.p.DB().Table(core.TreeRootsTable).Save(&core.SMTRoot{ + return upsertRootNodeIndex(s.p.DB(), s.smtName, root) +} + +func (s *sqlStorage) GetNode(ref core.NodeIndex) (core.Node, error) { + return getNode(s.p.DB(), s.nodesTableName, ref) +} + +func (s *sqlStorage) InsertNode(n core.Node) error { + return insertNode(s.p.DB(), s.nodesTableName, n) +} + +func (s *sqlStorage) BeginBatch() (core.Transaction, error) { + return &sqlBatchStorage{ + tx: s.p.DB().Begin(), + smtName: s.smtName, + nodesTableName: s.nodesTableName, + }, nil +} + +type sqlBatchStorage struct { + tx *gorm.DB + smtName string + nodesTableName string +} + +func (b *sqlBatchStorage) UpsertRootNodeIndex(root core.NodeIndex) error { + return upsertRootNodeIndex(b.tx, b.smtName, root) +} + +func (b *sqlBatchStorage) GetNode(ref core.NodeIndex) (core.Node, error) { + return getNode(b.tx, b.nodesTableName, ref) +} + +func (b *sqlBatchStorage) InsertNode(n core.Node) error { + return insertNode(b.tx, b.nodesTableName, n) +} + +func (b *sqlBatchStorage) Commit() error { + return b.tx.Commit().Error +} + +func (b *sqlBatchStorage) Rollback() error { + return b.tx.Rollback().Error +} + +func (m *sqlStorage) Close() { + m.p.Close() +} + +func upsertRootNodeIndex(batchOrDb *gorm.DB, name string, root core.NodeIndex) error { + err := batchOrDb.Table(core.TreeRootsTable).Save(&core.SMTRoot{ RootIndex: root.Hex(), - Name: s.smtName, + Name: name, }).Error return err } -func (s *sqlStorage) GetNode(ref core.NodeIndex) (core.Node, error) { +func getNode(batchOrDb *gorm.DB, nodesTableName string, ref core.NodeIndex) (core.Node, error) { // the node's reference key (not the index) is used as the key to // store the node in the DB n := core.SMTNode{ RefKey: ref.Hex(), } - err := s.p.DB().Table(s.nodesTableName).First(&n).Error + err := batchOrDb.Table(nodesTableName).First(&n).Error if err == gorm.ErrRecordNotFound { return nil, ErrNotFound } else if err != nil { @@ -98,7 +148,7 @@ func (s *sqlStorage) GetNode(ref core.NodeIndex) (core.Node, error) { return newNode, err } -func (s *sqlStorage) InsertNode(n core.Node) error { +func insertNode(batchOrDb *gorm.DB, nodesTableName string, n core.Node) error { // we clone the node so that the value properties are not saved dbNode := &core.SMTNode{ RefKey: n.Ref().Hex(), @@ -114,10 +164,6 @@ func (s *sqlStorage) InsertNode(n core.Node) error { dbNode.Index = &idx } - err := s.p.DB().Table(s.nodesTableName).Create(dbNode).Error + err := batchOrDb.Table(nodesTableName).Create(dbNode).Error return err } - -func (m *sqlStorage) Close() { - m.p.Close() -} diff --git a/go-sdk/pkg/sparse-merkle-tree/core/storage.go b/go-sdk/pkg/sparse-merkle-tree/core/storage.go index e2a8802..dafc6c9 100644 --- a/go-sdk/pkg/sparse-merkle-tree/core/storage.go +++ b/go-sdk/pkg/sparse-merkle-tree/core/storage.go @@ -30,10 +30,21 @@ type Storage interface { // InsertNode inserts a node into the storage. Where the private values of a node are stored // is implementation-specific InsertNode(Node) error + // Batch executes a batch of operations in a single transaction. The semantics of the batch + // function follows the semantics of the gorm.DB.Transaction function. + BeginBatch() (Transaction, error) // Close closes the storage resource Close() } +type Transaction interface { + UpsertRootNodeIndex(NodeIndex) error + GetNode(NodeIndex) (Node, error) + InsertNode(Node) error + Commit() error + Rollback() error +} + const ( // we use a table to store the root node indexes for // all the merkle trees in the database diff --git a/go-sdk/pkg/sparse-merkle-tree/storage/storage.go b/go-sdk/pkg/sparse-merkle-tree/storage/storage.go index 4c9d212..4147da2 100644 --- a/go-sdk/pkg/sparse-merkle-tree/storage/storage.go +++ b/go-sdk/pkg/sparse-merkle-tree/storage/storage.go @@ -21,10 +21,6 @@ import ( "github.com/hyperledger-labs/zeto/go-sdk/pkg/sparse-merkle-tree/core" ) -func NewMemoryStorage() core.Storage { - return storage.NewMemoryStorage() -} - func NewSqlStorage(provider core.SqlDBProvider, smtName string) (core.Storage, error) { return storage.NewSqlStorage(provider, smtName), nil }