Skip to content

Commit

Permalink
Add eval_grad: forward mode autodiff for the compute graph
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Oct 2, 2023
1 parent 4720b18 commit 264988b
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 3 deletions.
1 change: 1 addition & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module(name = "compute-graph-autodiff", version = "0.1")

bazel_dep(name = "abseil-cpp", version = "20230802.0")
bazel_dep(name = "eigen", version = "3.4.0")
bazel_dep(name = "fmt", version = "10.1.1")
bazel_dep(name = "googletest", version = "1.14.0")
bazel_dep(name = "protobuf", version = "21.7")
Expand Down
3 changes: 2 additions & 1 deletion src/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ cc_library(
"@abseil-cpp//absl/status:status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/container:flat_hash_map",
"@fmt//:fmt"
"@fmt//:fmt",
"@eigen//:eigen"
],
visibility = ["//visibility:public"]
)
Expand Down
68 changes: 68 additions & 0 deletions src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ under certain conditions: see LICENSE.
#include "graph.h"

#include <cassert>
#include <cstddef> // std::size_t
#include <fstream>

#include "absl/status/status.h"
Expand Down Expand Up @@ -65,13 +66,39 @@ std::unique_ptr<const Op> op_from_proto(const gpb::Graph& gproto) {

return op;
}

std::size_t find_var_idx(std::string_view name, const Inputs& inputs) {
// TODO this should _not_ be in the hot loop. do it once at the beginning.
std::vector<std::string> var_names;
var_names.reserve(inputs.size());
for (auto& [name, value] : inputs) var_names.push_back(name);
std::sort(var_names.begin(), var_names.end());

auto it = std::find(var_names.begin(), var_names.end(), name);
return std::distance(var_names.begin(), it);
}
} // end of anonymous namespace

float Sum::eval(const Inputs& inputs) const noexcept {
assert(op1 && op2);
return op1->eval(inputs) + op2->eval(inputs);
}

float Sum::eval_grad(const Inputs& inputs, float* grad_out) const noexcept {
// we want it row-major so we can give you views over the different rows to
// callees
Eigen::Matrix<float, 2, Eigen::Dynamic, Eigen::RowMajorBit> jacobian(
2, inputs.size());
const float value1 = op1->eval_grad(inputs, jacobian.data());
const float value2 = op2->eval_grad(inputs, jacobian.data() + inputs.size());

Eigen::Map<Eigen::RowVectorXf> grad_out_as_eigen(grad_out, inputs.size());
// the vector in the vector-jacobian product is just Ones() for a Sum
grad_out_as_eigen = Eigen::RowVector2f::Ones() * jacobian;

return value1 + value2;
}

std::pair<gpb::Graph::OpCase, void*> Sum::to_proto() const noexcept {
auto* sum = new gpb::Sum(); // caller will take ownership

Expand All @@ -98,6 +125,22 @@ float Mul::eval(const Inputs& inputs) const noexcept {
return op1->eval(inputs) * op2->eval(inputs);
}

float Mul::eval_grad(const Inputs& inputs, float* grad_out) const noexcept {
// we want it row-major so we can give you views over the different rows to
// callees
Eigen::Matrix<float, 2, Eigen::Dynamic, Eigen::RowMajorBit> jacobian(
2, inputs.size());
const float value1 = op1->eval_grad(inputs, jacobian.data());
const float value2 = op2->eval_grad(inputs, jacobian.data() + inputs.size());

Eigen::Map<Eigen::RowVectorXf> grad_out_as_eigen(grad_out, inputs.size());
Eigen::RowVector2f dmul_dops;
dmul_dops << value2, value1;
grad_out_as_eigen = dmul_dops * jacobian;

return value1 * value2;
}

std::pair<gpb::Graph::OpCase, void*> Mul::to_proto() const noexcept {
auto* mul = new gpb::Mul(); // caller will take ownership

Expand All @@ -124,6 +167,20 @@ float Graph::eval(const Inputs& inputs) const noexcept {
return op->eval(inputs);
}

std::pair<float, Eigen::RowVectorXf> Graph::eval_grad(
const Inputs& inputs) const noexcept {
assert(op);

Eigen::RowVectorXf grads(inputs.size());
return {op->eval_grad(inputs, grads.data()), grads};
}

float Const::eval_grad(const Inputs& inputs, float* grad_out) const noexcept {
// derivatives of a constant are all zero
for (std::size_t i = 0; i < inputs.size(); ++i) grad_out[i] = 0;
return value;
}

std::pair<gpb::Graph::OpCase, void*> Const::to_proto() const noexcept {
auto* c = new gpb::Const(); // caller will take ownership
c->set_value(value);
Expand All @@ -142,6 +199,17 @@ float Var::eval(const Inputs& inputs) const noexcept {
return var_it->second;
}

float Var::eval_grad(const Inputs& inputs, float* grad_out) const noexcept {
// derivatives of a variable w.r.t. all variables is a one-hot vector:
// the only 1. is at the position of the variable itself
for (std::size_t i = 0; i < inputs.size(); ++i) grad_out[i] = 0;
// TODO find a way to move this out of the hot loop
const std::size_t var_idx = find_var_idx(name, inputs);
grad_out[var_idx] = 1.;

return eval(inputs);
}

std::pair<gpb::Graph::OpCase, void*> Var::to_proto() const noexcept {
auto* var = new gpb::Var(); // caller will take ownership
var->set_name(name);
Expand Down
28 changes: 26 additions & 2 deletions src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ under certain conditions: see LICENSE.
#include <string_view>
#include <utility> // std::pair

#include "Eigen/Core"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -43,6 +44,13 @@ class Op {
/// Evaluate this operation on the inputs provided.
virtual float eval(const Inputs& inputs) const noexcept = 0;

/// Evaluate this operation and its gradient w.r.t. the inputs provided.
/// The `grad_out` argument is a buffer of sufficient size that the caller
/// allocated that should be filled with the values of the gradients.
/// See Graph::eval_grad() for more information.
virtual float eval_grad(const Inputs& inputs,
float* grad_out) const noexcept = 0;

/// Retrieve a type-erased protobuf representation of the operation.
/// The first element of the pair is the protobuf enum that specifies
/// what operation has been serialized, the second element is a void
Expand All @@ -63,6 +71,8 @@ class Sum : public Op {

float eval(const Inputs& inputs) const noexcept final;

float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;

static std::unique_ptr<Sum> from_proto(
Expand All @@ -80,6 +90,8 @@ class Mul : public Op {

float eval(const Inputs& inputs) const noexcept final;

float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;

static std::unique_ptr<Mul> from_proto(
Expand Down Expand Up @@ -107,10 +119,18 @@ class Graph {

float eval(const Inputs& inputs) const noexcept;

// Serialize this Graph instance into a corresponding protobuf object.
/// Evaluate the graph and its gradient at the given point.
/// The elements of the gradient are the derivative w.r.t. the input variables
/// in alphabetical order.
/// The gradient is evaluated via automatic differentiation (forward mode).
// TODO let users specify w.r.t. which variables to derive.
std::pair<float, Eigen::RowVectorXf> eval_grad(
const Inputs& inputs) const noexcept;

/// Serialize this Graph instance into a corresponding protobuf object.
std::unique_ptr<const graph_proto::Graph> to_proto() const noexcept;

// Deserialize a protobuf object into a Graph instance.
/// Deserialize a protobuf object into a Graph instance.
static Graph from_proto(const graph_proto::Graph& gproto) noexcept;
};

Expand Down Expand Up @@ -154,6 +174,8 @@ class Const : public Op {

float eval(const Inputs&) const noexcept final { return value; }

float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;
static std::unique_ptr<Const> from_proto(
const graph_proto::Const& cproto) noexcept;
Expand Down Expand Up @@ -214,6 +236,8 @@ class Var : public Op {

float eval(const Inputs& inputs) const noexcept final;

float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;

static std::unique_ptr<Var> from_proto(
Expand Down
48 changes: 48 additions & 0 deletions tests/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,51 @@ TEST(Tests, WriteAndReadGraph) {
const Graph &g2 = gs.value();
EXPECT_FLOAT_EQ(g2.eval(inputs), 42.);
}

TEST(Tests, SumGradient) {
const Var x{"x"};
const Const c{2.};
const Graph g = x + c;
const auto &[value, grads] = g.eval_grad({{"x", 3.}});
EXPECT_FLOAT_EQ(value, 5.);
EXPECT_EQ(grads.size(), 1);
EXPECT_FLOAT_EQ(grads(0), 1.);
}

TEST(Tests, MulGradient) {
const Var x{"x"};
const Const c{2.};
const Graph g = x * c;
const auto &[value, grads] = g.eval_grad({{"x", 3.}});
EXPECT_FLOAT_EQ(value, 6.);
EXPECT_EQ(grads.size(), 1);
EXPECT_FLOAT_EQ(grads(0), 2.);
}

TEST(Tests, MultipleVarGradient) {
const Var x{"x"};
const Var y{"y"};
const Var z{"z"};
const Const c{10.};

// clang-format off
// g(x,y,z) = yx^3 + xyz + cz(x + y) + c
// dg/dx = 3yx^2 + yz + cz
// dg/dy = x^3 + xz + cz
// dg/dz = xy + cx + cy
const Graph g = x*x*x*y + x*y*z + c*z*(x + y) + c;
// clang-format on

const Inputs inputs = {{"x", 2.}, {"y", 3.}, {"z", 4.}};
const auto &[value, grads] = g.eval_grad(inputs);

// g(2,3,4) = 24 + 24 + 200 + 10 = 258
EXPECT_FLOAT_EQ(value, 258.);

// dg/dx(2,3,4) = 36 + 12 + 40 = 88
EXPECT_FLOAT_EQ(grads(0), 88.);
// dg/dy(2,3,4) = 8 + 8 + 40 = 56
EXPECT_FLOAT_EQ(grads(1), 56.);
// dg/dz(2,3,4) = 6 + 20 + 30 = 56
EXPECT_FLOAT_EQ(grads(2), 56.);
}

0 comments on commit 264988b

Please sign in to comment.