Skip to content

Commit

Permalink
Rewrite NeuralNetworkModelManager
Browse files Browse the repository at this point in the history
  • Loading branch information
Alextopher committed Jul 2, 2024
1 parent fab7591 commit 7a68919
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,87 +18,250 @@
package org.photonvision.common.configuration;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.List;
import org.photonvision.common.logging.LogGroup;
import org.photonvision.common.logging.Logger;
import org.photonvision.rknn.RknnJNI;

/**
* Manages the loading of neural network models.
*
* <p>Models are loaded from the filesystem at the <code>modelsFolder</code> location. PhotonVision
* also supports shipping pre-trained models as resources in the JAR. If the model is not found on
* the filesystem, it will be extracted from the JAR to the filesystem.
*
* <p>Each model must have a corresponding <code>labels</code> file. The labels file format is
* simply a list of string names per label, one label per line. The labels file must have the same
* name as the model file, but with the suffix <code>-labels.txt</code> instead of <code>.rknn
* </code>.
*
* <p>Note: PhotonVision currently only supports YOLOv5 and YOLOv8 models in the <code>.rknn</code>
* format.
*/
public class NeuralNetworkModelManager {
/** Singleton instance of the NeuralNetworkModelManager */
private static NeuralNetworkModelManager INSTANCE;
private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config);

private final String MODEL_NAME = "note-640-640-yolov5s.rknn";
private final RknnJNI.ModelVersion modelVersion = RknnJNI.ModelVersion.YOLO_V5;
private File defaultModelFile;
private List<String> labels;
/**
* Private constructor to prevent instantiation
*
* @return The NeuralNetworkModelManager instance
*/
private NeuralNetworkModelManager() {}

/**
* Returns the singleton instance of the NeuralNetworkModelManager
*
* @return The singleton instance
*/
public static NeuralNetworkModelManager getInstance() {
if (INSTANCE == null) {
INSTANCE = new NeuralNetworkModelManager();
}
return INSTANCE;
}

/** Logger for the NeuralNetworkModelManager */
private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config);

/**
* Determines the model version based on the model's filename.
*
* <p>"yolov5" -> "YOLO_V5"
*
* <p>"yolov8" -> "YOLO_V8"
*
* @param modelName The model's filename
* @return The model version
*/
private static RknnJNI.ModelVersion getModelVersion(String modelName)
throws IllegalArgumentException {
if (modelName.contains("yolov5")) {
return RknnJNI.ModelVersion.YOLO_V5;
} else if (modelName.contains("yolov8")) {
return RknnJNI.ModelVersion.YOLO_V8;
} else {
throw new IllegalArgumentException("Unknown model version for model " + modelName);
}
}

/** This class represents a model that can be loaded by the RknnJNI. */
public class Model {
public final File modelFile;
public final RknnJNI.ModelVersion version;
public final List<String> labels;

public Model(String model, String labels) throws IllegalArgumentException {
this.version = getModelVersion(model);
this.modelFile = new File(model);
try {
this.labels = Files.readAllLines(Paths.get(labels));
} catch (IOException e) {
throw new IllegalArgumentException("Error reading labels file " + labels, e);
}
}

public String getPath() {
return modelFile.getAbsolutePath();
}
}

/**
* Stores model information, such as the model file, labels, and version.
*
* <p>The first model in the list is the default model.
*/
private List<Model> models;

/**
* Perform initial setup and extract default model from JAR to the filesystem
* Returns the default rknn model. This is simply the first model in the list.
*
* @param modelsFolder Where models live
* @return The default model
*/
public void initialize(File modelsFolder) {
var modelResourcePath = "/models/" + MODEL_NAME;
this.defaultModelFile = new File(modelsFolder, MODEL_NAME);
extractResource(modelResourcePath, defaultModelFile);
public Model getDefaultRknnModel() {
return models.get(0);
}

File labelsFile = new File(modelsFolder, "labels_v5.txt");
var labelResourcePath = "/models/" + labelsFile.getName();
extractResource(labelResourcePath, labelsFile);
/**
* Enumerates the names of all models.
*
* @return A list of model names
*/
public List<String> getModels() {
return models.stream().map(model -> model.modelFile.getName()).toList();
}

/**
* Returns the model with the given name.
*
* <p>TODO: Java 17 This should return an Optional<Model> instead of null.
*
* @param modelName The model name
* @return The model
*/
public Model getModel(String modelName) {
Model m =
models.stream()
.filter(model -> model.modelFile.getName().equals(modelName))
.findFirst()
.orElse(null);

if (m == null) {
logger.error("Model " + modelName + " not found.");
}

return m;
}

/**
* Loads models from the specified folder.
*
* @param modelsFolder The folder where the models are stored
*/
public void loadModels(File modelsFolder) {
if (!modelsFolder.exists()) {
logger.error("Models folder " + modelsFolder.getAbsolutePath() + " does not exist.");
return;
}

if (models == null) {
models = new ArrayList<>();
}

try {
labels = Files.readAllLines(Paths.get(labelsFile.getPath()));
Files.walk(modelsFolder.toPath())
.filter(Files::isRegularFile)
.filter(path -> path.toString().endsWith(".rknn"))
.forEach(
modelPath -> {
String model = modelPath.toString();
String labels = model.replace(".rknn", "-labels.txt");

try {
models.add(new Model(model, labels));
} catch (IllegalArgumentException e) {
logger.error("Failed to load model " + model, e);
}
});
} catch (IOException e) {
logger.error("Error reading labels.txt", e);
logger.error("Failed to load models from " + modelsFolder.getAbsolutePath(), e);
}

// Log the loaded models
StringBuilder sb = new StringBuilder();
sb.append("Loaded models: ");
for (Model model : models) {
sb.append(model.modelFile.getName()).append(", ");
}
sb.setLength(sb.length() - 2);
logger.info(sb.toString());
}

private void extractResource(String resourcePath, File outputFile) {
try (var in = NeuralNetworkModelManager.class.getResourceAsStream(resourcePath)) {
if (in == null) {
/**
* Extracts models from a JAR resource and copies them to the specified folder.
*
* @param modelsFolder the folder where the models will be copied to
*/
public void extractModels(File modelsFolder) {
if (!modelsFolder.exists()) {
modelsFolder.mkdirs();
}

String resourcePath = "models"; // Adjust path if necessary
try {
URL resourceURL = NeuralNetworkModelManager.class.getClassLoader().getResource(resourcePath);
if (resourceURL == null) {
logger.error("Failed to find jar resource at " + resourcePath);
return;
}

if (!outputFile.exists()) {
try (FileOutputStream fos = new FileOutputStream(outputFile)) {
int read = -1;
byte[] buffer = new byte[1024];
while ((read = in.read(buffer)) != -1) {
fos.write(buffer, 0, read);
Path resourcePathResolved = Paths.get(resourceURL.toURI());
Files.walk(resourcePathResolved)
.forEach(sourcePath -> copyResource(sourcePath, resourcePathResolved, modelsFolder));
} catch (Exception e) {
logger.error("Failed to extract models from JAR", e);
}
}

/**
* Copies a resource from the source path to the target path.
*
* @param sourcePath The path of the resource to be copied.
* @param resourcePathResolved The resolved path of the resource.
* @param modelsFolder The folder where the resource will be copied to.
*/
private void copyResource(Path sourcePath, Path resourcePathResolved, File modelsFolder) {
Path targetPath =
Paths.get(
modelsFolder.getAbsolutePath(), resourcePathResolved.relativize(sourcePath).toString());
try {
if (Files.isDirectory(sourcePath)) {
Files.createDirectories(targetPath);
} else {
Path parentDir = targetPath.getParent();
if (parentDir != null && !Files.exists(parentDir)) {
Files.createDirectories(parentDir);
}

if (!Files.exists(targetPath)) {
Files.copy(sourcePath, targetPath);
} else {
long sourceSize = Files.size(sourcePath);
long targetSize = Files.size(targetPath);
if (sourceSize != targetSize) {
Files.copy(sourcePath, targetPath, StandardCopyOption.REPLACE_EXISTING);
}
} catch (IOException e) {
logger.error("Error extracting resource to " + outputFile.toPath().toString(), e);
}
} else {
logger.info(
"File " + outputFile.toPath().toString() + " already exists. Skipping extraction.");
}
} catch (IOException e) {
logger.error("Error finding jar resource " + resourcePath, e);
logger.error("Failed to copy " + sourcePath + " to " + targetPath, e);
}
}

public File getDefaultRknnModel() {
return defaultModelFile;
}

public List<String> getLabels() {
return labels;
}

public RknnJNI.ModelVersion getModelVersion() {
return modelVersion;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import org.opencv.core.Mat;
import org.photonvision.common.configuration.NeuralNetworkModelManager;
import org.photonvision.common.logging.LogGroup;
import org.photonvision.common.logging.Logger;
import org.photonvision.common.util.TestUtils;
Expand Down Expand Up @@ -70,6 +71,10 @@ public static class RknnObjectDetector {

static volatile boolean hook = false;

public RknnObjectDetector(NeuralNetworkModelManager.Model model) {
this(model.getPath(), model.labels, model.version);
}

public RknnObjectDetector(String modelPath, List<String> labels, RknnJNI.ModelVersion version) {
synchronized (lock) {
objPointer = RknnJNI.create(modelPath, labels.size(), version.ordinal(), -1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,10 @@ public class RknnDetectionPipe

public RknnDetectionPipe() {
// For now this is hard-coded to defaults. Should be refactored into set pipe
// params, though.
// And ideally a little wrapper helper for only changing native stuff on content
// params, though. And ideally a little wrapper helper for only changing native stuff on content
// change created.
this.detector =
new RknnObjectDetector(
NeuralNetworkModelManager.getInstance().getDefaultRknnModel().getAbsolutePath(),
NeuralNetworkModelManager.getInstance().getLabels(),
NeuralNetworkModelManager.getInstance().getModelVersion());
new RknnObjectDetector(NeuralNetworkModelManager.getInstance().getDefaultRknnModel());
}

private static class Letterbox {
Expand Down
5 changes: 3 additions & 2 deletions photon-server/src/main/java/org/photonvision/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,9 @@ public static void main(String[] args) {
.setConfig(ConfigManager.getInstance().getConfig().getNetworkConfig());

logger.info("Loading ML models");
NeuralNetworkModelManager.getInstance()
.initialize(ConfigManager.getInstance().getModelsDirectory());
var modelManager = NeuralNetworkModelManager.getInstance();
modelManager.extractModels(ConfigManager.getInstance().getModelsDirectory());
modelManager.loadModels(ConfigManager.getInstance().getModelsDirectory());

if (isSmoketest) {
logger.info("PhotonVision base functionality loaded -- smoketest complete");
Expand Down

0 comments on commit 7a68919

Please sign in to comment.