Skip to content

Commit

Permalink
refactor printErrors
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Sep 30, 2024
1 parent 44fb786 commit 3cd8163
Showing 1 changed file with 34 additions and 58 deletions.
92 changes: 34 additions & 58 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
}

/* ************************************************************************ */
static void printFactor(const std::shared_ptr<Factor> &factor,
const DiscreteValues &assignment,
const KeyFormatter &keyFormatter) {
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
hgf->operator()(assignment)
->print("HybridGaussianFactor, component:", keyFormatter);
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
factor->print("GaussianFactor:\n", keyFormatter);
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
factor->print("DiscreteFactor:\n", keyFormatter);
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (hc->isContinuous()) {
factor->print("GaussianConditional:\n", keyFormatter);
} else if (hc->isDiscrete()) {
factor->print("DiscreteConditional:\n", keyFormatter);
} else {
hc->asHybrid()
->choose(assignment)
->print("HybridConditional, component:\n", keyFormatter);
}
} else {
factor->print("Unknown factor type\n", keyFormatter);
}
}

/* ************************************************************************ */
void HybridGaussianFactorGraph::printErrors(
const HybridValues &values, const std::string &str,
Expand All @@ -83,69 +109,19 @@ void HybridGaussianFactorGraph::printErrors(
&printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;

std::stringstream ss;

for (size_t i = 0; i < factors_.size(); i++) {
auto &&factor = factors_[i];
std::cout << "Factor " << i << ": ";

// Clear the stringstream
ss.str(std::string());

if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
hgf->operator()(values.discrete())->print(ss.str(), keyFormatter);
std::cout << "error = " << factor->error(values) << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
if (hc->isContinuous()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
<< "\n";
} else {
// Is hybrid
auto conditionalComponent =
hc->asHybrid()->operator()(values.discrete());
conditionalComponent->print(ss.str(), keyFormatter);
std::cout << "error = " << conditionalComponent->error(values)
<< "\n";
}
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << df->error(values.discrete()) << std::endl;
}

} else {
if (factor == nullptr) {
std::cout << "Factor " << i << ": nullptr\n";
continue;
}
const double errorValue = factor->error(values);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

// Print the factor
std::cout << "Factor " << i << ", error = " << errorValue << "\n";
printFactor(factor, values.discrete(), keyFormatter);
std::cout << "\n";
}
std::cout.flush();
Expand Down

0 comments on commit 3cd8163

Please sign in to comment.