Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: added test for generate_sbom function #4060

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cve_bin_tool/output_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def __init__(
self.sbom_format = sbom_format
self.sbom_root = sbom_root
self.offline = offline
self.sbom_packages = {}

def output_cves(self, outfile, output_type="console"):
"""Output a list of CVEs
Expand Down Expand Up @@ -914,7 +915,6 @@ def generate_sbom(
):
"""Create SBOM package and generate SBOM file."""
# Create SBOM
sbom_packages = {}
sbom_relationships = []
my_package = SBOMPackage()
sbom_relationship = SBOMRelationship()
Expand All @@ -933,7 +933,7 @@ def generate_sbom(
my_package.set_supplier("UNKNOWN", "NOASSERTION")

# Store package data
sbom_packages[(my_package.get_name(), my_package.get_value("version"))] = (
self.sbom_packages[(my_package.get_name(), my_package.get_value("version"))] = (
my_package.get_package()
)
sbom_relationship.initialise()
Expand All @@ -945,18 +945,18 @@ def generate_sbom(
my_package.initialise()
my_package.set_name(product_data.product)
my_package.set_version(product_data.version)
if product_data.vendor != "UNKNOWN":
if product_data.vendor.casefold() != "UNKNOWN".casefold():
my_package.set_supplier("Organization", product_data.vendor)
my_package.set_licensedeclared(license)
my_package.set_licenseconcluded(license)
if not (
(my_package.get_name(), my_package.get_value("version"))
in sbom_packages
in self.sbom_packages
and product_data.vendor == "unknown"
):
location = product_data.location
my_package.set_evidence(location) # Set location directly
sbom_packages[
self.sbom_packages[
(my_package.get_name(), my_package.get_value("version"))
] = my_package.get_package()
sbom_relationship.initialise()
Expand All @@ -967,7 +967,7 @@ def generate_sbom(

# Generate SBOM
my_sbom = SBOM()
my_sbom.add_packages(sbom_packages)
my_sbom.add_packages(self.sbom_packages)
my_sbom.add_relationships(sbom_relationships)
my_generator = SBOMGenerator(
sbom_type=sbom_type,
Expand Down
77 changes: 77 additions & 0 deletions test/test_output_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import unittest
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, call, patch

from jsonschema import validate
from rich.console import Console
Expand Down Expand Up @@ -1124,6 +1125,20 @@ class TestOutputEngine(unittest.TestCase):
]

def setUp(self) -> None:
self.all_product_data = [
ProductInfo(
product="product1",
version="1.0",
vendor="VendorA",
location="/usr/local/bin/product",
),
ProductInfo(
product="product2",
version="2.0",
vendor="unknown",
location="/usr/local/bin/product",
),
]
self.output_engine = OutputEngine(
all_cve_data=self.MOCK_OUTPUT,
scanned_dir="",
Expand All @@ -1134,6 +1149,68 @@ def setUp(self) -> None:
)
self.mock_file = tempfile.NamedTemporaryFile("w+", encoding="utf-8")

def test_generate_sbom(self):
with patch(
"cve_bin_tool.output_engine.SBOMPackage"
) as mock_sbom_package, patch("cve_bin_tool.output_engine.SBOMRelationship"):
mock_package_instance = MagicMock()
mock_sbom_package.return_value = mock_package_instance

self.output_engine.generate_sbom(
all_product_data=self.all_product_data,
filename="test.sbom",
sbom_type="spdx",
sbom_format="tag",
sbom_root="CVE-SCAN",
)

# Assertions
mock_package_instance.set_name.assert_any_call("CVEBINTOOL-CVE-SCAN")

# Check if set_name is called for each product
expected_calls = [
call(product.product) for product in self.all_product_data
]
mock_package_instance.set_name.assert_has_calls(
expected_calls, any_order=True
)

# Check if set_version is called for each product
expected_calls = [
call(product.version) for product in self.all_product_data
]
mock_package_instance.set_version.assert_has_calls(
expected_calls, any_order=True
)

# Check if set_supplier is called for VendorA
mock_package_instance.set_supplier.assert_any_call(
"Organization", "VendorA"
)

for call_args in mock_package_instance.set_supplier.call_args_list:
args, _ = call_args
self.assertNotEqual(args, ("Organization", "unknown"))

# Check if set_licensedeclared and set_licenseconcluded are called for each product
expected_calls = [call("NOASSERTION")] * len(self.all_product_data)
mock_package_instance.set_licensedeclared.assert_has_calls(
expected_calls, any_order=True
)
mock_package_instance.set_licenseconcluded.assert_has_calls(
expected_calls, any_order=True
)

# Ensure packages are added to sbom_packages correctly
expected_packages = {
mock_package_instance.get_package.return_value,
mock_package_instance.get_package.return_value,
}
actual_packages = [
package for package in self.output_engine.sbom_packages.values()
]
self.assertEqual(actual_packages, list(expected_packages))

def tearDown(self) -> None:
self.mock_file.close()

Expand Down
Loading