Skip to content

Commit

Permalink
[CHIA-1087] validate blocks in thread pool (instead of process pool) (#…
Browse files Browse the repository at this point in the history
…18681)

* bump chia_rs dependency and accomodate for changes to SpendBundleConditions and run_block_generator()

* use ThreadPoolExecutor instead of ProcessPoolExecutor in pre_validate_blocks_multiprocessing

* don't serialize BlockRecords when passing into pre_validate_blocks_multiprocessing

* dont' pickle the list of full blocks passed into batch_pre_validate_blocks

* don't serialize conditions when passed to batch_pre_validate_blocks

* don't serialize previous session blocks passed to batch_pre_validate_blocks

* don't serialize the return value from batch_pre_validate_blocks

* make the batch size 1, in pre_validate_blocks_multiprocessing(). With the jobs running in a thread pool, there's no serialization cost we need to amortize over a batch

* make batch_pre_validate_blocks() only validate a single block at a time. rename it to pre_validate_block()

* avoid copying recent_blocks into each validation job. since they run in a thread we can just use the blockchain object directly

* merge loops over the blocks in pre_validate_blocks_multiprocessing(), to simplify the code and to build fewer temporary lists
  • Loading branch information
arvidn authored Oct 15, 2024
1 parent fd4a59a commit d678131
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 192 deletions.
11 changes: 2 additions & 9 deletions chia/consensus/blockchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import logging
import time
import traceback
from concurrent.futures import Executor
from concurrent.futures.process import ProcessPoolExecutor
from concurrent.futures import Executor, ThreadPoolExecutor
from enum import Enum
from multiprocessing.context import BaseContext
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Set, Tuple, cast

Expand Down Expand Up @@ -48,7 +46,6 @@
from chia.util.inline_executor import InlineExecutor
from chia.util.ints import uint16, uint32, uint64, uint128
from chia.util.priority_mutex import PriorityMutex
from chia.util.setproctitle import getproctitle, setproctitle

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -126,7 +123,6 @@ async def create(
consensus_constants: ConsensusConstants,
blockchain_dir: Path,
reserved_cores: int,
multiprocessing_context: Optional[BaseContext] = None,
*,
single_threaded: bool = False,
) -> Blockchain:
Expand All @@ -145,11 +141,8 @@ async def create(
else:
cpu_count = available_logical_cores()
num_workers = max(cpu_count - reserved_cores, 1)
self.pool = ProcessPoolExecutor(
self.pool = ThreadPoolExecutor(
max_workers=num_workers,
mp_context=multiprocessing_context,
initializer=setproctitle,
initargs=(f"{getproctitle()}_block_validation_worker",),
)
log.info(f"Started {num_workers} processes for block validation")

Expand Down
284 changes: 102 additions & 182 deletions chia/consensus/multiprocess_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from chia.consensus.block_header_validation import validate_finished_header_block
from chia.consensus.block_record import BlockRecord
from chia.consensus.blockchain_interface import BlocksProtocol
from chia.consensus.blockchain_interface import BlockRecordsProtocol, BlocksProtocol
from chia.consensus.constants import ConsensusConstants
from chia.consensus.cost_calculator import NPCResult
from chia.consensus.full_block_to_block_record import block_to_block_record
Expand All @@ -30,7 +30,6 @@
from chia.types.unfinished_block import UnfinishedBlock
from chia.types.validation_state import ValidationState
from chia.util.augmented_chain import AugmentedBlockchain
from chia.util.block_cache import BlockCache
from chia.util.condition_tools import pkm_pairs
from chia.util.errors import Err, ValidationError
from chia.util.generator_tools import get_block_header, tx_removals_and_additions
Expand All @@ -50,115 +49,88 @@ class PreValidationResult(Streamable):
timing: uint32 # the time (in milliseconds) it took to pre-validate the block


def batch_pre_validate_blocks(
def pre_validate_block(
constants: ConsensusConstants,
blocks_pickled: Dict[bytes, bytes],
full_blocks_pickled: List[bytes],
prev_transaction_generators: List[Optional[List[bytes]]],
conditions: Dict[uint32, bytes],
expected_difficulty: List[uint64],
expected_sub_slot_iters: List[uint64],
blockchain: BlockRecordsProtocol,
block: FullBlock,
prev_generators: Optional[List[bytes]],
conds: Optional[SpendBundleConditions],
vs: ValidationState,
validate_signatures: bool,
prev_ses_block_bytes: Optional[List[Optional[bytes]]] = None,
) -> List[bytes]:
blocks: Dict[bytes32, BlockRecord] = {}
for k, v in blocks_pickled.items():
blocks[bytes32(k)] = BlockRecord.from_bytes_unchecked(v)
results: List[PreValidationResult] = []
) -> PreValidationResult:

# In this case, we are validating full blocks, not headers
for i in range(len(full_blocks_pickled)):
try:
validation_start = time.monotonic()
block: FullBlock = FullBlock.from_bytes_unchecked(full_blocks_pickled[i])
tx_additions: List[Coin] = []
removals: List[bytes32] = []
conds: Optional[SpendBundleConditions] = None
if block.height in conditions:
conds = SpendBundleConditions.from_bytes(conditions[block.height])
removals, tx_additions = tx_removals_and_additions(conds)
elif block.transactions_generator is not None:
# TODO: this function would be simpler if conditions were
# required to be passed in for all transaction blocks. We would
# no longer need prev_transaction_generators
prev_generators = prev_transaction_generators[i]
assert prev_generators is not None
assert block.transactions_info is not None
block_generator = BlockGenerator(block.transactions_generator, prev_generators)
assert block_generator.program == block.transactions_generator
npc_result = get_name_puzzle_conditions(
block_generator,
min(constants.MAX_BLOCK_COST_CLVM, block.transactions_info.cost),
mempool_mode=False,
height=block.height,
constants=constants,
try:
validation_start = time.monotonic()
tx_additions: List[Coin] = []
removals: List[bytes32] = []
if conds is not None:
removals, tx_additions = tx_removals_and_additions(conds)
elif block.transactions_generator is not None:
# TODO: this function would be simpler if conds was
# required to be passed in for all transaction blocks. We would
# no longer need prev_generators
assert prev_generators is not None
assert block.transactions_info is not None
block_generator = BlockGenerator(block.transactions_generator, prev_generators)
assert block_generator.program == block.transactions_generator
npc_result = get_name_puzzle_conditions(
block_generator,
min(constants.MAX_BLOCK_COST_CLVM, block.transactions_info.cost),
mempool_mode=False,
height=block.height,
constants=constants,
)
if npc_result.error is not None:
validation_time = time.monotonic() - validation_start
return PreValidationResult(
uint16(npc_result.error), None, npc_result.conds, False, uint32(validation_time * 1000)
)
if npc_result.error is not None:
validation_time = time.monotonic() - validation_start
results.append(
PreValidationResult(
uint16(npc_result.error), None, npc_result.conds, False, uint32(validation_time * 1000)
)
)
continue
assert npc_result.conds is not None
conds = npc_result.conds
removals, tx_additions = tx_removals_and_additions(conds)
assert npc_result.conds is not None
conds = npc_result.conds
removals, tx_additions = tx_removals_and_additions(conds)

header_block = get_block_header(block, tx_additions, removals)
prev_ses_block = None
if prev_ses_block_bytes is not None and len(prev_ses_block_bytes) > 0:
buffer = prev_ses_block_bytes[i]
if buffer is not None:
prev_ses_block = BlockRecord.from_bytes_unchecked(buffer)
required_iters, error = validate_finished_header_block(
constants,
BlockCache(blocks),
header_block,
True, # check_filter
expected_difficulty[i],
expected_sub_slot_iters[i],
prev_ses_block=prev_ses_block,
)
error_int: Optional[uint16] = None
if error is not None:
error_int = uint16(error.code.value)
header_block = get_block_header(block, tx_additions, removals)
required_iters, error = validate_finished_header_block(
constants,
blockchain,
header_block,
True, # check_filter
vs.current_difficulty,
vs.current_ssi,
prev_ses_block=vs.prev_ses_block,
)
error_int: Optional[uint16] = None
if error is not None:
error_int = uint16(error.code.value)

successfully_validated_signatures = False
# If we failed header block validation, no need to validate
# signature, the block is already invalid If this is False, it means
# either we don't have a signature (not a tx block) or we have an
# invalid signature (which also puts in an error) or we didn't
# validate the signature because we want to validate it later.
# add_block will attempt to validate the signature later.
if error_int is None and validate_signatures and conds is not None:
assert block.transactions_info is not None
pairs_pks, pairs_msgs = pkm_pairs(conds, constants.AGG_SIG_ME_ADDITIONAL_DATA)
if not AugSchemeMPL.aggregate_verify(
pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature
):
error_int = uint16(Err.BAD_AGGREGATE_SIGNATURE.value)
else:
successfully_validated_signatures = True
successfully_validated_signatures = False
# If we failed header block validation, no need to validate
# signature, the block is already invalid If this is False, it means
# either we don't have a signature (not a tx block) or we have an
# invalid signature (which also puts in an error) or we didn't
# validate the signature because we want to validate it later.
# add_block will attempt to validate the signature later.
if error_int is None and validate_signatures and conds is not None:
assert block.transactions_info is not None
pairs_pks, pairs_msgs = pkm_pairs(conds, constants.AGG_SIG_ME_ADDITIONAL_DATA)
if not AugSchemeMPL.aggregate_verify(pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature):
error_int = uint16(Err.BAD_AGGREGATE_SIGNATURE.value)
else:
successfully_validated_signatures = True

validation_time = time.monotonic() - validation_start
results.append(
PreValidationResult(
error_int,
required_iters,
conds,
successfully_validated_signatures,
uint32(validation_time * 1000),
)
)
except Exception:
error_stack = traceback.format_exc()
log.error(f"Exception: {error_stack}")
validation_time = time.monotonic() - validation_start
results.append(
PreValidationResult(uint16(Err.UNKNOWN.value), None, None, False, uint32(validation_time * 1000))
)
return [bytes(r) for r in results]
validation_time = time.monotonic() - validation_start
return PreValidationResult(
error_int,
required_iters,
conds,
successfully_validated_signatures,
uint32(validation_time * 1000),
)
except Exception:
error_stack = traceback.format_exc()
log.error(f"Exception: {error_stack}")
validation_time = time.monotonic() - validation_start
return PreValidationResult(uint16(Err.UNKNOWN.value), None, None, False, uint32(validation_time * 1000))


async def pre_validate_blocks_multiprocessing(
Expand Down Expand Up @@ -187,49 +159,28 @@ async def pre_validate_blocks_multiprocessing(
"""
prev_b: Optional[BlockRecord] = None

# Collects all the recent blocks (up to the previous sub-epoch)
recent_blocks: Dict[bytes32, BlockRecord] = {}
num_sub_slots_found = 0
num_blocks_seen = 0

if blocks[0].height > 0:
curr = block_records.try_block_record(blocks[0].prev_header_hash)
if curr is None:
return [PreValidationResult(uint16(Err.INVALID_PREV_BLOCK_HASH.value), None, None, False, uint32(0))]
prev_b = curr
num_sub_slots_to_look_for = 3 if curr.overflow else 2
header_hash = curr.header_hash
while (
curr.sub_epoch_summary_included is None
or num_blocks_seen < constants.NUMBER_OF_TIMESTAMPS
or num_sub_slots_found < num_sub_slots_to_look_for
) and curr.height > 0:
if curr.first_in_sub_slot:
assert curr.finished_challenge_slot_hashes is not None
num_sub_slots_found += len(curr.finished_challenge_slot_hashes)
recent_blocks[header_hash] = curr
if curr.is_transaction_block:
num_blocks_seen += 1
header_hash = curr.prev_hash
curr = block_records.block_record(curr.prev_hash)
assert curr is not None
recent_blocks[header_hash] = curr

# the agumented blockchain object will let us add temporary block records
# they won't actually be added to the underlying blockchain object
blockchain = AugmentedBlockchain(block_records)

diff_ssis: List[ValidationState] = []
prev_ses_block_list: List[Optional[BlockRecord]] = []
futures = []
# Pool of workers to validate blocks concurrently

for block in blocks:
assert isinstance(block, FullBlock)
if len(block.finished_sub_slots) > 0:
if block.finished_sub_slots[0].challenge_chain.new_difficulty is not None:
vs.current_difficulty = block.finished_sub_slots[0].challenge_chain.new_difficulty
if block.finished_sub_slots[0].challenge_chain.new_sub_slot_iters is not None:
vs.current_ssi = block.finished_sub_slots[0].challenge_chain.new_sub_slot_iters
overflow = is_overflow_block(constants, block.reward_chain_block.signage_point_index)
challenge = get_block_challenge(constants, block, BlockCache(recent_blocks), prev_b is None, overflow, False)
challenge = get_block_challenge(constants, block, blockchain, prev_b is None, overflow, False)
if block.reward_chain_block.challenge_chain_sp_vdf is None:
cc_sp_hash: bytes32 = challenge
else:
Expand Down Expand Up @@ -267,74 +218,43 @@ async def pre_validate_blocks_multiprocessing(
log.error("sub_epoch_summary does not match wp sub_epoch_summary list")
return [PreValidationResult(uint16(Err.INVALID_SUB_EPOCH_SUMMARY.value), None, None, False, uint32(0))]

recent_blocks[block_rec.header_hash] = block_rec
blockchain.add_extra_block(block, block_rec) # Temporarily add block to chain
prev_b = block_rec
diff_ssis.append(copy.copy(vs))
prev_ses_block_list.append(vs.prev_ses_block)
if block_rec.sub_epoch_summary_included is not None:
vs.prev_ses_block = block_rec

conditions_pickled = {}
for k, v in block_height_conds_map.items():
conditions_pickled[k] = bytes(v)
futures = []
# Pool of workers to validate blocks concurrently
recent_blocks_bytes = {bytes(k): bytes(v) for k, v in recent_blocks.items()} # convert to bytes
previous_generators: Optional[List[bytes]] = None

batch_size = 4
for i in range(0, len(blocks), batch_size):
end_i = min(i + batch_size, len(blocks))
blocks_to_validate = blocks[i:end_i]
b_pickled: List[bytes] = []
previous_generators: List[Optional[List[bytes]]] = []
for block in blocks_to_validate:
assert isinstance(block, FullBlock)
b_pickled.append(bytes(block))
try:
block_generator: Optional[BlockGenerator] = await get_block_generator(
blockchain.lookup_block_generators, block
)
except ValueError:
return [
PreValidationResult(
uint16(Err.FAILED_GETTING_GENERATOR_MULTIPROCESSING.value), None, None, False, uint32(0)
)
]
try:
block_generator: Optional[BlockGenerator] = await get_block_generator(
blockchain.lookup_block_generators, block
)
if block_generator is not None:
previous_generators.append(block_generator.generator_refs)
else:
previous_generators.append(None)

ses_blocks_bytes_list: List[Optional[bytes]] = []
for j in range(i, end_i):
ses_block_rec = prev_ses_block_list[j]
if ses_block_rec is None:
ses_blocks_bytes_list.append(None)
else:
ses_blocks_bytes_list.append(bytes(ses_block_rec))
previous_generators = block_generator.generator_refs
except ValueError:
return [
PreValidationResult(
uint16(Err.FAILED_GETTING_GENERATOR_MULTIPROCESSING.value), None, None, False, uint32(0)
)
]

futures.append(
asyncio.get_running_loop().run_in_executor(
pool,
batch_pre_validate_blocks,
pre_validate_block,
constants,
recent_blocks_bytes,
b_pickled,
blockchain,
block,
previous_generators,
conditions_pickled,
[diff_ssis[j].current_difficulty for j in range(i, end_i)],
[diff_ssis[j].current_ssi for j in range(i, end_i)],
block_height_conds_map.get(block.height),
copy.copy(vs),
validate_signatures,
ses_blocks_bytes_list,
)
)

if block_rec.sub_epoch_summary_included is not None:
vs.prev_ses_block = block_rec

# Collect all results into one flat list
return [
PreValidationResult.from_bytes(result)
for batch_result in (await asyncio.gather(*futures))
for result in batch_result
]
return list(await asyncio.gather(*futures))


def _run_generator(
Expand Down
Loading

0 comments on commit d678131

Please sign in to comment.