Skip to content

Commit

Permalink
dnn: allow ReadNet() function to only pass model file, and remove tes…
Browse files Browse the repository at this point in the history
…ts for Caffe

Signed-off-by: deadprogram <[email protected]>
  • Loading branch information
deadprogram committed Sep 9, 2024
1 parent bf7356c commit ee743fa
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 181 deletions.
14 changes: 10 additions & 4 deletions dnn.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,17 @@ func ReadNetBytes(framework string, model []byte, config []byte) (Net, error) {
if err != nil {
return Net{}, err
}
bConfig, err := toByteArray(config)
if err != nil {
return Net{}, err

var bConfig C.ByteArray
if len(config) > 0 {
pbConfig, err := toByteArray(config)
if err != nil {
return Net{}, err
}
bConfig = *pbConfig
}
return Net{p: unsafe.Pointer(C.Net_ReadNetBytes(cFramework, *bModel, *bConfig))}, nil

return Net{p: unsafe.Pointer(C.Net_ReadNetBytes(cFramework, *bModel, bConfig))}, nil
}

// ReadNetFromCaffe reads a network model stored in Caffe framework's format.
Expand Down
210 changes: 33 additions & 177 deletions dnn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,238 +2,94 @@ package gocv

import (
"image"
"io/ioutil"
"os"
"path/filepath"
"testing"
)

func checkNet(t *testing.T, net Net) {
net.SetPreferableBackend(NetBackendDefault)
net.SetPreferableTarget(NetTargetCPU)

img := IMRead("images/space_shuttle.jpg", IMReadColor)
if img.Empty() {
t.Error("Invalid Mat in ReadNet test")
}
defer img.Close()

blob := BlobFromImage(img, 1.0, image.Pt(224, 224), NewScalar(0, 0, 0, 0), false, false)
if blob.Empty() {
t.Error("Invalid blob in ReadNet test")
}
defer blob.Close()

net.SetInput(blob, "data")

layer := net.GetLayer(0)
defer layer.Close()

if layer.InputNameToIndex("notthere") != -1 {
t.Error("Invalid layer in ReadNet test")
}
if layer.OutputNameToIndex("notthere") != -1 {
t.Error("Invalid layer in ReadNet test")
}
if layer.GetName() != "_input" {
t.Errorf("Invalid layer name in ReadNet test: %s\n", layer.GetName())
}
if layer.GetType() != "" {
t.Errorf("Invalid layer type in ReadNet test: %s\n", layer.GetType())
}

ids := net.GetUnconnectedOutLayers()
if len(ids) != 1 {
t.Errorf("Invalid len output layers in ReadNet test: %d\n", len(ids))
}

if len(ids) == 1 && ids[0] != 142 {
t.Errorf("Invalid unconnected output layers in ReadNet test: %d\n", ids[0])
}

lnames := net.GetLayerNames()
if len(lnames) != 142 {
t.Errorf("Invalid len layer names in ReadNet test: %d\n", len(lnames))
}

if len(lnames) == 142 && lnames[1] != "conv1/relu_7x7" {
t.Errorf("Invalid layer name in ReadNet test: %s\n", lnames[1])
}

prob := net.ForwardLayers([]string{"prob"})
if len(prob) == 0 {
t.Error("Invalid len prob in ReadNet test")
}

if prob[0].Empty() {
t.Error("Invalid prob[0] in ReadNet test")
}

probMat := prob[0].Reshape(1, 1)
defer probMat.Close()
_, maxVal, minLoc, maxLoc := MinMaxLoc(probMat)

if round(float64(maxVal), 0.00005) != 0.9998 {
t.Errorf("ReadNet maxVal incorrect: %v\n", round(float64(maxVal), 0.00005))
}

if minLoc.X != 955 || minLoc.Y != 0 {
t.Errorf("ReadNet minLoc incorrect: %v\n", minLoc)
}

if maxLoc.X != 812 || maxLoc.Y != 0 {
t.Errorf("ReadNet maxLoc incorrect: %v\n", maxLoc)
}

perf := net.GetPerfProfile()
if perf == 0 {
t.Error("ReadNet GetPerfProfile error")
}
for _, bl := range prob {
bl.Close()
}
}

func TestReadNetDisk(t *testing.T) {
path := os.Getenv("GOCV_CAFFE_TEST_FILES")
func TestReadNetDiskFromTensorflow(t *testing.T) {
path := os.Getenv("GOCV_TENSORFLOW_TEST_FILES")
if path == "" {
t.Skip("Unable to locate Caffe model files for tests")
t.Skip("Unable to locate Tensorflow model files for tests")
}

net := ReadNet(path+"/bvlc_googlenet.caffemodel", path+"/bvlc_googlenet.prototxt")
net := ReadNet(path+"/tensorflow_inception_graph.pb", "")
if net.Empty() {
t.Errorf("Unable to load Caffe model using ReadNet")
t.Errorf("Unable to load Tensorflow model using ReadNet")
}
defer net.Close()

checkNet(t, net)
checkTensorflowNet(t, net)
}

func TestReadNetMemory(t *testing.T) {
path := os.Getenv("GOCV_CAFFE_TEST_FILES")
func TestReadNetMemoryFromTensorflow(t *testing.T) {
path := os.Getenv("GOCV_TENSORFLOW_TEST_FILES")
if path == "" {
t.Skip("Unable to locate Caffe model files for tests")
t.Skip("Unable to locate Tensorflow model files for tests")
}

bModel, err := ioutil.ReadFile(path + "/bvlc_googlenet.caffemodel")
bModel, err := os.ReadFile(path + "/tensorflow_inception_graph.pb")
if err != nil {
t.Errorf("Failed to load model from file: %v", err)
}

_, err = ReadNetBytes("caffe", nil, nil)
_, err = ReadNetBytes("tensorflow", nil, nil)
if err == nil {
t.Errorf("Should have error for reading nil model bytes")
}

bConfig, err := ioutil.ReadFile(path + "/bvlc_googlenet.prototxt")
if err != nil {
t.Errorf("Failed to load config from file: %v", err)
}

_, err = ReadNetBytes("caffe", bModel, nil)
if err == nil {
t.Errorf("Should have error for reading nil config bytes")
}

net, err := ReadNetBytes("caffe", bModel, bConfig)
net, err := ReadNetBytes("tensorflow", bModel, nil)
if err != nil {
t.Errorf("Failed to read net bytes: %v", err)
}
if net.Empty() {
t.Errorf("Unable to load Caffe model using ReadNetBytes")
t.Errorf("Unable to load Tensorflow model using ReadNetBytes")
}
defer net.Close()

checkNet(t, net)
}

func checkCaffeNet(t *testing.T, net Net) {
img := IMRead("images/space_shuttle.jpg", IMReadColor)
if img.Empty() {
t.Error("Invalid Mat in Caffe test")
}
defer img.Close()

blob := BlobFromImage(img, 1.0, image.Pt(224, 224), NewScalar(0, 0, 0, 0), false, false)
if blob.Empty() {
t.Error("Invalid blob in Caffe test")
}
defer blob.Close()

net.SetInput(blob, "data")
prob := net.Forward("prob")
defer prob.Close()
if prob.Empty() {
t.Error("Invalid prob in Caffe test")
}

probMat := prob.Reshape(1, 1)
defer probMat.Close()
_, maxVal, minLoc, maxLoc := MinMaxLoc(probMat)

if round(float64(maxVal), 0.00005) != 0.9998 {
t.Errorf("Caffe maxVal incorrect: %v\n", round(float64(maxVal), 0.00005))
}

if minLoc.X != 955 || minLoc.Y != 0 {
t.Errorf("Caffe minLoc incorrect: %v\n", minLoc)
}

if maxLoc.X != 812 || maxLoc.Y != 0 {
t.Errorf("Caffe maxLoc incorrect: %v\n", maxLoc)
}
checkTensorflowNet(t, net)
}

func TestCaffeDisk(t *testing.T) {
path := os.Getenv("GOCV_CAFFE_TEST_FILES")
func TestReadNetDiskFromONNX(t *testing.T) {
path := os.Getenv("GOCV_ONNX_TEST_FILES")
if path == "" {
t.Skip("Unable to locate Caffe model files for tests")
t.Skip("Unable to locate ONNX model files for tests")
}

net := ReadNetFromCaffe(path+"/bvlc_googlenet.prototxt", path+"/bvlc_googlenet.caffemodel")
net := ReadNet(filepath.Join(path, "googlenet-9.onnx"), "")
if net.Empty() {
t.Errorf("Unable to load Caffe model")
t.Errorf("Unable to load ONNX model using ReadNet")
}
defer net.Close()

checkCaffeNet(t, net)
checkONNXNet(t, net)
}

func TestCaffeMemory(t *testing.T) {
path := os.Getenv("GOCV_CAFFE_TEST_FILES")
func TestReadNetMemoryFromONNX(t *testing.T) {
path := os.Getenv("GOCV_ONNX_TEST_FILES")
if path == "" {
t.Skip("Unable to locate Caffe model files for tests")
t.Skip("Unable to locate ONNX model files for tests")
}

_, err := ReadNetFromCaffeBytes(nil, nil)
if err == nil {
t.Errorf("Should have error for reading nil model bytes")
}

bPrototxt, err := ioutil.ReadFile(path + "/bvlc_googlenet.prototxt")
bModel, err := os.ReadFile(filepath.Join(path, "googlenet-9.onnx"))
if err != nil {
t.Errorf("Failed to load Caffe prototxt from file: %v", err)
t.Errorf("Failed to load model from file: %v", err)
}

_, err = ReadNetFromCaffeBytes(bPrototxt, nil)
_, err = ReadNetBytes("onnx", nil, nil)
if err == nil {
t.Errorf("Should have error for reading nil config bytes")
t.Errorf("Should have error for reading nil model bytes")
}

bCaffeModel, err := ioutil.ReadFile(path + "/bvlc_googlenet.caffemodel")
net, err := ReadNetBytes("onnx", bModel, nil)
if err != nil {
t.Errorf("Failed to load Caffe caffemodel from file: %v", err)
}
net, err := ReadNetFromCaffeBytes(bPrototxt, bCaffeModel)
if err != nil {
t.Errorf("Error reading caffe from bytes: %v", err)
t.Errorf("Failed to read net bytes: %v", err)
}
if net.Empty() {
t.Errorf("Unable to load Caffe model")
t.Errorf("Unable to load Caffe model using ReadNetBytes")
}
defer net.Close()

checkCaffeNet(t, net)
checkONNXNet(t, net)
}

func checkTensorflowNet(t *testing.T, net Net) {
Expand Down Expand Up @@ -294,7 +150,7 @@ func TestTensorflowMemory(t *testing.T) {
t.Skip("Unable to locate Tensorflow model file for tests")
}

b, err := ioutil.ReadFile(path + "/tensorflow_inception_graph.pb")
b, err := os.ReadFile(path + "/tensorflow_inception_graph.pb")
if err != nil {
t.Errorf("Failed to load tensorflow model from file: %v", err)
}
Expand All @@ -316,7 +172,7 @@ func TestOnnxMemory(t *testing.T) {
t.Skip("Unable to locate ONNX model file for tests")
}

b, err := ioutil.ReadFile(filepath.Join(path, "googlenet-9.onnx"))
b, err := os.ReadFile(filepath.Join(path, "googlenet-9.onnx"))
if err != nil {
t.Errorf("Failed to load ONNX from file: %v", err)
}
Expand Down

0 comments on commit ee743fa

Please sign in to comment.