Skip to content

Commit

Permalink
Add support for compute graph I/O via protobuf
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Oct 1, 2023
1 parent 6ed27d7 commit c413d98
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 3 deletions.
3 changes: 3 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
module(name = "compute-graph-autodiff", version = "0.1")

bazel_dep(name = "abseil-cpp", version = "20230802.0")
bazel_dep(name = "fmt", version = "10.1.1")
bazel_dep(name = "googletest", version = "1.14.0")
bazel_dep(name = "protobuf", version = "21.7")
bazel_dep(name = "rules_proto", version = "5.3.0-21.7")
16 changes: 16 additions & 0 deletions src/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
load("@rules_proto//proto:defs.bzl", "proto_library")

cc_library(
name = "compute_graph_ad",
srcs = ["graph.cpp"],
hdrs = ["graph.h"],
deps = [
":graph_cc_proto",
"@abseil-cpp//absl/status:status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/container:flat_hash_map",
"@fmt//:fmt"
],
visibility = ["//tests:__pkg__"]
)

cc_proto_library(
name = "graph_cc_proto",
deps = [":graph_proto"],
)

proto_library(
name = "graph_proto",
srcs = ["graph.proto"],
)
179 changes: 179 additions & 0 deletions src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,207 @@ under certain conditions: see LICENSE.
#include "graph.h"

#include <cassert>
#include <fstream>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "fmt/core.h"
#include "src/graph.pb.h"

using namespace compute_graph_ad;
namespace gpb = graph_proto;

namespace {
gpb::Graph* make_proto_graph(gpb::Graph::OpCase optype, void* opptr) {
auto* gproto = new gpb::Graph(); // caller will take ownership

switch (optype) {
case gpb::Graph::OpCase::kSum:
gproto->set_allocated_sum(static_cast<gpb::Sum*>(opptr));
break;
case gpb::Graph::OpCase::kMul:
gproto->set_allocated_mul(static_cast<gpb::Mul*>(opptr));
break;
case gpb::Graph::OpCase::kVar:
gproto->set_allocated_var(static_cast<gpb::Var*>(opptr));
break;
case gpb::Graph::OpCase::kConst:
gproto->set_allocated_const_(static_cast<gpb::Const*>(opptr));
break;
case gpb::Graph::OpCase::OP_NOT_SET:
std::abort(); // TODO and log: this should never happen
break;
}

return gproto;
}

std::unique_ptr<const Op> op_from_proto(const gpb::Graph& gproto) {
std::unique_ptr<const Op> op;

switch (gproto.Op_case()) {
case gpb::Graph::OpCase::kSum:
op = Sum::from_proto(gproto.sum());
break;
case gpb::Graph::OpCase::kMul:
op = Mul::from_proto(gproto.mul());
break;
case gpb::Graph::OpCase::kVar:
op = Var::from_proto(gproto.var());
break;
case gpb::Graph::OpCase::kConst:
op = Const::from_proto(gproto.const_());
break;
case gpb::Graph::OpCase::OP_NOT_SET:
std::abort(); // TODO and log: this should never happen
break;
}

return op;
}
} // end of anonymous namespace

float Sum::eval(const Inputs& inputs) const noexcept {
assert(op1 && op2);
return op1->eval(inputs) + op2->eval(inputs);
}

std::pair<gpb::Graph::OpCase, void*> Sum::to_proto() const noexcept {
auto* sum = new gpb::Sum(); // caller will take ownership

// op1
auto [op1type, op1ptr] = op1->to_proto();
gpb::Graph* g1 = make_proto_graph(op1type, op1ptr);
sum->set_allocated_op1(g1);

// op2
auto [op2type, op2ptr] = op2->to_proto();
gpb::Graph* g2 = make_proto_graph(op2type, op2ptr);
sum->set_allocated_op2(g2);

return {gpb::Graph::OpCase::kSum, sum};
}

std::unique_ptr<Sum> Sum::from_proto(const gpb::Sum& sproto) noexcept {
return std::make_unique<Sum>(op_from_proto(sproto.op1()),
op_from_proto(sproto.op2()));
}

float Mul::eval(const Inputs& inputs) const noexcept {
assert(op1 && op2);
return op1->eval(inputs) * op2->eval(inputs);
}

std::pair<gpb::Graph::OpCase, void*> Mul::to_proto() const noexcept {
auto* mul = new gpb::Mul(); // caller will take ownership

// op1
auto [op1type, op1ptr] = op1->to_proto();
gpb::Graph* g1 = make_proto_graph(op1type, op1ptr);
mul->set_allocated_op1(g1);

// op2
auto [op2type, op2ptr] = op2->to_proto();
gpb::Graph* g2 = make_proto_graph(op2type, op2ptr);
mul->set_allocated_op2(g2);

return {gpb::Graph::OpCase::kMul, mul};
}

std::unique_ptr<Mul> Mul::from_proto(const gpb::Mul& mproto) noexcept {
return std::make_unique<Mul>(op_from_proto(mproto.op1()),
op_from_proto(mproto.op2()));
}

float Graph::eval(const Inputs& inputs) const noexcept {
assert(op);
return op->eval(inputs);
}

std::pair<gpb::Graph::OpCase, void*> Const::to_proto() const noexcept {
auto* c = new gpb::Const(); // caller will take ownership
c->set_value(value);
return {gpb::Graph::OpCase::kConst, c};
}

std::unique_ptr<Const> Const::from_proto(const gpb::Const& cproto) noexcept {
return std::make_unique<Const>(cproto.value());
}

float Var::eval(const Inputs& inputs) const noexcept {
auto var_it = inputs.find(name);
if (var_it == inputs.end()) {
std::abort(); // TODO also log an error
}
return var_it->second;
}

std::pair<gpb::Graph::OpCase, void*> Var::to_proto() const noexcept {
auto* var = new gpb::Var(); // caller will take ownership
var->set_name(name);
return {gpb::Graph::OpCase::kVar, var};
}

std::unique_ptr<Var> Var::from_proto(const gpb::Var& vproto) noexcept {
return std::make_unique<Var>(vproto.name());
}

std::unique_ptr<const gpb::Graph> Graph::to_proto() const noexcept {
auto [optype, opptr] = op->to_proto();
std::unique_ptr<const gpb::Graph> gproto(make_proto_graph(optype, opptr));
return gproto;
}

Graph Graph::from_proto(const gpb::Graph& gproto) noexcept {
return Graph(op_from_proto(gproto));
}

absl::Status compute_graph_ad::to_file(const Graph& graph, fs::path path) {
absl::Status ret_status = absl::OkStatus();

GOOGLE_PROTOBUF_VERIFY_VERSION;

const std::unique_ptr<const gpb::Graph> gproto = graph.to_proto();

{
std::ofstream out_file(path);
if (!out_file.good()) {
return absl::InvalidArgumentError(
fmt::format("Could not open file {} for writing.", path.string()));
}
const bool ok = gproto->SerializeToOstream(&out_file);
if (!ok)
ret_status.Update(absl::AbortedError(
fmt::format("Something went wrong while serializing Graph to file {}",
path.string())));
}

google::protobuf::ShutdownProtobufLibrary();

return ret_status;
}

absl::StatusOr<Graph> compute_graph_ad::from_file(fs::path path) {
GOOGLE_PROTOBUF_VERIFY_VERSION;

gpb::Graph gproto;

{
std::ifstream in_file(path);
if (!in_file.good()) {
return absl::InvalidArgumentError(
fmt::format("Could not open file {} for reading.", path.string()));
}

const bool ok = gproto.ParseFromIstream(&in_file);
if (!ok) {
google::protobuf::ShutdownProtobufLibrary();
return absl::AbortedError(
fmt::format("Something went wrong while serializing Graph to file {}",
path.string()));
}
}

absl::StatusOr<Graph> ret = Graph::from_proto(gproto);
return ret;
}
55 changes: 52 additions & 3 deletions src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@ This program comes with ABSOLUTELY NO WARRANTY.
This is free software, and you are welcome to redistribute it
under certain conditions: see LICENSE.
*/

#include <filesystem> // std::path
#include <memory>
#include <string>
#include <string_view>
#include <utility> // std::pair

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "src/graph.pb.h"

/* A note on the use of shared_ptr<const T>
Expand All @@ -35,6 +41,15 @@ class Op {

/// Evaluate this operation on the inputs provided.
virtual float eval(const Inputs& inputs) const noexcept = 0;

/// Retrieve a type-erased protobuf representation of the operation.
/// The first element of the pair is the protobuf enum that specifies
/// what operation has been serialized, the second element is a void
/// pointer to the corresponding protobuf class instance (e.g.
/// graph_proto::Sum).
/// The caller takes ownership of the pointer returned.
virtual std::pair<graph_proto::Graph::OpCase, void*> to_proto()
const noexcept = 0;
};

/// A sum operation, with two operands that can be operations themselves.
Expand All @@ -46,6 +61,11 @@ class Sum : public Op {
: op1(std::move(op1)), op2(std::move(op2)) {}

float eval(const Inputs& inputs) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;

static std::unique_ptr<Sum> from_proto(
const graph_proto::Sum& sproto) noexcept;
};

/// A multiplication operation, with two operands that can be operations
Expand All @@ -58,6 +78,11 @@ class Mul : public Op {
: op1(std::move(op1)), op2(std::move(op2)) {}

float eval(const Inputs& inputs) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;

static std::unique_ptr<Mul> from_proto(
const graph_proto::Mul& mproto) noexcept;
};

/// A compute graph.
Expand All @@ -80,6 +105,12 @@ class Graph {
}

float eval(const Inputs& inputs) const noexcept;

// Serialize this Graph instance into a corresponding protobuf object.
std::unique_ptr<const graph_proto::Graph> to_proto() const noexcept;

// Deserialize a protobuf object into a Graph instance.
static Graph from_proto(const graph_proto::Graph& gproto) noexcept;
};

/// A scalar constant.
Expand All @@ -89,8 +120,6 @@ class Const : public Op {
public:
Const(float value) : value(value) {}

float eval(const Inputs&) const noexcept final { return value; }

/* operator+ */
// using the hidden friend pattern to avoid polluting the global namespace:
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1601r0.pdf
Expand Down Expand Up @@ -121,6 +150,12 @@ class Const : public Op {
}

friend Graph operator*(const Const& c1, const Graph& g2) { return g2 * c1; }

float eval(const Inputs&) const noexcept final { return value; }

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;
static std::unique_ptr<Const> from_proto(
const graph_proto::Const& cproto) noexcept;
};

/// A scalar variable: a named placeholder for an input to `evaluate`.
Expand Down Expand Up @@ -177,5 +212,19 @@ class Var : public Op {
friend Graph operator*(const Var& v1, const Graph& g2) { return g2 * v1; }

float eval(const Inputs& inputs) const noexcept final;

std::pair<graph_proto::Graph::OpCase, void*> to_proto() const noexcept final;

static std::unique_ptr<Var> from_proto(
const graph_proto::Var& vproto) noexcept;
};
} // namespace compute_graph_ad

namespace fs = std::filesystem;

/// Serialize a compute graph to a protobuf file.
absl::Status to_file(const Graph& graph, fs::path path);

/// Deserialize a protobuf file into a Graph instance.
absl::StatusOr<Graph> from_file(fs::path path);

} // namespace compute_graph_ad
37 changes: 37 additions & 0 deletions src/graph.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
cpp-graph-autodiff Copyright (C) 2023 Enrico Guiraud
This program comes with ABSOLUTELY NO WARRANTY.
This is free software, and you are welcome to redistribute it
under certain conditions: see LICENSE.
*/

syntax = "proto2";

package graph_proto;

message Graph {
oneof Op {
Sum sum = 1;
Mul mul = 2;
Var var = 3;
Const const = 4;
}
}

message Sum {
required Graph op1 = 1;
required Graph op2 = 2;
}

message Mul {
required Graph op1 = 1;
required Graph op2 = 2;
}

message Var {
required string name = 1;
}

message Const {
required float value = 1;
}

0 comments on commit c413d98

Please sign in to comment.