Skip to content

Commit

Permalink
🎉 adds efficientnet
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Sep 26, 2023
1 parent 1cd0f9a commit d40f02f
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 57 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repository = "https://github.com/lenna-project/birds-plugin"
crate-type = ["cdylib", "rlib"]

[features]
default = ["plugin"]
default = ["plugin", "mobilenet"]
python = [
"lenna_core/python",
"ndarray",
Expand All @@ -21,6 +21,8 @@ python = [
"pythonize",
]
plugin = []
mobilenet = []
efficientnet = []

[dependencies]
console_error_panic_hook = "0.1"
Expand Down
Binary file added assets/birds_efficientnetb2.onnx
Binary file not shown.
61 changes: 61 additions & 0 deletions src/efficientnet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use image::DynamicImage;
use std::io::Cursor;
use tract_onnx::prelude::*;

use crate::{Birds, ModelType};

const SIZE: usize = 260;

impl Birds {
pub fn model() -> Result<ModelType, Box<dyn std::error::Error>> {
let data = include_bytes!("../assets/birds_efficientnetb2.onnx");
let mut cursor = Cursor::new(data);
let model = tract_onnx::onnx()
.model_for_read(&mut cursor)?
.with_input_fact(
0,
InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, SIZE, SIZE)),
)?
.into_optimized()?
.into_runnable()?;
Ok(model)
}

pub fn labels() -> Vec<String> {
let collect = include_str!("../assets/birds_labels.txt")
.to_string()
.lines()
.map(|s| s.to_string())
.collect();
collect
}

pub fn detect_label(
&self,
image: &Box<DynamicImage>,
) -> Result<Option<String>, Box<dyn std::error::Error>> {
let image_rgb = image.to_rgb8();
let resized = image::imageops::resize(
&image_rgb,
SIZE as u32,
SIZE as u32,
::image::imageops::FilterType::Triangle,
);
let tensor: Tensor =
tract_ndarray::Array4::from_shape_fn((1, 3, SIZE, SIZE), |(_, c, y, x)| {
(resized[(x as _, y as _)][c] as f32 / 255.0)
})
.into();

let result = self.model.run(tvec!(tensor.into())).unwrap();
let best = result[0]
.to_array_view::<f32>()?
.iter()
.cloned()
.zip(0..)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let index = best.unwrap().1;
let label = Self::labels()[index].to_string();
Ok(Some(label))
}
}
62 changes: 6 additions & 56 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ use lenna_core::plugins::PluginRegistrar;
use lenna_core::ProcessorConfig;
use lenna_core::{core::processor::ExifProcessor, core::processor::ImageProcessor, Processor};
use rusttype::{Font, Scale};
use std::io::Cursor;
use tract_onnx::prelude::*;

#[cfg(feature = "efficientnet")]
mod efficientnet;

#[cfg(feature = "mobilenet")]
mod mobilenet;

extern "C" fn register(registrar: &mut dyn PluginRegistrar) {
registrar.add_plugin(Box::new(Birds::default()));
}
Expand All @@ -16,7 +21,6 @@ extern "C" fn register(registrar: &mut dyn PluginRegistrar) {
lenna_core::export_plugin!(register);

type ModelType = SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;
const SIZE: usize = 224;

#[derive(Clone)]
pub struct Birds {
Expand All @@ -25,60 +29,6 @@ pub struct Birds {
label: Option<String>,
}

impl Birds {
pub fn model() -> Result<ModelType, Box<dyn std::error::Error>> {
let data = include_bytes!("../assets/birds_mobilenetv2.onnx");
let mut cursor = Cursor::new(data);
let model = tract_onnx::onnx()
.model_for_read(&mut cursor)?
.with_input_fact(
0,
InferenceFact::dt_shape(f32::datum_type(), tvec!(1, SIZE, SIZE, 3)),
)?
.into_optimized()?
.into_runnable()?;
Ok(model)
}

pub fn labels() -> Vec<String> {
let collect = include_str!("../assets/birds_labels.txt")
.to_string()
.lines()
.map(|s| s.to_string())
.collect();
collect
}

pub fn detect_label(
&self,
image: &Box<DynamicImage>,
) -> Result<Option<String>, Box<dyn std::error::Error>> {
let image_rgb = image.to_rgb8();
let resized = image::imageops::resize(
&image_rgb,
SIZE as u32,
SIZE as u32,
::image::imageops::FilterType::Triangle,
);
let tensor: Tensor =
tract_ndarray::Array4::from_shape_fn((1, SIZE, SIZE, 3), |(_, y, x, c)| {
resized[(x as _, y as _)][c] as f32 / 255.0
})
.into();

let result = self.model.run(tvec!(tensor.into())).unwrap();
let best = result[0]
.to_array_view::<f32>()?
.iter()
.cloned()
.zip(0..)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let index = best.unwrap().1;
let label = Self::labels()[index].to_string();
Ok(Some(label))
}
}

impl Default for Birds {
fn default() -> Self {
Birds {
Expand Down
63 changes: 63 additions & 0 deletions src/mobilenet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use image::DynamicImage;
use std::io::Cursor;
use tract_onnx::prelude::*;

use crate::{Birds, ModelType};

const SIZE: usize = 226;

impl Birds {
pub fn model() -> Result<ModelType, Box<dyn std::error::Error>> {
let data = include_bytes!("../assets/birds_mobilenetv2.onnx");
let mut cursor = Cursor::new(data);
let model = tract_onnx::onnx()
.model_for_read(&mut cursor)?
.with_input_fact(
0,
InferenceFact::dt_shape(f32::datum_type(), tvec!(1, SIZE, SIZE, 3)),
)?
.into_optimized()?
.into_runnable()?;
Ok(model)
}

pub fn labels() -> Vec<String> {
let collect = include_str!("../assets/birds_labels.txt")
.to_string()
.lines()
.map(|s| s.to_string())
.collect();
collect
}

pub fn detect_label(
&self,
image: &Box<DynamicImage>,
) -> Result<Option<String>, Box<dyn std::error::Error>> {
let image_rgb = image.to_rgb8();
let resized = image::imageops::resize(
&image_rgb,
SIZE as u32,
SIZE as u32,
::image::imageops::FilterType::Triangle,
);
let tensor: Tensor =
tract_ndarray::Array4::from_shape_fn((1, SIZE, SIZE, 3), |(_, y, x, c)| {
let mean = [0.485, 0.456, 0.406][c];
let std = [0.229, 0.224, 0.225][c];
(resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std
})
.into();

let result = self.model.run(tvec!(tensor.into())).unwrap();
let best = result[0]
.to_array_view::<f32>()?
.iter()
.cloned()
.zip(0..)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let index = best.unwrap().1;
let label = Self::labels()[index].to_string();
Ok(Some(label))
}
}

0 comments on commit d40f02f

Please sign in to comment.