diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h index 6f5ff425f1..04e4e0cac3 100644 --- a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -20,6 +20,7 @@ // \callgraph #pragma once +#include #include #include #include @@ -37,117 +38,288 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper { using Clique = typename BayesTree::Clique; using sharedClique = typename BayesTree::sharedClique; - /** Get the additional keys that need reelimination when marginalizing - * the variables in @p marginalizableKeys from the Bayes tree @p bayesTree. + /** + * This function identifies variables that need to be re-eliminated before + * performing marginalization. * - * @param[in] bayesTree The Bayes tree. - * @param[in] marginalizableKeys The keys to be marginalized. + * Re-elimination is necessary for a clique containing marginalizable + * variables if: * + * 1. Some non-marginalizable variables appear before marginalizable ones + * in that clique; + * 2. Or it has a child node depending on a marginalizable variable AND the + * subtree rooted at that child contains non-marginalizables. * - * When marginalizing a variable @f$ \theta @f$ from a Bayes tree, some - * nodes may need reelimination to ensure the variables to marginalize - * be eliminated first. - * - * We should consider two cases: - * - * 1. If a child node relies on @f$ \theta @f$ (i.e., @f$ \theta @f$ - * is a parent / separator of the node), then the frontal - * variables of the child node need to be reeliminated. In - * addition, all the descendants of the child node also need to - * be reeliminated. - * - * 2. If other frontal variables in the same node with @f$ \theta @f$ - * are in front of @f$ \theta @f$ but not to be marginalized, then - * these variables also need to be reeliminated. - * - * These variables were eliminated before @f$ \theta @f$ in the original - * Bayes tree, and after reelimination they will be eliminated after - * @f$ \theta @f$ so that @f$ \theta @f$ can be marginalized safely. + * In addition, the subtrees under the aforementioned cliques that require + * re-elimination, which contain non-marginalizable variables in their root + * node, also need to be re-eliminated. * + * @param[in] bayesTree The Bayes tree + * @param[in] marginalizableKeys Keys to be marginalized + * @return Set of additional keys that need to be re-eliminated */ - static void gatherAdditionalKeysToReEliminate( + static std::set gatherAdditionalKeysToReEliminate( const BayesTree& bayesTree, - const KeyVector& marginalizableKeys, - std::set& additionalKeys) { + const KeyVector& marginalizableKeys) { const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); - std::set marginalizableKeySet(marginalizableKeys.begin(), marginalizableKeys.end()); - std::set checkedCliques; + std::set additionalKeys; + std::set marginalizableKeySet( + marginalizableKeys.begin(), marginalizableKeys.end()); + std::set dependentSubtrees; + CachedSearch cachedSearch; + + // Check each clique that contains a marginalizable key + for (const sharedClique& clique : + getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { + + if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { + // Add frontal variables from current clique + addCliqueToKeySet(clique, &additionalKeys); + + // Then gather dependent subtrees to be added later + gatherDependentSubtrees( + clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch); + } + } + + // Add the remaining dependent cliques + for (const sharedClique& subtree : dependentSubtrees) { + addSubtreeToKeySet(subtree, &additionalKeys); + } - std::set dependentCliques; - for (const Key& key : marginalizableKeySet) { - sharedClique clique = bayesTree[key]; - if (checkedCliques.count(clique)) { - continue; + if (debug) { + std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; + for (const Key& key : additionalKeys) { + std::cout << DefaultKeyFormatter(key) << " "; } - checkedCliques.insert(clique); - - bool need_reeliminate = false; - bool has_non_marginalizable_ahead = false; - for (Key i: clique->conditional()->frontals()) { - if (marginalizableKeySet.count(i)) { - if (has_non_marginalizable_ahead) { - // Case 2 in the docstring - need_reeliminate = true; + std::cout << std::endl; + } + + return additionalKeys; + } + + protected: + + /** + * Gather the cliques containing any of the given keys. + * + * @param[in] bayesTree The Bayes tree + * @param[in] keysOfInterest Set of keys of interest + * @return Set of cliques that contain any of the given keys + */ + static std::set getCliquesContainingKeys( + const BayesTree& bayesTree, + const std::set& keysOfInterest) { + std::set cliques; + for (const Key& key : keysOfInterest) { + cliques.insert(bayesTree[key]); + } + return cliques; + } + + /** + * A struct to cache the results of the below two functions. + */ + struct CachedSearch { + std::unordered_map wholeMarginalizableCliques; + std::unordered_map wholeMarginalizableSubtrees; + }; + + /** + * Check if all variables in the clique are marginalizable. + * + * Note we use a cache map to avoid repeated searches. + */ + static bool isWholeCliqueMarginalizable( + const sharedClique& clique, + const std::set& marginalizableKeys, + CachedSearch* cache) { + auto it = cache->wholeMarginalizableCliques.find(clique.get()); + if (it != cache->wholeMarginalizableCliques.end()) { + return it->second; + } else { + bool ret = true; + for (Key key : clique->conditional()->frontals()) { + if (!marginalizableKeys.count(key)) { + ret = false; + break; + } + } + cache->wholeMarginalizableCliques.insert({clique.get(), ret}); + return ret; + } + } + + /** + * Check if all variables in the subtree are marginalizable. + * + * Note we use a cache map to avoid repeated searches. + */ + static bool isWholeSubtreeMarginalizable( + const sharedClique& subtree, + const std::set& marginalizableKeys, + CachedSearch* cache) { + auto it = cache->wholeMarginalizableSubtrees.find(subtree.get()); + if (it != cache->wholeMarginalizableSubtrees.end()) { + return it->second; + } else { + bool ret = true; + if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { + for (const sharedClique& child : subtree->children) { + if (!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { + ret = false; break; - } else { - // Check whether there's a child node dependent on this key. - for(const sharedClique& child: clique->children) { - if (std::find(child->conditional()->beginParents(), - child->conditional()->endParents(), i) - != child->conditional()->endParents()) { - // Case 1 in the docstring - need_reeliminate = true; - break; - } - } } - } else { - has_non_marginalizable_ahead = true; } + } else { + ret = false; } + cache->wholeMarginalizableSubtrees.insert({subtree.get(), ret}); + return ret; + } + } + + /** + * Check if a clique contains variables that need reelimination due to + * elimination ordering conflicts. + * + * @param[in] clique The clique to check + * @param[in] marginalizableKeys Set of keys to be marginalized + * @return true if any variables in the clique need re-elimination + */ + static bool needsReelimination( + const sharedClique& clique, + const std::set& marginalizableKeys, + CachedSearch* cache) { + bool hasNonMarginalizableAhead = false; + + // Check each frontal variable in order + for (Key key : clique->conditional()->frontals()) { + if (marginalizableKeys.count(key)) { + // If we've seen non-marginalizable variables before this one, + // we need to reeliminate + if (hasNonMarginalizableAhead) { + return true; + } - if (!need_reeliminate) { - // No variable needs to be reeliminated - continue; + // Check if any child depends on this marginalizable key and the + // subtree rooted at that child contains non-marginalizables. + for (const sharedClique& child : clique->children) { + if (hasDependency(child, key) && + !isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { + return true; + } + } } else { - // Need to reeliminate the current clique and all its children - // that rely on a marginalizable key. - for (Key i: clique->conditional()->frontals()) { - additionalKeys.insert(i); - for (const sharedClique& child: clique->children) { - if (!dependentCliques.count(child) && - std::find(child->conditional()->beginParents(), - child->conditional()->endParents(), i) - != child->conditional()->endParents()) { - dependentCliques.insert(child); - } + hasNonMarginalizableAhead = true; + } + } + return false; + } + + /** + * Gather all subtrees that depend on a marginalizable key and contain + * non-marginalizable variables in their root. + * + * @param[in] rootClique The starting clique + * @param[in] marginalizableKeys Set of keys to be marginalized + * @param[out] dependentSubtrees Pointer to set storing dependent cliques + */ + static void gatherDependentSubtrees( + const sharedClique& rootClique, + const std::set& marginalizableKeys, + std::set* dependentSubtrees, + CachedSearch* cache) { + for (Key key : rootClique->conditional()->frontals()) { + if (marginalizableKeys.count(key)) { + // Find children that depend on this key + for (const sharedClique& child : rootClique->children) { + if (!dependentSubtrees->count(child) && + hasDependency(child, key)) { + getSubtreesContainingNonMarginalizables( + child, marginalizableKeys, cache, dependentSubtrees); } } } } + } - // Recursively add the dependent keys - while (!dependentCliques.empty()) { - auto begin = dependentCliques.begin(); - sharedClique clique = *begin; - dependentCliques.erase(begin); + /** + * Gather all subtrees that contain non-marginalizable variables in its root. + */ + static void getSubtreesContainingNonMarginalizables( + const sharedClique& rootClique, + const std::set& marginalizableKeys, + CachedSearch* cache, + std::set* subtreesContainingNonMarginalizables) { + // If the root clique itself contains non-marginalizable variables, we + // just add it to subtreesContainingNonMarginalizables; + if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) { + subtreesContainingNonMarginalizables->insert(rootClique); + return; + } - for (Key key : clique->conditional()->frontals()) { - additionalKeys.insert(key); - } + // Otherwise, we need to recursively check the children + for (const sharedClique& child : rootClique->children) { + getSubtreesContainingNonMarginalizables( + child, marginalizableKeys, cache, + subtreesContainingNonMarginalizables); + } + } - for (const sharedClique& child: clique->children) { - dependentCliques.insert(child); - } + /** + * Add all frontal variables from a clique to a key set. + * + * @param[in] clique Clique to add keys from + * @param[out] additionalKeys Pointer to the output key set + */ + static void addCliqueToKeySet( + const sharedClique& clique, + std::set* additionalKeys) { + for (Key key : clique->conditional()->frontals()) { + additionalKeys->insert(key); } + } - if (debug) { - std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; - for (const Key& key : additionalKeys) { - std::cout << DefaultKeyFormatter(key) << " "; + /** + * Add all frontal variables from a subtree to a key set. + * + * @param[in] subRoot Root clique of the subtree + * @param[out] additionalKeys Pointer to the output key set + */ + static void addSubtreeToKeySet( + const sharedClique& subRoot, + std::set* additionalKeys) { + std::set cliques; + cliques.insert(subRoot); + while(!cliques.empty()) { + auto begin = cliques.begin(); + sharedClique clique = *begin; + cliques.erase(begin); + addCliqueToKeySet(clique, additionalKeys); + for (const sharedClique& child : clique->children) { + cliques.insert(child); } - std::cout << std::endl; + } + } + + /** + * Check if the clique depends on the given key. + * + * @param[in] clique Clique to check + * @param[in] key Key to check for dependencies + * @return true if clique depends on the key + */ + static bool hasDependency( + const sharedClique& clique, Key key) { + auto conditional = clique->conditional(); + if (std::find(conditional->beginParents(), + conditional->endParents(), key) + != conditional->endParents()) { + return true; + } else { + return false; } } }; diff --git a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp index 9d27f5713d..afe4fb3de1 100644 --- a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp +++ b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp @@ -120,8 +120,8 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update( } // Mark additional keys between the marginalized keys and the leaves - std::set additionalKeys; #ifdef GTSAM_OLD_MARGINALIZATION + std::set additionalKeys; for(Key key: marginalizableKeys) { ISAM2Clique::shared_ptr clique = isam_[key]; for(const ISAM2Clique::shared_ptr& child: clique->children) { @@ -129,8 +129,9 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update( } } #else - BayesTreeMarginalizationHelper::gatherAdditionalKeysToReEliminate( - isam_, marginalizableKeys, additionalKeys); + std::set additionalKeys = + BayesTreeMarginalizationHelper::gatherAdditionalKeysToReEliminate( + isam_, marginalizableKeys); #endif KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end());