-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Develop a FastAPI for serving the classifier of mussel and oyster lar…
…val stages
- Loading branch information
Showing
4 changed files
with
234 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
############################################################################################ | ||
# Imports (for creating a REST API for the model) | ||
############################################################################################ | ||
from fastapi import FastAPI, UploadFile, File | ||
from pydantic import BaseModel | ||
from io import BytesIO | ||
from PIL import Image | ||
import uvicorn | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision.transforms as transforms | ||
import torchvision.models as models | ||
|
||
############################################################################################ | ||
# Functions to be used by the FastAPI server | ||
############################################################################################ | ||
# Define the model loading function | ||
def load_model(model_path, num_classes=6): | ||
# Load the model checkpoint (remove map_location if you have a GPU) | ||
loaded_cpt = torch.load(model_path, map_location=torch.device('cpu')) | ||
# Define the model architecture and modify the number of output classes (by default, no pre-trained weights are used) | ||
model = models.efficientnet_v2_s() | ||
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, num_classes) | ||
# Load the state_dict in order to load the trained parameters | ||
model.load_state_dict(loaded_cpt) | ||
# Set the model to evaluation mode | ||
model.eval() | ||
return model | ||
|
||
# Define the image reading function | ||
def read_imagefile(file) -> Image.Image: | ||
image = Image.open(BytesIO(file)) | ||
return image | ||
|
||
# Define the image transformation function | ||
def transform_image(image, image_size=256): | ||
transform = transforms.Compose([ | ||
transforms.Resize((image_size, image_size)), | ||
transforms.ToTensor(), | ||
]) | ||
# Apply image transformations and add batch dimension | ||
image = transform(image).unsqueeze(0) | ||
return image | ||
|
||
# Define the image prediction function | ||
def predict_image(model, image: Image.Image): | ||
image_tensor = transform_image(image) | ||
with torch.no_grad(): | ||
outputs = model(image_tensor) | ||
classif_scores = F.softmax(outputs, dim=1) | ||
_, predicted = torch.max(outputs, 1) | ||
return predicted, classif_scores | ||
|
||
############################################################################################ | ||
# Entry point of the FastAPI server | ||
############################################################################################ | ||
# Define the FastAPI app object and configure it with the required routes and models | ||
api_server = FastAPI(title='Mussel and Oyster Larvae Classification API', | ||
description="Obtain model predictions for mussel and oyster larvae images.", | ||
version='1.0') | ||
|
||
# Load the pre-trained model | ||
model_path = "C:/Users/Wassim/Downloads/FairScope/Projet de fin d'études/Dashboard/streamlit classification app with fastapi/models/Effv2s_DA2+fill_256x256_cosan_mussel_oyster.pth" | ||
model = load_model(model_path) | ||
|
||
# Define the response model of the API for each image | ||
class Prediction(BaseModel): | ||
filename: str | ||
prediction: int | ||
scores: list | ||
|
||
@api_server.post("/predict", response_model=Prediction) | ||
async def get_prediction(file: UploadFile = File(...)): | ||
image = read_imagefile(await file.read()) | ||
predicted_class_index, predicted_classif_scores = predict_image(model, image) | ||
predicted_classif_scores = predicted_classif_scores.tolist()[0] # Convert tensor to list for response | ||
return {"filename": file.filename, "prediction": predicted_class_index, "scores": predicted_classif_scores} | ||
|
||
if __name__ == "__main__": | ||
uvicorn.run(api_server, host="0.0.0.0", port=8000) |
Binary file added
BIN
+77.9 MB
Fastapi model serving/models/Effv2s_DA2+fill_256x256_cosan_mussel_oyster.pth
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
############################################################################################ | ||
# Imports (for Streamlit app & model prediction) | ||
############################################################################################ | ||
import streamlit as st # https://docs.streamlit.io/develop/api-reference | ||
import requests | ||
import torch | ||
import numpy as np | ||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
import plotly.graph_objects as go | ||
from PIL import Image | ||
from io import BytesIO | ||
from itertools import cycle | ||
|
||
############################################################################################ | ||
# Functions to be used in the Streamlit app | ||
############################################################################################ | ||
# Define the function to save the probabilities to a file | ||
def save_probabilities(probas, filename='classification_scores.txt'): | ||
with open(filename, 'w') as f: | ||
for prob in probas: | ||
f.write(f'{prob[0]}: {prob[1].tolist()}\n') | ||
|
||
# Function to generate and display the graph of detected objects | ||
def display_distribution_plot(class_counts, sns_palette="pastel"): | ||
# Generate seaborn color palette | ||
seaborn_palette = sns.color_palette(sns_palette) | ||
# Convert seaborn colors to Plotly-compatible RGBA format | ||
plotly_colors = ['rgba' + str(tuple(int(255 * c) for c in color[:3]) + (1,)) for color in seaborn_palette] | ||
# Create a Plotly figure | ||
fig = go.Figure(data=[go.Bar(y=list(class_counts.values()), x=list(class_counts.keys()), orientation='v', marker_color=plotly_colors, text=list(class_counts.values()), textposition='auto')]) | ||
fig.update_layout( | ||
title='Distribution of Detected Objects', | ||
title_font=dict(size=20), | ||
xaxis_title='Object', | ||
yaxis_title='Count', | ||
xaxis=dict( | ||
title_font=dict(color='black', size=18), | ||
tickfont=dict(color='black'), | ||
showline=True | ||
), | ||
yaxis=dict( | ||
title_font=dict(color='black', size=18), | ||
tickfont=dict(color='black'), | ||
showline=True | ||
), | ||
height=600, | ||
width=800, | ||
paper_bgcolor="lightgray", | ||
margin=dict(pad=0, r=20, t=50, b=60, l=60) | ||
) | ||
st.plotly_chart(fig, use_container_width=True) | ||
|
||
############################################################################################ | ||
# Body of the Streamlit app | ||
############################################################################################ | ||
def main(): | ||
# Set the page configuration | ||
st.set_page_config( | ||
page_title="Larval Stage Classification", | ||
page_icon="fairscope_favicon.png", | ||
layout = 'wide', | ||
) | ||
|
||
# Set the title of the Streamlit app | ||
st.title("Classification of Mussel & Oyster Larval Stages") | ||
|
||
# Set the default theme of the plotly chart to light | ||
st.markdown( | ||
""" | ||
<style> | ||
.stPlotlyChart {{ | ||
outline: 10px solid #FFFFFF; | ||
border-radius: 5px; | ||
box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.20), 0 6px 20px 0 rgba(0, 0, 0, 0.30); | ||
}} | ||
</style> | ||
""", unsafe_allow_html=True | ||
) | ||
|
||
# Initialize an empty list to store predicted probabilities | ||
if 'probabilities' not in st.session_state: | ||
st.session_state.probabilities = [] | ||
|
||
# Define the class labels | ||
class_labels = ["Oyster - DV", "Oyster - PV", "Oyster - UV", "Mussel - DV", "Mussel - PV", "Mussel - UV"] | ||
|
||
# File uploader for image selection | ||
uploaded_files = st.file_uploader("Upload images", type=["jpg", "jpeg"], accept_multiple_files=True) | ||
|
||
# List to store predicted class labels | ||
predicted_class_labels = [] | ||
|
||
# Display a message indicating images classification part | ||
text = "<span style='font-size: 14px;'>Classify images</span>" | ||
st.markdown(text, unsafe_allow_html=True) | ||
|
||
# Create a grid layout of images | ||
cols = cycle(st.columns(4)) # Ref: https://discuss.streamlit.io/t/grid-of-images-with-the-same-height/10668/8 | ||
|
||
if uploaded_files is not None: | ||
# Iterate over uploaded images and predict their classes | ||
for (i, uploaded_file) in enumerate(uploaded_files): | ||
# Read the uploaded image | ||
image = Image.open(uploaded_file) | ||
|
||
# Convert image to bytes | ||
img_bytes = BytesIO() | ||
image.save(img_bytes, format=image.format) | ||
img_bytes = img_bytes.getvalue() | ||
|
||
# Make API request to perform image classification | ||
response = requests.post("http://localhost:8000/predict", files={"file": (uploaded_file.name, img_bytes, uploaded_file.type)}) | ||
predicted_class_index = response.json()["prediction"] | ||
predicted_classif_scores = response.json()["scores"] | ||
|
||
# Perform image classification | ||
file_name = f"{i}. {uploaded_file.name}" | ||
st.session_state.probabilities.append((file_name, dict(zip(class_labels, predicted_classif_scores)))) | ||
predicted_class_label = class_labels[predicted_class_index] | ||
predicted_class_labels.append(predicted_class_label) | ||
|
||
# Display the uploaded image with the predicted class | ||
next(cols).image(image, width=150, caption=f"{i}. {predicted_class_label} ({max(predicted_classif_scores):.4f})", use_column_width=True) | ||
|
||
# Determine the number of detected objects | ||
num_objects = len(predicted_class_labels) | ||
|
||
# Display the number of detected objects | ||
st.write(f"Number of detected objects: {num_objects}") | ||
|
||
# Count the occurrences of each class label | ||
class_counts = {label: predicted_class_labels.count(label) for label in class_labels} | ||
|
||
# Convert probabilities to string format | ||
probabilities_str = '\n'.join([f"{name}: {scores}" for name, scores in st.session_state.probabilities]) | ||
|
||
# Plot the distribution of detected objects | ||
if num_objects > 0: | ||
|
||
# Download the classification scores file with streamlit | ||
with st.sidebar: | ||
st.download_button(label="Download classification scores", data=probabilities_str, file_name="classification_scores.txt", mime="text/plain") | ||
|
||
# Display the distribution of detected objects | ||
display_distribution_plot(class_counts) | ||
|
||
############################################################################################ | ||
# Entry point of the Streamlit app | ||
############################################################################################ | ||
if __name__ == "__main__": | ||
main() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.