Skip to content

Commit

Permalink
feat: Add Distribution as new root model
Browse files Browse the repository at this point in the history
This is to support including the download_url.
  • Loading branch information
jonathan-d-zhang committed Aug 11, 2024
1 parent f6222d6 commit e93f7d8
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/mainframe/endpoints/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def submit_results(
scan.score = result.score
scan.finished_by = auth.subject
scan.commit_hash = result.commit
scan.files = result.files
scan.distributions = result.distributions

# These are the rules that already have an entry in the database
rules = session.scalars(select(Rule).where(Rule.name.in_(result.rules_matched))).all()
Expand Down
4 changes: 2 additions & 2 deletions src/mainframe/models/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)

from mainframe.models import Pydantic
from mainframe.models.schemas import Files
from mainframe.models.schemas import Distributions


class Base(MappedAsDataclass, DeclarativeBase, kw_only=True):
Expand Down Expand Up @@ -102,7 +102,7 @@ class Scan(Base):

commit_hash: Mapped[Optional[str]] = mapped_column(default=None)

files: Mapped[Optional[Files]] = mapped_column(Pydantic(Files), default=None)
distributions: Mapped[Optional[Distributions]] = mapped_column(Pydantic(Distributions), default=None)


Index(None, Scan.status, postgresql_where=or_(Scan.status == Status.QUEUED, Scan.status == Status.PENDING))
Expand Down
16 changes: 12 additions & 4 deletions src/mainframe/models/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,15 @@ class File(BaseModel):
matches: list[RuleMatch]


Files = RootModel[list[File]]
Files = list[File]


class Distribution(BaseModel):
download_url: str
files: Files


Distributions = RootModel[list[Distribution]]


class ServerMetadata(BaseModel):
Expand Down Expand Up @@ -88,7 +96,7 @@ class Package(BaseModel):

commit_hash: Optional[str]

files: Optional[Files]
distributions: Optional[Distributions]

@classmethod
def from_db(cls, scan: Scan):
Expand All @@ -110,7 +118,7 @@ def from_db(cls, scan: Scan):
finished_at=scan.finished_at,
finished_by=scan.finished_by,
commit_hash=scan.commit_hash,
files=scan.files,
distributions=scan.distributions,
)

@field_serializer(
Expand Down Expand Up @@ -179,7 +187,7 @@ class PackageScanResult(PackageSpecifier):
score: int = 0
inspector_url: Optional[str] = None
rules_matched: list[str] = []
files: Optional[Files] = None
distributions: Optional[Distributions] = None


class PackageScanResultFail(PackageSpecifier):
Expand Down
13 changes: 8 additions & 5 deletions tests/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from mainframe.json_web_token import AuthenticationData
from mainframe.models.orm import Scan, Status
from mainframe.models.schemas import (
Distribution,
Distributions,
File,
Files,
Match,
Expand Down Expand Up @@ -95,21 +97,21 @@ def test_package_lookup_files(db_session: Session):
rule = RuleMatch(identifier="rule1", patterns=[pattern], metadata={"author": "remmy", "score": 5})
file = File(path="dist1/a/b.py", matches=[rule])
files = Files([file])
distros = Distributions([Distribution(download_url="http://example.com", files=files)])
scan = Scan(
name="abc",
version="1.0.0",
status=Status.FINISHED,
queued_by="remmy",
files=files,
distributions=distros,
)

with db_session.begin():
db_session.add(scan)
db_session.commit()

package = lookup_package_info(db_session, name="abc", version="1.0.0")[0]

assert package.files == files
assert package.distributions == distros


def test_handle_success(db_session: Session, test_data: list[Scan], auth: AuthenticationData, rules_state: Rules):
Expand All @@ -126,6 +128,7 @@ def test_handle_success(db_session: Session, test_data: list[Scan], auth: Authen
rule = RuleMatch(identifier="rule1", patterns=[pattern], metadata={"author": "remmy", "score": 5})
file = File(path="dist1/a/b.py", matches=[rule])
files = Files([file])
distros = Distributions([Distribution(download_url="http://example.com", files=files)])

body = PackageScanResult(
name=job.name,
Expand All @@ -134,7 +137,7 @@ def test_handle_success(db_session: Session, test_data: list[Scan], auth: Authen
score=2,
inspector_url="test inspector url",
rules_matched=["a", "b", "c"],
files=files,
distributions=distros,
)
submit_results(body, db_session, auth)

Expand All @@ -147,7 +150,7 @@ def test_handle_success(db_session: Session, test_data: list[Scan], auth: Authen
assert record.score == 2
assert record.inspector_url == "test inspector url"
assert {rule.name for rule in record.rules} == {"a", "b", "c"}
assert record.files == files
assert record.distributions == distros
else:
assert all(scan.status != Status.QUEUED for scan in test_data)

Expand Down

0 comments on commit e93f7d8

Please sign in to comment.