diff --git a/benchmarks/bls_signature.nim b/benchmarks/bls_signature.nim index 218a112..f5c1a28 100644 --- a/benchmarks/bls_signature.nim +++ b/benchmarks/bls_signature.nim @@ -203,7 +203,7 @@ when BLS_BACKEND == BLST: var secureBlindingBytes: array[32, byte] secureBlindingBytes.bls_sha256_digest("Mr F was here") - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init() bench("Serial batch verify " & $numSigs & " msgs by "& $numSigs & " pubkeys (with blinding)", iters): secureBlindingBytes.bls_sha256_digest(secureBlindingBytes) @@ -223,7 +223,7 @@ when BLS_BACKEND == BLST: hashedMsg.bls_sha256_digest("msg" & $i) batch.add((pk, hashedMsg, sk.sign(hashedMsg))) - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var secureBlindingBytes: array[32, byte] secureBlindingBytes.bls_sha256_digest("Mr F was here") diff --git a/blscurve/bls_batch_verifier.nim b/blscurve/bls_batch_verifier.nim index b75d881..8d70e47 100644 --- a/blscurve/bls_batch_verifier.nim +++ b/blscurve/bls_batch_verifier.nim @@ -8,6 +8,7 @@ # those terms. import + stew/ptrops, ./bls_backend, ./bls_sig_min_pubkey when compileOption("threads"): @@ -43,7 +44,7 @@ type ## `pubkey` and `signature` are assumed to be grouped checked ## which is guaranteed at deserialization from bytes or hex - BatchedBLSVerifierCache* = object + BatchedBLSVerifierCache* {.requiresInit.} = object ## This types hold temporary contexts ## to batch BLS multi signatures (aggregated or individual) ## verification. @@ -51,11 +52,33 @@ type # Per-batch contexts for multithreaded batch verification batchContexts: seq[ContextMultiAggregateVerify[DST]] - updateResults: seq[tuple[ok: bool, padCacheLine: array[64, byte]]] + when compileOption("threads"): + flows: seq[Flowvar[bool]] # Serial Batch Verifier # ---------------------------------------------------------------------- +func init*(T: type BatchedBLSVerifierCache): T = + ## Initialise the cache for single-threaded usage + when compileOption("threads"): + BatchedBLSVerifierCache( + batchContexts: newSeq[ContextMultiAggregateVerify[DST]](1), + flows: @[] + ) + else: + BatchedBLSVerifierCache( + batchContexts: newSeq[ContextMultiAggregateVerify[DST]](1), + ) + + +when compileOption("threads"): + func init*(T: type BatchedBLSVerifierCache, tp: Taskpool): T = + ## Initialise the cache for multi-threaded usage + BatchedBLSVerifierCache( + batchContexts: newSeq[ContextMultiAggregateVerify[DST]](tp.numThreads), + flows: newSeq[Flowvar[bool]](tp.numThreads) + ) + func batchVerifySerial*( cache: var BatchedBLSVerifierCache, input: openArray[SignatureSet], @@ -69,11 +92,15 @@ func batchVerifySerial*( ## - One or more of the inputs was invalid on aggregation ## - One or more of the inputs had an invalid signature ## If knowing which input was problematic is required, they must be checked one by one. + ## + ## The `cache` argument is required to have been initialized via `init`. + ## This function performs no memory allocations. if input.len == 0: # Spec precondition return false - cache.batchContexts.setLen(1) + doAssert cache.batchContexts.len >= 1 + template ctx: untyped = cache.batchContexts[0] ctx.init(secureRandomBytes, "") @@ -91,7 +118,7 @@ func batchVerifySerial*( ctx.commit() # Final exponentiation - return ctx.finalVerify() + ctx.finalVerify() func batchVerifySerial*( input: openArray[SignatureSet], @@ -107,7 +134,7 @@ func batchVerifySerial*( ## If knowing which input was problematic is required, they must be checked one by one. # Don't {.noinit.} this or seq capacity will be != 0. - var batcher: BatchedBLSVerifierCache + var batcher = BatchedBLSVerifierCache.init() return batcher.batchVerifySerial(input, secureRandomBytes) when compileOption("threads"): @@ -229,9 +256,10 @@ when compileOption("threads"): proc batchVerifyParallel*( tp: Taskpool, - cache: var BatchedBLSVerifierCache, - input: openArray[SignatureSet], - secureRandomBytes: array[32, byte] + cache: ptr BatchedBLSVerifierCache, + setsPtr: ptr UncheckedArray[SignatureSet], + numSets: int, + secureRandomBytes: ptr array[32, byte] ): bool {.sideEffect.} = ## Multithreaded batch verification ## If multithreaded with -d:openmp requires OpenMP 3.0 (GCC 4.4, 2008) @@ -242,7 +270,6 @@ when compileOption("threads"): ## - One or more of the inputs was invalid on aggregation ## - One or more of the inputs had an invalid signature ## If knowing which input was problematic is required, they must be checked one by one. - let numSets = input.len if numSets == 0: # Spec precondition return false @@ -250,52 +277,44 @@ when compileOption("threads"): let numBatches = min(numSets, tp.numThreads) # Stage 0: Accumulators - setLen for noinit of seq - cache.batchContexts.setLen(numBatches) - cache.updateResults.setLen(numBatches) + doAssert cache[].batchContexts.len >= numBatches # No GC in a parallel section # Hence we use raw ptr UncheckedArray instead of seq - let contextsPtr = cache.batchContexts.toPtrUncheckedArray() - let setsPtr = input.toPtrUncheckedArray() - let updateResultsPtr = cache.updateResults.toPtrUncheckedArray() + let contextsPtr = cache[].batchContexts.toPtrUncheckedArray() # Stage 1: Accumulate partial pairings proc processSingleChunk( contextsPtr: ptr UncheckedArray[ContextMultiAggregateVerify[DST]], setsPtr: ptr UncheckedArray[SignatureSet], - updateResultsPtr: ptr UncheckedArray[tuple[ok: bool, padCacheLine: array[64, byte]]], secureRandomBytes: ptr array[32, byte], chunkID: int, - chunkStart, chunkLen: int) {.gcsafe, nimcall.}= + chunkStart, chunkLen: int): bool {.gcsafe, nimcall.} = contextsPtr[chunkID].init( secureRandomBytes[], threadSepTag = cast[array[sizeof(chunkID), byte]](chunkID) ) - updateResultsPtr[chunkID].ok = - accumPairingLines( - setsPtr, contextsPtr, - chunkID, - chunkStart, (chunkStart+chunkLen) - ) - + accumPairingLines( + setsPtr, contextsPtr, chunkID, + chunkStart, (chunkStart+chunkLen) + ) + var results: seq[Flowvar[bool]] for chunkID in 0 ..< numBatches: parallel_chunks(numBatches, numSets, chunkID, chunkStart, chunkLen): # Partition work into even chunks # Each thread receives a different start+len to process # chunkStart and chunkLen are set per-thread by the template - tp.spawn processSingleChunk( - contextsPtr, setsPtr, updateResultsPtr, - secureRandomBytes.unsafeAddr, + results.add(tp.spawn processSingleChunk( + contextsPtr, setsPtr, + secureRandomBytes, chunkID, chunkStart, chunkLen - ) + )) - tp.syncAll() - - for i in 0 ..< cache.updateResults.len: - if not updateResultsPtr[i].ok: + for res in results.mitems: + if not sync(res): return false # Stage 2: Reduce partial pairings @@ -312,6 +331,32 @@ when compileOption("threads"): return cache.batchContexts[0].finalVerify() + proc batchVerifyParallel*( + tp: Taskpool, + cache: var BatchedBLSVerifierCache, + input: openArray[SignatureSet], + secureRandomBytes: array[32, byte] + ): bool {.sideEffect.} = + ## Multithreaded batch verification + ## If multithreaded with -d:openmp requires OpenMP 3.0 (GCC 4.4, 2008) + ## This will verify all the inputs (PublicKey, message, Signature) triplets + ## at once and return true if verification is successful. + ## If unsuccessful: + ## - The input was empty + ## - One or more of the inputs was invalid on aggregation + ## - One or more of the inputs had an invalid signature + ## If knowing which input was problematic is required, they must be checked one by one. + ## + ## If called from multiple threads, each call must have its own + ## `BatchedBLSVerifierCache` instance initialised with the + ## threadpool-specific initializer. + ## + ## This function does not allocate garbage-collected memory. + + batchVerifyParallel( + tp, addr cache, makeUncheckedArray(baseAddr input), input.len, + unsafeAddr secureRandomBytes) + proc batchVerifyParallel*( tp: Taskpool, input: openArray[SignatureSet], @@ -328,11 +373,39 @@ when compileOption("threads"): ## If knowing which input was problematic is required, they must be checked one by one. # Don't {.noinit.} this or seq capacity will be != 0. - var batcher: BatchedBLSVerifierCache + var batcher = BatchedBLSVerifierCache.init(tp) return tp.batchVerifyParallel(batcher, input, secureRandomBytes) # Autoselect Batch Verifier # ---------------------------------------------------------------------- + proc batchVerify*( + tp: Taskpool, + cache: ptr BatchedBLSVerifierCache, + setsPtr: ptr UncheckedArray[SignatureSet], + numSets: int, + secureRandomBytes: ptr array[32, byte] + ): bool = + ## Verify all signatures in batch at once. + ## Returns true if all signatures are correct + ## Returns false if there is at least one incorrect signature + ## + ## This requires securely generated random bytes + ## for scalar blinding + ## to defend against forged signatures that would not + ## verify individually but would verify while aggregated. + ## + ## The blinding scheme also assumes that the attacker cannot + ## resubmit 2^64 times forged (publickey, message, signature) triplets + ## against the same `secureRandomBytes` + when compileOption("threads"): + if tp.numThreads > 1 and numSets >= 3: + return tp.batchVerifyParallel(cache, setsPtr, numSets, secureRandomBytes) + else: + return cache[].batchVerifySerial( + setsPtr.toOpenArray(0, numSets - 1), secureRandomBytes[]) + else: + return cache.batchVerifySerial( + setsPtr.toOpenArray(0, numSets - 1), secureRandomBytes[]) proc batchVerify*( tp: Taskpool, @@ -379,5 +452,5 @@ when compileOption("threads"): ## against the same `secureRandomBytes` # Don't {.noinit.} this or seq capacity will be != 0. - var batcher: BatchedBLSVerifierCache + var batcher = BatchedBLSVerifierCache.init(tp) return tp.batchVerify(batcher, input, secureRandomBytes) diff --git a/tests/eth2_vectors.nim b/tests/eth2_vectors.nim index 8c24127..b3490df 100644 --- a/tests/eth2_vectors.nim +++ b/tests/eth2_vectors.nim @@ -363,7 +363,7 @@ when BLS_BACKEND == BLST and compileOption("threads"): signatures = seq[Signature].aggFrom(test, "signatures") var tp = Taskpool.new(numThreads = 4) - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var batch: seq[SignatureSet] diff --git a/tests/t_batch_verifier.nim b/tests/t_batch_verifier.nim index b9acfd7..9239dcf 100644 --- a/tests/t_batch_verifier.nim +++ b/tests/t_batch_verifier.nim @@ -67,7 +67,7 @@ suite "Batch verification": let (pubkey, seckey) = keyGen(123) let sig = seckey.sign(msg) - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var batch: seq[SignatureSet] batch.add((pubkey, msg, sig)) @@ -76,7 +76,7 @@ suite "Batch verification": tp.batchVerify(batch, fakeRandomBytes) wrappedTest "Verify 2 (pubkey, message, signature) triplets": - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var batch: seq[SignatureSet] batch.addExample(1, "msg1") @@ -87,7 +87,7 @@ suite "Batch verification": tp.batchVerify(batch, fakeRandomBytes) wrappedTest "Verify 2^4 - 1 = 15 (pubkey, message, signature) triplets": - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var batch: seq[SignatureSet] for i in 0 ..< 15: @@ -98,7 +98,7 @@ suite "Batch verification": tp.batchVerify(batch, fakeRandomBytes) wrappedTest "Verify 2^4 = 16 (pubkey, message, signature) triplets": - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var batch: seq[SignatureSet] for i in 0 ..< 16: @@ -109,7 +109,7 @@ suite "Batch verification": tp.batchVerify(batch, fakeRandomBytes) wrappedTest "Verify 2^4 + 1 = 17 (pubkey, message, signature) triplets": - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var batch: seq[SignatureSet] for i in 0 ..< 17: @@ -127,7 +127,7 @@ suite "Batch verification": let (pubkey2, seckey2) = keyGen(2) - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var batch: seq[SignatureSet] batch.add((pubkey1, msg1, sig1)) @@ -204,7 +204,7 @@ suite "Batch forged signatures": var tp = Taskpool.new(numThreads = 4) wrappedTest "Single forged pair": - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var batch: seq[SignatureSet] batch.genForgedPair(1, "msg1", 2, "msg2") @@ -214,7 +214,7 @@ suite "Batch forged signatures": not tp.batchVerify(batch, fakeRandomBytes) wrappedTest "One forgery among many signatures": - var cache: BatchedBLSVerifierCache + var cache = BatchedBLSVerifierCache.init(tp) var batch: seq[SignatureSet] var rng = initRand(1234)