Skip to content

Commit

Permalink
add abstractions to support more neural network backends
Browse files Browse the repository at this point in the history
  • Loading branch information
Alextopher committed Sep 11, 2024
1 parent 0b2d30c commit be116a7
Show file tree
Hide file tree
Showing 15 changed files with 551 additions and 339 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ const interactiveCols = computed(() =>
? 9
: 8
);
// Filters out models that are not supported by the current backend, and returns a flattened list.
const supportedModels = computed(() => {
const { availableModels, supportedBackends } = useSettingsStore().general;
return supportedBackends.flatMap(backend => availableModels[backend] || []);
});
</script>

<template>
Expand All @@ -37,7 +43,7 @@ const interactiveCols = computed(() =>
label="Model"
tooltip="The model used to detect objects in the camera feed"
:select-cols="interactiveCols"
:items="useSettingsStore().general.availableModels"
:items="supportedModels"
@input="(value) => useCameraSettingsStore().changeCurrentPipelineSetting({ model: value }, false)"
/>
<pv-slider
Expand Down
8 changes: 4 additions & 4 deletions photon-client/src/stores/settings/GeneralSettingsStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ export const useSettingsStore = defineStore("settings", {
hardwareModel: undefined,
hardwarePlatform: undefined,
mrCalWorking: true,
rknnSupported: false,
availableModels: []
availableModels: {},
supportedBackends: []
},
network: {
ntServerAddress: "",
Expand Down Expand Up @@ -106,8 +106,8 @@ export const useSettingsStore = defineStore("settings", {
hardwarePlatform: data.general.hardwarePlatform || undefined,
gpuAcceleration: data.general.gpuAcceleration || undefined,
mrCalWorking: data.general.mrCalWorking,
rknnSupported: data.general.rknnSupported,
availableModels: data.general.availableModels || []
availableModels: data.general.availableModels || undefined,
supportedBackends: data.general.supportedBackends || []
};
this.lighting = data.lighting;
this.network = data.networkSettings;
Expand Down
4 changes: 2 additions & 2 deletions photon-client/src/types/SettingTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ export interface GeneralSettings {
hardwareModel?: string;
hardwarePlatform?: string;
mrCalWorking: boolean;
rknnSupported: boolean;
availableModels: string[];
availableModels: Record<string, string[]>;
supportedBackends: string[];
}

export interface MetricData {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.opencv.core.Size;
import java.util.Map;
import java.util.Optional;
import org.photonvision.common.hardware.Platform;
import org.photonvision.common.logging.LogGroup;
import org.photonvision.common.logging.Logger;
import org.photonvision.rknn.RknnJNI;
import org.photonvision.vision.objects.Model;
import org.photonvision.vision.objects.RknnModel;

/**
* Manages the loading of neural network models.
Expand All @@ -42,9 +47,6 @@
* simply a list of string names per label, one label per line. The labels file must have the same
* name as the model file, but with the suffix <code>-labels.txt</code> instead of <code>.rknn
* </code>.
*
* <p>Note: PhotonVision currently only supports YOLOv5 and YOLOv8 models in the <code>.rknn</code>
* format.
*/
public class NeuralNetworkModelManager {
/** Singleton instance of the NeuralNetworkModelManager */
Expand All @@ -55,7 +57,15 @@ public class NeuralNetworkModelManager {
*
* @return The NeuralNetworkModelManager instance
*/
private NeuralNetworkModelManager() {}
private NeuralNetworkModelManager() {
ArrayList<NeuralNetworkBackend> backends = new ArrayList<>();

if (Platform.isRK3588()) {
backends.add(NeuralNetworkBackend.RKNN);
}

supportedBackends = backends;
}

/**
* Returns the singleton instance of the NeuralNetworkModelManager
Expand All @@ -72,112 +82,142 @@ public static NeuralNetworkModelManager getInstance() {
/** Logger for the NeuralNetworkModelManager */
private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config);

/**
* Determines the model version based on the model's filename.
*
* <p>"yolov5" -> "YOLO_V5"
*
* <p>"yolov8" -> "YOLO_V8"
*
* @param modelName The model's filename
* @return The model version
*/
private static RknnJNI.ModelVersion getModelVersion(String modelName)
throws IllegalArgumentException {
if (modelName.contains("yolov5")) {
return RknnJNI.ModelVersion.YOLO_V5;
} else if (modelName.contains("yolov8")) {
return RknnJNI.ModelVersion.YOLO_V8;
} else {
throw new IllegalArgumentException("Unknown model version for model " + modelName);
}
}

/** This class represents a model that can be loaded by the RknnJNI. */
public class Model {
public final File modelFile;
public final RknnJNI.ModelVersion version;
public final List<String> labels;
public final Size inputSize;

/**
* Model constructor.
*
* @param model format `name-width-height-model.format`
* @param labels
* @throws IllegalArgumentException
*/
public Model(String model, String labels) throws IllegalArgumentException {
this.modelFile = new File(model);

String[] parts = modelFile.getName().split("-");
if (parts.length != 4) {
throw new IllegalArgumentException("Invalid model file name: " + model);
}

// TODO: model 'version' need to be replaced the by the product of 'Version' x 'Format'
this.version = getModelVersion(parts[3]);
public enum NeuralNetworkBackend {
RKNN(".rknn");

int width = Integer.parseInt(parts[1]);
int height = Integer.parseInt(parts[2]);
this.inputSize = new Size(width, height);
private String format;

try {
this.labels = Files.readAllLines(Paths.get(labels));
} catch (IOException e) {
throw new IllegalArgumentException("Failed to read labels file " + labels, e);
}

logger.info("Loaded model " + modelFile.getName());
private NeuralNetworkBackend(String format) {
this.format = format;
}
}

public String getName() {
return modelFile.getName();
}
private final List<NeuralNetworkBackend> supportedBackends;

/**
* Retrieves the list of supported backends.
*
* @return the list
*/
public List<String> getSupportedBackends() {
return supportedBackends.stream().map(Enum::toString).toList();
}

/**
* Stores model information, such as the model file, labels, and version.
*
* <p>The first model in the list is the default model.
*/
private List<Model> models;
private Map<NeuralNetworkBackend, ArrayList<Model>> models;

/**
* Returns the default rknn model. This is simply the first model in the list.
* Retrieves the deep neural network models available, in a format that can be used by the
* frontend.
*
* @return The default model
* @return A map containing the available models, where the key is the backend and the value is a
* list of model names.
*/
public Model getDefaultRknnModel() {
return models.get(0);
public Map<String, ArrayList<String>> getModels() {
Map<String, ArrayList<String>> modelMap = new HashMap<>();
if (models == null) {
return modelMap;
}

models.forEach(
(backend, backendModels) -> {
ArrayList<String> modelNames = new ArrayList<>();
backendModels.forEach(model -> modelNames.add(model.getName()));
modelMap.put(backend.toString(), modelNames);
});

return modelMap;
}

/**
* Enumerates the names of all models.
* Retrieves the model with the specified name, assuming it is available under a supported
* backend.
*
* <p>If this method returns `Optional.of(..)` then the model should be safe to load.
*
* @return A list of model names
* @param modelName the name of the model to retrieve
* @return an Optional containing the model if found, or an empty Optional if not found
*/
public List<String> getModels() {
return models.stream().map(model -> model.getName()).toList();
public Optional<Model> getModel(String modelName) {
if (models == null) {
return Optional.empty();
}

// Check if the model exists in any supported backend
for (NeuralNetworkBackend backend : supportedBackends) {
if (models.containsKey(backend)) {
Optional<Model> model =
models.get(backend).stream().filter(m -> m.getName().equals(modelName)).findFirst();
if (model.isPresent()) {
return model;
}
}
}

return Optional.empty();
}

/**
* Returns the model with the given name.
*
* <p>TODO: Java 17 This should return an Optional Model instead of null.
* The default model when no model is specified.
*
* @param modelName The model name
* @return The model
* @param model
*/
public Model getModel(String modelName) {
Model m =
models.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElse(null);
public Optional<Model> getDefaultModel() {
if (models == null) {
return Optional.empty();
}

if (supportedBackends.isEmpty()) {
return Optional.empty();
}

if (supportedBackends.contains(NeuralNetworkBackend.RKNN)
&& models.containsKey(NeuralNetworkBackend.RKNN)) {
return models.get(NeuralNetworkBackend.RKNN).stream()
.filter(model -> model.getName().equals("note-640-640-yolov5s.rknn"))
.findFirst();
}

return models.get(supportedBackends.get(0)).stream().findFirst();
}

if (m == null) {
logger.error("Model " + modelName + " not found.");
private void loadModel(File model) {
if (models == null) {
models = new HashMap<>();
}

return m;
// Get the model extension and check if it is supported
String modelExtension = model.getName().substring(model.getName().lastIndexOf('.'));
Optional<NeuralNetworkBackend> backend =
Arrays.stream(NeuralNetworkBackend.values())
.filter(b -> b.format.equals(modelExtension))
.findFirst();

if (!backend.isPresent()) {
logger.warn("Model " + model.getName() + " has an unknown extension.");
return;
}

String labels = model.getAbsolutePath().replace(backend.get().format, "-labels.txt");
ArrayList<Model> models = this.models.getOrDefault(backend.get(), new ArrayList<>());

try {
switch (backend.get()) {
case RKNN:
models.add(new RknnModel(model, labels));
break;
default:
break;
}
} catch (IllegalArgumentException e) {
logger.error("Failed to load model " + model.getName(), e);
} catch (IOException e) {
logger.error("Failed to read labels for model " + model.getName(), e);
}
}

/**
Expand All @@ -192,46 +232,41 @@ public void loadModels(File modelsFolder) {
}

if (models == null) {
models = new ArrayList<>();
models = new HashMap<>();
}

try {
Files.walk(modelsFolder.toPath())
.filter(Files::isRegularFile)
.filter(path -> path.toString().endsWith(".rknn"))
.forEach(
modelPath -> {
String model = modelPath.toString();
String labels = model.replace(".rknn", "-labels.txt");

try {
models.add(new Model(model, labels));
} catch (IllegalArgumentException e) {
logger.error("Failed to load model " + model, e);
}
});
.forEach(path -> loadModel(path.toFile()));
} catch (IOException e) {
logger.error("Failed to load models from " + modelsFolder.getAbsolutePath(), e);
}

// After loading all of the models, sort them by name to ensure a consistent ordering each time
models.forEach(
(backend, backendModels) ->
backendModels.sort((a, b) -> a.getName().compareTo(b.getName())));

// Log the loaded models
StringBuilder sb = new StringBuilder();
sb.append("Loaded models: ");
for (Model model : models) {
sb.append(model.modelFile.getName()).append(", ");
}
sb.setLength(sb.length() - 2);
logger.info(sb.toString());
models.forEach(
(backend, backendModels) -> {
sb.append(backend).append(" [");
backendModels.forEach(model -> sb.append(model.getName()).append(", "));
sb.append("] ");
});
}

/**
* Extracts models from a JAR resource and copies them to the specified folder.
* Extracts models from the JAR and copies them to disk.
*
* @param modelsFolder the folder where the models will be copied to
* @param modelsDirectory the directory on disk to save models
*/
public void extractModels(File modelsFolder) {
if (!modelsFolder.exists()) {
modelsFolder.mkdirs();
public void extractModels(File modelsDirectory) {
if (!modelsDirectory.exists()) {
modelsDirectory.mkdirs();
}

String resourcePath = "models";
Expand All @@ -244,7 +279,7 @@ public void extractModels(File modelsFolder) {

Path resourcePathResolved = Paths.get(resourceURL.toURI());
Files.walk(resourcePathResolved)
.forEach(sourcePath -> copyResource(sourcePath, resourcePathResolved, modelsFolder));
.forEach(sourcePath -> copyResource(sourcePath, resourcePathResolved, modelsDirectory));
} catch (Exception e) {
logger.error("Failed to extract models from JAR", e);
}
Expand Down
Loading

0 comments on commit be116a7

Please sign in to comment.