Skip to content

Commit

Permalink
fix: close playwright properly when building ssg (#581)
Browse files Browse the repository at this point in the history
* fix: close playwright properly when building ssg

* simplify by storing in a list of tuples

---------

Co-authored-by: Maarten A. Breddels <[email protected]>
  • Loading branch information
iisakkirotko and maartenbreddels committed Apr 4, 2024
1 parent dc2cfe1 commit 51b366d
Showing 1 changed file with 45 additions and 16 deletions.
61 changes: 45 additions & 16 deletions packages/solara-enterprise/solara_enterprise/ssg.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import concurrent.futures.thread
import logging
import multiprocessing.pool
import threading
import time
import typing
import urllib
import weakref
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple

import solara
from rich import print as rprint
Expand All @@ -29,7 +30,7 @@ class Playwright(threading.local):


pw = Playwright()
playwrights: List[Playwright] = []
_used: List[Tuple["playwright.sync_api.Browser", "playwright.sync_api._context_manager.PlaywrightContextManager"]] = []


class SSGData(TypedDict):
Expand All @@ -44,15 +45,48 @@ def _get_playwright():
return pw
from playwright.sync_api import sync_playwright

pw.number = 42
pw.context_manager = sync_playwright()
pw.sync_playwright = pw.context_manager.start()

pw.browser = pw.sync_playwright.chromium.launch(headless=not settings.ssg.headed)
pw.page = pw.browser.new_page()
playwrights.append(pw)
_used.append((pw.browser, pw.context_manager))
return pw


def _worker_with_cleanup(*args, **kwargs):
try:
concurrent.futures.thread._worker(*args, **kwargs)
finally:
pw = _get_playwright()
pw.browser.close()
pw.context_manager.__exit__(None, None, None)


class CleanupThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
def _adjust_thread_count(self):
# copy of the original code with _worker replaced
# if idle threads are available, don't spin new threads
if self._idle_semaphore.acquire(timeout=0):
return

# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue):
q.put(None)

num_threads = len(self._threads)
if num_threads < self._max_workers:
thread_name = "%s_%d" % (self._thread_name_prefix or self, num_threads)
t = threading.Thread(
name=thread_name, target=_worker_with_cleanup, args=(weakref.ref(self, weakref_cb), self._work_queue, self._initializer, self._initargs)
)
t.start()
self._threads.add(t) # type: ignore
concurrent.futures.thread._threads_queues[t] = self._work_queue # type: ignore


def ssg_crawl(base_url: str):
license.check("SSG")
import solara.server.app
Expand All @@ -69,31 +103,26 @@ def ssg_crawl(base_url: str):
# although in theory we should be able to run this with multiple threads
# there are issues with uvloop:
# e.g.: "Racing with another loop to spawn a process."
thread_pool = multiprocessing.pool.ThreadPool(1)
thread_pool = CleanupThreadPoolExecutor(max_workers=1)

results = []
for route in routes:
results.append(thread_pool.apply_async(ssg_crawl_route, [f"{base_url}/", route, build_path, thread_pool]))
results.append(thread_pool.submit(ssg_crawl_route, f"{base_url}/", route, build_path, thread_pool))

def wait(async_result):
results = async_result.get()
results = async_result.result()
for result in results:
wait(result)

for result in results:
wait(result)
thread_pool.close()
thread_pool.join()
for pw in playwrights:
assert pw.browser is not None
assert pw.context_manager is not None
pw.browser.close()
pw.context_manager.__exit__(None, None, None)

thread_pool.shutdown()

rprint("Done building SSG")


def ssg_crawl_route(base_url: str, route: solara.Route, build_path: Path, thread_pool: multiprocessing.pool.ThreadPool):
def ssg_crawl_route(base_url: str, route: solara.Route, build_path: Path, thread_pool: CleanupThreadPoolExecutor):
# if route
url = base_url + (route.path if route.path != "/" else "")
if not route.children:
Expand Down Expand Up @@ -146,7 +175,7 @@ def ssg_crawl_route(base_url: str, route: solara.Route, build_path: Path, thread
rprint(f"Skipping existing render: {path}")
results = []
for child in route.children:
result = thread_pool.apply_async(ssg_crawl_route, [url + "/", child, build_path / Path(route.path), thread_pool])
result = thread_pool.submit(ssg_crawl_route, url + "/", child, build_path / Path(route.path), thread_pool)
results.append(result)
return results

Expand Down

0 comments on commit 51b366d

Please sign in to comment.