diff --git a/conf/cuckooml.conf b/conf/cuckooml.conf index 71d65e46..b4cb0296 100644 --- a/conf/cuckooml.conf +++ b/conf/cuckooml.conf @@ -38,3 +38,6 @@ compare_new_samples = true # Set folder for samples to be compared against clustering test_directory = sample_data/test + +# Enable plotting functionality +plotting = true diff --git a/modules/processing/cuckooml.py b/modules/processing/cuckooml.py index 14912c24..d99cf02c 100644 --- a/modules/processing/cuckooml.py +++ b/modules/processing/cuckooml.py @@ -15,11 +15,24 @@ from lib.cuckoo.common.constants import CUCKOO_ROOT from math import log +global PLOTTING_LIBRARIES_IMPORTED +PLOTTING_LIBRARIES_IMPORTED = False + +if Config("cuckooml").cuckooml.plotting: + try: + import matplotlib.pyplot as plt + import seaborn as sns + PLOTTING_LIBRARIES_IMPORTED = True + except ImportError, e: + print >> sys.stderr, "Plotting libraries \ + (matplotlib and seaborn) are not available." + print >> sys.stderr, e + PLOTTING_LIBRARIES_IMPORTED = False + + try: - import matplotlib.pyplot as plt import numpy as np import pandas as pd - import seaborn as sns from hdbscan import HDBSCAN from sklearn import metrics from sklearn.cluster import DBSCAN @@ -797,6 +810,19 @@ def filter_dataset(self, dataset=None, feature_coverage=0.1, def detect_abnormal_behaviour(self, count_dataset=None, figures=True): """Detect samples that behave significantly different than others.""" + + # Safety check for plotting + if not PLOTTING_LIBRARIES_IMPORTED and figures: + print >> sys.stderr, "Warning: plotting libraries were not imported. \n" \ + "Plots wont be produced." + + if not Config("cuckooml").cuckooml.plotting: + print >> sys.stderr, " Plotting is disabled in cuckooml config." + else: + print >> sys.stderr, "Plotting libraries are missing." + figures = False + + if count_dataset is None: # Pull all count features count_features = self.feature_category(":count:") @@ -1133,6 +1159,18 @@ def performance_metric(clustering, labels, data, noise): def clustering_label_distribution(self, clustering, labels, plot=False): """Get statistics about number of ground truth labels per cluster.""" + + # Safety check for plotting + if not PLOTTING_LIBRARIES_IMPORTED and plot: + print >> sys.stderr, "Warning: plotting libraries were not imported. \n" \ + "Plots wont be produced." + + if not Config("cuckooml").cuckooml.plotting: + print >> sys.stderr, " Plotting is disabled in cuckooml config." + else: + print >> sys.stderr, "Plotting libraries are missing." + plot = False + cluster_ids = set(clustering["label"].tolist()) labels_ids = set(labels["label"].tolist()) cluster_distribution = {}