Skip to content

Commit

Permalink
Optionally restrict the range of ephemeral ports
Browse files Browse the repository at this point in the history
  • Loading branch information
sirf committed Oct 27, 2022
1 parent f08956b commit 439c5e8
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/aioice/ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import socket
import threading
from itertools import count
from typing import Dict, List, Optional, Set, Text, Tuple, Union, cast
from typing import Dict, Iterable, List, Optional, Set, Text, Tuple, Union, cast

import netifaces

from . import mdns, stun, turn
from .candidate import Candidate, candidate_foundation, candidate_priority
from .utils import random_string
from .utils import create_datagram_endpoint, random_string

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -282,6 +282,7 @@ class Connection:
:param turn_transport: The transport for TURN server, `"udp"` or `"tcp"`.
:param use_ipv4: Whether to use IPv4 candidates.
:param use_ipv6: Whether to use IPv6 candidates.
:param ephemeral_ports: Set of allowed ephemeral local ports to bind to.
"""

def __init__(
Expand All @@ -296,6 +297,7 @@ def __init__(
turn_transport: str = "udp",
use_ipv4: bool = True,
use_ipv6: bool = True,
ephemeral_ports: Optional[Iterable[int]] = None,
) -> None:
self.ice_controlling = ice_controlling
#: Local username, automatically set to a random value.
Expand Down Expand Up @@ -340,6 +342,7 @@ def __init__(
self._tie_breaker = secrets.randbits(64)
self._use_ipv4 = use_ipv4
self._use_ipv6 = use_ipv6
self._ephemeral_ports = ephemeral_ports

@property
def local_candidates(self) -> List[Candidate]:
Expand Down Expand Up @@ -847,16 +850,14 @@ async def get_component_candidates(
self, component: int, addresses: List[str], timeout: int = 5
) -> List[Candidate]:
candidates = []
loop = asyncio.get_event_loop()

# gather host candidates
host_protocols = []
for address in addresses:
# create transport
try:
transport, protocol = await loop.create_datagram_endpoint(
lambda: StunProtocol(self), local_addr=(address, 0)
)
transport, protocol = await create_datagram_endpoint(
lambda: StunProtocol(self), local_address=address, local_ports=self._ephemeral_ports)
sock = transport.get_extra_info("socket")
if sock is not None:
sock.setsockopt(
Expand Down
35 changes: 35 additions & 0 deletions src/aioice/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import os
import random
import secrets
import string
from typing import Iterable, Optional, Tuple


def random_string(length: int) -> str:
Expand All @@ -10,3 +13,35 @@ def random_string(length: int) -> str:

def random_transaction_id() -> bytes:
return os.urandom(12)


async def create_datagram_endpoint(protocol_factory,
remote_addr: Tuple[str, int] = None,
local_address: str = None,
local_ports: Optional[Iterable[int]] = None,
):
"""
Asynchronousley create a datagram endpoint.
:param protocol_factory: Callable returning a protocol instance.
:param remote_addr: Remote address and port.
:param local_address: Local address to bind to.
:param local_ports: Set of allowed local ports to bind to.
"""
if local_ports is not None:
ports = list(local_ports)
random.shuffle(ports)
else:
ports = (0,)
loop = asyncio.get_event_loop()
for port in ports:
try:
transport, protocol = await loop.create_datagram_endpoint(
protocol_factory, remote_addr=remote_addr, local_addr=(local_address, port)
)
return transport, protocol
except OSError as exc:
if port == ports[-1]:
# this was the last port, give up
raise exc
raise ValueError("local_ports must not be empty")
62 changes: 62 additions & 0 deletions tests/test_ice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import os
import random
import socket
import unittest
from unittest import mock
Expand Down Expand Up @@ -1161,6 +1162,67 @@ async def test_repr(self):
conn._id = 1
self.assertEqual(repr(conn), "Connection(1)")

@asynctest
async def test_connection_ephemeral_ports(self):
addresses = ["127.0.0.1"]

# Let the OS pick a random port - should always yield a candidate
conn1 = ice.Connection(ice_controlling=True)
c = await conn1.get_component_candidates(0, addresses)
self.assertTrue(c[0].port >= 1 and c[0].port <= 65535)

# Try opening a new connection with the same port - should never yield candidates
conn2 = ice.Connection(ice_controlling=True, ephemeral_ports=[c[0].port])
c = await conn2.get_component_candidates(0, addresses)
self.assertEqual(len(c), 0) # port already in use, no candidates
await conn1.close()

# Empty set of ports - illegal argument
conn3 = ice.Connection(ice_controlling=True, ephemeral_ports=[])
with self.assertRaises(ValueError):
await conn3.get_component_candidates(0, addresses)

# Range of 100 ports
lower = random.randint(1024, 65536 - 100)
upper = lower + 100
ports = set(range(lower, upper))
try:
# The MDNS port is often in use, avoid it
ports.remove(mdns.MDNS_PORT)
except KeyError:
pass

# Exhaust the range of ports - should always yield candidates
conns = []
for i in range(0, len(ports)):
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn.get_component_candidates(i, addresses)
if c:
self.assertTrue(c[0].port >= lower and c[0].port < upper)
conns.append(conn)
self.assertGreaterEqual(len(conns), len(ports) - 1) # Account for at most 1 port in use by another process

# Open one more connection from the same range - should never yield candidates
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn.get_component_candidates(0, addresses)
self.assertEqual(len(c), 0) # all ports are exhausted, no candidates

# Close one connection and try again - should always yield a candidate
await conns.pop().close()
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn.get_component_candidates(0, addresses)
self.assertTrue(c[0].port >= lower and c[0].port < upper)
await conn.close()

# cleanup
for conn in conns:
await conn.close()

# Bind to wildcard local address - should always yield a candidate
conn = ice.Connection(ice_controlling=True)
c = await conn.get_component_candidates(0, [None])
self.assertTrue(c[0].port >= 1 and c[0].port <= 65535)
await conn.close()

class StunProtocolTest(unittest.TestCase):
@asynctest
Expand Down

0 comments on commit 439c5e8

Please sign in to comment.