Skip to content

Commit

Permalink
Add ndarray overloads for more dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul committed Jul 20, 2024
1 parent 0a3a8a9 commit 19a078b
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
53 changes: 52 additions & 1 deletion jormungandr/cpp/autodiff/BindVariableBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ void BindVariableBlock(nb::module_& autodiff,
[](VariableBlock<VariableMatrix>& self,
nb::DRef<Eigen::MatrixXd> values) { self.SetValue(values); },
"values"_a, DOC(sleipnir, VariableBlock, SetValue, 2));
cls.def(
"set_value",
[](VariableBlock<VariableMatrix>& self,
nb::DRef<Eigen::MatrixXf> values) {
self.SetValue(values.cast<double>());
},
"values"_a, DOC(sleipnir, VariableBlock, SetValue, 2));
cls.def(
"set_value",
[](VariableBlock<VariableMatrix>& self,
nb::DRef<Eigen::Matrix<int64_t, Eigen::Dynamic, Eigen::Dynamic>>
values) { self.SetValue(values.cast<double>()); },
"values"_a, DOC(sleipnir, VariableBlock, SetValue, 2));
cls.def(
"set_value",
[](VariableBlock<VariableMatrix>& self,
nb::DRef<Eigen::Matrix<int32_t, Eigen::Dynamic, Eigen::Dynamic>>
values) { self.SetValue(values.cast<double>()); },
"values"_a, DOC(sleipnir, VariableBlock, SetValue, 2));
cls.def(
"__setitem__",
[](VariableBlock<VariableMatrix>& self, int row, const Variable& value) {
Expand Down Expand Up @@ -103,6 +122,8 @@ void BindVariableBlock(nb::module_& autodiff,
nb::cast<VariableBlock<VariableMatrix>>(value);
} else if (auto rhs = TryCastToEigen<double>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<float>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<int64_t>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<int32_t>(value)) {
Expand Down Expand Up @@ -233,6 +254,8 @@ void BindVariableBlock(nb::module_& autodiff,
if (ufunc_name == "<ufunc 'matmul'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() * self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() * self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() * self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
Expand All @@ -247,12 +270,16 @@ void BindVariableBlock(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'add'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() + self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() + self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() + self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() + self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self + rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self + rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self + rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -261,12 +288,16 @@ void BindVariableBlock(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'subtract'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() - self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() - self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() - self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() - self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self - rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self - rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self - rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -275,12 +306,16 @@ void BindVariableBlock(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'equal'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() == self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() == self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() == self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() == self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self == rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self == rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self == rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -289,12 +324,16 @@ void BindVariableBlock(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'less'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() < self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() < self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() < self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() < self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self < rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self < rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self < rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -303,12 +342,16 @@ void BindVariableBlock(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'less_equal'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() <= self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() <= self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() <= self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() <= self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self <= rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self <= rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self <= rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -317,12 +360,16 @@ void BindVariableBlock(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'greater'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() > self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() > self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() > self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() > self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self > rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self > rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self > rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -331,12 +378,16 @@ void BindVariableBlock(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'greater_equal'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() >= self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() >= self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() >= self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() >= self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self >= rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self >= rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self >= rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand Down Expand Up @@ -505,6 +556,6 @@ void BindVariableBlock(nb::module_& autodiff,
return CwiseReduce(lhs, rhs, func);
},
"lhs"_a, "rhs"_a, "func"_a, DOC(sleipnir, CwiseReduce));
}
} // NOLINT(readability/fn_size)

} // namespace sleipnir
52 changes: 52 additions & 0 deletions jormungandr/cpp/autodiff/BindVariableMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ void BindVariableMatrix(nb::module_& autodiff,
self.SetValue(values);
},
"values"_a, DOC(sleipnir, VariableMatrix, SetValue));
cls.def(
"set_value",
[](VariableMatrix& self, nb::DRef<Eigen::MatrixXf> values) {
self.SetValue(values.cast<double>());
},
"values"_a, DOC(sleipnir, VariableMatrix, SetValue));
cls.def(
"set_value",
[](VariableMatrix& self,
nb::DRef<Eigen::Matrix<int64_t, Eigen::Dynamic, Eigen::Dynamic>>
values) { self.SetValue(values.cast<double>()); },
"values"_a, DOC(sleipnir, VariableMatrix, SetValue));
cls.def(
"set_value",
[](VariableMatrix& self,
nb::DRef<Eigen::Matrix<int32_t, Eigen::Dynamic, Eigen::Dynamic>>
values) { self.SetValue(values.cast<double>()); },
"values"_a, DOC(sleipnir, VariableMatrix, SetValue));
cls.def(
"__setitem__",
[](VariableMatrix& self, int row, const Variable& value) {
Expand Down Expand Up @@ -94,6 +112,8 @@ void BindVariableMatrix(nb::module_& autodiff,
nb::cast<VariableMatrix>(value);
} else if (auto rhs = TryCastToEigen<double>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<float>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<int64_t>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<int32_t>(value)) {
Expand Down Expand Up @@ -227,12 +247,16 @@ void BindVariableMatrix(nb::module_& autodiff,
if (ufunc_name == "<ufunc 'matmul'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() * self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() * self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() * self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() * self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self * rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self * rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self * rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -241,12 +265,16 @@ void BindVariableMatrix(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'add'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() + self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() + self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() + self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() + self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self + rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self + rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self + rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -255,12 +283,16 @@ void BindVariableMatrix(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'subtract'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() - self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() - self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() - self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() - self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self - rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self - rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self - rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -269,12 +301,16 @@ void BindVariableMatrix(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'equal'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() == self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() == self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() == self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() == self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self == rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self == rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self == rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -283,12 +319,16 @@ void BindVariableMatrix(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'less'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() < self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() < self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() < self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() < self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self < rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self < rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self < rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -297,12 +337,16 @@ void BindVariableMatrix(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'less_equal'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() <= self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() <= self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() <= self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() <= self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self <= rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self <= rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self <= rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -311,12 +355,16 @@ void BindVariableMatrix(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'greater'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() > self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() > self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() > self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() > self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self > rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self > rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self > rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand All @@ -325,12 +373,16 @@ void BindVariableMatrix(nb::module_& autodiff,
} else if (ufunc_name == "<ufunc 'greater_equal'>") {
if (auto lhs = TryCastToEigen<double>(inputs[0])) {
return nb::cast(lhs.value() >= self);
} else if (auto lhs = TryCastToEigen<float>(inputs[0])) {
return nb::cast(lhs.value() >= self);
} else if (auto lhs = TryCastToEigen<int64_t>(inputs[0])) {
return nb::cast(lhs.value() >= self);
} else if (auto lhs = TryCastToEigen<int32_t>(inputs[0])) {
return nb::cast(lhs.value() >= self);
} else if (auto rhs = TryCastToEigen<double>(inputs[1])) {
return nb::cast(self >= rhs.value());
} else if (auto rhs = TryCastToEigen<float>(inputs[1])) {
return nb::cast(self >= rhs.value());
} else if (auto rhs = TryCastToEigen<int64_t>(inputs[1])) {
return nb::cast(self >= rhs.value());
} else if (auto rhs = TryCastToEigen<int32_t>(inputs[1])) {
Expand Down

0 comments on commit 19a078b

Please sign in to comment.