Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid blocking batchVerifyParallel #154

Merged
merged 3 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/bls_signature.nim
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ when BLS_BACKEND == BLST:
var secureBlindingBytes: array[32, byte]
secureBlindingBytes.bls_sha256_digest("Mr F was here")

var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init()

bench("Serial batch verify " & $numSigs & " msgs by "& $numSigs & " pubkeys (with blinding)", iters):
secureBlindingBytes.bls_sha256_digest(secureBlindingBytes)
Expand All @@ -223,7 +223,7 @@ when BLS_BACKEND == BLST:
hashedMsg.bls_sha256_digest("msg" & $i)
batch.add((pk, hashedMsg, sk.sign(hashedMsg)))

var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var secureBlindingBytes: array[32, byte]
secureBlindingBytes.bls_sha256_digest("Mr F was here")

Expand Down
139 changes: 106 additions & 33 deletions blscurve/bls_batch_verifier.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# those terms.

import
stew/ptrops,
./bls_backend, ./bls_sig_min_pubkey

when compileOption("threads"):
Expand Down Expand Up @@ -43,19 +44,41 @@ type
## `pubkey` and `signature` are assumed to be grouped checked
## which is guaranteed at deserialization from bytes or hex

BatchedBLSVerifierCache* = object
BatchedBLSVerifierCache* {.requiresInit.} = object
## This types hold temporary contexts
## to batch BLS multi signatures (aggregated or individual)
## verification.
## As the contexts are heavy, they can be reused

# Per-batch contexts for multithreaded batch verification
batchContexts: seq[ContextMultiAggregateVerify[DST]]
updateResults: seq[tuple[ok: bool, padCacheLine: array[64, byte]]]
when compileOption("threads"):
flows: seq[Flowvar[bool]]

# Serial Batch Verifier
# ----------------------------------------------------------------------

func init*(T: type BatchedBLSVerifierCache): T =
## Initialise the cache for single-threaded usage
when compileOption("threads"):
BatchedBLSVerifierCache(
batchContexts: newSeq[ContextMultiAggregateVerify[DST]](1),
flows: @[]
)
else:
BatchedBLSVerifierCache(
batchContexts: newSeq[ContextMultiAggregateVerify[DST]](1),
)


when compileOption("threads"):
func init*(T: type BatchedBLSVerifierCache, tp: Taskpool): T =
## Initialise the cache for multi-threaded usage
BatchedBLSVerifierCache(
batchContexts: newSeq[ContextMultiAggregateVerify[DST]](tp.numThreads),
flows: newSeq[Flowvar[bool]](tp.numThreads)
)

func batchVerifySerial*(
cache: var BatchedBLSVerifierCache,
input: openArray[SignatureSet],
Expand All @@ -69,11 +92,15 @@ func batchVerifySerial*(
## - One or more of the inputs was invalid on aggregation
## - One or more of the inputs had an invalid signature
## If knowing which input was problematic is required, they must be checked one by one.
##
## The `cache` argument is required to have been initialized via `init`.
## This function performs no memory allocations.
if input.len == 0:
# Spec precondition
return false

cache.batchContexts.setLen(1)
doAssert cache.batchContexts.len >= 1

template ctx: untyped = cache.batchContexts[0]
ctx.init(secureRandomBytes, "")

Expand All @@ -91,7 +118,7 @@ func batchVerifySerial*(
ctx.commit()

# Final exponentiation
return ctx.finalVerify()
ctx.finalVerify()

func batchVerifySerial*(
input: openArray[SignatureSet],
Expand All @@ -107,7 +134,7 @@ func batchVerifySerial*(
## If knowing which input was problematic is required, they must be checked one by one.

# Don't {.noinit.} this or seq capacity will be != 0.
var batcher: BatchedBLSVerifierCache
var batcher = BatchedBLSVerifierCache.init()
return batcher.batchVerifySerial(input, secureRandomBytes)

when compileOption("threads"):
Expand Down Expand Up @@ -229,9 +256,10 @@ when compileOption("threads"):

proc batchVerifyParallel*(
tp: Taskpool,
cache: var BatchedBLSVerifierCache,
input: openArray[SignatureSet],
secureRandomBytes: array[32, byte]
cache: ptr BatchedBLSVerifierCache,
setsPtr: ptr UncheckedArray[SignatureSet],
numSets: int,
secureRandomBytes: ptr array[32, byte]
): bool {.sideEffect.} =
## Multithreaded batch verification
## If multithreaded with -d:openmp requires OpenMP 3.0 (GCC 4.4, 2008)
Expand All @@ -242,60 +270,51 @@ when compileOption("threads"):
## - One or more of the inputs was invalid on aggregation
## - One or more of the inputs had an invalid signature
## If knowing which input was problematic is required, they must be checked one by one.
let numSets = input.len
if numSets == 0:
# Spec precondition
return false

let numBatches = min(numSets, tp.numThreads)

# Stage 0: Accumulators - setLen for noinit of seq
cache.batchContexts.setLen(numBatches)
cache.updateResults.setLen(numBatches)
doAssert cache[].batchContexts.len >= numBatches

# No GC in a parallel section
# Hence we use raw ptr UncheckedArray instead of seq
let contextsPtr = cache.batchContexts.toPtrUncheckedArray()
let setsPtr = input.toPtrUncheckedArray()
let updateResultsPtr = cache.updateResults.toPtrUncheckedArray()
let contextsPtr = cache[].batchContexts.toPtrUncheckedArray()

# Stage 1: Accumulate partial pairings
proc processSingleChunk(
contextsPtr: ptr UncheckedArray[ContextMultiAggregateVerify[DST]],
setsPtr: ptr UncheckedArray[SignatureSet],
updateResultsPtr: ptr UncheckedArray[tuple[ok: bool, padCacheLine: array[64, byte]]],
secureRandomBytes: ptr array[32, byte],
chunkID: int,
chunkStart, chunkLen: int) {.gcsafe, nimcall.}=
chunkStart, chunkLen: int): bool {.gcsafe, nimcall.} =

contextsPtr[chunkID].init(
secureRandomBytes[],
threadSepTag = cast[array[sizeof(chunkID), byte]](chunkID)
)

updateResultsPtr[chunkID].ok =
accumPairingLines(
setsPtr, contextsPtr,
chunkID,
chunkStart, (chunkStart+chunkLen)
)

accumPairingLines(
setsPtr, contextsPtr, chunkID,
chunkStart, (chunkStart+chunkLen)
)
var results: seq[Flowvar[bool]]
for chunkID in 0 ..< numBatches:
parallel_chunks(numBatches, numSets, chunkID, chunkStart, chunkLen):
# Partition work into even chunks
# Each thread receives a different start+len to process
# chunkStart and chunkLen are set per-thread by the template

tp.spawn processSingleChunk(
contextsPtr, setsPtr, updateResultsPtr,
secureRandomBytes.unsafeAddr,
results.add(tp.spawn processSingleChunk(
contextsPtr, setsPtr,
secureRandomBytes,
chunkID, chunkStart, chunkLen
)
))

tp.syncAll()

for i in 0 ..< cache.updateResults.len:
if not updateResultsPtr[i].ok:
for res in results.mitems:
if not sync(res):
return false

# Stage 2: Reduce partial pairings
Expand All @@ -312,6 +331,32 @@ when compileOption("threads"):

return cache.batchContexts[0].finalVerify()

proc batchVerifyParallel*(
tp: Taskpool,
cache: var BatchedBLSVerifierCache,
input: openArray[SignatureSet],
secureRandomBytes: array[32, byte]
): bool {.sideEffect.} =
## Multithreaded batch verification
## If multithreaded with -d:openmp requires OpenMP 3.0 (GCC 4.4, 2008)
## This will verify all the inputs (PublicKey, message, Signature) triplets
## at once and return true if verification is successful.
## If unsuccessful:
## - The input was empty
## - One or more of the inputs was invalid on aggregation
## - One or more of the inputs had an invalid signature
## If knowing which input was problematic is required, they must be checked one by one.
##
## If called from multiple threads, each call must have its own
## `BatchedBLSVerifierCache` instance initialised with the
## threadpool-specific initializer.
##
## This function does not allocate garbage-collected memory.

batchVerifyParallel(
tp, addr cache, makeUncheckedArray(baseAddr input), input.len,
unsafeAddr secureRandomBytes)

proc batchVerifyParallel*(
tp: Taskpool,
input: openArray[SignatureSet],
Expand All @@ -328,11 +373,39 @@ when compileOption("threads"):
## If knowing which input was problematic is required, they must be checked one by one.

# Don't {.noinit.} this or seq capacity will be != 0.
var batcher: BatchedBLSVerifierCache
var batcher = BatchedBLSVerifierCache.init(tp)
return tp.batchVerifyParallel(batcher, input, secureRandomBytes)

# Autoselect Batch Verifier
# ----------------------------------------------------------------------
proc batchVerify*(
tp: Taskpool,
cache: ptr BatchedBLSVerifierCache,
setsPtr: ptr UncheckedArray[SignatureSet],
numSets: int,
secureRandomBytes: ptr array[32, byte]
): bool =
## Verify all signatures in batch at once.
## Returns true if all signatures are correct
## Returns false if there is at least one incorrect signature
##
## This requires securely generated random bytes
## for scalar blinding
## to defend against forged signatures that would not
## verify individually but would verify while aggregated.
##
## The blinding scheme also assumes that the attacker cannot
## resubmit 2^64 times forged (publickey, message, signature) triplets
## against the same `secureRandomBytes`
when compileOption("threads"):
if tp.numThreads > 1 and numSets >= 3:
return tp.batchVerifyParallel(cache, setsPtr, numSets, secureRandomBytes)
else:
return cache[].batchVerifySerial(
setsPtr.toOpenArray(0, numSets - 1), secureRandomBytes[])
else:
return cache.batchVerifySerial(
setsPtr.toOpenArray(0, numSets - 1), secureRandomBytes[])

proc batchVerify*(
tp: Taskpool,
Expand Down Expand Up @@ -379,5 +452,5 @@ when compileOption("threads"):
## against the same `secureRandomBytes`

# Don't {.noinit.} this or seq capacity will be != 0.
var batcher: BatchedBLSVerifierCache
var batcher = BatchedBLSVerifierCache.init(tp)
return tp.batchVerify(batcher, input, secureRandomBytes)
2 changes: 1 addition & 1 deletion tests/eth2_vectors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ when BLS_BACKEND == BLST and compileOption("threads"):
signatures = seq[Signature].aggFrom(test, "signatures")

var tp = Taskpool.new(numThreads = 4)
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]


Expand Down
16 changes: 8 additions & 8 deletions tests/t_batch_verifier.nim
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ suite "Batch verification":
let (pubkey, seckey) = keyGen(123)
let sig = seckey.sign(msg)

var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

batch.add((pubkey, msg, sig))
Expand All @@ -76,7 +76,7 @@ suite "Batch verification":
tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "Verify 2 (pubkey, message, signature) triplets":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

batch.addExample(1, "msg1")
Expand All @@ -87,7 +87,7 @@ suite "Batch verification":
tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "Verify 2^4 - 1 = 15 (pubkey, message, signature) triplets":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

for i in 0 ..< 15:
Expand All @@ -98,7 +98,7 @@ suite "Batch verification":
tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "Verify 2^4 = 16 (pubkey, message, signature) triplets":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

for i in 0 ..< 16:
Expand All @@ -109,7 +109,7 @@ suite "Batch verification":
tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "Verify 2^4 + 1 = 17 (pubkey, message, signature) triplets":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

for i in 0 ..< 17:
Expand All @@ -127,7 +127,7 @@ suite "Batch verification":

let (pubkey2, seckey2) = keyGen(2)

var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

batch.add((pubkey1, msg1, sig1))
Expand Down Expand Up @@ -204,7 +204,7 @@ suite "Batch forged signatures":
var tp = Taskpool.new(numThreads = 4)

wrappedTest "Single forged pair":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

batch.genForgedPair(1, "msg1", 2, "msg2")
Expand All @@ -214,7 +214,7 @@ suite "Batch forged signatures":
not tp.batchVerify(batch, fakeRandomBytes)

wrappedTest "One forgery among many signatures":
var cache: BatchedBLSVerifierCache
var cache = BatchedBLSVerifierCache.init(tp)
var batch: seq[SignatureSet]

var rng = initRand(1234)
Expand Down