Skip to content

Commit

Permalink
Merge pull request #47 from randovania/feature/bmsas
Browse files Browse the repository at this point in the history
Speedup BMSAS
  • Loading branch information
henriquegemignani authored Aug 1, 2023
2 parents ddd3598 + 45f41d9 commit d2c3ac3
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 19 deletions.
13 changes: 12 additions & 1 deletion src/mercury_engine_data_structures/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,23 @@ def make_vector(value: construct.Construct):
arr.name = "items"
get_len = construct.len_(construct.this.items)

return construct.FocusedSeq(
result = construct.FocusedSeq(
"items",
"count" / construct.Rebuild(construct.Int32ul, get_len),
arr,
)

def _emitparse(code):
return f"ListContainer(({value._compileparse(code)}) for i in range({construct.Int32ul._compileparse(code)}))"
result._emitparse = _emitparse

def _emitbuild(code):
return (f"(reuse(len(obj), lambda obj: {construct.Int32ul._compilebuild(code)}),"
f" list({value._compilebuild(code)} for obj in obj), obj)[2]")
result._emitbuild = _emitbuild

return result


def make_enum(values: typing.Union[typing.List[str], typing.Dict[str, int]], *,
add_invalid: bool = True):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import operator

import construct.expr

# Workaround construct's bug (See https://github.com/construct/construct/issues/1039)
construct.expr.opnames[operator.and_] = "&"
construct.expr.opnames[operator.or_] = "|"


# Hex for some reason doesn't support compilation for building, despite being trivial to do
# So let's hack it in.
def _hex_emitbuild(self, code):
return self.subcon._compilebuild(code)


construct.Hex._emitbuild = _hex_emitbuild
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import construct
from construct.core import (
FixedSized,
GreedyBytes,
Expand All @@ -17,6 +18,7 @@ def _decode(self, obj, context, path):
except UnicodeDecodeError as e:
raise StringError(f"string decoding failed: {e}", path=path) from e


def PaddedStringRobust(length, encoding):
r"""
Configurable, fixed-length or variable-length string field.
Expand Down Expand Up @@ -52,7 +54,7 @@ def _emitfulltype(ksy, bitwise):
return macro


def PascalStringRobust(lengthfield, encoding):
def PascalStringRobust(lengthfield: construct.Construct, encoding):
r"""
Length-prefixed string. The length field can be variable length (such as VarInt) or fixed length (such as Int64ub).
:class:`~construct.core.VarInt` is recommended when designing new protocols. Stored length is in bytes,
Expand Down Expand Up @@ -86,6 +88,16 @@ def _emitseq(ksy, bitwise):
]
macro._emitseq = _emitseq

def _emitbuild(code: construct.CodeGen):
i = code.allocateId()
code.append(f"def add_prefix_{i}(io, obj): return {lengthfield._compilebuild(code)}")

return (f"reuse(obj.encode({repr(encoding)}),"
f" lambda encoded: (add_prefix_{i}(io, len(encoded)), io.write(encoded)))")

macro._emitbuild = _emitbuild


return macro

def CStringRobust(encoding):
Expand Down
29 changes: 25 additions & 4 deletions src/mercury_engine_data_structures/formats/bmsas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,33 @@

StrId = PascalStringRobust(Int16ul, "utf-8")


class StrIdOrInt(Select):
def __init__(self):
super().__init__(StrId, Hex(Int64ul))

def _emitparse(self, code: construct.CodeGen):
code.append(f"""
def parse_str_or_int(io):
fallback = io.tell()
try:
return {StrId._compileparse(code)}
except UnicodeDecodeError:
io.seek(fallback)
return {Int64ul._compileparse(code)}
""")
return "parse_str_or_int(io)"

def _emitbuild(self, code: construct.CodeGen):
return f"({Int64ul._compilebuild(code)}) if isinstance(obj, int) else ({StrId._compilebuild(code)})"


Argument = Struct(
key=PropertyEnum,
value=Switch(
construct.this.key[0],
{
's': Select(StrId, Hex(Int64ul)),
's': StrIdOrInt(),
'f': Float,
'b': Flag,
'u': Int32ul,
Expand Down Expand Up @@ -93,7 +114,7 @@
unk7=Float,
unk8=Int32ul,
unk9=Float,
unk10=If(construct.this.unk0 & 32, Select(PropertyEnum, Hex(Int64ul))),
unk10=If(construct.this.unk0 & 32, Hex(Int64ul)),
unk11=If(construct.this.unk0 & 64, StrId),
unk12=make_vector(Struct(
unk1=Float,
Expand All @@ -114,7 +135,7 @@
unk14=make_vector(Struct(
unk0=Int64ul,
curve=StrId,
unk1=make_vector(Select(PropertyEnum, Hex(Int64ul))),
unk1=make_vector(Hex(Int64ul)),
unk2=Int32ul,
)),
)
Expand All @@ -126,7 +147,7 @@
unk=Hex(Int32ul),
animations=make_vector(Animation),
_end=construct.Terminated,
)
).compile()


class Bmsas(BaseResource):
Expand Down
66 changes: 53 additions & 13 deletions src/mercury_engine_data_structures/formats/property_enum.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
import enum
import typing
import warnings
from typing import Dict
from typing import Dict, Tuple

import construct

from mercury_engine_data_structures import crc, dread_data


class HashSet(enum.Enum):
DREAD_PROPERTY = enum.auto()
DREAD_FILE_NAME = enum.auto()

def get_hashes(self) -> Tuple[Dict[str, int], Dict[int, str]]:
if self == HashSet.DREAD_PROPERTY:
return dread_data.all_name_to_property_id(), dread_data.all_property_id_to_name()
elif self == HashSet.DREAD_FILE_NAME:
return dread_data.all_name_to_asset_id(), dread_data.all_asset_id_to_name()
else:
raise ValueError("Unknown")


class CRCAdapter(construct.Adapter):
def __init__(self, subcon, known_hashes: Dict[str, int], allow_unknowns=False, display_warnings=True):
super().__init__(subcon)
self.known_hashes = known_hashes
self.inverted_hashes = {value: name for name, value in known_hashes.items()}
known_hashes: Dict[str, int]
inverted_hashes: Dict[int, str]

def __init__(self, hash_set: HashSet, allow_unknowns=False, display_warnings=True):
super().__init__(construct.Hex(construct.Int64ul))
self._raw_subcon = construct.Int64ul
self.hash_set = hash_set
self.known_hashes, self.inverted_hashes = hash_set.get_hashes()
self.allow_unknowns = allow_unknowns
self.display_warnings = display_warnings

def _decode(self, obj, context, path):
def _decode(self, obj: int, context, path):
try:
return self.inverted_hashes[obj]
except KeyError:
Expand All @@ -24,11 +43,11 @@ def _decode(self, obj, context, path):
warnings.warn(UserWarning(msg))
return obj
raise construct.MappingError(
"parsing failed, "+msg,
"parsing failed, " + msg,
path=path,
)

def _encode(self, obj, context, path):
def _encode(self, obj: typing.Union[str, int], context, path):
try:
return self.known_hashes[obj]
except KeyError:
Expand All @@ -42,13 +61,34 @@ def _encode(self, obj, context, path):
return crc.crc64(obj)

raise construct.MappingError(
"building failed, "+msg,
"building failed, " + msg,
path=path
)

def _emitparse(self, code: construct.CodeGen):
n = self.hash_set.name
code.append("from mercury_engine_data_structures.formats.property_enum import HashSet")
code.append(f"known_hashes_{n}, inverted_hashes_{n} = HashSet.{n}.get_hashes()")

if self.allow_unknowns:
return f"reuse({self.subcon._compileparse(code)}, lambda key: inverted_hashes_{n}.get(key, key))"
else:
return f"inverted_hashes_{n}[{self.subcon._compileparse(code)}]"

def _emitbuild(self, code: construct.CodeGen):
if self.allow_unknowns:
raise NotImplementedError

n = self.hash_set.name
code.append("from mercury_engine_data_structures.formats.property_enum import HashSet")
code.append(f"known_hashes_{n}, inverted_hashes_{n} = HashSet.{n}.get_hashes()")

ret: str = self._raw_subcon._compilebuild(code)
return ret.replace(".pack(obj)", f".pack(known_hashes_{n}[obj])")


PropertyEnum = CRCAdapter(construct.Hex(construct.Int64ul), dread_data.all_name_to_property_id())
PropertyEnumUnsafe = CRCAdapter(construct.Hex(construct.Int64ul), dread_data.all_name_to_property_id(), True)
PropertyEnum = CRCAdapter(HashSet.DREAD_PROPERTY)
PropertyEnumUnsafe = CRCAdapter(HashSet.DREAD_PROPERTY, True)

FileNameEnum = CRCAdapter(construct.Hex(construct.Int64ul), dread_data.all_name_to_asset_id())
FileNameEnumUnsafe = CRCAdapter(construct.Hex(construct.Int64ul), dread_data.all_name_to_asset_id(), True, False)
FileNameEnum = CRCAdapter(HashSet.DREAD_FILE_NAME)
FileNameEnumUnsafe = CRCAdapter(HashSet.DREAD_FILE_NAME, True, False)

0 comments on commit d2c3ac3

Please sign in to comment.