Skip to content

Commit

Permalink
Merge pull request #28 from rollkit/tux/client-context
Browse files Browse the repository at this point in the history
  • Loading branch information
Manav-Aggarwal committed Jan 12, 2024
2 parents 755acd2 + bff650e commit 82f5296
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 46 deletions.
14 changes: 8 additions & 6 deletions da.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
package da

import "context"

// DA defines very generic interface for interaction with Data Availability layers.
type DA interface {
// MaxBlobSize returns the max blob size
MaxBlobSize() (uint64, error)
MaxBlobSize(ctx context.Context) (uint64, error)

// Get returns Blob for each given ID, or an error.
//
// Error should be returned if ID is not formatted properly, there is no Blob for given ID or any other client-level
// error occurred (dropped connection, timeout, etc).
Get(ids []ID) ([]Blob, error)
Get(ctx context.Context, ids []ID) ([]Blob, error)

// GetIDs returns IDs of all Blobs located in DA at given height.
GetIDs(height uint64) ([]ID, error)
GetIDs(ctx context.Context, height uint64) ([]ID, error)

// Commit creates a Commitment for each given Blob.
Commit(blobs []Blob) ([]Commitment, error)
Commit(ctx context.Context, blobs []Blob) ([]Commitment, error)

// Submit submits the Blobs to Data Availability layer.
//
// This method is synchronous. Upon successful submission to Data Availability layer, it returns ID identifying blob
// in DA and Proof of inclusion.
// If options is nil, default options are used.
Submit(blobs []Blob, gasPrice float64) ([]ID, []Proof, error)
Submit(ctx context.Context, blobs []Blob, gasPrice float64) ([]ID, []Proof, error)

// Validate validates Commitments against the corresponding Proofs. This should be possible without retrieving the Blobs.
Validate(ids []ID, proofs []Proof) ([]bool, error)
Validate(ctx context.Context, ids []ID, proofs []Proof) ([]bool, error)
}

// Blob is the data submitted/received from DA interface.
Expand Down
24 changes: 12 additions & 12 deletions proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,24 @@ func (c *Client) Stop() error {
}

// MaxBlobSize returns the DA MaxBlobSize
func (c *Client) MaxBlobSize() (uint64, error) {
func (c *Client) MaxBlobSize(ctx context.Context) (uint64, error) {
req := &pbda.MaxBlobSizeRequest{}
resp, err := c.client.MaxBlobSize(context.TODO(), req)
resp, err := c.client.MaxBlobSize(ctx, req)
if err != nil {
return 0, err
}
return resp.MaxBlobSize, nil
}

// Get returns Blob for each given ID, or an error.
func (c *Client) Get(ids []da.ID) ([]da.Blob, error) {
func (c *Client) Get(ctx context.Context, ids []da.ID) ([]da.Blob, error) {
req := &pbda.GetRequest{
Ids: make([]*pbda.ID, len(ids)),
}
for i := range ids {
req.Ids[i] = &pbda.ID{Value: ids[i]}
}
resp, err := c.client.Get(context.TODO(), req)
resp, err := c.client.Get(ctx, req)
if err != nil {
return nil, err
}
Expand All @@ -64,9 +64,9 @@ func (c *Client) Get(ids []da.ID) ([]da.Blob, error) {
}

// GetIDs returns IDs of all Blobs located in DA at given height.
func (c *Client) GetIDs(height uint64) ([]da.ID, error) {
func (c *Client) GetIDs(ctx context.Context, height uint64) ([]da.ID, error) {
req := &pbda.GetIDsRequest{Height: height}
resp, err := c.client.GetIDs(context.TODO(), req)
resp, err := c.client.GetIDs(ctx, req)
if err != nil {
return nil, err
}
Expand All @@ -75,12 +75,12 @@ func (c *Client) GetIDs(height uint64) ([]da.ID, error) {
}

// Commit creates a Commitment for each given Blob.
func (c *Client) Commit(blobs []da.Blob) ([]da.Commitment, error) {
func (c *Client) Commit(ctx context.Context, blobs []da.Blob) ([]da.Commitment, error) {
req := &pbda.CommitRequest{
Blobs: blobsDA2PB(blobs),
}

resp, err := c.client.Commit(context.TODO(), req)
resp, err := c.client.Commit(ctx, req)
if err != nil {
return nil, err
}
Expand All @@ -89,13 +89,13 @@ func (c *Client) Commit(blobs []da.Blob) ([]da.Commitment, error) {
}

// Submit submits the Blobs to Data Availability layer.
func (c *Client) Submit(blobs []da.Blob, gasPrice float64) ([]da.ID, []da.Proof, error) {
func (c *Client) Submit(ctx context.Context, blobs []da.Blob, gasPrice float64) ([]da.ID, []da.Proof, error) {
req := &pbda.SubmitRequest{
Blobs: blobsDA2PB(blobs),
GasPrice: gasPrice,
}

resp, err := c.client.Submit(context.TODO(), req)
resp, err := c.client.Submit(ctx, req)
if err != nil {
return nil, nil, err
}
Expand All @@ -111,11 +111,11 @@ func (c *Client) Submit(blobs []da.Blob, gasPrice float64) ([]da.ID, []da.Proof,
}

// Validate validates Commitments against the corresponding Proofs. This should be possible without retrieving the Blobs.
func (c *Client) Validate(ids []da.ID, proofs []da.Proof) ([]bool, error) {
func (c *Client) Validate(ctx context.Context, ids []da.ID, proofs []da.Proof) ([]bool, error) {
req := &pbda.ValidateRequest{
Ids: idsDA2PB(ids),
Proofs: proofsDA2PB(proofs),
}
resp, err := c.client.Validate(context.TODO(), req)
resp, err := c.client.Validate(ctx, req)
return resp.Results, err
}
12 changes: 6 additions & 6 deletions proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ type proxySrv struct {
}

func (p *proxySrv) MaxBlobSize(ctx context.Context, request *pbda.MaxBlobSizeRequest) (*pbda.MaxBlobSizeResponse, error) {
maxBlobSize, err := p.target.MaxBlobSize()
maxBlobSize, err := p.target.MaxBlobSize(ctx)
return &pbda.MaxBlobSizeResponse{MaxBlobSize: maxBlobSize}, err
}

func (p *proxySrv) Get(ctx context.Context, request *pbda.GetRequest) (*pbda.GetResponse, error) {
ids := idsPB2DA(request.Ids)
blobs, err := p.target.Get(ids)
blobs, err := p.target.Get(ctx, ids)
return &pbda.GetResponse{Blobs: blobsDA2PB(blobs)}, err
}

func (p *proxySrv) GetIDs(ctx context.Context, request *pbda.GetIDsRequest) (*pbda.GetIDsResponse, error) {
ids, err := p.target.GetIDs(request.Height)
ids, err := p.target.GetIDs(ctx, request.Height)
if err != nil {
return nil, err
}
Expand All @@ -46,7 +46,7 @@ func (p *proxySrv) GetIDs(ctx context.Context, request *pbda.GetIDsRequest) (*pb

func (p *proxySrv) Commit(ctx context.Context, request *pbda.CommitRequest) (*pbda.CommitResponse, error) {
blobs := blobsPB2DA(request.Blobs)
commits, err := p.target.Commit(blobs)
commits, err := p.target.Commit(ctx, blobs)
if err != nil {
return nil, err
}
Expand All @@ -57,7 +57,7 @@ func (p *proxySrv) Commit(ctx context.Context, request *pbda.CommitRequest) (*pb
func (p *proxySrv) Submit(ctx context.Context, request *pbda.SubmitRequest) (*pbda.SubmitResponse, error) {
blobs := blobsPB2DA(request.Blobs)

ids, proofs, err := p.target.Submit(blobs, request.GasPrice)
ids, proofs, err := p.target.Submit(ctx, blobs, request.GasPrice)
if err != nil {
return nil, err
}
Expand All @@ -79,7 +79,7 @@ func (p *proxySrv) Validate(ctx context.Context, request *pbda.ValidateRequest)
ids := idsPB2DA(request.Ids)
proofs := proofsPB2DA(request.Proofs)
//TODO implement me
validity, err := p.target.Validate(ids, proofs)
validity, err := p.target.Validate(ctx, ids, proofs)
if err != nil {
return nil, err
}
Expand Down
13 changes: 7 additions & 6 deletions test/dummy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package test

import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
Expand Down Expand Up @@ -49,12 +50,12 @@ func NewDummyDA(opts ...func(*DummyDA) *DummyDA) *DummyDA {
var _ da.DA = &DummyDA{}

// MaxBlobSize returns the max blob size in bytes.
func (d *DummyDA) MaxBlobSize() (uint64, error) {
func (d *DummyDA) MaxBlobSize(ctx context.Context) (uint64, error) {
return d.maxBlobSize, nil
}

// Get returns Blobs for given IDs.
func (d *DummyDA) Get(ids []da.ID) ([]da.Blob, error) {
func (d *DummyDA) Get(ctx context.Context, ids []da.ID) ([]da.Blob, error) {
d.mu.Lock()
defer d.mu.Unlock()
blobs := make([]da.Blob, len(ids))
Expand All @@ -78,7 +79,7 @@ func (d *DummyDA) Get(ids []da.ID) ([]da.Blob, error) {
}

// GetIDs returns IDs of Blobs at given DA height.
func (d *DummyDA) GetIDs(height uint64) ([]da.ID, error) {
func (d *DummyDA) GetIDs(ctx context.Context, height uint64) ([]da.ID, error) {
d.mu.Lock()
defer d.mu.Unlock()
kvps := d.data[height]
Expand All @@ -90,7 +91,7 @@ func (d *DummyDA) GetIDs(height uint64) ([]da.ID, error) {
}

// Commit returns cryptographic Commitments for given blobs.
func (d *DummyDA) Commit(blobs []da.Blob) ([]da.Commitment, error) {
func (d *DummyDA) Commit(ctx context.Context, blobs []da.Blob) ([]da.Commitment, error) {
commits := make([]da.Commitment, len(blobs))
for i, blob := range blobs {
commits[i] = d.getHash(blob)
Expand All @@ -99,7 +100,7 @@ func (d *DummyDA) Commit(blobs []da.Blob) ([]da.Commitment, error) {
}

// Submit stores blobs in DA layer.
func (d *DummyDA) Submit(blobs []da.Blob, gasPrice float64) ([]da.ID, []da.Proof, error) {
func (d *DummyDA) Submit(ctx context.Context, blobs []da.Blob, gasPrice float64) ([]da.ID, []da.Proof, error) {
d.mu.Lock()
defer d.mu.Unlock()
ids := make([]da.ID, len(blobs))
Expand All @@ -116,7 +117,7 @@ func (d *DummyDA) Submit(blobs []da.Blob, gasPrice float64) ([]da.ID, []da.Proof
}

// Validate checks the Proofs for given IDs.
func (d *DummyDA) Validate(ids []da.ID, proofs []da.Proof) ([]bool, error) {
func (d *DummyDA) Validate(ctx context.Context, ids []da.ID, proofs []da.Proof) ([]bool, error) {
if len(ids) != len(proofs) {
return nil, errors.New("number of IDs doesn't equal to number of proofs")
}
Expand Down
38 changes: 22 additions & 16 deletions test/test_suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package test

import (
"bytes"
"context"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -40,58 +41,59 @@ func BasicDATest(t *testing.T, da da.DA) {
msg1 := []byte("message 1")
msg2 := []byte("message 2")

id1, proof1, err := da.Submit([]Blob{msg1}, -1)
ctx := context.TODO()
id1, proof1, err := da.Submit(ctx, []Blob{msg1}, -1)
assert.NoError(t, err)
assert.NotEmpty(t, id1)
assert.NotEmpty(t, proof1)

id2, proof2, err := da.Submit([]Blob{msg2}, -1)
id2, proof2, err := da.Submit(ctx, []Blob{msg2}, -1)
assert.NoError(t, err)
assert.NotEmpty(t, id2)
assert.NotEmpty(t, proof2)

id3, proof3, err := da.Submit([]Blob{msg1}, -1)
id3, proof3, err := da.Submit(ctx, []Blob{msg1}, -1)
assert.NoError(t, err)
assert.NotEmpty(t, id3)
assert.NotEmpty(t, proof3)

assert.NotEqual(t, id1, id2)
assert.NotEqual(t, id1, id3)

ret, err := da.Get(id1)
ret, err := da.Get(ctx, id1)
assert.NoError(t, err)
assert.Equal(t, []Blob{msg1}, ret)

commitment1, err := da.Commit([]Blob{msg1})
commitment1, err := da.Commit(ctx, []Blob{msg1})
assert.NoError(t, err)
assert.NotEmpty(t, commitment1)

commitment2, err := da.Commit([]Blob{msg2})
commitment2, err := da.Commit(ctx, []Blob{msg2})
assert.NoError(t, err)
assert.NotEmpty(t, commitment2)

oks, err := da.Validate(id1, proof1)
oks, err := da.Validate(ctx, id1, proof1)
assert.NoError(t, err)
assert.NotEmpty(t, oks)
for _, ok := range oks {
assert.True(t, ok)
}

oks, err = da.Validate(id2, proof2)
oks, err = da.Validate(ctx, id2, proof2)
assert.NoError(t, err)
assert.NotEmpty(t, oks)
for _, ok := range oks {
assert.True(t, ok)
}

oks, err = da.Validate(id1, proof2)
oks, err = da.Validate(ctx, id1, proof2)
assert.NoError(t, err)
assert.NotEmpty(t, oks)
for _, ok := range oks {
assert.False(t, ok)
}

oks, err = da.Validate(id2, proof1)
oks, err = da.Validate(ctx, id2, proof1)
assert.NoError(t, err)
assert.NotEmpty(t, oks)
for _, ok := range oks {
Expand All @@ -101,7 +103,8 @@ func BasicDATest(t *testing.T, da da.DA) {

// CheckErrors ensures that errors are handled properly by DA.
func CheckErrors(t *testing.T, da da.DA) {
blob, err := da.Get([]ID{[]byte("invalid")})
ctx := context.TODO()
blob, err := da.Get(ctx, []ID{[]byte("invalid")})
assert.Error(t, err)
assert.Empty(t, blob)
}
Expand All @@ -110,7 +113,8 @@ func CheckErrors(t *testing.T, da da.DA) {
func GetIDsTest(t *testing.T, da da.DA) {
msgs := [][]byte{[]byte("msg1"), []byte("msg2"), []byte("msg3")}

ids, proofs, err := da.Submit(msgs, -1)
ctx := context.TODO()
ids, proofs, err := da.Submit(ctx, msgs, -1)
assert.NoError(t, err)
assert.Len(t, ids, len(msgs))
assert.Len(t, proofs, len(msgs))
Expand All @@ -122,12 +126,12 @@ func GetIDsTest(t *testing.T, da da.DA) {
// As we're the only user, we don't need to handle external data (that could be submitted in real world).
// There is no notion of height, so we need to scan the DA to get test data back.
for i := uint64(1); !found && !time.Now().After(end); i++ {
ret, err := da.GetIDs(i)
ret, err := da.GetIDs(ctx, i)
if err != nil {
t.Error("failed to get IDs:", err)
}
if len(ret) > 0 {
blobs, err := da.Get(ret)
blobs, err := da.Get(ctx, ret)
assert.NoError(t, err)

// Submit ensures atomicity of batch, so it makes sense to compare actual blobs (bodies) only when lengths
Expand All @@ -151,18 +155,20 @@ func ConcurrentReadWriteTest(t *testing.T, da da.DA) {
var wg sync.WaitGroup
wg.Add(2)

ctx := context.TODO()

go func() {
defer wg.Done()
for i := uint64(1); i <= 100; i++ {
_, err := da.GetIDs(i)
_, err := da.GetIDs(ctx, i)
assert.NoError(t, err)
}
}()

go func() {
defer wg.Done()
for i := uint64(1); i <= 100; i++ {
_, _, err := da.Submit([][]byte{[]byte("test")}, -1)
_, _, err := da.Submit(ctx, [][]byte{[]byte("test")}, -1)
assert.NoError(t, err)
}
}()
Expand Down

0 comments on commit 82f5296

Please sign in to comment.