Skip to content

Commit

Permalink
Call compilation and append output to metadata. Add entry point name …
Browse files Browse the repository at this point in the history
…to custom ops and plugin defined custom op id.

PiperOrigin-RevId: 676194738
  • Loading branch information
LukeBoyer authored and tensorflower-gardener committed Sep 19, 2024
1 parent 7da892f commit 0735e02
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 20 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/experimental/lrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cc_binary(
"//tensorflow/compiler/mlir/lite/experimental/lrt/core:api_internal",
"//tensorflow/compiler/mlir/lite/experimental/lrt/core:lite_rt_model_init",
"//tensorflow/compiler/mlir/lite/experimental/lrt/core:model",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:Support",
],
)
74 changes: 63 additions & 11 deletions tensorflow/compiler/mlir/lite/experimental/lrt/apply_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <iostream>
#include <string>
#include <string_view>
#include <vector>

#include "llvm/Support/CommandLine.h"
#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h"
Expand All @@ -39,14 +40,15 @@ static llvm::cl::opt<std::string> model_path(
static llvm::cl::opt<std::string> soc_manufacturer(
"soc_man",
llvm::cl::desc("String identifier of SoC backend (pixel, qcc, darwinn)."),
llvm::cl::init("Example"));
llvm::cl::init("ExampleSocManufacturer"));

// NOLINTNEXTLINE
static llvm::cl::opt<std::string> soc_model(
"soc_model",
llvm::cl::desc("Compilation configuration identifier (chip type)."),
llvm::cl::init("DummyMulOp"));

// TODO swap "dry_run" for optional "don't delete partitioned subgraphs".
// NOLINTNEXTLINE
static llvm::cl::opt<bool> dry_run(
"dry_run",
Expand All @@ -61,6 +63,7 @@ static llvm::cl::opt<bool> dry_run(
}

void DumpSubgraph(const LrtSubgraphT& subgraph, std::string_view label) {
#ifndef NDEBUG
std::cerr << "===== " << label << " =====\n";
for (auto op : subgraph.ops) {
debug::DumpOp(*op);
Expand All @@ -72,6 +75,7 @@ void DumpSubgraph(const LrtSubgraphT& subgraph, std::string_view label) {
for (auto tensor : subgraph.outputs) {
std::cerr << "SG_OUT " << tensor << "\n";
}
#endif
}

bool IsSocModelSupported(LrtCompilerPlugin plugin,
Expand All @@ -91,8 +95,8 @@ bool IsSocModelSupported(LrtCompilerPlugin plugin,

// TODO: b/366821557 - Replace loading pre-compiled plugin.
UniqueLrtCompilerPlugin LoadPlugin() {
if (soc_manufacturer != "Example") {
std::cerr << "Only Example currently supported";
if (soc_manufacturer != LrtPluginSocManufacturer()) {
std::cerr << "Only ExampleSocManufacturer currently supported";
return nullptr;
}

Expand All @@ -115,6 +119,9 @@ UniqueLrtModel LoadModel(std::string_view filename) {
}

LrtStatus ApplyPlugin(LrtModel model, LrtCompilerPlugin plugin) {
LRT_RETURN_STATUS_IF_NOT_OK(
RegisterCustomOpCode(model, LrtPluginSocManufacturer()));

LrtOpListT selected_ops;
LRT_RETURN_STATUS_IF_NOT_OK(
LrtPluginPartitionModel(plugin, model, &selected_ops));
Expand All @@ -124,26 +131,71 @@ LrtStatus ApplyPlugin(LrtModel model, LrtCompilerPlugin plugin) {

// TODO: b/366821557 - Support multiple subgraphs in plugin application.
auto& main_subgraph = model->subgraphs.front();
DumpSubgraph(main_subgraph, "Main subgraph before partioning.");

std::vector<LrtSubgraph> slices;
std::vector<LrtOp> custom_ops;
slices.reserve(partitions.size());
custom_ops.reserve(partitions.size());

for (auto& partition : partitions) {
LrtSubgraph new_subgraph = &model->subgraphs.emplace_back();
algo::GraphSlicer::SlicePartitionFromGraph(main_subgraph, new_subgraph,
partition);

LrtOp custom_op = algo::GraphSlicer::SlicePartitionFromGraph(
main_subgraph, new_subgraph, partition);
custom_ops.push_back(custom_op);
slices.push_back(new_subgraph);

DumpSubgraph(*new_subgraph, "New subgraph");
}

DumpSubgraph(main_subgraph, "Main subgraph");
DumpSubgraph(main_subgraph, "Main subgraph after partioning.");

if (dry_run) {
return StatusOk();
}

LrtCompiledResult compiled_result;
LRT_RETURN_STATUS_IF_NOT_OK(
LrtPluginCompile(plugin, slices.data(), slices.size(), &compiled_result));

lrt_param_index_t num_calls_compiled;
LRT_RETURN_STATUS_IF_NOT_OK(
LrtCompiledResultGetNumCalls(compiled_result, &num_calls_compiled));

if (num_calls_compiled != slices.size()) {
std::cerr
<< "Plugin must provide and entry point for each compiled partition\n";
return StatusCreate(kLrtStatusErrorNotFound);
}

for (int i = 0; i < num_calls_compiled; ++i) {
const void* call_info;
size_t call_info_size;

LRT_RETURN_STATUS_IF_NOT_OK(LrtCompiledResultGetCallInfo(
compiled_result, i, &call_info, &call_info_size));

auto* custom_op = custom_ops.at(i);
custom_op->custom_options.assign(reinterpret_cast<const char*>(call_info),
call_info_size);
}

const void* byte_code;
size_t byte_code_size;

LRT_RETURN_STATUS_IF_NOT_OK(LrtCompiledResultGetByteCode(
compiled_result, &byte_code, &byte_code_size));

LRT_RETURN_STATUS_IF_NOT_OK(AppendMetadata(model, byte_code, byte_code_size,
LrtPluginSocManufacturer()));

return StatusOk();
}

int main(int argc, char** argv) {
llvm::cl::ParseCommandLineOptions(argc, argv);

if (!dry_run) {
std::cerr << "Only dry run currently supported" << "\n";
return 1;
}

auto model = LoadModel(model_path);
EXIT_IF_NULL(model, "Failed to load model");

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ cc_test(
":graph_tools",
":lite_rt_model_init",
"//tensorflow/compiler/mlir/lite/experimental/lrt/test_data:test_data_util",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_googletest//:gtest_main",
"@flatbuffers//:runtime_cc",
"@llvm-project//llvm:Support",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bool HasValidGeneralTopology(LrtSubgraph subgraph) {
}

for (auto tensor : subgraph->outputs) {
if (!implied_subgraph_outs.contains(tensor)) {
if (implied_subgraph_outs.find(tensor) == implied_subgraph_outs.end()) {
_LRT_D_MSG("Mismatched subgraph outs");
return false;
}
Expand All @@ -71,7 +71,7 @@ bool HasValidGeneralTopology(LrtSubgraph subgraph) {
}

for (auto tensor : subgraph->inputs) {
if (!implied_subgraph_ins.contains(tensor)) {
if (implied_subgraph_ins.find(tensor) == implied_subgraph_ins.end()) {
_LRT_D_MSG("Mismatched subgraph ins");
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ LrtStatus SetDefaultOptions(tflite::BuiltinOptionsUnion& opts, LrtOpCode code) {
return StatusOk();
}

void SetCustomOptions(tflite::OperatorT& op, std::string_view options_data) {
const uint8_t* data = reinterpret_cast<const uint8_t*>(options_data.data());
op.custom_options.assign(data, data + options_data.size());
op.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS;
}

//===----------------------------------------------------------------------===//
// Load //
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -287,6 +293,11 @@ LrtStatus ModelUnpacker::Unpack(LrtModel model) {
return StatusOk();
}

LrtStatus RegisterCustomOpCode(LrtModel model, const char* new_op_code) {
model->custom_op_code.assign(new_op_code);
return StatusOk();
}

LrtStatus LoadModel(std::unique_ptr<tflite::ModelT> flatbuffer,
LrtModel* model) {
auto lrt_model = std::make_unique<LrtModelT>();
Expand All @@ -297,6 +308,9 @@ LrtStatus LoadModel(std::unique_ptr<tflite::ModelT> flatbuffer,

lrt_model->flatbuffer_model->subgraphs.clear();

// Set as empty string in case its not set explictly.
LRT_RETURN_STATUS_IF_NOT_OK(RegisterCustomOpCode(lrt_model.get(), ""));

*model = lrt_model.release();

return StatusOk();
Expand Down Expand Up @@ -357,16 +371,19 @@ class ModelRepacker {

void ModelRepacker::BuildOpCodeMap(
LrtModel model, std::unordered_map<LrtOpCode, uint32_t>& map) {
// TODO: b/365299994 - Also add partition/custom op to op code map.
// Add the user set custom code to the flatbuffers known codes.
auto& custom_code = model->flatbuffer_model->operator_codes.emplace_back(
std::make_unique<tflite::OperatorCodeT>());
custom_code->builtin_code = tflite::BuiltinOperator_CUSTOM;
custom_code->custom_code = model->custom_op_code;
custom_code->version = 1;

auto& codes = model->flatbuffer_model->operator_codes;

for (int i = 0; i < codes.size(); ++i) {
const auto tfl_code = codes[i]->builtin_code;
map.insert({static_cast<LrtOpCode>(tfl_code), i});
}
auto& custom_op_code =
codes.emplace_back(std::make_unique<tflite::OperatorCodeT>());
custom_op_code->builtin_code = tflite::BuiltinOperator_CUSTOM;
map.insert({kLrtOpCodeTflCustom, codes.size() - 1});
}

LrtStatus ModelRepacker::SerializeTensor(LrtTensor tensor,
Expand Down Expand Up @@ -403,6 +420,10 @@ LrtStatus ModelRepacker::SerializeOp(
LRT_RETURN_STATUS_IF_NOT_OK_MSG(
SetDefaultOptions(target.builtin_options, op->op_code),
"Failed serializing options");

if (!op->custom_options.empty()) {
SetCustomOptions(target, op->custom_options);
}
// TODO: b/365299994 - Support exotic op fields in serialize.

return StatusOk();
Expand Down Expand Up @@ -474,6 +495,24 @@ LrtStatus ModelRepacker::Repack(LrtModel model) {
return StatusOk();
}

LrtStatus AppendMetadata(LrtModel model, const void* metadata,
size_t metadata_size, const char* metadata_name) {
const auto metadata_buffer_ind = model->flatbuffer_model->buffers.size();

auto& metadata_buffer = model->flatbuffer_model->buffers.emplace_back(
std::make_unique<tflite::BufferT>());
auto raw_metadata = reinterpret_cast<const uint8_t*>(metadata);
metadata_buffer->data.assign(raw_metadata, raw_metadata + metadata_size);
model->flatbuffer_model->metadata_buffer.push_back(metadata_buffer_ind);

auto& fb_metadata = model->flatbuffer_model->metadata.emplace_back(
std::make_unique<tflite::MetadataT>());
fb_metadata->name.assign(metadata_name);
fb_metadata->buffer = metadata_buffer_ind;

return StatusOk();
}

LrtStatus SerializeModel(LrtModel model, uint8_t** buf, size_t* size,
size_t* offset) {
// Destroy model before return.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,19 @@ LrtStatus LoadModelFromFile(const char* path, LrtModel* model);
// Load model from flatbuffer memory.
LrtStatus LoadModel(const uint8_t* buf, size_t buf_size, LrtModel* model);

// Add a new custom code to the registry in this model. This will be associated
// with all custom ops and should only can be set once.
// TODO consider expanding this to allow for "custom op builder" hook.
LrtStatus RegisterCustomOpCode(LrtModel model, const char* new_op_code);

// Destroy model and any associated storage.
void ModelDestroy(LrtModel model);

// Adds given metadata buffer to be serialized with the flatbuffer. Buffer can
// be retrieved at runtime under `metadata_name`.
LrtStatus AppendMetadata(LrtModel model, const void* metadata,
size_t metadata_size, const char* metadata_name);

// Serializes model to bytes. NOTE this destroys the model before it returns.
// NOTE: Caller takes ownership of `buf`. Flatbuffers are packed into their
// arrays back to front, so the valid flatbuffer is buf[offset, size].
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ struct LrtOpT {

LrtOpCode op_code;

// This is a placeholder to be usd by just custom ops for now.
std::string custom_options;

// TODO: b/365299994 - Add support for op options.
};

Expand Down Expand Up @@ -122,6 +125,10 @@ struct LrtModelT {

// Initial flatbuffer loaded in. "Subgraphs" field has been invalidated.
std::unique_ptr<tflite::ModelT> flatbuffer_model;

// Custom code associated with all customs ops emitted during
// re-serialization.
std::string custom_op_code;
};

//
Expand Down
Loading

0 comments on commit 0735e02

Please sign in to comment.