diff --git a/jormungandr/cpp/autodiff/BindVariableBlock.cpp b/jormungandr/cpp/autodiff/BindVariableBlock.cpp index cdbb6fd5..fced0a34 100644 --- a/jormungandr/cpp/autodiff/BindVariableBlock.cpp +++ b/jormungandr/cpp/autodiff/BindVariableBlock.cpp @@ -50,6 +50,25 @@ void BindVariableBlock(nb::module_& autodiff, [](VariableBlock& self, nb::DRef values) { self.SetValue(values); }, "values"_a, DOC(sleipnir, VariableBlock, SetValue, 2)); + cls.def( + "set_value", + [](VariableBlock& self, + nb::DRef values) { + self.SetValue(values.cast()); + }, + "values"_a, DOC(sleipnir, VariableBlock, SetValue, 2)); + cls.def( + "set_value", + [](VariableBlock& self, + nb::DRef> + values) { self.SetValue(values.cast()); }, + "values"_a, DOC(sleipnir, VariableBlock, SetValue, 2)); + cls.def( + "set_value", + [](VariableBlock& self, + nb::DRef> + values) { self.SetValue(values.cast()); }, + "values"_a, DOC(sleipnir, VariableBlock, SetValue, 2)); cls.def( "__setitem__", [](VariableBlock& self, int row, const Variable& value) { @@ -103,6 +122,8 @@ void BindVariableBlock(nb::module_& autodiff, nb::cast>(value); } else if (auto rhs = TryCastToEigen(value)) { self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value(); + } else if (auto rhs = TryCastToEigen(value)) { + self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value(); } else if (auto rhs = TryCastToEigen(value)) { self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value(); } else if (auto rhs = TryCastToEigen(value)) { @@ -233,6 +254,8 @@ void BindVariableBlock(nb::module_& autodiff, if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() * self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() * self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() * self); } else if (auto lhs = TryCastToEigen(inputs[0])) { @@ -247,12 +270,16 @@ void BindVariableBlock(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() + self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() + self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() + self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() + self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self + rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self + rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self + rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -261,12 +288,16 @@ void BindVariableBlock(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() - self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() - self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() - self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() - self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self - rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self - rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self - rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -275,12 +306,16 @@ void BindVariableBlock(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() == self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() == self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() == self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() == self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self == rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self == rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self == rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -289,12 +324,16 @@ void BindVariableBlock(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() < self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() < self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() < self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() < self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self < rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self < rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self < rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -303,12 +342,16 @@ void BindVariableBlock(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() <= self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() <= self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() <= self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() <= self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self <= rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self <= rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self <= rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -317,12 +360,16 @@ void BindVariableBlock(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() > self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() > self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() > self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() > self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self > rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self > rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self > rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -331,12 +378,16 @@ void BindVariableBlock(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() >= self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() >= self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() >= self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() >= self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self >= rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self >= rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self >= rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -505,6 +556,6 @@ void BindVariableBlock(nb::module_& autodiff, return CwiseReduce(lhs, rhs, func); }, "lhs"_a, "rhs"_a, "func"_a, DOC(sleipnir, CwiseReduce)); -} +} // NOLINT(readability/fn_size) } // namespace sleipnir diff --git a/jormungandr/cpp/autodiff/BindVariableMatrix.cpp b/jormungandr/cpp/autodiff/BindVariableMatrix.cpp index 42e91ba4..30b74a9e 100644 --- a/jormungandr/cpp/autodiff/BindVariableMatrix.cpp +++ b/jormungandr/cpp/autodiff/BindVariableMatrix.cpp @@ -42,6 +42,24 @@ void BindVariableMatrix(nb::module_& autodiff, self.SetValue(values); }, "values"_a, DOC(sleipnir, VariableMatrix, SetValue)); + cls.def( + "set_value", + [](VariableMatrix& self, nb::DRef values) { + self.SetValue(values.cast()); + }, + "values"_a, DOC(sleipnir, VariableMatrix, SetValue)); + cls.def( + "set_value", + [](VariableMatrix& self, + nb::DRef> + values) { self.SetValue(values.cast()); }, + "values"_a, DOC(sleipnir, VariableMatrix, SetValue)); + cls.def( + "set_value", + [](VariableMatrix& self, + nb::DRef> + values) { self.SetValue(values.cast()); }, + "values"_a, DOC(sleipnir, VariableMatrix, SetValue)); cls.def( "__setitem__", [](VariableMatrix& self, int row, const Variable& value) { @@ -94,6 +112,8 @@ void BindVariableMatrix(nb::module_& autodiff, nb::cast(value); } else if (auto rhs = TryCastToEigen(value)) { self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value(); + } else if (auto rhs = TryCastToEigen(value)) { + self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value(); } else if (auto rhs = TryCastToEigen(value)) { self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value(); } else if (auto rhs = TryCastToEigen(value)) { @@ -227,12 +247,16 @@ void BindVariableMatrix(nb::module_& autodiff, if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() * self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() * self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() * self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() * self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self * rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self * rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self * rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -241,12 +265,16 @@ void BindVariableMatrix(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() + self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() + self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() + self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() + self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self + rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self + rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self + rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -255,12 +283,16 @@ void BindVariableMatrix(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() - self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() - self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() - self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() - self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self - rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self - rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self - rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -269,12 +301,16 @@ void BindVariableMatrix(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() == self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() == self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() == self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() == self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self == rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self == rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self == rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -283,12 +319,16 @@ void BindVariableMatrix(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() < self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() < self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() < self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() < self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self < rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self < rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self < rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -297,12 +337,16 @@ void BindVariableMatrix(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() <= self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() <= self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() <= self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() <= self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self <= rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self <= rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self <= rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -311,12 +355,16 @@ void BindVariableMatrix(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() > self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() > self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() > self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() > self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self > rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self > rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self > rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { @@ -325,12 +373,16 @@ void BindVariableMatrix(nb::module_& autodiff, } else if (ufunc_name == "") { if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() >= self); + } else if (auto lhs = TryCastToEigen(inputs[0])) { + return nb::cast(lhs.value() >= self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() >= self); } else if (auto lhs = TryCastToEigen(inputs[0])) { return nb::cast(lhs.value() >= self); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self >= rhs.value()); + } else if (auto rhs = TryCastToEigen(inputs[1])) { + return nb::cast(self >= rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) { return nb::cast(self >= rhs.value()); } else if (auto rhs = TryCastToEigen(inputs[1])) {