Skip to content

Commit

Permalink
Optionally restrict the range of ephemeral ports
Browse files Browse the repository at this point in the history
  • Loading branch information
sirf committed Oct 27, 2022
1 parent f08956b commit b0f060f
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/aioice/ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import socket
import threading
from itertools import count
from typing import Dict, List, Optional, Set, Text, Tuple, Union, cast
from typing import Dict, Iterable, List, Optional, Set, Text, Tuple, Union, cast

import netifaces

from . import mdns, stun, turn
from .candidate import Candidate, candidate_foundation, candidate_priority
from .utils import random_string
from .utils import create_datagram_endpoint, random_string

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -282,6 +282,7 @@ class Connection:
:param turn_transport: The transport for TURN server, `"udp"` or `"tcp"`.
:param use_ipv4: Whether to use IPv4 candidates.
:param use_ipv6: Whether to use IPv6 candidates.
:param ephemeral_ports: Set of allowed ephemeral local ports to bind to.
"""

def __init__(
Expand All @@ -296,6 +297,7 @@ def __init__(
turn_transport: str = "udp",
use_ipv4: bool = True,
use_ipv6: bool = True,
ephemeral_ports: Optional[Iterable[int]] = None,
) -> None:
self.ice_controlling = ice_controlling
#: Local username, automatically set to a random value.
Expand Down Expand Up @@ -340,6 +342,7 @@ def __init__(
self._tie_breaker = secrets.randbits(64)
self._use_ipv4 = use_ipv4
self._use_ipv6 = use_ipv6
self._ephemeral_ports = ephemeral_ports

@property
def local_candidates(self) -> List[Candidate]:
Expand Down Expand Up @@ -847,16 +850,14 @@ async def get_component_candidates(
self, component: int, addresses: List[str], timeout: int = 5
) -> List[Candidate]:
candidates = []
loop = asyncio.get_event_loop()

# gather host candidates
host_protocols = []
for address in addresses:
# create transport
try:
transport, protocol = await loop.create_datagram_endpoint(
lambda: StunProtocol(self), local_addr=(address, 0)
)
transport, protocol = await create_datagram_endpoint(
lambda: StunProtocol(self), local_address=address, ephemeral_ports=self._ephemeral_ports)
sock = transport.get_extra_info("socket")
if sock is not None:
sock.setsockopt(
Expand Down
35 changes: 35 additions & 0 deletions src/aioice/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import os
import random
import secrets
import string
from typing import Iterable, Optional, Tuple


def random_string(length: int) -> str:
Expand All @@ -10,3 +13,35 @@ def random_string(length: int) -> str:

def random_transaction_id() -> bytes:
return os.urandom(12)


async def create_datagram_endpoint(protocol_factory,
remote_addr: Tuple[str, int] = None,
local_address: str = None,
ephemeral_ports: Optional[Iterable[int]] = None,
):
"""
Asynchronousley create a datagram endpoint.
:param protocol_factory: Callable returning a protocol instance.
:param remote_addr: Remote address and port.
:param local_address: Local address to bind to.
:param ephemeral_ports: Set of allowed local ephemeral ports to bind to.
"""
if ephemeral_ports is not None:
ports = list(ephemeral_ports)
random.shuffle(ports)
else:
ports = (0,)
loop = asyncio.get_event_loop()
for port in ports:
try:
transport, protocol = await loop.create_datagram_endpoint(
protocol_factory, remote_addr=remote_addr, local_addr=(local_address, port)
)
return transport, protocol
except OSError as exc:
if port == ports[-1]:
# this was the last port, give up
raise exc
raise ValueError("ephemeral_ports must not be empty")
53 changes: 53 additions & 0 deletions tests/test_ice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import os
import random
import socket
import unittest
from unittest import mock
Expand Down Expand Up @@ -1161,6 +1162,58 @@ async def test_repr(self):
conn._id = 1
self.assertEqual(repr(conn), "Connection(1)")

@asynctest
async def test_connection_ephemeral_ports(self):
addresses = ["127.0.0.1"]

# Let the OS pick a random port - should never fail
conn1 = ice.Connection(ice_controlling=True)
c = await conn1.get_component_candidates(0, addresses)
self.assertTrue(c[0].port >= 1 and c[0].port <= 65535)

# Try opening a new connection with the same port - should always fail
conn2 = ice.Connection(ice_controlling=True, ephemeral_ports=[c[0].port])
c = await conn2.get_component_candidates(0, addresses)
self.assertTrue(len(c) == 0) # port already in use, no candidates
await conn1.close()

# Empty set of ports - illegal argument
conn3 = ice.Connection(ice_controlling=True, ephemeral_ports=[])
with self.assertRaises(ValueError) as exc:
await conn3.get_component_candidates(0, addresses)

# Range of ports
lower = random.randint(1024, 65536 - 100)
upper = lower + 100
ports = list(range(lower, upper))

conn4 = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn4.get_component_candidates(0, addresses)
self.assertTrue(c[0].port >= lower and c[0].port < upper)

# Exhaust the range of ports
conns = []
for i in range(1, len(ports)):
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn.get_component_candidates(i, addresses)
self.assertTrue(c[0].port >= lower and c[0].port < lower + len(ports))
conns.append(conn)

# Open one more connection from the same range - should always fail
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn.get_component_candidates(0, addresses)
self.assertTrue(len(c) == 0) # all ports are already in use, no candidates

# Close one connection and try again - should never fail
await conn4.close()
conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports)
c = await conn.get_component_candidates(0, addresses)
self.assertTrue(c[0].port >= lower and c[0].port < lower + len(ports))

# cleanup
await conn.close()
for conn in conns:
await conn.close()

class StunProtocolTest(unittest.TestCase):
@asynctest
Expand Down

0 comments on commit b0f060f

Please sign in to comment.