Skip to content

Commit

Permalink
fix: create new version comparison function (#3470)
Browse files Browse the repository at this point in the history
* fix: create new version comparison function

We had been using packaging's version parsing tools, but as they move more
towards pep440 compliance they aren't as useful for comparing arbitrary
versions that may not follow the same scheme.  This moves us to our own
function.  It may need some further tweaking for special cases such as release
candidates or dev versions.

Signed-off-by: Terri Oda <[email protected]>
  • Loading branch information
terriko authored Nov 16, 2023
1 parent 94bbae5 commit 9a95eca
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 132 deletions.
85 changes: 8 additions & 77 deletions cve_bin_tool/cve_scanner.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: GPL-3.0-or-later

import re
import sqlite3
import sys
from collections import defaultdict
from logging import Logger
from pathlib import Path
from string import ascii_lowercase
from typing import DefaultDict, Dict, List, Tuple
from typing import DefaultDict, Dict, List

from packaging.version import Version
from packaging.version import parse as parse_version
from rich.console import Console

from cve_bin_tool.cvedb import DBNAME, DISK_LOCATION_DEFAULT
Expand All @@ -20,6 +17,7 @@
from cve_bin_tool.log import LOGGER
from cve_bin_tool.theme import cve_theme
from cve_bin_tool.util import CVE, CVEData, ProductInfo, VersionInfo
from cve_bin_tool.version_compare import Version


class CVEScanner:
Expand Down Expand Up @@ -99,12 +97,8 @@ def get_cves(self, product_info: ProductInfo, triage_data: TriageData):
# Removing * from vendors that are guessed by the package list parser
vendor = product_info.vendor.replace("*", "")

# Need to manipulate version to ensure canonical form of version

parsed_version, parsed_version_between = self.canonical_convert(product_info)
# If canonical form of version numbering not found, exit
if parsed_version == "UNKNOWN":
return
# Use our Version class to do version compares
parsed_version = Version(product_info.version)

self.cursor.execute(query, [vendor, product_info.product, str(parsed_version)])

Expand Down Expand Up @@ -133,29 +127,17 @@ def get_cves(self, product_info: ProductInfo, triage_data: TriageData):
version_end_excluding,
) = cve_range

# pep-440 doesn't include versions of the type 1.1.0g used by openssl
# or versions of the type 9a used by libjpeg
# so if this is openssl or libjpeg, convert the last letter to a .number
if product_info.product in {"openssl", "libjpeg"}:
# if last character is a letter, convert it to .number
# version = self.letter_convert(product_info.version)
version_start_including = self.letter_convert(version_start_including)
version_start_excluding = self.letter_convert(version_start_excluding)
version_end_including = self.letter_convert(version_end_including)
version_end_excluding = self.letter_convert(version_end_excluding)
parsed_version = parsed_version_between

# check the start range
passes_start = False
if (
version_start_including is not self.RANGE_UNSET
and parsed_version >= parse_version(version_start_including)
and parsed_version >= Version(version_start_including)
):
passes_start = True

if (
version_start_excluding is not self.RANGE_UNSET
and parsed_version > parse_version(version_start_excluding)
and parsed_version > Version(version_start_excluding)
):
passes_start = True

Expand All @@ -170,13 +152,13 @@ def get_cves(self, product_info: ProductInfo, triage_data: TriageData):
passes_end = False
if (
version_end_including is not self.RANGE_UNSET
and parsed_version <= parse_version(version_end_including)
and parsed_version <= Version(version_end_including)
):
passes_end = True

if (
version_end_excluding is not self.RANGE_UNSET
and parsed_version < parse_version(version_end_excluding)
and parsed_version < Version(version_end_excluding)
):
passes_end = True

Expand Down Expand Up @@ -313,57 +295,6 @@ def get_cves(self, product_info: ProductInfo, triage_data: TriageData):
if product_info not in self.all_product_data:
self.all_product_data[product_info] = len(cves)

def letter_convert(self, version: str) -> str:
"""pkg_resources follows pep-440 which doesn't expect openssl style 1.1.0g version numbering
or libjpeg style 9a version numbering
So to fake it, if the last character is a letter, replace it with .number before comparing
"""
if not version: # if version is empty return it.
return version

# Check for short string
if len(version) < 2:
return version

last_char = version[-1]
second_last_char = version[-2]

if last_char in self.ALPHA_TO_NUM and second_last_char in self.ALPHA_TO_NUM:
version = f"{version[:-2]}.{self.ALPHA_TO_NUM[second_last_char]}.{self.ALPHA_TO_NUM[last_char]}"

elif last_char in self.ALPHA_TO_NUM:
version = f"{version[:-1]}.{self.ALPHA_TO_NUM[last_char]}"
return version

VersionType = Version

def canonical_convert(
self, product_info: ProductInfo
) -> Tuple[VersionType, VersionType]:
version_between = parse_version("")
if product_info.version == "":
return parse_version(product_info.version), version_between
if product_info.product in {"openssl", "libjpeg"}:
pv = re.search(r"\d[.\d]*[a-z]?", product_info.version)
version_between = parse_version(self.letter_convert(pv.group(0)))
else:
# Ensure canonical form of version numbering
if ":" in product_info.version:
# Handle x:a.b<string> e.g. 2:7.4+23
components = product_info.version.split(":")
pv = re.search(r"\d[.\d]*", components[1])
else:
# Handle a.b.c<string> e.g. 1.20.9rel1
pv = re.search(r"\d[.\d]*", product_info.version)
if pv is None:
parsed_version = "UNKNOWN"
self.logger.warning(
f"error parsing {product_info.vendor}.{product_info.product} v{product_info.version} - manual inspection required"
)
else:
parsed_version = parse_version(pv.group(0))
return parsed_version, version_between

def affected(self):
"""Returns list of vendor.product and version tuples identified from
scan"""
Expand Down
219 changes: 219 additions & 0 deletions cve_bin_tool/version_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: GPL-3.0-or-later

import re

"""
A class for comparing arbitrary versions of products.
Splits versions up using common whitespace delimiters and also splits out letters
so that things like openSSL's 1.1.1y type of version will work too.
This may need some additional smarts for stuff like "rc" or "beta" and potentially for
things like distro versioning. I don't know yet.
"""


class CannotParseVersionException(Exception):
"""
Thrown if the version doesn't comply with our expectations
"""


class UnknownVersion(Exception):
"""
Thrown if version is null or "unknown".
"""


def parse_version(version_string: str):
"""
Splits a version string into an array for comparison.
This includes dealing with some letters.
e.g. 1.1.1a would become [1, 1, 1, a]
"""

if not version_string or version_string.lower() == "unknown":
raise UnknownVersion(f"version string = {version_string}")

versionString = version_string.strip()
versionArray = []

# convert - and _ to be treated like . below
# we could switch to a re split but it seems to leave blanks so this is less hassle
versionString = versionString.replace("-", ".")
versionString = versionString.replace("_", ".")
# Note: there may be other non-alphanumeric characters we want to add here in the
# future, but we'd like to look at those cases before adding them in case the version
# logic is very different.

# Attempt a split
split_version = versionString.split(".")

# if the whole string was numeric then we're done and you can move on
if versionString.isnumeric():
versionArray = split_version
return versionArray

# Go through and split up anything like 6a in to 6 and a
number_letter = re.compile("([0-9]+)([a-zA-Z]+)")
letter_number = re.compile("([a-zA-Z]+)([0-9]+)")
for section in split_version:
# if it's all letters or all nubmers, just add it to the array
if section.isnumeric() or section.isalpha():
versionArray.append(section)

# if it looks like 42a split out the letters and numbers
# We will treat 42a as coming *after* version 42.
elif re.match(number_letter, section):
result = re.findall(number_letter, section)

# We're expecting a result that looks like [("42", "a")] but let's verify
# and then add it to the array
if len(result) == 1 and len(result[0]) == 2:
versionArray.append(result[0][0])
versionArray.append(result[0][1])
else:
raise CannotParseVersionException(f"version string = {versionString}")

# if it looks like rc1 or dev7 we'll leave it together as it may be some kind of pre-release
# and we'll probably want to handle it specially in the compare.
# We need to threat 42dev7 as coming *before* version 42.
elif re.match(letter_number, section):
versionArray.append(section)

# If all else fails, complain
else:
if versionString != ".":
raise CannotParseVersionException(f"version string = {versionString}")

return versionArray


def version_compare(v1: str, v2: str):
"""
Compare two versions by converting them to arrays
returns 0 if they're the same.
returns 1 if v1 > v2
returns -1 if v1 < v2findall
n
"""
v1_array = parse_version(v1)
v2_array = parse_version(v2)

for i in range(len(v1_array)):
if len(v2_array) > i:
# If it's all numbers, cast to int and compare
if v1_array[i].isnumeric() and v2_array[i].isnumeric():
if int(v1_array[i]) > int(v2_array[i]):
return 1
if int(v1_array[i]) < int(v2_array[i]):
return -1

# If they're letters just do a string compare, I don't have a better idea
# This might be a bad choice in some cases: Do we want ag < z?
# I suspect projects using letters in version names may not use ranges in nvd
# for this reason (e.g. openssl)
# Converting to lower() so that 3.14a == 3.14A
# but this may not be ideal in all cases
elif v1_array[i].isalpha() and v2_array[i].isalpha():
if v1_array[i].lower() > v2_array[i].lower():
return 1
if v1_array[i].lower() < v2_array[i].lower():
return -1

else:
# They are not the same type, and we're comparing mixed letters and numbers.
# We'll treat letters as less than numbers.
# This will result in things like rc1, dev9, b2 getting treated like pre-releases
# as in https://peps.python.org/pep-0440/
# So 1.2.pre4 would be less than 1.2.1 and (so would 1.2.post1)
if v1_array[i].isalnum() and v2_array[i].isnumeric():
return -1
elif v1_array[i].isnumeric() and v2_array[i].isalnum():
return 1

# They're both of type letter567 and we'll convert them to be letter.567 and
# run them through the compare function again
# Honestly it's hard to guess if .dev3 is going to be more or less than .rc4
# unless you know the project, so hopefully people don't expect that kind of range
# matching
v1_newstring = re.sub("([a-zA-Z]+)([0-9]+)", r"\1.\2", v1_array[i])
v2_newstring = re.sub("([a-zA-Z]+)([0-9]+)", r"\1.\2", v2_array[i])
print(f"`{v1_newstring}` and `{v2_newstring}`")
return version_compare(v1_newstring, v2_newstring)

# And if all else fails, just compare the strings
if v1_array[i] > v2_array[i]:
return 1
if v1_array[i] < v2_array[i]:
return -1

else:
# v1 has more digits than v2
# Check to see if v1's something that looks like a pre-release (a2, dev8, rc4)
# e.g. 4.5.a1 would be less than 4.5
if re.match("([a-zA-Z]+)([0-9]+)", v1_array[i]):
return -1

# Otherwise, v1 has more digits than v2 and the previous ones matched,
# so it's probably later. e.g. 1.2.3 amd 1.2.q are both > 1.2
return 1

# if we made it this far and they've matched, see if there's more stuff in v2
# e.g. 1.2.3 or 1.2a comes after 1.2
if len(v2_array) > len(v1_array):
# special case: if v2 declares itself a post-release, we'll say it's bigger than v1
if v2_array[len(v1_array)].startswith("post"):
return -1

# if what's in v2 next looks like a pre-release number (e.g. a2, dev8, rc4) then we'll
# claim v1 is still bigger, otherwise we'll say v2 is.
if re.match("([0-9]+)([a-zA-Z]+)", v2_array[len(v1_array)]):
return 1

return -1

return 0


class Version(str):
"""
A class to make version comparisons look more pretty:
Version("1.2") > Version("1.1")
"""

def __cmp__(self, other):
"""compare"""
return version_compare(self, other)

def __lt__(self, other):
"""<"""
return bool(version_compare(self, other) < 0)

def __le__(self, other):
"""<="""
return bool(version_compare(self, other) <= 0)

def __gt__(self, other):
""">"""
return bool(version_compare(self, other) > 0)

def __ge__(self, other):
""">="""
return bool(version_compare(self, other) >= 0)

def __eq__(self, other):
"""=="""
return bool(version_compare(self, other) == 0)

def __ne__(self, other):
"""!="""
return bool(version_compare(self, other) != 0)

def __repr__(self):
"""print the version string"""
return f"Version: {self}"
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jsonschema>=3.0.2
lib4sbom>=0.5.0
python-gnupg
packageurl-python
packaging<22.0
packaging
plotly
pyyaml>=5.4
requests
Expand Down
2 changes: 1 addition & 1 deletion test/test_csv2cve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_csv2cve_valid_file(self, caplog):

for cve_count, product in [
[60, "haxx.curl version 7.34.0"],
[10, "mit.kerberos_5 version 1.15.1"],
[9, "mit.kerberos_5 version 1.15.1"],
]:
retrieved_cve_count = 0
for captured_line in caplog.record_tuples:
Expand Down
Loading

0 comments on commit 9a95eca

Please sign in to comment.