From 264988b8685790348193ade82b5a6688ad500cae Mon Sep 17 00:00:00 2001 From: Enrico Guiraud Date: Sun, 1 Oct 2023 18:29:48 -0600 Subject: [PATCH] Add `eval_grad`: forward mode autodiff for the compute graph --- MODULE.bazel | 1 + src/BUILD.bazel | 3 ++- src/graph.cpp | 68 +++++++++++++++++++++++++++++++++++++++++++++++++ src/graph.h | 28 ++++++++++++++++++-- tests/test.cpp | 48 ++++++++++++++++++++++++++++++++++ 5 files changed, 145 insertions(+), 3 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 630d778..3ff66fe 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -1,6 +1,7 @@ module(name = "compute-graph-autodiff", version = "0.1") bazel_dep(name = "abseil-cpp", version = "20230802.0") +bazel_dep(name = "eigen", version = "3.4.0") bazel_dep(name = "fmt", version = "10.1.1") bazel_dep(name = "googletest", version = "1.14.0") bazel_dep(name = "protobuf", version = "21.7") diff --git a/src/BUILD.bazel b/src/BUILD.bazel index df7ed8c..779b1a0 100644 --- a/src/BUILD.bazel +++ b/src/BUILD.bazel @@ -9,7 +9,8 @@ cc_library( "@abseil-cpp//absl/status:status", "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/container:flat_hash_map", - "@fmt//:fmt" + "@fmt//:fmt", + "@eigen//:eigen" ], visibility = ["//visibility:public"] ) diff --git a/src/graph.cpp b/src/graph.cpp index 4e3cd87..94232a3 100644 --- a/src/graph.cpp +++ b/src/graph.cpp @@ -7,6 +7,7 @@ under certain conditions: see LICENSE. #include "graph.h" #include +#include // std::size_t #include #include "absl/status/status.h" @@ -65,6 +66,17 @@ std::unique_ptr op_from_proto(const gpb::Graph& gproto) { return op; } + +std::size_t find_var_idx(std::string_view name, const Inputs& inputs) { + // TODO this should _not_ be in the hot loop. do it once at the beginning. + std::vector var_names; + var_names.reserve(inputs.size()); + for (auto& [name, value] : inputs) var_names.push_back(name); + std::sort(var_names.begin(), var_names.end()); + + auto it = std::find(var_names.begin(), var_names.end(), name); + return std::distance(var_names.begin(), it); +} } // end of anonymous namespace float Sum::eval(const Inputs& inputs) const noexcept { @@ -72,6 +84,21 @@ float Sum::eval(const Inputs& inputs) const noexcept { return op1->eval(inputs) + op2->eval(inputs); } +float Sum::eval_grad(const Inputs& inputs, float* grad_out) const noexcept { + // we want it row-major so we can give you views over the different rows to + // callees + Eigen::Matrix jacobian( + 2, inputs.size()); + const float value1 = op1->eval_grad(inputs, jacobian.data()); + const float value2 = op2->eval_grad(inputs, jacobian.data() + inputs.size()); + + Eigen::Map grad_out_as_eigen(grad_out, inputs.size()); + // the vector in the vector-jacobian product is just Ones() for a Sum + grad_out_as_eigen = Eigen::RowVector2f::Ones() * jacobian; + + return value1 + value2; +} + std::pair Sum::to_proto() const noexcept { auto* sum = new gpb::Sum(); // caller will take ownership @@ -98,6 +125,22 @@ float Mul::eval(const Inputs& inputs) const noexcept { return op1->eval(inputs) * op2->eval(inputs); } +float Mul::eval_grad(const Inputs& inputs, float* grad_out) const noexcept { + // we want it row-major so we can give you views over the different rows to + // callees + Eigen::Matrix jacobian( + 2, inputs.size()); + const float value1 = op1->eval_grad(inputs, jacobian.data()); + const float value2 = op2->eval_grad(inputs, jacobian.data() + inputs.size()); + + Eigen::Map grad_out_as_eigen(grad_out, inputs.size()); + Eigen::RowVector2f dmul_dops; + dmul_dops << value2, value1; + grad_out_as_eigen = dmul_dops * jacobian; + + return value1 * value2; +} + std::pair Mul::to_proto() const noexcept { auto* mul = new gpb::Mul(); // caller will take ownership @@ -124,6 +167,20 @@ float Graph::eval(const Inputs& inputs) const noexcept { return op->eval(inputs); } +std::pair Graph::eval_grad( + const Inputs& inputs) const noexcept { + assert(op); + + Eigen::RowVectorXf grads(inputs.size()); + return {op->eval_grad(inputs, grads.data()), grads}; +} + +float Const::eval_grad(const Inputs& inputs, float* grad_out) const noexcept { + // derivatives of a constant are all zero + for (std::size_t i = 0; i < inputs.size(); ++i) grad_out[i] = 0; + return value; +} + std::pair Const::to_proto() const noexcept { auto* c = new gpb::Const(); // caller will take ownership c->set_value(value); @@ -142,6 +199,17 @@ float Var::eval(const Inputs& inputs) const noexcept { return var_it->second; } +float Var::eval_grad(const Inputs& inputs, float* grad_out) const noexcept { + // derivatives of a variable w.r.t. all variables is a one-hot vector: + // the only 1. is at the position of the variable itself + for (std::size_t i = 0; i < inputs.size(); ++i) grad_out[i] = 0; + // TODO find a way to move this out of the hot loop + const std::size_t var_idx = find_var_idx(name, inputs); + grad_out[var_idx] = 1.; + + return eval(inputs); +} + std::pair Var::to_proto() const noexcept { auto* var = new gpb::Var(); // caller will take ownership var->set_name(name); diff --git a/src/graph.h b/src/graph.h index e6f8b1f..80c5258 100644 --- a/src/graph.h +++ b/src/graph.h @@ -12,6 +12,7 @@ under certain conditions: see LICENSE. #include #include // std::pair +#include "Eigen/Core" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -43,6 +44,13 @@ class Op { /// Evaluate this operation on the inputs provided. virtual float eval(const Inputs& inputs) const noexcept = 0; + /// Evaluate this operation and its gradient w.r.t. the inputs provided. + /// The `grad_out` argument is a buffer of sufficient size that the caller + /// allocated that should be filled with the values of the gradients. + /// See Graph::eval_grad() for more information. + virtual float eval_grad(const Inputs& inputs, + float* grad_out) const noexcept = 0; + /// Retrieve a type-erased protobuf representation of the operation. /// The first element of the pair is the protobuf enum that specifies /// what operation has been serialized, the second element is a void @@ -63,6 +71,8 @@ class Sum : public Op { float eval(const Inputs& inputs) const noexcept final; + float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final; + std::pair to_proto() const noexcept final; static std::unique_ptr from_proto( @@ -80,6 +90,8 @@ class Mul : public Op { float eval(const Inputs& inputs) const noexcept final; + float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final; + std::pair to_proto() const noexcept final; static std::unique_ptr from_proto( @@ -107,10 +119,18 @@ class Graph { float eval(const Inputs& inputs) const noexcept; - // Serialize this Graph instance into a corresponding protobuf object. + /// Evaluate the graph and its gradient at the given point. + /// The elements of the gradient are the derivative w.r.t. the input variables + /// in alphabetical order. + /// The gradient is evaluated via automatic differentiation (forward mode). + // TODO let users specify w.r.t. which variables to derive. + std::pair eval_grad( + const Inputs& inputs) const noexcept; + + /// Serialize this Graph instance into a corresponding protobuf object. std::unique_ptr to_proto() const noexcept; - // Deserialize a protobuf object into a Graph instance. + /// Deserialize a protobuf object into a Graph instance. static Graph from_proto(const graph_proto::Graph& gproto) noexcept; }; @@ -154,6 +174,8 @@ class Const : public Op { float eval(const Inputs&) const noexcept final { return value; } + float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final; + std::pair to_proto() const noexcept final; static std::unique_ptr from_proto( const graph_proto::Const& cproto) noexcept; @@ -214,6 +236,8 @@ class Var : public Op { float eval(const Inputs& inputs) const noexcept final; + float eval_grad(const Inputs& inputs, float* grad_out) const noexcept final; + std::pair to_proto() const noexcept final; static std::unique_ptr from_proto( diff --git a/tests/test.cpp b/tests/test.cpp index aca7b2b..b6b5383 100644 --- a/tests/test.cpp +++ b/tests/test.cpp @@ -80,3 +80,51 @@ TEST(Tests, WriteAndReadGraph) { const Graph &g2 = gs.value(); EXPECT_FLOAT_EQ(g2.eval(inputs), 42.); } + +TEST(Tests, SumGradient) { + const Var x{"x"}; + const Const c{2.}; + const Graph g = x + c; + const auto &[value, grads] = g.eval_grad({{"x", 3.}}); + EXPECT_FLOAT_EQ(value, 5.); + EXPECT_EQ(grads.size(), 1); + EXPECT_FLOAT_EQ(grads(0), 1.); +} + +TEST(Tests, MulGradient) { + const Var x{"x"}; + const Const c{2.}; + const Graph g = x * c; + const auto &[value, grads] = g.eval_grad({{"x", 3.}}); + EXPECT_FLOAT_EQ(value, 6.); + EXPECT_EQ(grads.size(), 1); + EXPECT_FLOAT_EQ(grads(0), 2.); +} + +TEST(Tests, MultipleVarGradient) { + const Var x{"x"}; + const Var y{"y"}; + const Var z{"z"}; + const Const c{10.}; + + // clang-format off + // g(x,y,z) = yx^3 + xyz + cz(x + y) + c + // dg/dx = 3yx^2 + yz + cz + // dg/dy = x^3 + xz + cz + // dg/dz = xy + cx + cy + const Graph g = x*x*x*y + x*y*z + c*z*(x + y) + c; + // clang-format on + + const Inputs inputs = {{"x", 2.}, {"y", 3.}, {"z", 4.}}; + const auto &[value, grads] = g.eval_grad(inputs); + + // g(2,3,4) = 24 + 24 + 200 + 10 = 258 + EXPECT_FLOAT_EQ(value, 258.); + + // dg/dx(2,3,4) = 36 + 12 + 40 = 88 + EXPECT_FLOAT_EQ(grads(0), 88.); + // dg/dy(2,3,4) = 8 + 8 + 40 = 56 + EXPECT_FLOAT_EQ(grads(1), 56.); + // dg/dz(2,3,4) = 6 + 20 + 30 = 56 + EXPECT_FLOAT_EQ(grads(2), 56.); +}