-
What is the code for inference for custom image dataset example. Here is my code, don't know what to do next. pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device) {
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist");
let dataset = ImageFolderDataset::cifar10_test();
let value = dataset.get(0).unwrap();
} It also gives error:
How can I convert dataset to tensor, and run model.forward If I type annotate with let record: Cnn<B>= CompactRecorder::new()
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist"); It gives error:
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Finally without error, but not tested by running the code: pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device) {
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist");
let model: Cnn<B> = Cnn::new(NUM_CLASSES.into(), &device).load_record(record);
let dataset = ImageFolderDataset::cifar10_test();
let item = dataset.get(0).unwrap();
let annotation = item.clone().annotation;
let batcher = ClassificationBatcher::new(device);
let batch = batcher.batch(vec![item]);
let output = model.forward(batch.images);
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
println!("Predicted {} Expected {:?}", predicted, annotation);
} |
Beta Was this translation helpful? Give feedback.
-
Sorry I didn't see your question in time! Your code looks good. I'll provide a snippet with all modules required for future reference. use burn::{
data::{
dataloader::batcher::Batcher,
dataset::vision::{Annotation, ImageDatasetItem},
},
module::Module,
record::{CompactRecorder, Recorder},
tensor::backend::Backend,
};
use crate::{data::ClassificationBatcher, model::Cnn};
// NUM_CLASSES const in training.rs is private right now
const NUM_CLASSES: u8 = 10;
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: ImageDatasetItem) {
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist");
let model: Cnn<B> = Cnn::new(NUM_CLASSES.into(), &device).load_record(record);
let mut label = 0;
if let Annotation::Label(category) = item.annotation {
label = category;
};
let batcher = ClassificationBatcher::new(device);
let batch = batcher.batch(vec![item]);
let output = model.forward(batch.images);
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
println!("Predicted {} Expected {:?}", predicted, label);
} And you can call the function with an item from the dataset, for example: infer::<MyBackend>(
artifact_dir,
device,
ImageFolderDataset::cifar10_test().get(0).unwrap()
); |
Beta Was this translation helpful? Give feedback.
Sorry I didn't see your question in time!
Your code looks good. I'll provide a snippet with all modules required for future reference.