Skip to content

Commit

Permalink
avoid blocking batchVerifyParallel (#154)
Browse files Browse the repository at this point in the history
* avoid blocking batchVerifyParallel

The current version of `batchVerifyParallel` calls `syncAll` which syncs
on all executing tasks.

This PR changes this to syncing a Flowvar instead thus allowing
`batchVerifyParallel` to be called as a task itself.

Requires status-im/nim-taskpools#33

* autoselect too

---------

Co-authored-by: zah <[email protected]>
  • Loading branch information
arnetheduck and zah authored Jul 25, 2023
1 parent d8507ef commit 5937eb9
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 44 deletions.
4 changes: 2 additions & 2 deletions benchmarks/bls_signature.nim
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ when BLS_BACKEND == BLST:
var secureBlindingBytes: array[32, byte]
secureBlindingBytes.bls_sha256_digest("Mr F was here")

var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init()

bench("Serial batch verify " & $numSigs & " msgs by "& $numSigs & " pubkeys (with blinding)", iters):
secureBlindingBytes.bls_sha256_digest(secureBlindingBytes)
Expand All @@ -223,7 +223,7 @@ when BLS_BACKEND == BLST:
hashedMsg.bls_sha256_digest("msg" & $i)
batch.add((pk, hashedMsg, sk.sign(hashedMsg)))

var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var secureBlindingBytes: array[32, byte]
secureBlindingBytes.bls_sha256_digest("Mr F was here")

Expand Down
139 changes: 106 additions & 33 deletions blscurve/bls_batch_verifier.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# those terms.

import
stew/ptrops,
./bls_backend, ./bls_sig_min_pubkey

when compileOption("threads"):
Expand Down Expand Up @@ -43,19 +44,41 @@ type
## `pubkey` and `signature` are assumed to be grouped checked
## which is guaranteed at deserialization from bytes or hex

BatchedBLSVerifierCache* = object
BatchedBLSVerifierCache* {.requiresInit.} = object
## This types hold temporary contexts
## to batch BLS multi signatures (aggregated or individual)
## verification.
## As the contexts are heavy, they can be reused

# Per-batch contexts for multithreaded batch verification
batchContexts: seq[ContextMultiAggregateVerify[DST]]
updateResults: seq[tuple[ok: bool, padCacheLine: array[64, byte]]]
when compileOption("threads"):
flows: seq[Flowvar[bool]]

# Serial Batch Verifier
# ----------------------------------------------------------------------

func init*(T: type BatchedBLSVerifierCache): T =
## Initialise the cache for single-threaded usage
when compileOption("threads"):
BatchedBLSVerifierCache(
batchContexts: newSeq[ContextMultiAggregateVerify[DST]](1),
flows: @[]
)
else:
BatchedBLSVerifierCache(
batchContexts: newSeq[ContextMultiAggregateVerify[DST]](1),
)


when compileOption("threads"):
func init*(T: type BatchedBLSVerifierCache, tp: Taskpool): T =
## Initialise the cache for multi-threaded usage
BatchedBLSVerifierCache(
batchContexts: newSeq[ContextMultiAggregateVerify[DST]](tp.numThreads),
flows: newSeq[Flowvar[bool]](tp.numThreads)
)

func batchVerifySerial*(
cache: var BatchedBLSVerifierCache,
input: openArray[SignatureSet],
Expand All @@ -69,11 +92,15 @@ func batchVerifySerial*(
## - One or more of the inputs was invalid on aggregation
## - One or more of the inputs had an invalid signature
## If knowing which input was problematic is required, they must be checked one by one.
##
## The `cache` argument is required to have been initialized via `init`.
## This function performs no memory allocations.
if input.len == 0:
# Spec precondition
return false

cache.batchContexts.setLen(1)
doAssert cache.batchContexts.len >= 1

template ctx: untyped = cache.batchContexts[0]
ctx.init(secureRandomBytes, "")

Expand All @@ -91,7 +118,7 @@ func batchVerifySerial*(
ctx.commit()

# Final exponentiation
return ctx.finalVerify()
ctx.finalVerify()

func batchVerifySerial*(
input: openArray[SignatureSet],
Expand All @@ -107,7 +134,7 @@ func batchVerifySerial*(
## If knowing which input was problematic is required, they must be checked one by one.

# Don't {.noinit.} this or seq capacity will be != 0.
var batcher: BatchedBLSVerifierCache
var batcher = BatchedBLSVerifierCache.init()
return batcher.batchVerifySerial(input, secureRandomBytes)

when compileOption("threads"):
Expand Down Expand Up @@ -229,9 +256,10 @@ when compileOption("threads"):

proc batchVerifyParallel*(
tp: Taskpool,
cache: var BatchedBLSVerifierCache,
input: openArray[SignatureSet],
secureRandomBytes: array[32, byte]
cache: ptr BatchedBLSVerifierCache,
setsPtr: ptr UncheckedArray[SignatureSet],
numSets: int,
secureRandomBytes: ptr array[32, byte]
): bool {.sideEffect.} =
## Multithreaded batch verification
## If multithreaded with -d:openmp requires OpenMP 3.0 (GCC 4.4, 2008)
Expand All @@ -242,60 +270,51 @@ when compileOption("threads"):
## - One or more of the inputs was invalid on aggregation
## - One or more of the inputs had an invalid signature
## If knowing which input was problematic is required, they must be checked one by one.
let numSets = input.len
if numSets == 0:
# Spec precondition
return false

let numBatches = min(numSets, tp.numThreads)

# Stage 0: Accumulators - setLen for noinit of seq
cache.batchContexts.setLen(numBatches)
cache.updateResults.setLen(numBatches)
doAssert cache[].batchContexts.len >= numBatches

# No GC in a parallel section
# Hence we use raw ptr UncheckedArray instead of seq
let contextsPtr = cache.batchContexts.toPtrUncheckedArray()
let setsPtr = input.toPtrUncheckedArray()
let updateResultsPtr = cache.updateResults.toPtrUncheckedArray()
let contextsPtr = cache[].batchContexts.toPtrUncheckedArray()

# Stage 1: Accumulate partial pairings
proc processSingleChunk(
contextsPtr: ptr UncheckedArray[ContextMultiAggregateVerify[DST]],
setsPtr: ptr UncheckedArray[SignatureSet],
updateResultsPtr: ptr UncheckedArray[tuple[ok: bool, padCacheLine: array[64, byte]]],
secureRandomBytes: ptr array[32, byte],
chunkID: int,
chunkStart, chunkLen: int) {.gcsafe, nimcall.}=
chunkStart, chunkLen: int): bool {.gcsafe, nimcall.} =

contextsPtr[chunkID].init(
secureRandomBytes[],
threadSepTag = cast[array[sizeof(chunkID), byte]](chunkID)
)

updateResultsPtr[chunkID].ok =
accumPairingLines(
setsPtr, contextsPtr,
chunkID,
chunkStart, (chunkStart+chunkLen)
)

accumPairingLines(
setsPtr, contextsPtr, chunkID,
chunkStart, (chunkStart+chunkLen)
)
var results: seq[Flowvar[bool]]
for chunkID in 0 ..< numBatches:
parallel_chunks(numBatches, numSets, chunkID, chunkStart, chunkLen):
# Partition work into even chunks
# Each thread receives a different start+len to process
# chunkStart and chunkLen are set per-thread by the template

tp.spawn processSingleChunk(
contextsPtr, setsPtr, updateResultsPtr,
secureRandomBytes.unsafeAddr,
results.add(tp.spawn processSingleChunk(
contextsPtr, setsPtr,
secureRandomBytes,
chunkID, chunkStart, chunkLen
)
))

tp.syncAll()

for i in 0 ..< cache.updateResults.len:
if not updateResultsPtr[i].ok:
for res in results.mitems:
if not sync(res):
return false

# Stage 2: Reduce partial pairings
Expand All @@ -312,6 +331,32 @@ when compileOption("threads"):

return cache.batchContexts[0].finalVerify()

proc batchVerifyParallel*(
tp: Taskpool,
cache: var BatchedBLSVerifierCache,
input: openArray[SignatureSet],
secureRandomBytes: array[32, byte]
): bool {.sideEffect.} =
## Multithreaded batch verification
## If multithreaded with -d:openmp requires OpenMP 3.0 (GCC 4.4, 2008)
## This will verify all the inputs (PublicKey, message, Signature) triplets
## at once and return true if verification is successful.
## If unsuccessful:
## - The input was empty
## - One or more of the inputs was invalid on aggregation
## - One or more of the inputs had an invalid signature
## If knowing which input was problematic is required, they must be checked one by one.
##
## If called from multiple threads, each call must have its own
## `BatchedBLSVerifierCache` instance initialised with the
## threadpool-specific initializer.
##
## This function does not allocate garbage-collected memory.

batchVerifyParallel(
tp, addr cache, makeUncheckedArray(baseAddr input), input.len,
unsafeAddr secureRandomBytes)

proc batchVerifyParallel*(
tp: Taskpool,
input: openArray[SignatureSet],
Expand All @@ -328,11 +373,39 @@ when compileOption("threads"):
## If knowing which input was problematic is required, they must be checked one by one.

# Don't {.noinit.} this or seq capacity will be != 0.
var batcher: BatchedBLSVerifierCache
var batcher = BatchedBLSVerifierCache.init(tp)
return tp.batchVerifyParallel(batcher, input, secureRandomBytes)

# Autoselect Batch Verifier
# ----------------------------------------------------------------------
proc batchVerify*(
tp: Taskpool,
cache: ptr BatchedBLSVerifierCache,
setsPtr: ptr UncheckedArray[SignatureSet],
numSets: int,
secureRandomBytes: ptr array[32, byte]
): bool =
## Verify all signatures in batch at once.
## Returns true if all signatures are correct
## Returns false if there is at least one incorrect signature
##
## This requires securely generated random bytes
## for scalar blinding
## to defend against forged signatures that would not
## verify individually but would verify while aggregated.
##
## The blinding scheme also assumes that the attacker cannot
## resubmit 2^64 times forged (publickey, message, signature) triplets
## against the same `secureRandomBytes`
when compileOption("threads"):
if tp.numThreads > 1 and numSets >= 3:
return tp.batchVerifyParallel(cache, setsPtr, numSets, secureRandomBytes)
else:
return cache[].batchVerifySerial(
setsPtr.toOpenArray(0, numSets - 1), secureRandomBytes[])
else:
return cache.batchVerifySerial(
setsPtr.toOpenArray(0, numSets - 1), secureRandomBytes[])

proc batchVerify*(
tp: Taskpool,
Expand Down Expand Up @@ -379,5 +452,5 @@ when compileOption("threads"):
## against the same `secureRandomBytes`

# Don't {.noinit.} this or seq capacity will be != 0.
var batcher: BatchedBLSVerifierCache
var batcher = BatchedBLSVerifierCache.init(tp)
return tp.batchVerify(batcher, input, secureRandomBytes)
2 changes: 1 addition & 1 deletion tests/eth2_vectors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ when BLS_BACKEND == BLST and compileOption("threads"):
signatures = seq[Signature].aggFrom(test, "signatures")

var tp = Taskpool.new(numThreads = 4)
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]


Expand Down
16 changes: 8 additions & 8 deletions tests/t_batch_verifier.nim
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ suite "Batch verification":
let (pubkey, seckey) = keyGen(123)
let sig = seckey.sign(msg)

var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

batch.add((pubkey, msg, sig))
Expand All @@ -76,7 +76,7 @@ suite "Batch verification":
tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "Verify 2 (pubkey, message, signature) triplets":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

batch.addExample(1, "msg1")
Expand All @@ -87,7 +87,7 @@ suite "Batch verification":
tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "Verify 2^4 - 1 = 15 (pubkey, message, signature) triplets":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

for i in 0 ..< 15:
Expand All @@ -98,7 +98,7 @@ suite "Batch verification":
tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "Verify 2^4 = 16 (pubkey, message, signature) triplets":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

for i in 0 ..< 16:
Expand All @@ -109,7 +109,7 @@ suite "Batch verification":
tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "Verify 2^4 + 1 = 17 (pubkey, message, signature) triplets":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

for i in 0 ..< 17:
Expand All @@ -127,7 +127,7 @@ suite "Batch verification":

let (pubkey2, seckey2) = keyGen(2)

var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

batch.add((pubkey1, msg1, sig1))
Expand Down Expand Up @@ -204,7 +204,7 @@ suite "Batch forged signatures":
var tp = Taskpool.new(numThreads = 4)

wrappedTest "Single forged pair":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

batch.genForgedPair(1, "msg1", 2, "msg2")
Expand All @@ -214,7 +214,7 @@ suite "Batch forged signatures":
not tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "One forgery among many signatures":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

var rng = initRand(1234)
Expand Down

0 comments on commit 5937eb9

Please sign in to comment.