Skip to content

Commit

Permalink
feat: upgrade aleph-message from 0.4.2 to 0.4.8 and fix all mypy issu…
Browse files Browse the repository at this point in the history
…es. (#578)

* feat: upgrade aleph-message from 0.4.2 to 0.4.8 and fix all mypy issues.

Co-authored-by: 1yam <[email protected]>

* Feature: Allow processing of Confidential message (#573)

* Feature: Alembic migrations to allow environment_trusted_execution_firmware and environment_trusted_execution_policy

* Feature: Allow pyaleph to handle confidential instance type

* Feature: adding node_hash to VmDB

---------

Co-authored-by: 1yam <[email protected]>

---------

Co-authored-by: Andres D. Molins <[email protected]>
Co-authored-by: 1yam <[email protected]>
Co-authored-by: nesitor <[email protected]>
  • Loading branch information
4 people authored Jul 15, 2024
1 parent 692cb1c commit 25489df
Show file tree
Hide file tree
Showing 29 changed files with 583 additions and 190 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Add trusted execution fields to vms table
Revision ID: 2543def8f601
Revises: e682fc8f9506
Create Date: 2024-07-02 13:19:10.675168
"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "2543def8f601"
down_revision = "e682fc8f9506"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"vms",
sa.Column("environment_trusted_execution_policy", sa.Integer(), nullable=True),
)
op.add_column(
"vms",
sa.Column("environment_trusted_execution_firmware", sa.String(), nullable=True),
)
op.add_column(
"vms",
sa.Column("node_hash", sa.String(), nullable=True),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("vms", "environment_trusted_execution_firmware")
op.drop_column("vms", "environment_trusted_execution_policy")
op.drop_column("vms", "node_hash")

# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ install_requires =
aiohttp==3.8.4
aioipfs@git+https://github.com/aleph-im/aioipfs.git@d671c79b2871bb4d6c8877ba1e7f3ffbe7d20b71
alembic==1.12.1
aleph-message==0.4.2
aleph-message==0.4.8
aleph-p2p-client@git+https://github.com/aleph-im/p2p-service-client-python@2c04af39c566217f629fd89505ffc3270fba8676
aleph-pytezos@git+https://github.com/aleph-im/aleph-pytezos.git@32dd1749a4773da494275709060632cbeba9a51b
asyncpg==0.28.0
Expand Down
5 changes: 3 additions & 2 deletions src/aleph/chains/chain_data_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Optional, List, Any, Mapping, Set, cast, Type, Union, Self

import aio_pika.abc
from aleph_message.models import StoreContent, ItemType, Chain, MessageType
from aleph_message.models import StoreContent, ItemType, Chain, MessageType, ItemHash
from configmanager import Config
from pydantic import ValidationError

Expand Down Expand Up @@ -187,7 +187,8 @@ def _get_tx_messages_smart_contract_protocol(tx: ChainTxDb) -> List[Dict[str, An
address=payload.address,
time=payload.timestamp_seconds,
item_type=ItemType.ipfs,
item_hash=payload.content,
item_hash=ItemHash(payload.content),
metadata=None,
)
item_content = content.json(exclude_none=True)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/chains/ethereum.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def _request_transactions(
try:
jdata = json.loads(message)
context = TxContext(
chain=CHAIN_NAME,
chain=Chain(CHAIN_NAME),
hash=event_data.transactionHash.hex(),
time=timestamp,
height=event_data.blockNumber,
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/chains/nuls2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def _request_transactions(
jdata = json.loads(ddata)

context = TxContext(
chain=CHAIN_NAME,
chain=Chain(CHAIN_NAME),
hash=tx["hash"],
height=tx["height"],
time=tx["createTime"],
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/db/accessors/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def reject_existing_pending_message(

# The message may already be processed and someone is sending invalid copies.
# Just drop the pending message.
message_status = get_message_status(session=session, item_hash=item_hash)
message_status = get_message_status(session=session, item_hash=ItemHash(item_hash))
if message_status:
if message_status.status not in (MessageStatus.PENDING, MessageStatus.REJECTED):
delete_pending_message(session=session, pending_message=pending_message)
Expand Down
8 changes: 7 additions & 1 deletion src/aleph/db/models/vms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime as dt
from typing import Any, Optional, Dict, List

from aleph_message.models.execution.program import MachineType, Encoding
from aleph_message.models.execution import MachineType, Encoding
from aleph_message.models.execution.volume import VolumePersistence
from sqlalchemy import Column, String, ForeignKey, Boolean, Integer, TIMESTAMP
from sqlalchemy.dialects.postgresql import JSONB
Expand Down Expand Up @@ -132,6 +132,11 @@ class VmBaseDb(Base):
environment_aleph_api: bool = Column(Boolean, nullable=False)
environment_shared_cache: bool = Column(Boolean, nullable=False)

environment_trusted_execution_policy: Optional[int] = Column(Integer, nullable=True)
environment_trusted_execution_firmware: Optional[str] = Column(
String, nullable=True
)

resources_vcpus: int = Column(Integer, nullable=False)
resources_memory: int = Column(Integer, nullable=False)
resources_seconds: int = Column(Integer, nullable=False)
Expand All @@ -142,6 +147,7 @@ class VmBaseDb(Base):
cpu_vendor: Optional[str] = Column(String, nullable=True)
node_owner: Optional[str] = Column(String, nullable=True)
node_address_regex: Optional[str] = Column(String, nullable=True)
node_hash: Optional[str] = Column(String, nullable=True)

replaces: Optional[str] = Column(ForeignKey(item_hash), nullable=True)
created: dt.datetime = Column(TIMESTAMP(timezone=True), nullable=False)
Expand Down
6 changes: 3 additions & 3 deletions src/aleph/handlers/content/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def _insert_aggregate_element(session: DbSession, message: MessageDb):
content = cast(AggregateContent, message.parsed_content)
aggregate_element = AggregateElementDb(
item_hash=message.item_hash,
key=content.key,
key=str(content.key),
owner=content.address,
content=content.content,
creation_datetime=timestamp_to_datetime(message.parsed_content.time),
Expand Down Expand Up @@ -228,10 +228,10 @@ async def forget_message(self, session: DbSession, message: MessageDb) -> Set[st
key = content.key

LOGGER.debug("Deleting aggregate element %s...", message.item_hash)
delete_aggregate(session=session, owner=owner, key=key)
delete_aggregate(session=session, owner=owner, key=str(key))
delete_aggregate_element(session=session, item_hash=message.item_hash)

LOGGER.debug("Refreshing aggregate %s/%s...", owner, key)
refresh_aggregate(session=session, owner=owner, key=key)
refresh_aggregate(session=session, owner=owner, key=str(key))

return set()
4 changes: 2 additions & 2 deletions src/aleph/handlers/content/forget.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def _forget_message(
async def _forget_item_hash(
self, session: DbSession, item_hash: str, forgotten_by: MessageDb
):
message_status = get_message_status(session=session, item_hash=item_hash)
message_status = get_message_status(session=session, item_hash=ItemHash(item_hash))
if not message_status:
raise ForgetTargetNotFound(target_hash=item_hash)

Expand All @@ -187,7 +187,7 @@ async def _forget_item_hash(
)
raise ForgetTargetNotFound(item_hash)

message = get_message_by_item_hash(session=session, item_hash=item_hash)
message = get_message_by_item_hash(session=session, item_hash=ItemHash(item_hash))
if not message:
raise ForgetTargetNotFound(item_hash)

Expand Down
3 changes: 2 additions & 1 deletion src/aleph/handlers/content/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ async def process_post(self, session: DbSession, message: MessageDb):
if (
content.type == self.balances_post_type
and content.address in self.balances_addresses
and content.content
):
LOGGER.info("Updating balances...")
update_balances(session=session, content=content.content)
Expand All @@ -150,7 +151,7 @@ async def forget_message(self, session: DbSession, message: MessageDb) -> Set[st
delete_post(session=session, item_hash=message.item_hash)

if content.type == "amend":
original_post = get_original_post(session, content.ref)
original_post = get_original_post(session, str(content.ref))
if original_post is None:
raise InternalError(
f"Could not find original post ({content.ref} for amend ({message.item_hash})."
Expand Down
81 changes: 48 additions & 33 deletions src/aleph/handlers/content/vm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import logging
import math
from typing import List, Set, overload, Protocol, Optional
from typing import List, Set, overload, Protocol, Optional, Union

from aleph_message.models import (
ProgramContent,
ExecutableContent,
InstanceContent,
MessageType,
)
from aleph_message.models.execution import MachineType
from aleph_message.models.execution.instance import RootfsVolume
from aleph_message.models.execution.program import ProgramContent
from aleph_message.models.execution.volume import (
AbstractVolume,
ImmutableVolume,
Expand Down Expand Up @@ -72,29 +75,24 @@

def _get_vm_content(message: MessageDb) -> ExecutableContent:
content = message.parsed_content
if not isinstance(content, ExecutableContent):
if not isinstance(content, (InstanceContent, ProgramContent)):
raise InvalidMessageFormat(
f"Unexpected content type for program message: {message.item_hash}"
)
return content


from aleph_message.models.execution.program import (
MachineType,
ProgramContent,
)


@overload
def _map_content_to_db_model(item_hash: str, content: InstanceContent) -> VmInstanceDb:
...
def _map_content_to_db_model(
item_hash: str, content: InstanceContent
) -> VmInstanceDb: ...


# For some reason, mypy is not happy with the overload resolution here.
# This seems linked to multiple inheritance of Pydantic base models, a deeper investigation
# is required.
@overload
def _map_content_to_db_model(item_hash: str, content: ProgramContent) -> ProgramDb: # type: ignore[misc]
def _map_content_to_db_model(item_hash: str, content: ProgramContent) -> ProgramDb:
...


Expand All @@ -117,6 +115,16 @@ def _map_content_to_db_model(item_hash, content):
node_owner = node.owner
node_address_regex = node.address_regex

trusted_execution_policy = None
trusted_execution_firmware = None
node_hash = None
if not isinstance(content, ProgramContent):
if content.environment.trusted_execution is not None:
trusted_execution_policy = content.environment.trusted_execution.policy
trusted_execution_firmware = content.environment.trusted_execution.firmware
if hasattr(content.requirements, 'node_hash'):
node_hash = content.requirements.node_hash

return db_cls(
owner=content.address,
item_hash=item_hash,
Expand All @@ -127,6 +135,8 @@ def _map_content_to_db_model(item_hash, content):
environment_internet=content.environment.internet,
environment_aleph_api=content.environment.aleph_api,
environment_shared_cache=content.environment.shared_cache,
environment_trusted_execution_policy=trusted_execution_policy,
environment_trusted_execution_firmware=trusted_execution_firmware,
resources_vcpus=content.resources.vcpus,
resources_memory=content.resources.memory,
resources_seconds=content.resources.seconds,
Expand All @@ -136,6 +146,7 @@ def _map_content_to_db_model(item_hash, content):
node_address_regex=node_address_regex,
volumes=volumes,
created=timestamp_to_datetime(content.time),
node_hash=node_hash
)


Expand Down Expand Up @@ -172,7 +183,7 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb:
content = _get_vm_content(message)
vm = _map_content_to_db_model(message.item_hash, content)

if isinstance(vm, ProgramDb):
if isinstance(vm, ProgramDb) and isinstance(content, ProgramContent):
vm.program_type = content.type
vm.persistent = bool(content.on.persistent)
vm.http_trigger = content.on.http
Expand Down Expand Up @@ -208,13 +219,16 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb:

elif isinstance(content, InstanceContent):
parent = content.rootfs.parent
vm.rootfs = RootfsVolumeDb(
parent_ref=parent.ref,
parent_use_latest=parent.use_latest,
size_mib=content.rootfs.size_mib,
persistence=content.rootfs.persistence,
)
vm.authorized_keys = content.authorized_keys
if isinstance(vm, VmInstanceDb):
vm.rootfs = RootfsVolumeDb(
parent_ref=parent.ref,
parent_use_latest=parent.use_latest,
size_mib=content.rootfs.size_mib,
persistence=content.rootfs.persistence,
)
vm.authorized_keys = content.authorized_keys
else:
raise TypeError(f"Unexpected VM message content type: {type(vm)}")

else:
raise TypeError(f"Unexpected VM message content type: {type(content)}")
Expand Down Expand Up @@ -265,18 +279,18 @@ def check_parent_volumes_size_requirements(
) -> None:
def _get_parent_volume_file(_parent: ParentVolume) -> StoredFileDb:
if _parent.use_latest:
file_tag = get_file_tag(session=session, tag=_parent.ref)
file_tag = get_file_tag(session=session, tag=FileTag(_parent.ref))
if file_tag is None:
raise InternalError(
f"Could not find latest version of parent volume {volume.parent.ref}"
f"Could not find latest version of parent volume {_parent.ref}"
)

return file_tag.file

file_pin = get_message_file_pin(session=session, item_hash=_parent.ref)
if file_pin is None:
raise InternalError(
f"Could not find original version of parent volume {volume.parent.ref}"
f"Could not find original version of parent volume {_parent.ref}"
)

return file_pin.file
Expand All @@ -285,7 +299,7 @@ class HasParent(Protocol):
parent: ParentVolume
size_mib: int

volumes_with_parent: List[HasParent] = [
volumes_with_parent: List[Union[PersistentVolume, RootfsVolume]] = [
volume
for volume in content.volumes
if isinstance(volume, PersistentVolume) and volume.parent is not None
Expand All @@ -295,16 +309,17 @@ class HasParent(Protocol):
volumes_with_parent.append(content.rootfs)

for volume in volumes_with_parent:
volume_metadata = _get_parent_volume_file(volume.parent)
volume_size = volume.size_mib * 1024 * 1024
if volume_size < volume_metadata.size:
raise VmVolumeTooSmall(
parent_size=volume_metadata.size,
parent_ref=volume.parent.ref,
parent_file=volume_metadata.hash,
volume_name=getattr(volume, "name", "rootfs"),
volume_size=volume_size,
)
if volume.parent:
volume_metadata = _get_parent_volume_file(volume.parent)
volume_size = volume.size_mib * 1024 * 1024
if volume_size < volume_metadata.size:
raise VmVolumeTooSmall(
parent_size=volume_metadata.size,
parent_ref=volume.parent.ref,
parent_file=volume_metadata.hash,
volume_name=getattr(volume, "name", "rootfs"),
volume_size=volume_size,
)


class VmMessageHandler(ContentHandler):
Expand Down
3 changes: 2 additions & 1 deletion src/aleph/handlers/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import ValidationError
from sqlalchemy import insert

from aleph_message.models import ItemHash
from aleph.chains.signature_verifier import SignatureVerifier
from aleph.db.accessors.files import insert_content_file_pin, upsert_file
from aleph.db.accessors.messages import (
Expand Down Expand Up @@ -377,7 +378,7 @@ async def process(
"""

existing_message = get_message_by_item_hash(
session=session, item_hash=pending_message.item_hash
session=session, item_hash=ItemHash(pending_message.item_hash)
)
if existing_message:
await self.confirm_existing_message(
Expand Down
Loading

0 comments on commit 25489df

Please sign in to comment.