diff --git a/src/aioice/ice.py b/src/aioice/ice.py index f33a8dc..35d1d37 100644 --- a/src/aioice/ice.py +++ b/src/aioice/ice.py @@ -8,13 +8,13 @@ import socket import threading from itertools import count -from typing import Dict, List, Optional, Set, Text, Tuple, Union, cast +from typing import Dict, Iterable, List, Optional, Set, Text, Tuple, Union, cast import netifaces from . import mdns, stun, turn from .candidate import Candidate, candidate_foundation, candidate_priority -from .utils import random_string +from .utils import create_datagram_endpoint, random_string logger = logging.getLogger(__name__) @@ -282,6 +282,7 @@ class Connection: :param turn_transport: The transport for TURN server, `"udp"` or `"tcp"`. :param use_ipv4: Whether to use IPv4 candidates. :param use_ipv6: Whether to use IPv6 candidates. + :param ephemeral_ports: Set of allowed ephemeral local ports to bind to. """ def __init__( @@ -296,6 +297,7 @@ def __init__( turn_transport: str = "udp", use_ipv4: bool = True, use_ipv6: bool = True, + ephemeral_ports: Optional[Iterable[int]] = None, ) -> None: self.ice_controlling = ice_controlling #: Local username, automatically set to a random value. @@ -340,6 +342,7 @@ def __init__( self._tie_breaker = secrets.randbits(64) self._use_ipv4 = use_ipv4 self._use_ipv6 = use_ipv6 + self._ephemeral_ports = ephemeral_ports @property def local_candidates(self) -> List[Candidate]: @@ -847,16 +850,14 @@ async def get_component_candidates( self, component: int, addresses: List[str], timeout: int = 5 ) -> List[Candidate]: candidates = [] - loop = asyncio.get_event_loop() # gather host candidates host_protocols = [] for address in addresses: # create transport try: - transport, protocol = await loop.create_datagram_endpoint( - lambda: StunProtocol(self), local_addr=(address, 0) - ) + transport, protocol = await create_datagram_endpoint( + lambda: StunProtocol(self), local_address=address, local_ports=self._ephemeral_ports) sock = transport.get_extra_info("socket") if sock is not None: sock.setsockopt( diff --git a/src/aioice/utils.py b/src/aioice/utils.py index a292edf..611bede 100644 --- a/src/aioice/utils.py +++ b/src/aioice/utils.py @@ -1,6 +1,9 @@ +import asyncio import os +import random import secrets import string +from typing import Iterable, Optional, Tuple def random_string(length: int) -> str: @@ -10,3 +13,35 @@ def random_string(length: int) -> str: def random_transaction_id() -> bytes: return os.urandom(12) + + +async def create_datagram_endpoint(protocol_factory, + remote_addr: Tuple[str, int] = None, + local_address: str = None, + local_ports: Optional[Iterable[int]] = None, +): + """ + Asynchronousley create a datagram endpoint. + + :param protocol_factory: Callable returning a protocol instance. + :param remote_addr: Remote address and port. + :param local_address: Local address to bind to. + :param local_ports: Set of allowed local ports to bind to. + """ + if local_ports is not None: + ports = list(local_ports) + random.shuffle(ports) + else: + ports = (0,) + loop = asyncio.get_event_loop() + for port in ports: + try: + transport, protocol = await loop.create_datagram_endpoint( + protocol_factory, remote_addr=remote_addr, local_addr=(local_address, port) + ) + return transport, protocol + except OSError as exc: + if port == ports[-1]: + # this was the last port, give up + raise exc + raise ValueError("local_ports must not be empty") diff --git a/tests/test_ice.py b/tests/test_ice.py index fb5e308..64c9d39 100644 --- a/tests/test_ice.py +++ b/tests/test_ice.py @@ -1,6 +1,7 @@ import asyncio import functools import os +import random import socket import unittest from unittest import mock @@ -1161,6 +1162,67 @@ async def test_repr(self): conn._id = 1 self.assertEqual(repr(conn), "Connection(1)") + @asynctest + async def test_connection_ephemeral_ports(self): + addresses = ["127.0.0.1"] + + # Let the OS pick a random port - should always yield a candidate + conn1 = ice.Connection(ice_controlling=True) + c = await conn1.get_component_candidates(0, addresses) + self.assertTrue(c[0].port >= 1 and c[0].port <= 65535) + + # Try opening a new connection with the same port - should never yield candidates + conn2 = ice.Connection(ice_controlling=True, ephemeral_ports=[c[0].port]) + c = await conn2.get_component_candidates(0, addresses) + self.assertEqual(len(c), 0) # port already in use, no candidates + await conn1.close() + + # Empty set of ports - illegal argument + conn3 = ice.Connection(ice_controlling=True, ephemeral_ports=[]) + with self.assertRaises(ValueError): + await conn3.get_component_candidates(0, addresses) + + # Range of 100 ports + lower = random.randint(1024, 65536 - 100) + upper = lower + 100 + ports = set(range(lower, upper)) + try: + # The MDNS port is often in use, avoid it + ports.remove(mdns.MDNS_PORT) + except KeyError: + pass + + # Exhaust the range of ports - should always yield candidates + conns = [] + for i in range(0, len(ports)): + conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports) + c = await conn.get_component_candidates(i, addresses) + if c: + self.assertTrue(c[0].port >= lower and c[0].port < upper) + conns.append(conn) + self.assertGreaterEqual(len(conns), len(ports) - 1) # Account for at most 1 port in use by another process + + # Open one more connection from the same range - should never yield candidates + conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports) + c = await conn.get_component_candidates(0, addresses) + self.assertEqual(len(c), 0) # all ports are exhausted, no candidates + + # Close one connection and try again - should always yield a candidate + await conns.pop().close() + conn = ice.Connection(ice_controlling=True, ephemeral_ports=ports) + c = await conn.get_component_candidates(0, addresses) + self.assertTrue(c[0].port >= lower and c[0].port < upper) + await conn.close() + + # cleanup + for conn in conns: + await conn.close() + + # Bind to wildcard local address - should always yield a candidate + conn = ice.Connection(ice_controlling=True) + c = await conn.get_component_candidates(0, [None]) + self.assertTrue(c[0].port >= 1 and c[0].port <= 65535) + await conn.close() class StunProtocolTest(unittest.TestCase): @asynctest