Skip to content

Commit

Permalink
Implement MNIST model and inference
Browse files Browse the repository at this point in the history
Signed-off-by: Aisuko <[email protected]>
  • Loading branch information
Aisuko committed Nov 4, 2023
1 parent fb67c91 commit 1d2fd99
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 94 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ prepare
/ggml-metal.metal
target/
Cargo.lock
model.bin
10 changes: 10 additions & 0 deletions backend/rust/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ burn:
@echo "Burning..."
@cargo run --bin server --package backend-burn


############################################################################################################
# gRPC testing commands


.PHONY: list
list:
@echo "Burning..."
@grpcurl -plaintext -import-path ../../../pkg/grpc/proto -proto backend.proto list backend.Backend

.PHONY: health
health:
@echo "Burning..."
Expand Down
1 change: 1 addition & 0 deletions backend/rust/backend-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ path = "src/main.rs"

# import bunker here
bunker = { path = "../bunker" }
models = { path = "../models" }

tokio = "1.33.0"
async-trait = "0.1.74"
Expand Down
21 changes: 20 additions & 1 deletion backend/rust/backend-burn/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use async_trait::async_trait;

use tracing::{event, span, Level};

use models::*;
// implement BackendService trait in bunker

#[derive(Default, Debug)]
Expand All @@ -35,7 +36,25 @@ impl BackendService for BurnBackend {

#[tracing::instrument]
async fn predict(&self, request: Request<PredictOptions>) -> Result<Response<Reply>, Status> {
todo!()
let mut models: Vec<Box<dyn LLM>> = vec![Box::new(models::MNINST::new())];
let result = models[0].predict(request.into_inner());

match result {
Ok(res) => {
let reply = Reply {
message: res.into(),
};
let res = Response::new(reply);
Ok(res)
}
Err(e) => {
let reply = Reply {
message: e.to_string().into(),
};
let res = Response::new(reply);
Ok(res)
}
}
}

#[tracing::instrument]
Expand Down
9 changes: 8 additions & 1 deletion backend/rust/models/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
default = ["ndarray"]

ndarray = ["burn/ndarray"]
wgpu = ["burn/wgpu"]

[dependencies]
burn = { version="0.10.0", features=["ndarray"] } # https://github.com/mudler/LocalAI/discussions/1219
bunker = { path = "../bunker" }
burn = { version="0.10.0", features=["ndarray","wgpu"] } # https://github.com/mudler/LocalAI/discussions/1219
serde = "1.0.190"
10 changes: 9 additions & 1 deletion backend/rust/models/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
pub(crate) mod onnx;
pub(crate) mod mnist;
pub use mnist::mnist::MNINST;

use bunker::pb::{ModelOptions, PredictOptions};

pub trait LLM {
fn load_model(&mut self, request: ModelOptions) -> Result<String, Box<dyn std::error::Error>>;
fn predict(&mut self, request: PredictOptions) -> Result<String, Box<dyn std::error::Error>>;
}
185 changes: 185 additions & 0 deletions backend/rust/models/src/mnist/mnist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
//! Defination of a mninst model and config of it.
//! The source code is from https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/src/model.rs
//! The license is Apache-2.0 and MIT.
//! Adapter by Aisuko

use burn::{
backend::wgpu::{compute::init_async, AutoGraphicsApi, WgpuDevice},
module::Module,
nn::{self, BatchNorm, PaddingConfig2d},
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
tensor::{backend::Backend, Tensor},
};

// https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/model.bin
static STATE_ENCODED: &[u8] = include_bytes!("model.bin");

const NUM_CLASSES: usize = 10;

#[derive(Module, Debug)]
/// A struct representing an MNINST model.
pub struct MNINST<B: Backend> {
/// The first convolutional block of the model.
conv1: ConvBlock<B>,
/// The second convolutional block of the model.
conv2: ConvBlock<B>,
/// The third convolutional block of the model.
conv3: ConvBlock<B>,
/// A dropout layer used in the model.
dropout: nn::Dropout,
/// The first fully connected layer of the model.
fc1: nn::Linear<B>,
/// The second fully connected layer of the model.
fc2: nn::Linear<B>,
/// The activation function used in the model.
activation: nn::GELU,
}

impl<B: Backend> MNINST<B> {
pub fn new() -> Self {
let conv1 = ConvBlock::new([1, 8], [3, 3]); // 1 input channel, 8 output channels, 3x3 kernel size
let conv2 = ConvBlock::new([8, 16], [3, 3]); // 8 input channels, 16 output channels, 3x3 kernel size
let conv3 = ConvBlock::new([16, 24], [3, 3]); // 16 input channels, 24 output channels, 3x3 kernel size
let hidden_size = 24 * 22 * 22;
let fc1 = nn::LinearConfig::new(hidden_size, 32)
.with_bias(false)
.init();
let fc2 = nn::LinearConfig::new(32, NUM_CLASSES)
.with_bias(false)
.init();

let dropout = nn::DropoutConfig::new(0.5).init();

let instance = Self {
conv1: conv1,
conv2: conv2,
conv3: conv3,
dropout: dropout,
fc1: fc1,
fc2: fc2,
activation: nn::GELU::new(),
};
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
.load(STATE_ENCODED.to_vec())
.expect("Failed to decode state");

instance.load_record(record)
}

/// Applies the forward pass of the neural network on the given input tensor.
///
/// # Arguments
///
/// * `input` - A 3-dimensional tensor of shape [batch_size, height, width].
///
/// # Returns
///
/// A 2-dimensional tensor of shape [batch_size, num_classes] containing the output of the neural network.
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
// Get the dimensions of the input tensor
let [batch_size, height, width] = input.dims();
// Reshape the input tensor to have a shape of [batch_size, 1, height, width] and detach it
let x = input.reshape([batch_size, 1, height, width]).detach();
// Apply the first convolutional layer to the input tensor
let x = self.conv1.forward(x);
// Apply the second convolutional layer to the output of the first convolutional layer
let x = self.conv2.forward(x);
// Apply the third convolutional layer to the output of the second convolutional layer
let x = self.conv3.forward(x);

// Get the dimensions of the output tensor from the third convolutional layer
let [batch_size, channels, height, width] = x.dims();
// Reshape the output tensor to have a shape of [batch_size, channels*height*width]
let x = x.reshape([batch_size, channels * height * width]);

// Apply dropout to the output of the third convolutional layer
let x = self.dropout.forward(x);
// Apply the first fully connected layer to the output of the dropout layer
let x = self.fc1.forward(x);
// Apply the activation function to the output of the first fully connected layer
let x = self.activation.forward(x);

// Apply the second fully connected layer to the output of the activation function
self.fc2.forward(x)
}

pub fn inference(&mut self, input: &[f32]) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
// Reshape from the 1D array to 3d tensor [batch, height, width]
let input: Tensor<B, 3> = Tensor::from_floats(input).reshape([1, 28, 28]);

// Normalize input: make between [0,1] and make the mean=0 and std=1
// values mean=0.1307, std=0.3081
// Source: https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122
let input = ((input / 255) - 0.1307) / 0.3081;

// Run the tensor input through the model
let output: Tensor<B, 2> = self.forward(input);

// Convert the model output into probalibility distribution using softmax formula
let output = burn::tensor::activation::softmax(output, 1);

// Flatten oupuut tensor with [1,10] shape into boxed slice of [f32]
let output = output.into_data().convert::<f32>().value;

Ok(output)
}
}

/// A struct representing a convolutional block in a neural network model.
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
/// A 2D convolutional layer.
conv: nn::conv::Conv2d<B>,
/// A batch normalization layer.
norm: BatchNorm<B, 2>,
/// A GELU activation function.
activation: nn::GELU,
}

/// A convolutional block with batch normalization and GELU activation.
impl<B: Backend> ConvBlock<B> {
/// Creates a new `ConvBlock` with the given number of output channels and kernel size.
pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self {
// Initialize a 2D convolutional layer with the given output channels and kernel size,
// and set the padding to "valid".
let conv = nn::conv::Conv2dConfig::new(channels, kernel_size)
.with_padding(PaddingConfig2d::Valid)
.init();

// Initialize a batch normalization layer with the number of channels in the second dimension of the output.
let norm = nn::BatchNormConfig::new(channels[1]).init();

// Create a new `ConvBlock` with the initialized convolutional and batch normalization layers,
// and a GELU activation function.
Self {
conv: conv,
norm: norm,
activation: nn::GELU::new(),
}
}

/// Applies the convolutional block to the given input tensor.
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
// Apply the convolutional layer to the input tensor.
let x = self.conv.forward(input);

// Apply the batch normalization layer to the output of the convolutional layer.
let x = self.norm.forward(x);

// Apply the GELU activation function to the output of the batch normalization layer.
self.activation.forward(x)
}
}

#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "ndarray")]
pub type Backend = burn::backend::NdArrayBackend<f32>;
#[test]
fn test_inference() {
let mut model = MNINST::<Backend>::new();
let output = model.inference(&[0.0; 28 * 28]).unwrap();
assert_eq!(output.len(), 10);
}
}
33 changes: 33 additions & 0 deletions backend/rust/models/src/mnist/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use crate::LLM;
use bunker::pb::{ModelOptions, PredictOptions};

pub(crate) mod mnist;

#[cfg(feature = "ndarray")]
pub type Backend = burn::backend::NdArrayBackend<f32>;

impl LLM for mnist::MNINST<Backend> {
fn load_model(&mut self, request: ModelOptions) -> Result<String, Box<dyn std::error::Error>> {
todo!("load model")
}

fn predict(&mut self, pre_ops: PredictOptions) -> Result<String, Box<dyn std::error::Error>> {
// convert prost::alloc::string::String to &[f32]
let input = pre_ops.prompt.as_bytes();
let input = input.iter().map(|x| *x as f32).collect::<Vec<f32>>();

let result = self.inference(&input);

match result {
Ok(output) => {
let output = output
.iter()
.map(|f| f.to_string())
.collect::<Vec<String>>()
.join(",");
Ok(output)
}
Err(e) => Err(e),
}
}
}
1 change: 0 additions & 1 deletion backend/rust/models/src/onnx/inference.rs

This file was deleted.

Loading

0 comments on commit 1d2fd99

Please sign in to comment.