Skip to content

Commit

Permalink
Refactor nanobind cast attempts (#616)
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul authored Jul 27, 2024
1 parent 5e67d8e commit 7e23825
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 60 deletions.
12 changes: 12 additions & 0 deletions jormungandr/cpp/NumPy.hpp → jormungandr/cpp/TryCast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ namespace nb = nanobind;

namespace sleipnir {

/**
* Converts the given nb::object to a C++ type.
*/
template <typename T>
inline std::optional<T> TryCast(const nb::object& obj) {
if (nb::isinstance<T>(obj)) {
return nb::cast<T>(obj);
} else {
return std::nullopt;
}
}

/**
* Converts the given nb::ndarray to an Eigen matrix.
*/
Expand Down
54 changes: 24 additions & 30 deletions jormungandr/cpp/autodiff/BindVariableBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <sleipnir/optimization/Constraints.hpp>

#include "Docstrings.hpp"
#include "NumPy.hpp"
#include "TryCast.hpp"

namespace nb = nanobind;

Expand Down Expand Up @@ -91,9 +91,9 @@ void BindVariableBlock(nb::class_<VariableBlock<VariableMatrix>>& cls) {

// Row slice
const auto& rowElem = slices[0];
if (nb::isinstance<nb::slice>(rowElem)) {
auto rowSlice = nb::cast<nb::slice>(rowElem);
auto [start, stop, step, sliceLength] = rowSlice.compute(self.Rows());
if (auto rowSlice = TryCast<nb::slice>(rowElem)) {
auto [start, stop, step, sliceLength] =
rowSlice.value().compute(self.Rows());
rowOffset = start;
blockRows = stop - start;
} else {
Expand All @@ -103,22 +103,20 @@ void BindVariableBlock(nb::class_<VariableBlock<VariableMatrix>>& cls) {

// Column slice
const auto& colElem = slices[1];
if (nb::isinstance<nb::slice>(colElem)) {
auto colSlice = nb::cast<nb::slice>(colElem);
auto [start, stop, step, sliceLength] = colSlice.compute(self.Cols());
if (auto colSlice = TryCast<nb::slice>(colElem)) {
auto [start, stop, step, sliceLength] =
colSlice.value().compute(self.Cols());
colOffset = start;
blockCols = stop - start;
} else {
colOffset = nb::cast<int>(colElem);
blockCols = 1;
}

if (nb::isinstance<VariableMatrix>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<VariableMatrix>(value);
} else if (nb::isinstance<VariableBlock<VariableMatrix>>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<VariableBlock<VariableMatrix>>(value);
if (auto rhs = TryCast<VariableMatrix>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCast<VariableBlock<VariableMatrix>>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<double>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<float>(value)) {
Expand All @@ -127,15 +125,12 @@ void BindVariableBlock(nb::class_<VariableBlock<VariableMatrix>>& cls) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<int32_t>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (nb::isinstance<Variable>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<Variable>(value);
} else if (nb::isinstance<nb::float_>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<double>(value);
} else if (nb::isinstance<nb::int_>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<int>(value);
} else if (auto rhs = TryCast<Variable>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCast<double>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCast<int>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else {
throw nb::value_error(
"VariableBlock.__setitem__ not implemented for value");
Expand Down Expand Up @@ -163,8 +158,7 @@ void BindVariableBlock(nb::class_<VariableBlock<VariableMatrix>>& cls) {

// If both indices are integers instead of slices, return Variable
// instead of VariableBlock
if (nb::isinstance<nb::int_>(slices[0]) &&
nb::isinstance<nb::int_>(slices[1])) {
if (nb::isinstance<int>(slices[0]) && nb::isinstance<int>(slices[1])) {
int row = nb::cast<int>(slices[0]);
int col = nb::cast<int>(slices[1]);

Expand All @@ -188,9 +182,9 @@ void BindVariableBlock(nb::class_<VariableBlock<VariableMatrix>>& cls) {

// Row slice
const auto& rowElem = slices[0];
if (nb::isinstance<nb::slice>(rowElem)) {
auto rowSlice = nb::cast<nb::slice>(rowElem);
auto [start, stop, step, sliceLength] = rowSlice.compute(self.Rows());
if (auto rowSlice = TryCast<nb::slice>(rowElem)) {
auto [start, stop, step, sliceLength] =
rowSlice.value().compute(self.Rows());
rowOffset = start;
blockRows = stop - start;
} else {
Expand All @@ -203,9 +197,9 @@ void BindVariableBlock(nb::class_<VariableBlock<VariableMatrix>>& cls) {

// Column slice
const auto& colElem = slices[1];
if (nb::isinstance<nb::slice>(colElem)) {
auto colSlice = nb::cast<nb::slice>(colElem);
auto [start, stop, step, sliceLength] = colSlice.compute(self.Cols());
if (auto colSlice = TryCast<nb::slice>(colElem)) {
auto [start, stop, step, sliceLength] =
colSlice.value().compute(self.Cols());
colOffset = start;
blockCols = stop - start;
} else {
Expand Down
54 changes: 24 additions & 30 deletions jormungandr/cpp/autodiff/BindVariableMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <sleipnir/optimization/Constraints.hpp>

#include "Docstrings.hpp"
#include "NumPy.hpp"
#include "TryCast.hpp"

namespace nb = nanobind;

Expand Down Expand Up @@ -82,9 +82,9 @@ void BindVariableMatrix(nb::module_& autodiff,

// Row slice
const auto& rowElem = slices[0];
if (nb::isinstance<nb::slice>(rowElem)) {
auto rowSlice = nb::cast<nb::slice>(rowElem);
auto [start, stop, step, sliceLength] = rowSlice.compute(self.Rows());
if (auto rowSlice = TryCast<nb::slice>(rowElem)) {
auto [start, stop, step, sliceLength] =
rowSlice.value().compute(self.Rows());
rowOffset = start;
blockRows = stop - start;
} else {
Expand All @@ -94,22 +94,20 @@ void BindVariableMatrix(nb::module_& autodiff,

// Column slice
const auto& colElem = slices[1];
if (nb::isinstance<nb::slice>(colElem)) {
auto colSlice = nb::cast<nb::slice>(colElem);
auto [start, stop, step, sliceLength] = colSlice.compute(self.Cols());
if (auto colSlice = TryCast<nb::slice>(colElem)) {
auto [start, stop, step, sliceLength] =
colSlice.value().compute(self.Cols());
colOffset = start;
blockCols = stop - start;
} else {
colOffset = nb::cast<int>(colElem);
blockCols = 1;
}

if (nb::isinstance<VariableMatrix>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<VariableMatrix>(value);
} else if (nb::isinstance<VariableBlock<VariableMatrix>>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<VariableMatrix>(value);
if (auto rhs = TryCast<VariableMatrix>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCast<VariableBlock<VariableMatrix>>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<double>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<float>(value)) {
Expand All @@ -118,15 +116,12 @@ void BindVariableMatrix(nb::module_& autodiff,
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCastToEigen<int32_t>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (nb::isinstance<Variable>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<Variable>(value);
} else if (nb::isinstance<nb::float_>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<double>(value);
} else if (nb::isinstance<nb::int_>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) =
nb::cast<int>(value);
} else if (auto rhs = TryCast<Variable>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCast<double>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else if (auto rhs = TryCast<int>(value)) {
self.Block(rowOffset, colOffset, blockRows, blockCols) = rhs.value();
} else {
throw nb::value_error(
"VariableMatrix.__setitem__ not implemented for value");
Expand Down Expand Up @@ -154,8 +149,7 @@ void BindVariableMatrix(nb::module_& autodiff,

// If both indices are integers instead of slices, return Variable
// instead of VariableBlock
if (nb::isinstance<nb::int_>(slices[0]) &&
nb::isinstance<nb::int_>(slices[1])) {
if (nb::isinstance<int>(slices[0]) && nb::isinstance<int>(slices[1])) {
int row = nb::cast<int>(slices[0]);
int col = nb::cast<int>(slices[1]);

Expand All @@ -179,9 +173,9 @@ void BindVariableMatrix(nb::module_& autodiff,

// Row slice
const auto& rowElem = slices[0];
if (nb::isinstance<nb::slice>(rowElem)) {
auto rowSlice = nb::cast<nb::slice>(rowElem);
auto [start, stop, step, sliceLength] = rowSlice.compute(self.Rows());
if (auto rowSlice = TryCast<nb::slice>(rowElem)) {
auto [start, stop, step, sliceLength] =
rowSlice.value().compute(self.Rows());
rowOffset = start;
blockRows = stop - start;
} else {
Expand All @@ -194,9 +188,9 @@ void BindVariableMatrix(nb::module_& autodiff,

// Column slice
const auto& colElem = slices[1];
if (nb::isinstance<nb::slice>(colElem)) {
auto colSlice = nb::cast<nb::slice>(colElem);
auto [start, stop, step, sliceLength] = colSlice.compute(self.Cols());
if (auto colSlice = TryCast<nb::slice>(colElem)) {
auto [start, stop, step, sliceLength] =
colSlice.value().compute(self.Cols());
colOffset = start;
blockCols = stop - start;
} else {
Expand Down

0 comments on commit 7e23825

Please sign in to comment.