-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
25 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,8 +15,20 @@ world ml { | |
import errors; | ||
} | ||
|
||
/// Inference is performed on a specific `device`. | ||
interface device { | ||
/// Define where tensors reside and graphs execute. | ||
enum location { | ||
This comment has been minimized.
Sorry, something went wrong. |
||
cpu, | ||
gpu, | ||
tpu | ||
} | ||
} | ||
|
||
/// All inputs and outputs to an ML inference are represented as `tensor`s. | ||
interface tensor { | ||
use device.{location}; | ||
|
||
/// The dimensions of a tensor. | ||
/// | ||
/// The array length matches the tensor rank and each element in the array describes the size of | ||
|
@@ -44,8 +56,8 @@ interface tensor { | |
type tensor-data = list<u8>; | ||
|
||
resource tensor { | ||
constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data, | ||
location: option<execution-target>); | ||
/// Construct a tensor that lives on the host CPU. | ||
constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data); | ||
|
||
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor | ||
// containing a single value, use `[1]` for the tensor dimensions. | ||
|
@@ -55,7 +67,7 @@ interface tensor { | |
ty: func() -> tensor-type; | ||
|
||
// Describe where the tensor is currently located (e.g., `cpu`, `gpu`, `tpu`). | ||
location: func() -> execution-target; | ||
location: func() -> location; | ||
This comment has been minimized.
Sorry, something went wrong.
geekbeast
Contributor
|
||
|
||
// Return the tensor data. If the tensor is located on a device other than the CPU, this | ||
// operation may result in an expensive data copy operation. | ||
|
@@ -74,8 +86,9 @@ interface tensor { | |
/// framework (e.g., TensorFlow): | ||
interface graph { | ||
use errors.{error}; | ||
use tensor.{tensor}; | ||
use device.{location}; | ||
use inference.{graph-execution-context}; | ||
use tensor.{tensor}; | ||
|
||
/// An execution graph for performing inference (i.e., a model). | ||
resource graph { | ||
|
@@ -93,21 +106,15 @@ interface graph { | |
autodetect, | ||
} | ||
|
||
/// Define where the graph should be executed. | ||
enum execution-target { | ||
cpu, | ||
gpu, | ||
tpu | ||
} | ||
|
||
/// The graph initialization data. | ||
/// | ||
/// This gets bundled up into an array of buffers because implementing backends may encode their | ||
/// graph IR in parts (e.g., OpenVINO stores its IR and weights separately). | ||
type graph-builder = list<u8>; | ||
|
||
/// Load a `graph` from an opaque sequence of bytes to use for inference. | ||
load: func(builder: list<graph-builder>, encoding: graph-encoding, target: execution-target) -> result<graph, error>; | ||
/// Load a `graph` from an opaque sequence of bytes to use for inference on the specified device | ||
/// `location`. | ||
load: func(builder: list<graph-builder>, encoding: graph-encoding, location: location) -> result<graph, error>; | ||
|
||
/// Load a `graph` by name. | ||
/// | ||
|
@@ -128,6 +135,11 @@ interface inference { | |
/// TODO: this may no longer be necessary in WIT | ||
/// (https://github.com/WebAssembly/wasi-nn/issues/43) | ||
resource graph-execution-context { | ||
/// Load a tensor using the graph context. Unlike the `tensor` constructor, this function | ||
/// will co-locate the tensor data on a specific device using the graph's underlying | ||
/// backend; this may avoid some copies, improving performance. | ||
load-tensor: func(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data) -> result<tensor, error>; | ||
|
||
/// Define the inputs to use for inference. | ||
set-input: func(name: string, tensor: tensor) -> result<_, error>; | ||
|
||
|
We briefly discussed making this a string as each framework has varying support for different devices. A concrete example is multi-gpu settings, if you have multiple devices of a single type you will want to do things like
torch.device(cuda:0)
.