Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: upgrade aleph-message from 0.4.2 to 0.4.8 and fix all mypy issues. #578

Merged
merged 2 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Add trusted execution fields to vms table

Revision ID: 2543def8f601
Revises: e682fc8f9506
Create Date: 2024-07-02 13:19:10.675168

"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "2543def8f601"
down_revision = "e682fc8f9506"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"vms",
sa.Column("environment_trusted_execution_policy", sa.Integer(), nullable=True),
)
op.add_column(
"vms",
sa.Column("environment_trusted_execution_firmware", sa.String(), nullable=True),
)
op.add_column(
"vms",
sa.Column("node_hash", sa.String(), nullable=True),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("vms", "environment_trusted_execution_firmware")
op.drop_column("vms", "environment_trusted_execution_policy")
op.drop_column("vms", "node_hash")

Check warning on line 41 in deployment/migrations/versions/0023_add_trusted_execution_fields_to_vms_.py

View check run for this annotation

Codecov / codecov/patch

deployment/migrations/versions/0023_add_trusted_execution_fields_to_vms_.py#L39-L41

Added lines #L39 - L41 were not covered by tests

# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ install_requires =
aiohttp==3.8.4
aioipfs@git+https://github.com/aleph-im/aioipfs.git@d671c79b2871bb4d6c8877ba1e7f3ffbe7d20b71
alembic==1.12.1
aleph-message==0.4.2
aleph-message==0.4.8
aleph-p2p-client@git+https://github.com/aleph-im/p2p-service-client-python@2c04af39c566217f629fd89505ffc3270fba8676
aleph-pytezos@git+https://github.com/aleph-im/aleph-pytezos.git@32dd1749a4773da494275709060632cbeba9a51b
asyncpg==0.28.0
Expand Down
5 changes: 3 additions & 2 deletions src/aleph/chains/chain_data_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Optional, List, Any, Mapping, Set, cast, Type, Union, Self

import aio_pika.abc
from aleph_message.models import StoreContent, ItemType, Chain, MessageType
from aleph_message.models import StoreContent, ItemType, Chain, MessageType, ItemHash
from configmanager import Config
from pydantic import ValidationError

Expand Down Expand Up @@ -187,7 +187,8 @@ def _get_tx_messages_smart_contract_protocol(tx: ChainTxDb) -> List[Dict[str, An
address=payload.address,
time=payload.timestamp_seconds,
item_type=ItemType.ipfs,
item_hash=payload.content,
item_hash=ItemHash(payload.content),
metadata=None,
)
item_content = content.json(exclude_none=True)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/chains/ethereum.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def _request_transactions(
try:
jdata = json.loads(message)
context = TxContext(
chain=CHAIN_NAME,
chain=Chain(CHAIN_NAME),
hash=event_data.transactionHash.hex(),
time=timestamp,
height=event_data.blockNumber,
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/chains/nuls2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def _request_transactions(
jdata = json.loads(ddata)

context = TxContext(
chain=CHAIN_NAME,
chain=Chain(CHAIN_NAME),
hash=tx["hash"],
height=tx["height"],
time=tx["createTime"],
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/db/accessors/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def reject_existing_pending_message(

# The message may already be processed and someone is sending invalid copies.
# Just drop the pending message.
message_status = get_message_status(session=session, item_hash=item_hash)
message_status = get_message_status(session=session, item_hash=ItemHash(item_hash))
if message_status:
if message_status.status not in (MessageStatus.PENDING, MessageStatus.REJECTED):
delete_pending_message(session=session, pending_message=pending_message)
Expand Down
8 changes: 7 additions & 1 deletion src/aleph/db/models/vms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime as dt
from typing import Any, Optional, Dict, List

from aleph_message.models.execution.program import MachineType, Encoding
from aleph_message.models.execution import MachineType, Encoding
from aleph_message.models.execution.volume import VolumePersistence
from sqlalchemy import Column, String, ForeignKey, Boolean, Integer, TIMESTAMP
from sqlalchemy.dialects.postgresql import JSONB
Expand Down Expand Up @@ -132,6 +132,11 @@ class VmBaseDb(Base):
environment_aleph_api: bool = Column(Boolean, nullable=False)
environment_shared_cache: bool = Column(Boolean, nullable=False)

environment_trusted_execution_policy: Optional[int] = Column(Integer, nullable=True)
environment_trusted_execution_firmware: Optional[str] = Column(
String, nullable=True
)

resources_vcpus: int = Column(Integer, nullable=False)
resources_memory: int = Column(Integer, nullable=False)
resources_seconds: int = Column(Integer, nullable=False)
Expand All @@ -142,6 +147,7 @@ class VmBaseDb(Base):
cpu_vendor: Optional[str] = Column(String, nullable=True)
node_owner: Optional[str] = Column(String, nullable=True)
node_address_regex: Optional[str] = Column(String, nullable=True)
node_hash: Optional[str] = Column(String, nullable=True)

replaces: Optional[str] = Column(ForeignKey(item_hash), nullable=True)
created: dt.datetime = Column(TIMESTAMP(timezone=True), nullable=False)
Expand Down
6 changes: 3 additions & 3 deletions src/aleph/handlers/content/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def _insert_aggregate_element(session: DbSession, message: MessageDb):
content = cast(AggregateContent, message.parsed_content)
aggregate_element = AggregateElementDb(
item_hash=message.item_hash,
key=content.key,
key=str(content.key),
owner=content.address,
content=content.content,
creation_datetime=timestamp_to_datetime(message.parsed_content.time),
Expand Down Expand Up @@ -228,10 +228,10 @@ async def forget_message(self, session: DbSession, message: MessageDb) -> Set[st
key = content.key

LOGGER.debug("Deleting aggregate element %s...", message.item_hash)
delete_aggregate(session=session, owner=owner, key=key)
delete_aggregate(session=session, owner=owner, key=str(key))
delete_aggregate_element(session=session, item_hash=message.item_hash)

LOGGER.debug("Refreshing aggregate %s/%s...", owner, key)
refresh_aggregate(session=session, owner=owner, key=key)
refresh_aggregate(session=session, owner=owner, key=str(key))

return set()
4 changes: 2 additions & 2 deletions src/aleph/handlers/content/forget.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def _forget_message(
async def _forget_item_hash(
self, session: DbSession, item_hash: str, forgotten_by: MessageDb
):
message_status = get_message_status(session=session, item_hash=item_hash)
message_status = get_message_status(session=session, item_hash=ItemHash(item_hash))
if not message_status:
raise ForgetTargetNotFound(target_hash=item_hash)

Expand All @@ -187,7 +187,7 @@ async def _forget_item_hash(
)
raise ForgetTargetNotFound(item_hash)

message = get_message_by_item_hash(session=session, item_hash=item_hash)
message = get_message_by_item_hash(session=session, item_hash=ItemHash(item_hash))
if not message:
raise ForgetTargetNotFound(item_hash)

Expand Down
3 changes: 2 additions & 1 deletion src/aleph/handlers/content/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ async def process_post(self, session: DbSession, message: MessageDb):
if (
content.type == self.balances_post_type
and content.address in self.balances_addresses
and content.content
):
LOGGER.info("Updating balances...")
update_balances(session=session, content=content.content)
Expand All @@ -150,7 +151,7 @@ async def forget_message(self, session: DbSession, message: MessageDb) -> Set[st
delete_post(session=session, item_hash=message.item_hash)

if content.type == "amend":
original_post = get_original_post(session, content.ref)
original_post = get_original_post(session, str(content.ref))
if original_post is None:
raise InternalError(
f"Could not find original post ({content.ref} for amend ({message.item_hash})."
Expand Down
81 changes: 48 additions & 33 deletions src/aleph/handlers/content/vm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import logging
import math
from typing import List, Set, overload, Protocol, Optional
from typing import List, Set, overload, Protocol, Optional, Union

from aleph_message.models import (
ProgramContent,
ExecutableContent,
InstanceContent,
MessageType,
)
from aleph_message.models.execution import MachineType
from aleph_message.models.execution.instance import RootfsVolume
from aleph_message.models.execution.program import ProgramContent
from aleph_message.models.execution.volume import (
AbstractVolume,
ImmutableVolume,
Expand Down Expand Up @@ -72,29 +75,24 @@

def _get_vm_content(message: MessageDb) -> ExecutableContent:
content = message.parsed_content
if not isinstance(content, ExecutableContent):
if not isinstance(content, (InstanceContent, ProgramContent)):
raise InvalidMessageFormat(
f"Unexpected content type for program message: {message.item_hash}"
)
return content


from aleph_message.models.execution.program import (
MachineType,
ProgramContent,
)


@overload
def _map_content_to_db_model(item_hash: str, content: InstanceContent) -> VmInstanceDb:
...
def _map_content_to_db_model(
item_hash: str, content: InstanceContent
) -> VmInstanceDb: ...


# For some reason, mypy is not happy with the overload resolution here.
# This seems linked to multiple inheritance of Pydantic base models, a deeper investigation
# is required.
@overload
def _map_content_to_db_model(item_hash: str, content: ProgramContent) -> ProgramDb: # type: ignore[misc]
def _map_content_to_db_model(item_hash: str, content: ProgramContent) -> ProgramDb:
...


Expand All @@ -117,6 +115,16 @@ def _map_content_to_db_model(item_hash, content):
node_owner = node.owner
node_address_regex = node.address_regex

trusted_execution_policy = None
trusted_execution_firmware = None
node_hash = None
if not isinstance(content, ProgramContent):
if content.environment.trusted_execution is not None:
trusted_execution_policy = content.environment.trusted_execution.policy
trusted_execution_firmware = content.environment.trusted_execution.firmware
if hasattr(content.requirements, 'node_hash'):
node_hash = content.requirements.node_hash

return db_cls(
owner=content.address,
item_hash=item_hash,
Expand All @@ -127,6 +135,8 @@ def _map_content_to_db_model(item_hash, content):
environment_internet=content.environment.internet,
environment_aleph_api=content.environment.aleph_api,
environment_shared_cache=content.environment.shared_cache,
environment_trusted_execution_policy=trusted_execution_policy,
environment_trusted_execution_firmware=trusted_execution_firmware,
resources_vcpus=content.resources.vcpus,
resources_memory=content.resources.memory,
resources_seconds=content.resources.seconds,
Expand All @@ -136,6 +146,7 @@ def _map_content_to_db_model(item_hash, content):
node_address_regex=node_address_regex,
volumes=volumes,
created=timestamp_to_datetime(content.time),
node_hash=node_hash
)


Expand Down Expand Up @@ -172,7 +183,7 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb:
content = _get_vm_content(message)
vm = _map_content_to_db_model(message.item_hash, content)

if isinstance(vm, ProgramDb):
if isinstance(vm, ProgramDb) and isinstance(content, ProgramContent):
vm.program_type = content.type
vm.persistent = bool(content.on.persistent)
vm.http_trigger = content.on.http
Expand Down Expand Up @@ -208,13 +219,16 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb:

elif isinstance(content, InstanceContent):
parent = content.rootfs.parent
vm.rootfs = RootfsVolumeDb(
parent_ref=parent.ref,
parent_use_latest=parent.use_latest,
size_mib=content.rootfs.size_mib,
persistence=content.rootfs.persistence,
)
vm.authorized_keys = content.authorized_keys
if isinstance(vm, VmInstanceDb):
vm.rootfs = RootfsVolumeDb(
parent_ref=parent.ref,
parent_use_latest=parent.use_latest,
size_mib=content.rootfs.size_mib,
persistence=content.rootfs.persistence,
)
vm.authorized_keys = content.authorized_keys
else:
raise TypeError(f"Unexpected VM message content type: {type(vm)}")

else:
raise TypeError(f"Unexpected VM message content type: {type(content)}")
Expand Down Expand Up @@ -265,18 +279,18 @@ def check_parent_volumes_size_requirements(
) -> None:
def _get_parent_volume_file(_parent: ParentVolume) -> StoredFileDb:
if _parent.use_latest:
file_tag = get_file_tag(session=session, tag=_parent.ref)
file_tag = get_file_tag(session=session, tag=FileTag(_parent.ref))
if file_tag is None:
raise InternalError(
f"Could not find latest version of parent volume {volume.parent.ref}"
f"Could not find latest version of parent volume {_parent.ref}"
)

return file_tag.file

file_pin = get_message_file_pin(session=session, item_hash=_parent.ref)
if file_pin is None:
raise InternalError(
f"Could not find original version of parent volume {volume.parent.ref}"
f"Could not find original version of parent volume {_parent.ref}"
)

return file_pin.file
Expand All @@ -285,7 +299,7 @@ class HasParent(Protocol):
parent: ParentVolume
size_mib: int

volumes_with_parent: List[HasParent] = [
volumes_with_parent: List[Union[PersistentVolume, RootfsVolume]] = [
volume
for volume in content.volumes
if isinstance(volume, PersistentVolume) and volume.parent is not None
Expand All @@ -295,16 +309,17 @@ class HasParent(Protocol):
volumes_with_parent.append(content.rootfs)

for volume in volumes_with_parent:
volume_metadata = _get_parent_volume_file(volume.parent)
volume_size = volume.size_mib * 1024 * 1024
if volume_size < volume_metadata.size:
raise VmVolumeTooSmall(
parent_size=volume_metadata.size,
parent_ref=volume.parent.ref,
parent_file=volume_metadata.hash,
volume_name=getattr(volume, "name", "rootfs"),
volume_size=volume_size,
)
if volume.parent:
volume_metadata = _get_parent_volume_file(volume.parent)
volume_size = volume.size_mib * 1024 * 1024
if volume_size < volume_metadata.size:
raise VmVolumeTooSmall(
parent_size=volume_metadata.size,
parent_ref=volume.parent.ref,
parent_file=volume_metadata.hash,
volume_name=getattr(volume, "name", "rootfs"),
volume_size=volume_size,
)


class VmMessageHandler(ContentHandler):
Expand Down
3 changes: 2 additions & 1 deletion src/aleph/handlers/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import ValidationError
from sqlalchemy import insert

from aleph_message.models import ItemHash
from aleph.chains.signature_verifier import SignatureVerifier
from aleph.db.accessors.files import insert_content_file_pin, upsert_file
from aleph.db.accessors.messages import (
Expand Down Expand Up @@ -377,7 +378,7 @@ async def process(
"""

existing_message = get_message_by_item_hash(
session=session, item_hash=pending_message.item_hash
session=session, item_hash=ItemHash(pending_message.item_hash)
)
if existing_message:
await self.confirm_existing_message(
Expand Down
Loading
Loading