Skip to content

Commit

Permalink
Apply feedback from ML meeting
Browse files Browse the repository at this point in the history
  • Loading branch information
abrown committed Mar 29, 2024
1 parent 2d2f180 commit b2594a9
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions wit/wasi-nn.wit
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,20 @@ world ml {
import errors;
}

/// Inference is performed on a specific `device`.
interface device {
/// Define where tensors reside and graphs execute.
enum location {

This comment has been minimized.

Copy link
@geekbeast

geekbeast Apr 1, 2024

Contributor

We briefly discussed making this a string as each framework has varying support for different devices. A concrete example is multi-gpu settings, if you have multiple devices of a single type you will want to do things like torch.device(cuda:0).

cpu,
gpu,
tpu
}
}

/// All inputs and outputs to an ML inference are represented as `tensor`s.
interface tensor {
use device.{location};

/// The dimensions of a tensor.
///
/// The array length matches the tensor rank and each element in the array describes the size of
Expand Down Expand Up @@ -44,8 +56,8 @@ interface tensor {
type tensor-data = list<u8>;

resource tensor {
constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data,
location: option<execution-target>);
/// Construct a tensor that lives on the host CPU.
constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data);

// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
// containing a single value, use `[1]` for the tensor dimensions.
Expand All @@ -55,7 +67,7 @@ interface tensor {
ty: func() -> tensor-type;

// Describe where the tensor is currently located (e.g., `cpu`, `gpu`, `tpu`).
location: func() -> execution-target;
location: func() -> location;

This comment has been minimized.

Copy link
@geekbeast

geekbeast Apr 1, 2024

Contributor

there should also be a to(...) function that takes a location and also location should actually just be device


// Return the tensor data. If the tensor is located on a device other than the CPU, this
// operation may result in an expensive data copy operation.
Expand All @@ -74,8 +86,9 @@ interface tensor {
/// framework (e.g., TensorFlow):
interface graph {
use errors.{error};
use tensor.{tensor};
use device.{location};
use inference.{graph-execution-context};
use tensor.{tensor};

/// An execution graph for performing inference (i.e., a model).
resource graph {
Expand All @@ -93,21 +106,15 @@ interface graph {
autodetect,
}

/// Define where the graph should be executed.
enum execution-target {
cpu,
gpu,
tpu
}

/// The graph initialization data.
///
/// This gets bundled up into an array of buffers because implementing backends may encode their
/// graph IR in parts (e.g., OpenVINO stores its IR and weights separately).
type graph-builder = list<u8>;

/// Load a `graph` from an opaque sequence of bytes to use for inference.
load: func(builder: list<graph-builder>, encoding: graph-encoding, target: execution-target) -> result<graph, error>;
/// Load a `graph` from an opaque sequence of bytes to use for inference on the specified device
/// `location`.
load: func(builder: list<graph-builder>, encoding: graph-encoding, location: location) -> result<graph, error>;

/// Load a `graph` by name.
///
Expand All @@ -128,6 +135,11 @@ interface inference {
/// TODO: this may no longer be necessary in WIT
/// (https://github.com/WebAssembly/wasi-nn/issues/43)
resource graph-execution-context {
/// Load a tensor using the graph context. Unlike the `tensor` constructor, this function
/// will co-locate the tensor data on a specific device using the graph's underlying
/// backend; this may avoid some copies, improving performance.
load-tensor: func(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data) -> result<tensor, error>;

/// Define the inputs to use for inference.
set-input: func(name: string, tensor: tensor) -> result<_, error>;

Expand Down

0 comments on commit b2594a9

Please sign in to comment.