Skip to content

Commit

Permalink
perf: ipyvue widget can use a faster less generate state_get
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Feb 9, 2024
1 parent eefd358 commit f7ef3f3
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions solara/server/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,58 @@ def Output_exit(self, exc_type, exc_value, traceback):
ip.display_pub._hooks.pop()


def patch_ipyvue_performance():
import functools
from collections.abc import Iterable

@functools.lru_cache(None) # type: ignore
def class_traits(cls, **metadata):
# we cache it for performance reasons
return cls.class_traits(**metadata)

@functools.lru_cache(None) # type: ignore
def to_jsons_meta(cls):
traits = class_traits(cls)
callables = {}
for name, trait in traits.items():
to_json = trait.metadata.get("to_json")
if to_json:
callables[name] = to_json
return callables

def get_state_fast(self, key=None, drop_defaults=False):
cls = type(self)
traits = class_traits(cls, sync=True) # type: ignore
if key is None:
keys = list(traits)
elif isinstance(key, str):
keys = [key]
elif isinstance(key, Iterable):
keys = list(key)
else:
raise ValueError("key must be a string, an iterable of keys, or None")
state = {}
to_jsons = to_jsons_meta(cls) # type: ignore
assert drop_defaults is False
trait_values = self._trait_values
for k in keys:
if k not in trait_values:
value = getattr(self, k)
else:
value = trait_values[k]
if k in to_jsons:
wire_value = to_jsons[k](value, self)
else:
# should we call _trait_to_json?
wire_value = value
state[k] = wire_value
return state

import ipyvue

ipyvue.VueWidget.get_state = get_state_fast


def patch():
global _patched
global global_widgets_dict
Expand All @@ -307,6 +359,12 @@ def patch():
_patched = True
__builtins__["display"] = IPython.display.display

if settings.main.experimental_performance:
# this might be a bit too much
# import traitlets
# traitlets.TraitType._validate = lambda self, trait, value: value

patch_ipyvue_performance()
# the ipyvue.Template module cannot be accessed like ipyvue.Template
# because the import in ipvue overrides it
template_mod = sys.modules["ipyvue.Template"]
Expand Down

0 comments on commit f7ef3f3

Please sign in to comment.