diff --git a/dnn.go b/dnn.go index da30e2ce..32818bd1 100644 --- a/dnn.go +++ b/dnn.go @@ -348,11 +348,17 @@ func ReadNetBytes(framework string, model []byte, config []byte) (Net, error) { if err != nil { return Net{}, err } - bConfig, err := toByteArray(config) - if err != nil { - return Net{}, err + + var bConfig C.ByteArray + if len(config) > 0 { + pbConfig, err := toByteArray(config) + if err != nil { + return Net{}, err + } + bConfig = *pbConfig } - return Net{p: unsafe.Pointer(C.Net_ReadNetBytes(cFramework, *bModel, *bConfig))}, nil + + return Net{p: unsafe.Pointer(C.Net_ReadNetBytes(cFramework, *bModel, bConfig))}, nil } // ReadNetFromCaffe reads a network model stored in Caffe framework's format. diff --git a/dnn_test.go b/dnn_test.go index 7ca04972..958c0922 100644 --- a/dnn_test.go +++ b/dnn_test.go @@ -2,238 +2,94 @@ package gocv import ( "image" - "io/ioutil" "os" "path/filepath" "testing" ) -func checkNet(t *testing.T, net Net) { - net.SetPreferableBackend(NetBackendDefault) - net.SetPreferableTarget(NetTargetCPU) - - img := IMRead("images/space_shuttle.jpg", IMReadColor) - if img.Empty() { - t.Error("Invalid Mat in ReadNet test") - } - defer img.Close() - - blob := BlobFromImage(img, 1.0, image.Pt(224, 224), NewScalar(0, 0, 0, 0), false, false) - if blob.Empty() { - t.Error("Invalid blob in ReadNet test") - } - defer blob.Close() - - net.SetInput(blob, "data") - - layer := net.GetLayer(0) - defer layer.Close() - - if layer.InputNameToIndex("notthere") != -1 { - t.Error("Invalid layer in ReadNet test") - } - if layer.OutputNameToIndex("notthere") != -1 { - t.Error("Invalid layer in ReadNet test") - } - if layer.GetName() != "_input" { - t.Errorf("Invalid layer name in ReadNet test: %s\n", layer.GetName()) - } - if layer.GetType() != "" { - t.Errorf("Invalid layer type in ReadNet test: %s\n", layer.GetType()) - } - - ids := net.GetUnconnectedOutLayers() - if len(ids) != 1 { - t.Errorf("Invalid len output layers in ReadNet test: %d\n", len(ids)) - } - - if len(ids) == 1 && ids[0] != 142 { - t.Errorf("Invalid unconnected output layers in ReadNet test: %d\n", ids[0]) - } - - lnames := net.GetLayerNames() - if len(lnames) != 142 { - t.Errorf("Invalid len layer names in ReadNet test: %d\n", len(lnames)) - } - - if len(lnames) == 142 && lnames[1] != "conv1/relu_7x7" { - t.Errorf("Invalid layer name in ReadNet test: %s\n", lnames[1]) - } - - prob := net.ForwardLayers([]string{"prob"}) - if len(prob) == 0 { - t.Error("Invalid len prob in ReadNet test") - } - - if prob[0].Empty() { - t.Error("Invalid prob[0] in ReadNet test") - } - - probMat := prob[0].Reshape(1, 1) - defer probMat.Close() - _, maxVal, minLoc, maxLoc := MinMaxLoc(probMat) - - if round(float64(maxVal), 0.00005) != 0.9998 { - t.Errorf("ReadNet maxVal incorrect: %v\n", round(float64(maxVal), 0.00005)) - } - - if minLoc.X != 955 || minLoc.Y != 0 { - t.Errorf("ReadNet minLoc incorrect: %v\n", minLoc) - } - - if maxLoc.X != 812 || maxLoc.Y != 0 { - t.Errorf("ReadNet maxLoc incorrect: %v\n", maxLoc) - } - - perf := net.GetPerfProfile() - if perf == 0 { - t.Error("ReadNet GetPerfProfile error") - } - for _, bl := range prob { - bl.Close() - } -} - -func TestReadNetDisk(t *testing.T) { - path := os.Getenv("GOCV_CAFFE_TEST_FILES") +func TestReadNetDiskFromTensorflow(t *testing.T) { + path := os.Getenv("GOCV_TENSORFLOW_TEST_FILES") if path == "" { - t.Skip("Unable to locate Caffe model files for tests") + t.Skip("Unable to locate Tensorflow model files for tests") } - net := ReadNet(path+"/bvlc_googlenet.caffemodel", path+"/bvlc_googlenet.prototxt") + net := ReadNet(path+"/tensorflow_inception_graph.pb", "") if net.Empty() { - t.Errorf("Unable to load Caffe model using ReadNet") + t.Errorf("Unable to load Tensorflow model using ReadNet") } defer net.Close() - checkNet(t, net) + checkTensorflowNet(t, net) } -func TestReadNetMemory(t *testing.T) { - path := os.Getenv("GOCV_CAFFE_TEST_FILES") +func TestReadNetMemoryFromTensorflow(t *testing.T) { + path := os.Getenv("GOCV_TENSORFLOW_TEST_FILES") if path == "" { - t.Skip("Unable to locate Caffe model files for tests") + t.Skip("Unable to locate Tensorflow model files for tests") } - bModel, err := ioutil.ReadFile(path + "/bvlc_googlenet.caffemodel") + bModel, err := os.ReadFile(path + "/tensorflow_inception_graph.pb") if err != nil { t.Errorf("Failed to load model from file: %v", err) } - _, err = ReadNetBytes("caffe", nil, nil) + _, err = ReadNetBytes("tensorflow", nil, nil) if err == nil { t.Errorf("Should have error for reading nil model bytes") } - bConfig, err := ioutil.ReadFile(path + "/bvlc_googlenet.prototxt") - if err != nil { - t.Errorf("Failed to load config from file: %v", err) - } - - _, err = ReadNetBytes("caffe", bModel, nil) - if err == nil { - t.Errorf("Should have error for reading nil config bytes") - } - - net, err := ReadNetBytes("caffe", bModel, bConfig) + net, err := ReadNetBytes("tensorflow", bModel, nil) if err != nil { t.Errorf("Failed to read net bytes: %v", err) } if net.Empty() { - t.Errorf("Unable to load Caffe model using ReadNetBytes") + t.Errorf("Unable to load Tensorflow model using ReadNetBytes") } defer net.Close() - checkNet(t, net) -} - -func checkCaffeNet(t *testing.T, net Net) { - img := IMRead("images/space_shuttle.jpg", IMReadColor) - if img.Empty() { - t.Error("Invalid Mat in Caffe test") - } - defer img.Close() - - blob := BlobFromImage(img, 1.0, image.Pt(224, 224), NewScalar(0, 0, 0, 0), false, false) - if blob.Empty() { - t.Error("Invalid blob in Caffe test") - } - defer blob.Close() - - net.SetInput(blob, "data") - prob := net.Forward("prob") - defer prob.Close() - if prob.Empty() { - t.Error("Invalid prob in Caffe test") - } - - probMat := prob.Reshape(1, 1) - defer probMat.Close() - _, maxVal, minLoc, maxLoc := MinMaxLoc(probMat) - - if round(float64(maxVal), 0.00005) != 0.9998 { - t.Errorf("Caffe maxVal incorrect: %v\n", round(float64(maxVal), 0.00005)) - } - - if minLoc.X != 955 || minLoc.Y != 0 { - t.Errorf("Caffe minLoc incorrect: %v\n", minLoc) - } - - if maxLoc.X != 812 || maxLoc.Y != 0 { - t.Errorf("Caffe maxLoc incorrect: %v\n", maxLoc) - } + checkTensorflowNet(t, net) } -func TestCaffeDisk(t *testing.T) { - path := os.Getenv("GOCV_CAFFE_TEST_FILES") +func TestReadNetDiskFromONNX(t *testing.T) { + path := os.Getenv("GOCV_ONNX_TEST_FILES") if path == "" { - t.Skip("Unable to locate Caffe model files for tests") + t.Skip("Unable to locate ONNX model files for tests") } - net := ReadNetFromCaffe(path+"/bvlc_googlenet.prototxt", path+"/bvlc_googlenet.caffemodel") + net := ReadNet(filepath.Join(path, "googlenet-9.onnx"), "") if net.Empty() { - t.Errorf("Unable to load Caffe model") + t.Errorf("Unable to load ONNX model using ReadNet") } defer net.Close() - checkCaffeNet(t, net) + checkONNXNet(t, net) } -func TestCaffeMemory(t *testing.T) { - path := os.Getenv("GOCV_CAFFE_TEST_FILES") +func TestReadNetMemoryFromONNX(t *testing.T) { + path := os.Getenv("GOCV_ONNX_TEST_FILES") if path == "" { - t.Skip("Unable to locate Caffe model files for tests") + t.Skip("Unable to locate ONNX model files for tests") } - _, err := ReadNetFromCaffeBytes(nil, nil) - if err == nil { - t.Errorf("Should have error for reading nil model bytes") - } - - bPrototxt, err := ioutil.ReadFile(path + "/bvlc_googlenet.prototxt") + bModel, err := os.ReadFile(filepath.Join(path, "googlenet-9.onnx")) if err != nil { - t.Errorf("Failed to load Caffe prototxt from file: %v", err) + t.Errorf("Failed to load model from file: %v", err) } - _, err = ReadNetFromCaffeBytes(bPrototxt, nil) + _, err = ReadNetBytes("onnx", nil, nil) if err == nil { - t.Errorf("Should have error for reading nil config bytes") + t.Errorf("Should have error for reading nil model bytes") } - bCaffeModel, err := ioutil.ReadFile(path + "/bvlc_googlenet.caffemodel") + net, err := ReadNetBytes("onnx", bModel, nil) if err != nil { - t.Errorf("Failed to load Caffe caffemodel from file: %v", err) - } - net, err := ReadNetFromCaffeBytes(bPrototxt, bCaffeModel) - if err != nil { - t.Errorf("Error reading caffe from bytes: %v", err) + t.Errorf("Failed to read net bytes: %v", err) } if net.Empty() { - t.Errorf("Unable to load Caffe model") + t.Errorf("Unable to load Caffe model using ReadNetBytes") } defer net.Close() - - checkCaffeNet(t, net) + checkONNXNet(t, net) } func checkTensorflowNet(t *testing.T, net Net) { @@ -294,7 +150,7 @@ func TestTensorflowMemory(t *testing.T) { t.Skip("Unable to locate Tensorflow model file for tests") } - b, err := ioutil.ReadFile(path + "/tensorflow_inception_graph.pb") + b, err := os.ReadFile(path + "/tensorflow_inception_graph.pb") if err != nil { t.Errorf("Failed to load tensorflow model from file: %v", err) } @@ -316,7 +172,7 @@ func TestOnnxMemory(t *testing.T) { t.Skip("Unable to locate ONNX model file for tests") } - b, err := ioutil.ReadFile(filepath.Join(path, "googlenet-9.onnx")) + b, err := os.ReadFile(filepath.Join(path, "googlenet-9.onnx")) if err != nil { t.Errorf("Failed to load ONNX from file: %v", err) }