Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WASI-NN] Add support for a PyTorch backend for wasi-nn #9234

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

rahulchaphalkar
Copy link
Contributor

This change adds a PyTorch backend for wasi-nn.
tch crate is used for Libtorch bindings. I have added an image classification example to demonstrate its usage, which uses a torchscript model.
This backend is currently gated behind a wasi-nn feature flag --features pytorch as due to dynamic linking, a Libtorch v2.4.0 installation on the system (specified by LIBTORCH=/path/to/libtorch) is needed for building.

@rahulchaphalkar rahulchaphalkar requested review from alexcrichton and removed request for a team September 12, 2024 18:18
@abrown abrown self-assigned this Sep 12, 2024
@alexcrichton alexcrichton requested review from abrown and removed request for a team and alexcrichton September 12, 2024 20:09
Copy link
Collaborator

@abrown abrown left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good start. The main thing to fix is the handling of the input and output tensors.

crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
) -> Result<Graph, BackendError> {
// Load the model from the file path
let compiled_module =
CModule::load_on_device(path, map_execution_target_to_string(target)).unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems incorrect: load_from_dir is going to pass the path to a directory and this code will try to use it as a file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to path.join("model.pt")

.iter()
.map(|&dim| dim as i64)
.collect::<Vec<_>>();
self.1 = TchTensor::from_data_size(&input_tensor.data, &dimensions, kind);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function needs to handle the passed index: a model can have multiple inputs and the job this function needs to do is map the incoming tensor to the right one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of the differences for this backend. The module's forward method should handle multiple inputs appropriately if it does support multiple inputs. The vector of input tensors being passed to forward should be sufficient, no index or name is needed.


fn compute(&mut self) -> Result<(), BackendError> {
// Use forward method on the compiled module/model after locking the mutex, and pass the input tensor to it
self.1 = self.0.lock().unwrap().forward_ts(&[&self.1]).unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing the output tensor in the same location as the input tensor means that set_input followed immediately by get_output would return the input tensor... probably not what you want here. It looks like forward_ts only returns a single tensor so perhaps just create an output field for that and another input: Vec<Tensor> for the inputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I added inputs as vector of tensors. Also changed structs in general to use named fields to make code more readable.

Comment on lines 106 to 111
let data = vec![0f32; numel];
let mut data_u8: Vec<u8> = data
.iter()
.flat_map(|&x| x.to_le_bytes().to_vec())
.collect();
self.1.copy_data_u8(&mut data_u8, numel);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect: we need to retrieve the data regardless of the type, so we need to first figure out how many bytes each Kind is before constructing the receiving buffer, like:

Suggested change
let data = vec![0f32; numel];
let mut data_u8: Vec<u8> = data
.iter()
.flat_map(|&x| x.to_le_bytes().to_vec())
.collect();
self.1.copy_data_u8(&mut data_u8, numel);
let data = vec![0u8; size_of(ty) * numel];
self.1.copy_data_u8(&mut data, numel);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, added kind_to_size() for conversions.

crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
@abrown
Copy link
Collaborator

abrown commented Sep 20, 2024

The cargo vet situation is a bit much:

    cargo vet diff zstd-safe 5.0.1+zstd.1.5.2 5.0.2+zstd.1.5.2
                                                          gyscos         zstd                        2 files changed, 4 insertions(+), 4 deletions(-)
    cargo vet diff zstd 0.11.1+zstd.1.5.2 0.11.2+zstd.1.5.2
                                                          gyscos         zip                         3 files changed, 5 insertions(+), 5 deletions(-)
    cargo vet diff num-complex 0.4.2 0.4.6                cuviper        ndarray                     6 files changed, 188 insertions(+), 48 deletions(-)
      NOTE: this project trusts Josh Stone (cuviper) - consider cargo vet trust num-complex or cargo vet trust --all cuviper
    cargo vet inspect constant_time_eq 0.1.5              cesarb         zip                         311 lines
    cargo vet diff sha1 0.10.5 0.10.6                     newpavlov      zip                         7 files changed, 302 insertions(+), 20 deletions(-)
    cargo vet inspect rawpointer 0.2.1                    bluss          ndarray and matrixmultiply  559 lines
    cargo vet diff zip 0.6.4 0.6.6                        Plecra         tch and torch-sys           14 files changed, 604 insertions(+), 109 deletions(-)
    cargo vet inspect inout 0.1.3                         newpavlov      cipher                      1112 lines
      NOTE: cargo vet import zcash would eliminate this
    cargo vet inspect pbkdf2 0.9.0                        tarcieri       zip                         1120 lines
    cargo vet inspect bzip2 0.4.4                         alexcrichton   zip                         2094 lines
      NOTE: this project trusts Alex Crichton (alexcrichton) - consider cargo vet trust bzip2 or cargo vet trust --all alexcrichton
    cargo vet inspect safetensors 0.3.3                   Narsil         tch                         2200 lines
    cargo vet inspect cipher 0.4.4                        newpavlov      aes                         2635 lines
      NOTE: cargo vet import zcash would reduce this to a [130](https://github.com/bytecodealliance/wasmtime/actions/runs/10836457564/job/30070281197?pr=9234#step:6:131)0-line diff
    cargo vet inspect password-hash 0.3.2                 tarcieri       pbkdf2                      3139 lines
    cargo vet inspect base64ct 1.6.0                      tarcieri       password-hash               3381 lines
    cargo vet diff half 1.8.2 2.4.1                       starkat99      tch                         19 files changed, 2546 insertions(+), 958 deletions(-)
    cargo vet inspect time 0.1.44                         jhpratt        zip                         3915 lines
    cargo vet inspect aes 0.7.5                           tarcieri       zip                         6822 lines
    cargo vet inspect matrixmultiply 0.3.8                bluss          ndarray                     7934 lines
    cargo vet inspect ndarray 0.15.6                      jturner314     tch                         41996 lines
    cargo vet inspect torch-sys 0.17.0                    LaurentMazare  tch                         52119 lines
    cargo vet inspect bzip2-sys 0.1.11+1.0.8              alexcrichton   bzip2                       264[133](https://github.com/bytecodealliance/wasmtime/actions/runs/10836457564/job/30070281197?pr=9234#step:6:134) lines
      NOTE: this project trusts Alex Crichton (alexcrichton) - consider cargo vet trust bzip2-sys or cargo vet trust --all alexcrichton
    cargo vet inspect tch 0.17.0                          LaurentMazare  wasmtime-wasi-nn            2287297 lines

@rahulchaphalkar
Copy link
Contributor Author

This is a good start. The main thing to fix is the handling of the input and output tensors.

Thanks for the review, Andrew. I've marked smaller Nits as resolved, and I've addressed other comments as well, but kept them 'unresolved' as of now until you take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants