Skip to content

Commit

Permalink
fix: changed metric ids in cvedb to constants and fixed test (#4473) (#…
Browse files Browse the repository at this point in the history
…4475)

* fixes #4473
  • Loading branch information
weichslgartner authored Oct 15, 2024
1 parent 81a15cf commit 2b15bc5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 22 deletions.
10 changes: 7 additions & 3 deletions cve_bin_tool/cvedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
DBNAME = "cve.db"
OLD_CACHE_DIR = Path("~") / ".cache" / "cvedb"

EPSS_METRIC_ID = 1
CVSS_2_METRIC_ID = 2
CVSS_3_METRIC_ID = 3


class CVEDB:
"""
Expand Down Expand Up @@ -615,9 +619,9 @@ def populate_metrics(self):
# Insert a row without specifying cve_metrics_id
insert_metrics = self.INSERT_QUERIES["insert_metrics"]
data = [
(1, "EPSS"),
(2, "CVSS-2"),
(3, "CVSS-3"),
(EPSS_METRIC_ID, "EPSS"),
(CVSS_2_METRIC_ID, "CVSS-2"),
(CVSS_3_METRIC_ID, "CVSS-3"),
]
# Execute the insert query for each row
for row in data:
Expand Down
21 changes: 6 additions & 15 deletions cve_bin_tool/data_sources/epss_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ def __init__(self, error_mode=ErrorMode.TruncTrace):
self.backup_cachedir = self.BACKUPCACHEDIR
self.epss_path = str(Path(self.cachedir) / "epss")
self.file_name = os.path.join(self.epss_path, "epss_scores-current.csv")
self.epss_metric_id = None
self.source_name = self.SOURCE

async def update_epss(self, cursor):
async def update_epss(self):
"""
Updates the EPSS data by downloading and parsing the CSV file.
Returns:
Expand All @@ -51,7 +50,6 @@ async def update_epss(self, cursor):
"""
self.LOGGER.debug("Fetching EPSS data...")

self.EPSS_id_finder(cursor)
await self.download_epss_data()
self.epss_data = self.parse_epss_data()
return self.epss_data
Expand Down Expand Up @@ -110,15 +108,6 @@ async def download_epss_data(self):
except aiohttp.ClientError as e:
self.LOGGER.error(f"An error occurred during downloading epss {e}")

def EPSS_id_finder(self, cursor):
"""Search for metric id in EPSS table"""
query = """
SELECT metrics_id FROM metrics
WHERE metrics_name = "EPSS"
"""
cursor.execute(query)
self.epss_metric_id = cursor.fetchall()[0][0]

def parse_epss_data(self, file_path=None):
"""Parse epss data from the file path given and return the parse data"""
parsed_data = []
Expand All @@ -138,9 +127,11 @@ def parse_epss_data(self, file_path=None):
# Parse the data from the remaining rows
for row in reader:
cve_id, epss_score, epss_percentile = row[:3]
parsed_data.append(
(cve_id, self.epss_metric_id, epss_score, epss_percentile)
)

# prevent circular dependency
from cve_bin_tool.cvedb import EPSS_METRIC_ID

parsed_data.append((cve_id, EPSS_METRIC_ID, epss_score, epss_percentile))
return parsed_data

async def get_cve_data(self):
Expand Down
5 changes: 1 addition & 4 deletions test/test_source_epss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,12 @@ def setup_class(cls):
]

def test_parse_epss(self):
# EPSS need metrics table to populated in the database. To get the EPSS metric id from table.
# EPSS need metrics table to populated in the database. EPSS metric id is a constant.
cvedb = CVEDB()
# creating table
cvedb.init_database()
# populating metrics
cvedb.populate_metrics()
cursor = cvedb.db_open_and_get_cursor()
# seting EPSS_metric_id
self.epss.EPSS_id_finder(cursor)
# parsing the data
self.epss_data = self.epss.parse_epss_data(self.epss.file_name)
cvedb.db_close()
Expand Down

0 comments on commit 2b15bc5

Please sign in to comment.