Skip to content

Commit

Permalink
Pass around Eigen Maps rather than raw pointers
Browse files Browse the repository at this point in the history
Eigen Maps carry size information.
  • Loading branch information
eguiraud committed Oct 2, 2023
1 parent 9912a74 commit ac2610c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 23 deletions.
49 changes: 32 additions & 17 deletions src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,22 @@ float Sum::eval(const Inputs& inputs) const noexcept {
return op1->eval(inputs) + op2->eval(inputs);
}

float Sum::eval_grad(const Inputs& inputs, float* grad_out) const noexcept {
// we want it row-major so we can give you views over the different rows to
// callees
float Sum::eval_grad(const Inputs& inputs,
Eigen::Map<Eigen::RowVectorXf>& grad_out) const noexcept {
// row-major so we can give out views over the different rows to callees
Eigen::Matrix<float, 2, Eigen::Dynamic, Eigen::RowMajorBit> jacobian(
2, inputs.size());
const float value1 = op1->eval_grad(inputs, jacobian.data());
const float value2 = op2->eval_grad(inputs, jacobian.data() + inputs.size());

Eigen::Map<Eigen::RowVectorXf> grad_out_as_eigen(grad_out, inputs.size());
Eigen::Map<Eigen::RowVectorXf> first_row_view(jacobian.data(),
jacobian.cols());
const float value1 = op1->eval_grad(inputs, first_row_view);

Eigen::Map<Eigen::RowVectorXf> second_row_view(
jacobian.data() + jacobian.rowStride(), jacobian.cols());
const float value2 = op2->eval_grad(inputs, second_row_view);

// the vector in the vector-jacobian product is just Ones() for a Sum
grad_out_as_eigen = Eigen::RowVector2f::Ones() * jacobian;
grad_out = Eigen::RowVector2f::Ones() * jacobian;

return value1 + value2;
}
Expand Down Expand Up @@ -127,18 +132,24 @@ float Mul::eval(const Inputs& inputs) const noexcept {
return op1->eval(inputs) * op2->eval(inputs);
}

float Mul::eval_grad(const Inputs& inputs, float* grad_out) const noexcept {
// we want it row-major so we can give you views over the different rows to
// callees
float Mul::eval_grad(const Inputs& inputs,
Eigen::Map<Eigen::RowVectorXf>& grad_out) const noexcept {
// row-major so we can give out views over the different rows to callees
Eigen::Matrix<float, 2, Eigen::Dynamic, Eigen::RowMajorBit> jacobian(
2, inputs.size());
const float value1 = op1->eval_grad(inputs, jacobian.data());
const float value2 = op2->eval_grad(inputs, jacobian.data() + inputs.size());

Eigen::Map<Eigen::RowVectorXf> grad_out_as_eigen(grad_out, inputs.size());
Eigen::Map<Eigen::RowVectorXf> first_row_view(jacobian.data(),
jacobian.cols());
const float value1 = op1->eval_grad(inputs, first_row_view);

Eigen::Map<Eigen::RowVectorXf> second_row_view(
jacobian.data() + jacobian.rowStride(), jacobian.cols());
const float value2 = op2->eval_grad(inputs, second_row_view);

// dMul/dvalue_i for value1*value2 is (value2, value1)
Eigen::RowVector2f dmul_dops;
dmul_dops << value2, value1;
grad_out_as_eigen = dmul_dops * jacobian;
grad_out = dmul_dops * jacobian;

return value1 * value2;
}
Expand Down Expand Up @@ -174,10 +185,13 @@ std::pair<float, Eigen::RowVectorXf> Graph::eval_grad(
assert(op);

Eigen::RowVectorXf grads(inputs.size());
return {op->eval_grad(inputs, grads.data()), grads};
Eigen::Map<Eigen::RowVectorXf> grads_view(grads.data(), grads.size());
return {op->eval_grad(inputs, grads_view), grads};
}

float Const::eval_grad(const Inputs& inputs, float* grad_out) const noexcept {
float Const::eval_grad(
const Inputs& inputs,
Eigen::Map<Eigen::RowVectorXf>& grad_out) const noexcept {
// derivatives of a constant are all zero
for (std::size_t i = 0; i < inputs.size(); ++i) grad_out[i] = 0;
return value;
Expand All @@ -201,7 +215,8 @@ float Var::eval(const Inputs& inputs) const noexcept {
return var_it->second;
}

float Var::eval_grad(const Inputs& inputs, float* grad_out) const noexcept {
float Var::eval_grad(const Inputs& inputs,
Eigen::Map<Eigen::RowVectorXf>& grad_out) const noexcept {
// derivatives of a variable w.r.t. all variables is a one-hot vector:
// the only 1. is at the position of the variable itself
for (std::size_t i = 0; i < inputs.size(); ++i) grad_out[i] = 0;
Expand Down
21 changes: 15 additions & 6 deletions src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ class Op {
/// The `grad_out` argument is a buffer of sufficient size that the caller
/// allocated that should be filled with the values of the gradients.
/// See Graph::eval_grad() for more information.
virtual float eval_grad(const Inputs& inputs,
float* grad_out) const noexcept = 0;
virtual float eval_grad(
const Inputs& inputs,
Eigen::Map<Eigen::RowVectorXf>& grad_out) const noexcept = 0;

/// Retrieve a type-erased protobuf representation of the operation.
/// The first element of the pair is the protobuf enum that specifies
Expand All @@ -71,7 +72,9 @@ class Sum : public Op {

float eval(const Inputs& inputs) const noexcept final;

float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final;
float eval_grad(
const Inputs& inputs,
Eigen::Map<Eigen::RowVectorXf>& grad_out) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;

Expand All @@ -90,7 +93,9 @@ class Mul : public Op {

float eval(const Inputs& inputs) const noexcept final;

float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final;
float eval_grad(
const Inputs& inputs,
Eigen::Map<Eigen::RowVectorXf>& grad_out) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;

Expand Down Expand Up @@ -174,7 +179,9 @@ class Const : public Op {

float eval(const Inputs&) const noexcept final { return value; }

float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final;
float eval_grad(
const Inputs& inputs,
Eigen::Map<Eigen::RowVectorXf>& grad_out) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;
static std::unique_ptr<Const> from_proto(
Expand Down Expand Up @@ -239,7 +246,9 @@ class Var : public Op {

float eval(const Inputs& inputs) const noexcept final;

float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final;
float eval_grad(
const Inputs& inputs,
Eigen::Map<Eigen::RowVectorXf>& grad_out) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;

Expand Down

0 comments on commit ac2610c

Please sign in to comment.