Skip to content

Commit

Permalink
feat: Update EPSS queries and test cases (#3172)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexbeast2 authored Aug 7, 2023
1 parent 0bdf6ad commit 06b55f7
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 19 deletions.
41 changes: 41 additions & 0 deletions cve_bin_tool/cve_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,20 @@ def get_cves(self, product_info: ProductInfo, triage_data: TriageData):
row_dict["cvss_version"] = (
row_dict["cvss_version"] or row["cvss_version"]
)
# executing query to get metric for CVE
metric_result = self.metric((row["cve_number"],))
# row_dict doesnt have metric as key. As it based on result from query on cve_severity table
# declaring row_dict[metric]
row_dict["metric"] = {}
# # looping for result of query for metrics.
for key, value in metric_result.items():
row_dict["metric"][key] = [
value[0],
value[1],
]
self.logger.debug(
f'metrics found in CVE {row_dict["cve_number"]} is {row_dict["metric"]}'
)
cve = CVE(**row_dict)
cves.append(cve)

Expand Down Expand Up @@ -344,6 +358,33 @@ def affected(self):
for cve_data in self.all_cve_data
)

def metric(self, cve_number):
"""The query needs to be executed separately because if it is executed using the same cursor, the search stops.
We need to create a separate connection and cursor for the query to be executed independently.
Finally, the function should return a dictionary with the metrics of a given CVE.
"""
conn = sqlite3.connect(self.dbname)
cur = conn.cursor()
query = """
SELECT metrics.metrics_name, cve_metrics.metric_score, cve_metrics.metric_field
FROM cve_metrics, metrics
WHERE cve_metrics.cve_number = ? AND cve_metrics.metric_id = metrics.metrics_id
GROUP BY cve_metrics.metric_id;
"""
metric_result = cur.execute(query, (cve_number))
met = {}
# looping for result of query for metrics.
for result in metric_result:
metric_name, metric_score, metric_field = result
met[metric_name] = [
metric_score,
metric_field,
]
self.logger.debug(f"metrics found in CVE {cve_number} is {met}")
cur.close()
conn.close()
return met

def __enter__(self):
self.connection = sqlite3.connect(self.dbname)
self.connection.row_factory = sqlite3.Row
Expand Down
27 changes: 22 additions & 5 deletions cve_bin_tool/cvedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ async def refresh(self) -> None:
if self.version_check:
check_latest_version()

epss = epss_source.Epss_Source()
self.epss_data = await epss.update_epss()
await self.get_data()

def refresh_cache_and_update_db(self) -> None:
Expand Down Expand Up @@ -456,9 +454,13 @@ def populate_db(self) -> None:
we'll need a better parser to match those together.
"""

self.store_epss_data()
self.populate_metrics()

# EPSS uses metrics table to get the EPSS metric id.
# It can't be ran before creation of metrics table.
self.populate_epss()
self.store_epss_data()

for idx, data in enumerate(self.data):
_, source_name = data

Expand Down Expand Up @@ -532,6 +534,7 @@ def populate_severity(self, severity_data, cursor, data_source):
cursor.execute(del_cve_range, [cve["ID"], data_source])

def populate_cve_metrics(self, severity_data, cursor):
"""Adds data into CVE metrics table"""
insert_cve_metrics = self.INSERT_QUERIES["insert_cve_metrics"]

for cve in severity_data:
Expand Down Expand Up @@ -585,6 +588,7 @@ def populate_affected(self, affected_data, cursor, data_source):
LOGGER.info(f"Unable to insert data for {data_source} - {e}")

def populate_metrics(self):
"""Adding data to metric table."""
cursor = self.db_open_and_get_cursor()
# Insert a row without specifying cve_metrics_id
insert_metrics = self.INSERT_QUERIES["insert_metrics"]
Expand All @@ -599,9 +603,19 @@ def populate_metrics(self):
self.connection.commit()
self.db_close()

def populate_epss(self):
"""Exploit Prediction Scoring System (EPSS) data to help users evaluate risks
Add EPSS data into the database"""
epss = epss_source.Epss_Source()
cursor = self.db_open_and_get_cursor()
self.epss_data = run_coroutine(epss.update_epss(cursor))
self.db_close()

def metric_finder(self, cursor, cve):
# SQL query to retrieve the metrics_name based on the metrics_id
# currently cve["CVSS_version"] return 2,3 based on there version and they are mapped accordingly to there metrics name in metrics table.
"""
SQL query to retrieve the metrics_name based on the metrics_id
currently cve["CVSS_version"] return 2,3 based on there version and they are mapped accordingly to there metrics name in metrics table.
"""
query = """
SELECT metrics_id FROM metrics
WHERE metrics_id=?
Expand All @@ -615,6 +629,9 @@ def metric_finder(self, cursor, cve):
metric = list(map(lambda x: x[0], cursor.fetchall()))
# Since the query is expected to return a single result, extract the first item from the list and store it in 'metric'
metric = metric[0]
self.LOGGER.debug(
f'For the given cve {cve["ID"]} the cvss version found {cve["CVSS_version"]} metrics ID added into database {metric}'
)
return metric

def clear_cached_data(self) -> None:
Expand Down
27 changes: 20 additions & 7 deletions cve_bin_tool/data_sources/epss_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def __init__(self, error_mode=ErrorMode.TruncTrace):
self.backup_cachedir = self.BACKUPCACHEDIR
self.epss_path = str(Path(self.cachedir) / "epss")
self.file_name = os.path.join(self.epss_path, "epss_scores-current.csv")
self.epss_metric_id = None

async def update_epss(self):
async def update_epss(self, cursor):
"""
Updates the EPSS data by downloading and parsing the CSV file.
Returns:
Expand All @@ -42,18 +43,19 @@ async def update_epss(self):
- EPSS score
- EPSS percentile
"""
self.EPSS_id_finder(cursor)
await self.download_and_parse_epss()
return self.epss_data

async def download_and_parse_epss(self):
# Downloads and parses the EPSS data from the CSV file.
"""Downloads and parses the EPSS data from the CSV file."""
await self.download_epss_data()
self.epss_data = self.parse_epss_data()

async def download_epss_data(self):
# Downloads the EPSS CSV file and saves it to the local filesystem.
# The download is only performed if the file is older than 24 hours.

"""Downloads the EPSS CSV file and saves it to the local filesystem.
The download is only performed if the file is older than 24 hours.
"""
os.makedirs(self.epss_path, exist_ok=True)
# Check if the file exists
if os.path.exists(self.file_name):
Expand Down Expand Up @@ -100,7 +102,17 @@ async def download_epss_data(self):
except aiohttp.ClientError as e:
self.LOGGER.error(f"An error occurred during downloading epss {e}")

def EPSS_id_finder(self, cursor):
"""Search for metric id in EPSS table"""
query = """
SELECT metrics_id FROM metrics
WHERE metrics_name = "EPSS"
"""
cursor.execute(query)
self.epss_metric_id = cursor.fetchall()[0][0]

def parse_epss_data(self, file_path=None):
"""Parse epss data from the file path given and return the parse data"""
parsed_data = []
if file_path is None:
file_path = self.file_name
Expand All @@ -115,9 +127,10 @@ def parse_epss_data(self, file_path=None):
# Skip the first line (header) and the next line (empty line)
next(reader)
next(reader)

# Parse the data from the remaining rows
for row in reader:
cve_id, epss_score, epss_percentile = row[:3]
parsed_data.append((cve_id, "EPSS", epss_score, epss_percentile))
parsed_data.append(
(cve_id, self.epss_metric_id, epss_score, epss_percentile)
)
return parsed_data
1 change: 1 addition & 0 deletions cve_bin_tool/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class CVE(NamedTuple):
cvss_vector: str = ""
data_source: str = ""
last_modified: str = ""
metric: dict[str, dict[float, str]] = {}


class ProductInfo(NamedTuple):
Expand Down
26 changes: 19 additions & 7 deletions test/test_source_epss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

from cve_bin_tool.cvedb import CVEDB
from cve_bin_tool.data_sources import epss_source


Expand All @@ -12,15 +13,26 @@ def setup_class(cls):
)

final_data = [
("CVE-1999-0001", "EPSS", "0.011", "0.82987"),
("CVE-2019-10354", "EPSS", "0.00287", "0.64385"),
("CVE-1999-0003", "EPSS", "0.999", "0.88555"),
("CVE-2023-28143", "EPSS", "0.00042", "0.05685"),
("CVE-2017-15360", "EPSS", "0.00078", "0.31839"),
("CVE-2008-4444", "EPSS", "0.07687", "0.93225"),
("CVE-1999-0007", "EPSS", "0.00180", "0.54020"),
("CVE-1999-0001", 1, "0.011", "0.82987"),
("CVE-2019-10354", 1, "0.00287", "0.64385"),
("CVE-1999-0003", 1, "0.999", "0.88555"),
("CVE-2023-28143", 1, "0.00042", "0.05685"),
("CVE-2017-15360", 1, "0.00078", "0.31839"),
("CVE-2008-4444", 1, "0.07687", "0.93225"),
("CVE-1999-0007", 1, "0.00180", "0.54020"),
]

def test_parse_epss(self):
# EPSS need metrics table to populated in the database. To get the EPSS metric id from table.
cvedb = CVEDB()
# creating table
cvedb.init_database()
# populating metrics
cvedb.populate_metrics()
cursor = cvedb.db_open_and_get_cursor()
# seting EPSS_metric_id
self.epss.EPSS_id_finder(cursor)
# parsing the data
self.epss_data = self.epss.parse_epss_data(self.epss.file_name)
cvedb.db_close()
assert self.epss_data == self.final_data

0 comments on commit 06b55f7

Please sign in to comment.