Skip to content

Commit

Permalink
Merge pull request #69 from randovania/feature/optimization
Browse files Browse the repository at this point in the history
A lot of optimization (234% faster!)
  • Loading branch information
henriquegemignani authored Aug 29, 2023
2 parents a479d23 + 9da482f commit 21b11c8
Show file tree
Hide file tree
Showing 24 changed files with 496 additions and 212 deletions.
49 changes: 39 additions & 10 deletions src/mercury_engine_data_structures/common_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import functools
import typing

import construct
Expand All @@ -16,6 +17,21 @@
CVector4D = construct.Array(4, Float)


def _vector_emitparse(length: int, code: construct.CodeGen):
code.append(f"CVector{length}D_Format = struct.Struct('<{length}f')")
return f"ListContainer(CVector{length}D_Format.unpack(io.read({length * 4})))"


def _vector_emitbuild(length: int, code: construct.CodeGen):
code.append(f"CVector{length}D_Format = struct.Struct('<{length}f')")
return f"(io.write(CVector{length}D_Format.pack(*obj)), obj)"


for i, vec in enumerate([CVector2D, CVector3D, CVector4D]):
vec._emitparse = functools.partial(_vector_emitparse, i + 2)
vec._emitbuild = functools.partial(_vector_emitbuild, i + 2)


class ListContainerWithKeyAccess(construct.ListContainer):
def __init__(self, item_key_field: str, item_value_field: str = "value"):
super().__init__()
Expand Down Expand Up @@ -52,7 +68,8 @@ def __init__(self, subcon, *, allow_duplicates: bool = False):
super().__init__(subcon)
self.allow_duplicates = allow_duplicates

def _decode(self, obj: construct.ListContainer, context, path):
def _decode(self, obj: construct.ListContainer, context: construct.Container, path: str,
) -> construct.ListContainer | construct.Container:
result = construct.Container()
for item in obj:
key = item.key
Expand All @@ -63,7 +80,8 @@ def _decode(self, obj: construct.ListContainer, context, path):
result[key] = item.value
return result

def _encode(self, obj: construct.Container, context, path):
def _encode(self, obj: construct.ListContainer | construct.Container, context: construct.Container, path: str,
) -> list:
if self.allow_duplicates and isinstance(obj, list):
return obj
return construct.ListContainer(
Expand All @@ -73,30 +91,39 @@ def _encode(self, obj: construct.Container, context, path):

def _emitparse(self, code):
fname = f"parse_dict_adapter_{code.allocateId()}"
if self.allow_duplicates:
on_duplicate = "return obj"
else:
on_duplicate = 'raise ConstructError("Duplicated keys found in object")'

block = f"""
def {fname}(io, this):
obj = {self.subcon._compileparse(code)}
result = Container()
for item in obj:
result[item.key] = item.value
if len(result) != len(obj):
raise ConstructError("Duplicated keys found in object")
{on_duplicate}
return result
"""
code.append(block)
return f"{fname}(io, this)"

def _emitbuild(self, code):
fname = f"build_dict_adapter_{code.allocateId()}"
block = f"""
wrap = "obj = ListContainer(Container(key=type_, value=item) for type_, item in original_obj.items())"
if self.allow_duplicates:
wrap = f"""
if isinstance(original_obj, list):
obj = original_obj
else:
{wrap}
"""
code.append(f"""
def {fname}(original_obj, io, this):
obj = ListContainer(
Container(key=type_, value=item)
for type_, item in original_obj.items()
)
{wrap}
return {self.subcon._compilebuild(code)}
"""
code.append(block)
""")
return f"{fname}(obj, io, this)"


Expand Down Expand Up @@ -265,11 +292,13 @@ def make_vector(value: construct.Construct):

def _emitparse(code):
return f"ListContainer(({value._compileparse(code)}) for i in range({construct.Int32ul._compileparse(code)}))"

result._emitparse = _emitparse

def _emitbuild(code):
return (f"(reuse(len(obj), lambda obj: {construct.Int32ul._compilebuild(code)}),"
f" list({value._compilebuild(code)} for obj in obj), obj)[2]")

result._emitbuild = _emitbuild

return result
Expand Down
20 changes: 20 additions & 0 deletions src/mercury_engine_data_structures/construct_extensions/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ def _encode(self, obj: typing.Union[str, enum.IntEnum, int], context, path) -> i

return obj

def _emitbuild(self, code: construct.CodeGen):
i = code.allocateId()

mapping = ", ".join(
f"{repr(enum_entry.name)}: {enum_entry.value}"
for enum_entry in self.enum_class
)

code.append(f"""
_enum_name_to_value_{i} = {{{mapping}}}
def _encode_enum_{i}(io, obj):
# {self.name}
try:
obj = obj.value
except AttributeError:
obj = _enum_name_to_value_{i}.get(obj, obj)
return {construct.Int32ul._compilebuild(code)}
""")
return f"_encode_enum_{i}(io, obj)"


def BitMaskEnum(enum_type: typing.Type[enum.IntEnum]):
flags = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import construct


class SwitchComplexKey(construct.Switch):
def _insert_keyfunc(self, code: construct.CodeGen):
if id(self.keyfunc) not in code.linkedinstances:
code.linkedinstances[id(self.keyfunc)] = self.keyfunc
return f"linkedinstances[{id(self.keyfunc)}](this)"

def _emitparse(self, code: construct.CodeGen):
fname = f"switch_cases_{code.allocateId()}"
code.append(f"{fname} = {{}}")
for key, sc in self.cases.items():
code.append(f"{fname}[{repr(key)}] = lambda io,this: {sc._compileparse(code)}")
defaultfname = f"switch_defaultcase_{code.allocateId()}"
code.append(f"{defaultfname} = lambda io,this: {self.default._compileparse(code)}")
return f"{fname}.get({self._insert_keyfunc(code)}, {defaultfname})(io, this)"

def _emitbuild(self, code: construct.CodeGen):
fname = f"switch_cases_{code.allocateId()}"
code.append(f"{fname} = {{}}")
for key, sc in self.cases.items():
code.append(f"{fname}[{repr(key)}] = lambda obj,io,this: {sc._compilebuild(code)}")
defaultfname = f"switch_defaultcase_{code.allocateId()}"
code.append(f"{defaultfname} = lambda obj,io,this: {self.default._compilebuild(code)}")
return f"{fname}.get({self._insert_keyfunc(code)}, {defaultfname})(obj, io, this)"


class ComplexIfThenElse(construct.IfThenElse):
def _insert_cond(self, code: construct.CodeGen):
if id(self.condfunc) not in code.linkedinstances:
code.linkedinstances[id(self.condfunc)] = self.condfunc
return f"linkedinstances[{id(self.condfunc)}](this)"

def _emitparse(self, code):
return "(({}) if ({}) else ({}))".format(self.thensubcon._compileparse(code),
self._insert_cond(code),
self.elsesubcon._compileparse(code),)

def _emitbuild(self, code):
return (f"(({self.thensubcon._compilebuild(code)}) if ("
f"{self._insert_cond(code)}) else ({self.elsesubcon._compilebuild(code)}))")


def ComplexIf(condfunc, subcon):
return ComplexIfThenElse(condfunc, subcon, construct.Pass)
64 changes: 62 additions & 2 deletions src/mercury_engine_data_structures/construct_extensions/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ def PaddedStringRobust(length, encoding):
u'Афон'
"""
macro = StringEncodedRobust(FixedSized(length, NullStripped(GreedyBytes, pad=encodingunit(encoding))), encoding)

def _emitfulltype(ksy, bitwise):
return dict(size=length, type="strz", encoding=encoding)

macro._emitfulltype = _emitfulltype
return macro

Expand Down Expand Up @@ -79,13 +81,15 @@ def PascalStringRobust(lengthfield: construct.Construct, encoding):

def _emitparse(code):
return f"io.read({lengthfield._compileparse(code)}).decode({repr(encoding)})"

macro._emitparse = _emitparse

def _emitseq(ksy, bitwise):
return [
dict(id="lengthfield", type=lengthfield._compileprimitivetype(ksy, bitwise)),
dict(id="data", size="lengthfield", type="str", encoding=encoding),
]

macro._emitseq = _emitseq

def _emitbuild(code: construct.CodeGen):
Expand All @@ -97,9 +101,9 @@ def _emitbuild(code: construct.CodeGen):

macro._emitbuild = _emitbuild


return macro


def CStringRobust(encoding):
r"""
String ending in a terminating null byte (or null bytes in case of UTF16 UTF32).
Expand All @@ -121,12 +125,66 @@ def CStringRobust(encoding):
>>> d.parse(_)
u'Афон'
"""
macro = StringEncodedRobust(NullTerminated(GreedyBytes, term=encodingunit(encoding)), encoding)
term = encodingunit(encoding)
macro = StringEncodedRobust(NullTerminated(GreedyBytes, term=term), encoding)

expected_size = 16

def _emitparse(code: construct.CodeGen):
i = code.allocateId()
code.append(f"""
def read_util_term_{i}(io):
try:
# Assume it's a BytesIO. Then use bytes.find to do hard work on C
b = io.getvalue()
end = b.find({repr(term)}, io.tell())
if end == -1:
raise StreamError
data = io.read(end - io.tell())
io.read({len(term)})
return data
except AttributeError:
# not a BytesIO
pass
data = bytearray()
while True:
before = io.tell()
b = io.read({len(term) * expected_size})
pos = b.find({repr(term)})
if pos != -1:
io.seek(before + pos + {len(term)})
data += b[:pos]
break
if len(b) < {len(term) * expected_size}:
io.seek(before)
b = io.read({len(term)})
if b == {repr(term)}:
break
elif len(b) < {len(term)}:
raise StreamError
data += b
return data
""")

return f"read_util_term_{i}(io).decode({repr(encoding)})"

macro._emitparse = _emitparse

def _emitfulltype(ksy, bitwise):
return dict(type="strz", encoding=encoding)

macro._emitfulltype = _emitfulltype

def _emitbuild(code: construct.CodeGen):
return f"(io.write(obj.encode({repr(encoding)})), io.write({repr(term)}), obj)[-1]"

macro._emitbuild = _emitbuild

return macro


def GreedyStringRobust(encoding):
r"""
String that reads entire stream until EOF, and writes a given string as-is.
Expand All @@ -147,7 +205,9 @@ def GreedyStringRobust(encoding):
u'Афон'
"""
macro = StringEncodedRobust(GreedyBytes, encoding)

def _emitfulltype(ksy, bitwise):
return dict(size_eos=True, type="str", encoding=encoding)

macro._emitfulltype = _emitfulltype
return macro
38 changes: 25 additions & 13 deletions src/mercury_engine_data_structures/file_tree_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ class OutputFormat(enum.Enum):
ROMFS = enum.auto()


def _find_entry_for_asset_id(asset_id: AssetId, pkg_header):
for entry in pkg_header.file_entries:
if entry.asset_id == asset_id:
return entry


def _read_file_with_entry(path: Path, entry):
with path.open("rb") as f:
f.seek(entry.start_offset)
Expand Down Expand Up @@ -125,11 +119,13 @@ def _update_headers(self):

self._ensured_asset_ids[name] = set()

self.headers[name].entries_by_id = {}
for entry in self.headers[name].file_entries:
if self._toc.get_size_for(entry.asset_id) is None:
logger.warning("File with asset id 0x%016x in pkg %s does not have an entry in the TOC",
entry.asset_id, name)
self._add_pkg_name_for_asset_id(entry.asset_id, name)
self.headers[name].entries_by_id[entry.asset_id] = entry

def all_asset_ids(self) -> Iterator[AssetId]:
"""
Expand Down Expand Up @@ -184,7 +180,7 @@ def get_raw_asset(self, asset_id: NameOrAssetId, *, in_pkg: Optional[str] = None
if in_pkg is not None and name != in_pkg:
continue

entry = _find_entry_for_asset_id(asset_id, header)
entry = header.entries_by_id.get(asset_id)
if entry is not None:
logger.info("Reading asset %s from pkg %s", str(original_name), name)
return _read_file_with_entry(self.path_for_pkg(name), entry)
Expand Down Expand Up @@ -305,7 +301,12 @@ def get_pkg(self, pkg_name: str) -> Pkg:

return self._in_memory_pkgs[pkg_name]

def save_modifications(self, output_path: Path, output_format: OutputFormat):
def save_modifications(self, output_path: Path, output_format: OutputFormat, *, finalize_editor: bool = True):
"""Creates a mod file in the given output format with all the modifications requested.
:param output_path: Where to write the mod files.
:param output_format: If we should create PKG files or not.
:param finalize_editor: If set, this editor will no longer be usable after this function, but is faster.
"""
replacements = []
modified_pkgs = set()
asset_ids_to_copy = {}
Expand All @@ -316,6 +317,11 @@ def save_modifications(self, output_path: Path, output_format: OutputFormat):
if None in modified_pkgs:
modified_pkgs.remove(None)

if output_format == OutputFormat.ROMFS:
# Clear modified_pkgs, so we don't read/write any new pkg
# We keep system.pkg because .bmmaps don't read properly with exlaunch and it's only 4MB
modified_pkgs = list(filter(lambda pkg: pkg == "packs/system/system.pkg", modified_pkgs))

# Ensure all pkgs we'll modify is in memory already.
# We'll need to read these files anyway to modify, so do it early to speedup
# the get_raw_assets for _ensured_asset_ids.
Expand Down Expand Up @@ -367,10 +373,6 @@ def save_modifications(self, output_path: Path, output_format: OutputFormat):
}, indent=4)
output_path.joinpath("replacements.json").write_text(replacement_json, "utf-8")

# Clear modified_pkgs so we don't write any new pkg
# We keep system.pkg because .bmmaps don't read properly with exlaunch and it's only 4MB
modified_pkgs = list(filter(lambda pkg: pkg == "packs/system/system.pkg", modified_pkgs))

# Update the PKGs
for pkg_name in modified_pkgs:
logger.info("Updating %s", pkg_name)
Expand Down Expand Up @@ -407,4 +409,14 @@ def save_modifications(self, output_path: Path, output_format: OutputFormat):
)

self._modified_resources = {}
self._update_headers()
if finalize_editor:
# _update_headers has significant runtime costs, so avoid it.
# But lets delete these attributes so further use of this object fails explicitly
del self.all_pkgs
del self.headers
del self._ensured_asset_ids
del self._files_for_asset_id
del self._name_for_asset_id
del self._toc
else:
self._update_headers()
Loading

0 comments on commit 21b11c8

Please sign in to comment.