Skip to content

Commit

Permalink
Simplify tunneling
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Sep 17, 2023
1 parent ca39637 commit 1ab6762
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 39 deletions.
72 changes: 33 additions & 39 deletions src/viser/_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,58 +90,52 @@ async def _make_tunnel(local_port: int, shared_state: DictProxy) -> None:
shared_state["url"] = res["url"]
shared_state["status"] = "connected"

def make_connection_task():
return asyncio.create_task(
connect(
"127.0.0.1",
local_port,
share_domain,
res["port"],
await asyncio.gather(
*[
asyncio.create_task(
_simple_proxy(
"127.0.0.1",
local_port,
share_domain,
res["port"],
)
)
)

connection_tasks = [make_connection_task() for _ in range(res["max_conn_count"])]
await asyncio.gather(*connection_tasks)


async def pipe(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
while True:
data = await r.read(4096)
if len(data) == 0:
# Done!
break
w.write(data)
await w.drain()
for _ in range(res["max_conn_count"])
]
)


async def connect(
async def _simple_proxy(
local_host: str,
local_port: int,
remote_host: str,
remote_port: int,
) -> None:
"""Establish a connection to the tunnel server."""

while True:
local_w = None
remote_w = None
async def relay(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None:
try:
local_r, local_w = await asyncio.open_connection(local_host, local_port)
remote_r, remote_w = await asyncio.open_connection(remote_host, remote_port)
await asyncio.wait(
[
asyncio.create_task(pipe(local_r, remote_w)),
asyncio.create_task(pipe(remote_r, local_w)),
],
return_when=asyncio.FIRST_COMPLETED,
)
except Exception as e:
while True:
data = await r.read(4096)
if len(data) == 0:
# Done!
break
w.write(data)
await w.drain()
except:
pass
finally:
if local_w is not None:
local_w.close()
if remote_w is not None:
remote_w.close()
w.close()
await w.wait_closed()

while True:
local_r, local_w = await asyncio.open_connection(local_host, local_port)
remote_r, remote_w = await asyncio.open_connection(remote_host, remote_port)

await asyncio.gather(
asyncio.create_task(relay(local_r, remote_w)),
asyncio.create_task(relay(remote_r, local_w)),
)

# Throttle connection attempts.
await asyncio.sleep(0.1)
Expand Down
1 change: 1 addition & 0 deletions src/viser/_viser.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def _(conn: infra.ClientConnection) -> None:
if not share:
self._share_tunnel = None
else:
rich.print("[bold](viser)[/bold] Share URL requested!")
self._share_tunnel = _ViserTunnel(port)

@self._share_tunnel.on_connect
Expand Down
1 change: 1 addition & 0 deletions src/viser/client/src/ControlPanel/ControlPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ function ConnectionStatus() {
<Loader
size="xs"
variant="bars"
color="red"
style={{ position: "absolute", ...styles }}
/>
)}
Expand Down

0 comments on commit 1ab6762

Please sign in to comment.