Skip to content

Commit

Permalink
Tunneling support, related UI improvements (#98)
Browse files Browse the repository at this point in the history
* Experimental tunneling support, related UI improvements

* Appease mypy
  • Loading branch information
brentyi authored Sep 16, 2023
1 parent 76628f7 commit dde0373
Show file tree
Hide file tree
Showing 13 changed files with 1,299 additions and 965 deletions.
59 changes: 30 additions & 29 deletions examples/01_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,38 @@

import viser

server = viser.ViserServer()

# Add a background image.
server.set_background_image(
iio.imread(Path(__file__).parent / "assets/Cal_logo.png"),
format="png",
)

# Add main image.
server.add_image(
"/img",
iio.imread(Path(__file__).parent / "assets/Cal_logo.png"),
4.0,
4.0,
format="png",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(2.0, 2.0, 0.0),
)
while True:
if __name__ == "__main__":
server = viser.ViserServer()

# Add a background image.
server.set_background_image(
iio.imread(Path(__file__).parent / "assets/Cal_logo.png"),
format="png",
)

# Add main image.
server.add_image(
"/noise",
onp.random.randint(
0,
256,
size=(400, 400, 3),
dtype=onp.uint8,
),
"/img",
iio.imread(Path(__file__).parent / "assets/Cal_logo.png"),
4.0,
4.0,
format="jpeg",
format="png",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(2.0, 2.0, -1e-2),
position=(2.0, 2.0, 0.0),
)
time.sleep(0.2)
while True:
server.add_image(
"/noise",
onp.random.randint(
0,
256,
size=(400, 400, 3),
dtype=onp.uint8,
),
4.0,
4.0,
format="jpeg",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(2.0, 2.0, -1e-2),
)
time.sleep(0.2)
3 changes: 2 additions & 1 deletion examples/07_record3d_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def main(
data_path: Path = Path(__file__).parent / "assets/record3d_dance",
downsample_factor: int = 4,
max_frames: int = 100,
share: bool = False,
) -> None:
server = viser.ViserServer()
server = viser.ViserServer(share=share)

print("Loading frames!")
loader = viser.extras.Record3dLoader(data_path)
Expand Down
3 changes: 2 additions & 1 deletion examples/08_smplx_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def main(
num_betas: int = 10,
num_expression_coeffs: int = 10,
ext: Literal["npz", "pkl"] = "npz",
share: bool = False,
) -> None:
server = viser.ViserServer()
server = viser.ViserServer(share=share)
server.configure_theme(control_layout="collapsible", dark_mode=True)
model = smplx.create(
model_path=str(model_path),
Expand Down
27 changes: 14 additions & 13 deletions src/viser/_message_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import base64
import colorsys
import io
import queue
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Dict, Optional, Tuple, TypeVar, Union, cast

import imageio.v3 as iio
Expand Down Expand Up @@ -147,8 +147,8 @@ def __init__(self, handler: infra.MessageHandler) -> None:
)

self._atomic_lock = threading.Lock()
self._queued_messages: queue.Queue = queue.Queue()
self._locked_thread_id = -1
self._queue_thread = ThreadPoolExecutor(max_workers=1)

def configure_theme(
self,
Expand Down Expand Up @@ -594,19 +594,20 @@ def reset_scene(self):

def _queue(self, message: _messages.Message) -> None:
"""Wrapped method for sending messages safely."""
# This implementation will retain message ordering because _queue_thread has
# just 1 worker.
from .infra._infra import error_print_wrapper
got_lock = self._atomic_lock.acquire(blocking=False)
if got_lock:
self._queue_unsafe(message)
self._atomic_lock.release()
else:
# Send when lock is acquirable, while retaining message order.
# This could be optimized!
self._queued_messages.put(message)

self._queue_thread.submit(
error_print_wrapper(lambda: self._queue_blocking(message))
)
def try_again():
with self._atomic_lock:
self._queue_unsafe(self._queued_messages.get())

def _queue_blocking(self, message: _messages.Message) -> None:
"""Wrapped method for sending messages safely. Blocks until ready to send."""
self._atomic_lock.acquire()
self._queue_unsafe(message)
self._atomic_lock.release()
threading.Thread(target=try_again).start()

@abc.abstractmethod
def _queue_unsafe(self, message: _messages.Message) -> None:
Expand Down
156 changes: 156 additions & 0 deletions src/viser/_tunnel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import asyncio
import multiprocessing as mp
import threading
import time
from multiprocessing.managers import DictProxy
from typing import Callable, Optional

import requests


class _ViserTunnel:
"""Tunneling utility for internal use."""

def __init__(self, local_port: int) -> None:
self._local_port = local_port
self._process: Optional[mp.Process] = None

manager = mp.Manager()
self._shared_state = manager.dict()
self._shared_state["status"] = "ready"
self._shared_state["url"] = None

def on_connect(self, callback: Callable[[], None]) -> None:
"""Establish the tunnel connection.
Returns URL if tunnel succeeds, otherwise None."""
assert self._process is None

self._shared_state["status"] = "connecting"

self._process = mp.Process(
target=_connect_job,
daemon=True,
args=(self._local_port, self._shared_state),
)
self._process.start()

def wait_job() -> None:
while self._shared_state["status"] == "connecting":
time.sleep(0.1)
callback()

threading.Thread(target=wait_job).start()

def get_url(self) -> Optional[str]:
"""Get tunnel URL. None if not connected (or connection failed)."""
return self._shared_state["url"]

def close(self) -> None:
"""Close the tunnel."""
if self._process is not None:
self._process.kill()
self._process.join()


def _connect_job(local_port: int, shared_state: DictProxy) -> None:
event_loop = asyncio.new_event_loop()
assert event_loop is not None
asyncio.set_event_loop(event_loop)
try:
event_loop.run_until_complete(_make_tunnel(local_port, shared_state))
except KeyboardInterrupt:
tasks = asyncio.all_tasks(event_loop)
for task in tasks:
task.cancel()
event_loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
event_loop.close()


async def _make_tunnel(local_port: int, shared_state: DictProxy) -> None:
share_domain = "share.viser.studio"

try:
response = requests.request(
"GET",
url=f"https://{share_domain}/?request_forward",
headers={"Content-Type": "application/json"},
)
if response.status_code != 200:
shared_state["status"] = "failed"
return
except requests.exceptions.ConnectionError:
shared_state["status"] = "failed"
return
except Exception as e:
shared_state["status"] = "failed"
raise e

res = response.json()
shared_state["url"] = res["url"]
shared_state["status"] = "connected"

def make_connection_task():
return asyncio.create_task(
connect(
"127.0.0.1",
local_port,
share_domain,
res["port"],
)
)

connection_tasks = [make_connection_task() for _ in range(res["max_conn_count"])]
await asyncio.gather(*connection_tasks)


async def pipe(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
while True:
data = await r.read(4096)
if len(data) == 0:
# Done!
break
w.write(data)
await w.drain()


async def connect(
local_host: str,
local_port: int,
remote_host: str,
remote_port: int,
) -> None:
"""Establish a connection to the tunnel server."""

while True:
local_w = None
remote_w = None
try:
local_r, local_w = await asyncio.open_connection(local_host, local_port)
remote_r, remote_w = await asyncio.open_connection(remote_host, remote_port)
await asyncio.wait(
[
asyncio.create_task(pipe(local_r, remote_w)),
asyncio.create_task(pipe(remote_r, local_w)),
],
return_when=asyncio.FIRST_COMPLETED,
)
except Exception:
pass
finally:
if local_w is not None:
local_w.close()
if remote_w is not None:
remote_w.close()


if __name__ == "__main__":
tunnel = _ViserTunnel(8080)
tunnel.on_connect(lambda: None)

time.sleep(2.0)
print("Trying to close")
tunnel.close()
print("Done trying to close")
time.sleep(10.0)
print("Exiting")
51 changes: 50 additions & 1 deletion src/viser/_viser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@

import numpy as onp
import numpy.typing as npt
import rich
from rich import box, style
from rich.panel import Panel
from rich.table import Table
from typing_extensions import override

from . import _client_autobuild, _messages, infra
from . import transforms as tf
from ._gui_api import GuiApi
from ._message_api import MessageApi, cast_vector
from ._scene_handles import FrameHandle, _SceneNodeHandleState
from ._tunnel import _ViserTunnel


@dataclasses.dataclass
Expand Down Expand Up @@ -234,13 +239,19 @@ class ViserServer(MessageApi, GuiApi):
Commands on a server object (`add_frame`, `add_gui_*`, ...) will be sent to all
clients, including new clients that connect after a command is called.
Args:
host: Host to bind server to.
port: Port to bind server to.
share: Experimental. If set to `True`, create and print a public, shareable URL
for this instance of viser.
"""

world_axes: FrameHandle
"""Handle for manipulating the world frame axes (/WorldAxes), which is instantiated
and then hidden by default."""

def __init__(self, host: str = "0.0.0.0", port: int = 8080):
def __init__(self, host: str = "0.0.0.0", port: int = 8080, share: bool = False):
server = infra.Server(
host=host,
port=port,
Expand Down Expand Up @@ -331,6 +342,38 @@ def _(conn: infra.ClientConnection) -> None:

# Start the server.
server.start()

# Form status print.
port = server._port # Port may have changed.
http_url = f"http://{host}:{port}"
ws_url = f"ws://{host}:{port}"
table = Table(
title=None,
show_header=False,
box=box.MINIMAL,
title_style=style.Style(bold=True),
)
table.add_row("HTTP", http_url)
table.add_row("Websocket", ws_url)
rich.print(Panel(table, title="[bold]viser[/bold]", expand=False))

# Create share tunnel if requested.
if not share:
self._share_tunnel = None
else:
self._share_tunnel = _ViserTunnel(port)

@self._share_tunnel.on_connect
def _() -> None:
assert self._share_tunnel is not None
share_url = self._share_tunnel.get_url()
if share_url is None:
rich.print("[bold](viser)[/bold] Could not generate share URL")
else:
rich.print(
f"[bold](viser)[/bold] Share URL (expires in 24 hours): {share_url}"
)

self.reset_scene()
self.world_axes = FrameHandle(
_SceneNodeHandleState(
Expand All @@ -342,6 +385,12 @@ def _(conn: infra.ClientConnection) -> None:
)
self.world_axes.visible = False

def stop(self) -> None:
"""Stop the Viser server and associated threads and tunnels."""
self._server.stop()
if self._share_tunnel is not None:
self._share_tunnel.close()

@override
def _get_api(self) -> MessageApi:
"""Message API to use."""
Expand Down
Loading

0 comments on commit dde0373

Please sign in to comment.