Skip to content

Commit

Permalink
great fun exposing boost::unordered_map for CostStack :D
Browse files Browse the repository at this point in the history
  • Loading branch information
ManifoldFR committed Sep 17, 2024
1 parent ac60bd1 commit 9b415a3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 34 deletions.
77 changes: 44 additions & 33 deletions bindings/python/src/modelling/expose-cost-stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "aligator/modelling/costs/sum-of-costs.hpp"

#include <eigenpy/std-pair.hpp>
#include <eigenpy/std-map.hpp>
#include <eigenpy/variant.hpp>

namespace aligator {
Expand All @@ -23,40 +24,50 @@ void exposeCostStack() {
eigenpy::StdPairConverter<CostItem>::registration();
eigenpy::VariantConverter<CostKey>::registration();

bp::class_<CostStack, bp::bases<CostAbstract>>(
"CostStack", "A weighted sum of other cost functions.", bp::no_init)
.def(bp::init<xyz::polymorphic<Manifold>, const int,
const std::vector<PolyCost> &, const std::vector<Scalar> &>(
("self"_a, "space", "nu", "components"_a = bp::list(),
"weights"_a = bp::list())))
.def(bp::init<const PolyCost &>(("self"_a, "cost")))
// .def_readwrite("components", &CostStack::components_,
// "Components of this cost stack.")
.def(
"addCost",
+[](CostStack &self, const PolyCost &cost, const Scalar weight) {
// return
self.addCost(cost, weight);
},
("self"_a, "cost", "weight"_a = 1.),
bp::return_internal_reference<>())
.def(
"addCost",
+[](CostStack &self, CostKey key, const PolyCost &cost,
const Scalar weight) {
// return
self.addCost(key, cost, weight);
},
("self"_a, "key", "cost", "weight"_a = 1.),
bp::return_internal_reference<>())
.def("size", &CostStack::size, "Get the number of cost components.")
.def(CopyableVisitor<CostStack>())
.def(PolymorphicMultiBaseVisitor<CostAbstract>());
{
bp::scope scope =
bp::class_<CostStack, bp::bases<CostAbstract>>(
"CostStack", "A weighted sum of other cost functions.", bp::no_init)
.def(bp::init<xyz::polymorphic<Manifold>, const int,
const std::vector<PolyCost> &,
const std::vector<Scalar> &>(
("self"_a, "space", "nu", "components"_a = bp::list(),
"weights"_a = bp::list())))
.def(bp::init<const PolyCost &>(("self"_a, "cost")))
.def_readwrite("components", &CostStack::components_,
"Components of this cost stack.")
.def(
"addCost",
+[](CostStack &self, const PolyCost &cost,
const Scalar weight) {
// return
self.addCost(cost, weight);
},
("self"_a, "cost", "weight"_a = 1.),
bp::return_internal_reference<>())
.def(
"addCost",
+[](CostStack &self, CostKey key, const PolyCost &cost,
const Scalar weight) {
// return
self.addCost(key, cost, weight);
},
("self"_a, "key", "cost", "weight"_a = 1.),
bp::return_internal_reference<>())
.def("size", &CostStack::size, "Get the number of cost components.")
.def(CopyableVisitor<CostStack>())
.def(PolymorphicMultiBaseVisitor<CostAbstract>());
eigenpy::GenericMapVisitor<CostMap, true>::expose("CostMap");
}

bp::register_ptr_to_python<shared_ptr<CostStackData>>();
bp::class_<CostStackData, bp::bases<CostData>>(
"CostStackData", "Data struct for CostStack.", bp::no_init)
.def_readonly("sub_cost_data", &CostStackData::sub_cost_data);
{
bp::register_ptr_to_python<shared_ptr<CostStackData>>();
bp::scope scope =
bp::class_<CostStackData, bp::bases<CostData>>(
"CostStackData", "Data struct for CostStack.", bp::no_init)
.def_readonly("sub_cost_data", &CostStackData::sub_cost_data);
eigenpy::GenericMapVisitor<CostStackData::DataMap, true>::expose("DataMap");
}
}

} // namespace python
Expand Down
5 changes: 4 additions & 1 deletion tests/python/test_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,14 @@ def test_stack_error():
rcost = QuadraticCost(Q, R)
cost_stack.addCost(rcost) # optional

cost_stack.components
print(cost_stack.components.todict())

rc2 = QuadraticCost(np.eye(3), np.eye(nu))
rc3 = QuadraticCost(np.eye(nx), np.eye(nu * 2))

cost_data = cost_stack.createData()
print(cost_data.sub_cost_data.tolist())
print(cost_data.sub_cost_data.todict())

with pytest.raises(Exception) as e_info:
cost_stack.addCost(rc2)
Expand Down

0 comments on commit 9b415a3

Please sign in to comment.