Skip to content

Commit

Permalink
Keep persistent connection
Browse files Browse the repository at this point in the history
Instead of re-connecting to Suricata for every scrape request, keep
a single persistent connection instead protected by a single mutex.

Not expecting many current scrapes, so that should be acceptable.
  • Loading branch information
awelzel committed Aug 1, 2023
1 parent 4fe74d2 commit 1fd1b16
Showing 1 changed file with 74 additions and 36 deletions.
110 changes: 74 additions & 36 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"

"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -274,51 +275,84 @@ var (
}
)

// Send a version message and dump-counters command over the
// Suricata unix socket and return the dump-counters response
// as map[string]interface{}
//
// May want to cleanup/generalize if there's ever a reason to support
// more commands.
func dumpCounters(conn net.Conn) (map[string]interface{}, error) {
var parsed map[string]interface{}
var line []byte
var err error
var cmdData []byte
func NewSuricataClient(socketPath string) *SuricataClient {
return &SuricataClient{socketPath: socketPath}
}

type SuricataClient struct {
socketPath string
conn net.Conn
}

func (c *SuricataClient) Close() {
if c.conn != nil {
c.conn.Close()
}

c.conn = nil
}

func (c *SuricataClient) EnsureConnection() error {
if c.conn == nil {
conn, err := net.Dial("unix", c.socketPath)
if err != nil {
return err
}

c.conn = conn
if err := c.Handshake(); err != nil {
return err
}
}

return nil
}

// Do the version handshake. Returns nil or the error.
func (c *SuricataClient) Handshake() error {
// Send the version as hand-shake.
cmdData, _ = json.Marshal(map[string]string{
cmdData, _ := json.Marshal(map[string]string{
"version": "0.2",
})
fmt.Fprintf(conn, "%s\n", string(cmdData))
fmt.Fprintf(c.conn, "%s\n", string(cmdData))

reader := bufio.NewReader(conn)
line, err = reader.ReadBytes('\n')
reader := bufio.NewReader(c.conn)
line, err := reader.ReadBytes('\n')
if err != nil {
return nil, err
c.Close()
return fmt.Errorf("Failed read response from Suricata: %v", err)
}

var parsed map[string]interface{}
err = json.Unmarshal(line, &parsed)
if err != nil {
return nil, fmt.Errorf("Failed to parse version response from Suricata: %v", err)
c.Close()
return fmt.Errorf("Failed to parse version response from Suricata: %v", err)
}

if parsed["return"] != "OK" {
return nil, fmt.Errorf("No OK response from Suricata: %v", parsed)
c.Close()
return fmt.Errorf("No \"OK\" response from Suricata: %v", parsed)
}

// Send dump-counters command.
cmdData, _ = json.Marshal(map[string]string{
return nil
}

// Send dump-counters command and return JSON as parsed map[string]interface{}
func (c *SuricataClient) DumpCounters() (map[string]interface{}, error) {
cmdData, _ := json.Marshal(map[string]string{
"command": "dump-counters",
})
fmt.Fprintf(conn, "%s\n", string(cmdData))
fmt.Fprintf(c.conn, "%s\n", string(cmdData))

// Read until '\n' shows up or there was an error. A lot of data
// is retuned, so may read short.
reader := bufio.NewReader(c.conn)
var response []byte
for {
data, err := reader.ReadBytes('\n')
if err != nil {
c.Close()
return nil, err
}

Expand All @@ -328,11 +362,13 @@ func dumpCounters(conn net.Conn) (map[string]interface{}, error) {
}
}

parsed = make(map[string]interface{})
var parsed map[string]interface{}
if err := json.Unmarshal(response, &parsed); err != nil {
c.Close()
return nil, err
}
if parsed["return"] != "OK" {
c.Close()
return nil, fmt.Errorf("ERROR: No OK response from Suricata: %v", parsed)
}

Expand Down Expand Up @@ -360,14 +396,6 @@ func newConstMetric(m metricInfo, data map[string]interface{}, labelValues ...st
return prometheus.MustNewConstMetric(m.desc, m.t, value, labelValues...)
}

type suricataCollector struct {
socketPath string
}

func (sc *suricataCollector) Describe(ch chan<- *prometheus.Desc) {
// No need?
}

// Extract Napatech related metrics from message
func handleNapatechMetrics(ch chan<- prometheus.Metric, message map[string]interface{}) {
if napaTotal, ok := message["napa_total"].(map[string]interface{}); ok {
Expand Down Expand Up @@ -613,15 +641,25 @@ func produceMetrics(ch chan<- prometheus.Metric, counters map[string]interface{}
handleNapatechMetrics(ch, message)
}

type suricataCollector struct {
client *SuricataClient
mu sync.Mutex // SuricataClient is not re-entrant, easy way out.
}

func (sc *suricataCollector) Describe(ch chan<- *prometheus.Desc) {
// No need?
}

func (sc *suricataCollector) Collect(ch chan<- prometheus.Metric) {
conn, err := net.Dial("unix", sc.socketPath)
if err != nil {
log.Printf("ERROR: Failed to connect to %v: %v", sc.socketPath, err)
sc.mu.Lock()
defer sc.mu.Unlock()

if err := sc.client.EnsureConnection(); err != nil {
log.Printf("ERROR: Failed to connect to %v", err)
return
}
defer conn.Close()

counters, err := dumpCounters(conn)
counters, err := sc.client.DumpCounters()
if err != nil {
log.Printf("ERROR: Failed to dump-counters: %v", err)
return
Expand Down Expand Up @@ -652,7 +690,7 @@ func main() {
return
}
r := prometheus.NewRegistry()
r.MustRegister(&suricataCollector{*socketPath})
r.MustRegister(&suricataCollector{NewSuricataClient(*socketPath), sync.Mutex{}})

http.Handle(*path, promhttp.HandlerFor(r, promhttp.HandlerOpts{}))
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 1fd1b16

Please sign in to comment.