diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 89a724e8e31a..c5c013c7d5c2 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -251,21 +251,21 @@ def _get_auth_chain_ids_using_cover_index_txn( include_given: bool, ) -> Set[str]: """Calculates the auth chain IDs using the chain index.""" - - # Optimize various function calls by side-stepping property look up on dict - # getter and setter calls. + # Optimize various function calls by side-stepping property look up on dict and set + # getter, setter, add, etc calls by taking references. You should find these throughout. # First we look up the chain ID/sequence numbers for the given events. initial_events = set(event_ids) + if include_given: + results = initial_events + else: + results = set() + # All the events that we've found that are reachable from the events. seen_events: Set[str] = set() seen_events_add = seen_events.add - # A map from chain ID to max sequence number of the given events. - event_chains: Dict[int, int] = {} - event_chains_get = event_chains.get - sql_1 = """ SELECT event_id, chain_id, sequence_number FROM event_auth_chains @@ -289,9 +289,15 @@ def _get_auth_chain_ids_using_cover_index_txn( g1 = self._authchain_event_id_to_chain_info.get s1 = self._authchain_event_id_to_chain_info.set + g2 = self._authchain_links_list.get + s2 = self._authchain_links_list.set + + g3 = self._authchain_chain_info_to_event_id.get + s3 = self._authchain_chain_info_to_event_id.set + with Measure( self._clock, - "_get_auth_chain_ids_using_cover_index_txn.section_1_all", + "_get_auth_chain_ids_using_cover_index_txn.all", ): initial_events_copy = set(initial_events) @@ -310,6 +316,24 @@ def _get_auth_chain_ids_using_cover_index_txn( # whittled down as we go and re-evaluated by the while loop after a # complete iteration. for event_id in set(initial_events_copy): + logger.debug(f"JASON: start, looking at event_id {event_id}") + # Right off the bat, we check the compiled cache to see if a loop + # can be avoided. + compiled_cache_entry = self._authchain_final_compilation.get( + event_id + ) + if compiled_cache_entry is not None: + logger.debug(f"JASON: found final compiled entry for event_id {event_id}: {compiled_cache_entry}") + results = results.union(compiled_cache_entry) + initial_events_copy.discard(event_id) + continue + + # Looks like we have to do the work, let's get started. + + # A map from chain ID to max sequence number of the given events. + event_chains: Dict[int, int] = {} + event_chains_get = event_chains.get + section_1_rows: Set[Tuple[str, int, int]] = set() s1_cache_entry = g1(event_id) if s1_cache_entry is not None: @@ -335,194 +359,236 @@ def _get_auth_chain_ids_using_cover_index_txn( ) s1(txn_event_id, (txn_chain_id, txn_sequence_number)) + # Remove the lock, in case we hit the next exception self._authchain_s1_lookup_lock.discard(event_id) + + if event_id not in seen_events: + # Check that we actually have a chain ID for this event. + + # This can happen due to e.g. downgrade/upgrade of + # the server. We raise an exception and fall back to + # the previous algorithm. + logger.info( + "Unexpectedly found an event that doesn't have a " + "chain ID in room %s: %s", + room_id, + event_id, + ) + raise _NoChainCoverIndex(room_id) + # Only toss the event because it was found initial_events_copy.discard(event_id) self._clock.sleep(0) # And now process the data that was retrieved + + # The first spot in the tuple would be the event_id, which we don't + # need here. It was used above to apply to the incremental cache. for _, chain_id, sequence_number in section_1_rows: max_sequence_result = event_chains_get(chain_id, 0) if sequence_number > max_sequence_result: event_chains[chain_id] = sequence_number - # Check that we actually have a chain ID for all the events. - events_missing_chain_info = initial_events.difference(seen_events) - if events_missing_chain_info: - # This can happen due to e.g. downgrade/upgrade of the server. We - # raise an exception and fall back to the previous algorithm. - logger.info( - "Unexpectedly found that events don't have chain IDs in room %s: %s", - room_id, - events_missing_chain_info, - ) - raise _NoChainCoverIndex(room_id) - - # Now we look up all links for the chains we have, adding chains that - # are reachable from any event. - # A map from chain ID to max sequence number *reachable* from any event ID. - chains: Dict[int, int] = {} - chains_get = chains.get - - g2 = self._authchain_links_list.get - s2 = self._authchain_links_list.set - - with Measure( - self._clock, - "_get_auth_chain_ids_using_cover_index_txn.section_2_all", - ): - # 'event_chains' represents the origin chain_id and maximum sequence number - # found from the initial events. From these, retrieve the chain links - while len(event_chains) != 0: - for chain_id, seq_no in dict(event_chains).items(): - section_2_rows = set() - # Add the initial set of chains, excluding the sequence corresponding to - # initial event. But only if the seq_no isn't 0, as then it won't exist. - max_sequence_result = max(seq_no - 1, chains_get(chain_id, 0)) - if max_sequence_result > 0: - chains[chain_id] = max_sequence_result - - s2_cache_entry = g2(chain_id) - # the seq_no above references a specific set of chains to start - # processing at. The cache will contain(if an entry is there at all) all - # chains referenced by origin chain_id. - if s2_cache_entry is not None: - for origin_seq_no, target_set_info in s2_cache_entry.items(): - # Prefilter out origin sequence numbers GREATER than what will - # even be looked at - if origin_seq_no <= seq_no: - for target_chain_id, target_seq_no in target_set_info: - section_2_rows.add( - ( - chain_id, - origin_seq_no, + # Now we look up all links for the chain id we have, adding chains that + # are reachable from this event. + + # A map from chain ID to max sequence number *reachable* from thsi event ID. + chains: Dict[int, int] = {} + chains_get = chains.get + final_results_entry: Set[str] = set() + + # 'event_chains' represents the origin chain_id and maximum + # sequence number found from the initial events. From these, + # retrieve the chain links + while len(event_chains) != 0: + for chain_id, seq_no in dict(event_chains).items(): + section_2_rows = set() + # Add the initial set of chains, excluding the sequence + # corresponding to initial event. But only if the seq_no + # isn't 0, as then it won't exist. + max_sequence_result = max( + seq_no - 1, chains_get(chain_id, 0) + ) + if max_sequence_result > 0: + chains[chain_id] = max_sequence_result + + s2_cache_entry = g2(chain_id) + # the seq_no above references a specific set of chains to + # start processing at. The cache will contain(if an entry + # is there at all) all chains referenced by origin chain_id. + if s2_cache_entry is not None: + for ( + origin_seq_no, + target_set_info, + ) in s2_cache_entry.items(): + # Prefilter out origin sequence numbers GREATER + # than what will even be looked at + if origin_seq_no <= seq_no: + for ( target_chain_id, target_seq_no, - ) - ) - del event_chains[chain_id] - - else: - if chain_id not in self._authchain_s2_lookup_lock: - self._authchain_s2_lookup_lock.add(chain_id) - # Using the origin chain_id here, this will pull all data about this - # origin chain from the database. There is likely to be more - # information here than we need, so we will filter it out below - txn.execute(sql_2, (chain_id,)) - cache_entries: Dict[ - int, Dict[int, Set[Tuple[int, int]]] - ] = {} - for ( - origin_chain_id, - origin_sequence_number, - target_chain_id, - target_sequence_number, - ) in txn: - section_2_rows.add( - ( + ) in target_set_info: + section_2_rows.add( + ( + chain_id, + origin_seq_no, + target_chain_id, + target_seq_no, + ) + ) + del event_chains[chain_id] + + else: + if chain_id not in self._authchain_s2_lookup_lock: + self._authchain_s2_lookup_lock.add(chain_id) + # Using the origin chain_id here, this will pull + # all data about this origin chain from the + # database. There is likely to be more + # information here than we need, so we will + # filter it out below + txn.execute(sql_2, (chain_id,)) + cache_entries: Dict[ + int, Dict[int, Set[Tuple[int, int]]] + ] = {} + for ( origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number, - ) - ) - # Batch up the new cache entries - cache_entries.setdefault( - origin_chain_id, {} - ).setdefault(origin_sequence_number, set()).add( - (target_chain_id, target_sequence_number) - ) - - # By not setting the cache entries into the cache while - # processing above, we avoid multiple cache hits and complicated - # updating brittleness. - for origin_chain_id, cache_entry in cache_entries.items(): - s2(origin_chain_id, cache_entry) - - self._authchain_s2_lookup_lock.discard(chain_id) - del event_chains[chain_id] - self._clock.sleep(0) - - for ( - _, - origin_sequence_number, - target_chain_id, - target_sequence_number, - ) in section_2_rows: - if origin_sequence_number <= seq_no: - # Target chains are only reachable if the origin sequence - # number of the link is less than the max sequence number in - # the origin chain. - - # This is slightly more optimized than using max() - target_seq_max_result = chains_get(target_chain_id, 0) - if target_sequence_number > target_seq_max_result: - chains[target_chain_id] = target_sequence_number + ) in txn: + section_2_rows.add( + ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) + ) + # Batch up the new cache entries + cache_entries.setdefault( + origin_chain_id, {} + ).setdefault(origin_sequence_number, set()).add( + (target_chain_id, target_sequence_number) + ) - # Now for each chain we figure out the maximum sequence number reachable - # from *any* event ID. Events with a sequence less than that are in the - # auth chain. - if include_given: - results = initial_events - else: - results = set() + # By not setting the cache entries into the cache + # while processing above, we avoid multiple cache + # hits and complicated updating brittleness. + for ( + origin_chain_id, + cache_entry, + ) in cache_entries.items(): + s2(origin_chain_id, cache_entry) - g3 = self._authchain_chain_info_to_event_id.get - s3 = self._authchain_chain_info_to_event_id.set + self._authchain_s2_lookup_lock.discard(chain_id) + del event_chains[chain_id] + self._clock.sleep(0) - with Measure( - self._clock, - "_get_auth_chain_ids_using_cover_index_txn.section_3_all", - ): - while len(chains) != 0: - for chain_id, max_sequence_number in dict(chains).items(): - section_3_rows: Set[Tuple[str, int, int]] = set() - # On the spot invalidation - seq_found = False - s3_cache_entry = g3(chain_id) - if s3_cache_entry is not None: - # Reload the entries if the max_sequence_number isn't - # present, which usually means there is newer data that isn't - # loaded. - if max_sequence_number in s3_cache_entry: - seq_found = True for ( - seq_number_key, - result_event_id, - ) in s3_cache_entry.items(): - if seq_number_key <= max_sequence_number: - section_3_rows.add( - (result_event_id, chain_id, seq_number_key) + _, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in section_2_rows: + if origin_sequence_number <= seq_no: + # Target chains are only reachable if the origin sequence + # number of the link is less than the max sequence number in + # the origin chain. + + # This is slightly more optimized than using max() + target_seq_max_result = chains_get( + target_chain_id, 0 ) - # results.add(result_event_id) - del chains[chain_id] - - if not seq_found: - if chain_id not in self._authchain_s3_lookup_lock: - self._authchain_s3_lookup_lock.add(chain_id) - txn.execute(sql_3, (chain_id,)) - for s3_event_id, s3_chain_id, s3_seq_no in txn: - section_3_rows.add( - (s3_event_id, s3_chain_id, s3_seq_no) - ) - - new_s3_cache_entry: Dict[int, str] = {} - for result_event_id, _, sequence_number in section_3_rows: - new_s3_cache_entry.setdefault( - sequence_number, result_event_id - ) - s3(chain_id, new_s3_cache_entry) - self._authchain_s3_lookup_lock.discard(chain_id) - del chains[chain_id] - self._clock.sleep(0) - - for ( - result_event_id, - _, - result_sequence_number, - ) in section_3_rows: - if result_sequence_number <= max_sequence_number: - results.add(result_event_id) + if target_sequence_number > target_seq_max_result: + chains[target_chain_id] = target_sequence_number + + # Now for each chain we figure out the maximum sequence number reachable + # from *this* event ID. Events with a sequence less than that are in the + # auth chain. + + while len(chains) != 0: + for s3_chain_id, max_sequence_number in dict( + chains + ).items(): + section_3_rows: Set[Tuple[str, int, int]] = set() + # On the spot invalidation + seq_found = False + s3_cache_entry = g3(s3_chain_id) + if s3_cache_entry is not None: + # Reload the entries if the max_sequence_number isn't + # present, which usually means there is newer data that isn't + # loaded. + if max_sequence_number in s3_cache_entry: + seq_found = True + for ( + seq_number_key, + result_event_id, + ) in s3_cache_entry.items(): + if ( + seq_number_key + <= max_sequence_number + ): + section_3_rows.add( + ( + result_event_id, + s3_chain_id, + seq_number_key, + ) + ) + del chains[s3_chain_id] + + if not seq_found: + if ( + s3_chain_id + not in self._authchain_s3_lookup_lock + ): + self._authchain_s3_lookup_lock.add( + s3_chain_id + ) + txn.execute(sql_3, (s3_chain_id,)) + for ( + txn_s3_event_id, + txn_s3_chain_id, + txn_s3_seq_no, + ) in txn: + section_3_rows.add( + ( + txn_s3_event_id, + txn_s3_chain_id, + txn_s3_seq_no, + ) + ) + + new_s3_cache_entry: Dict[int, str] = {} + for ( + result_event_id, + _, + sequence_number, + ) in section_3_rows: + new_s3_cache_entry.setdefault( + sequence_number, result_event_id + ) + s3(s3_chain_id, new_s3_cache_entry) + self._authchain_s3_lookup_lock.discard( + s3_chain_id + ) + del chains[s3_chain_id] + self._clock.sleep(0) + + for ( + result_event_id, + _, + result_sequence_number, + ) in section_3_rows: + if ( + result_sequence_number + <= max_sequence_number + ): + results.add(result_event_id) + final_results_entry.add(result_event_id) + self._authchain_final_compilation.set( + event_id, final_results_entry + ) return results