diff --git a/docs/conf.py b/docs/conf.py index 8fb3ac5ea..9e0f8b56d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -314,6 +314,11 @@ # Mocked ONLY for Sphinx/autodoc: this module does not exist in the codebase # but some doc tools may try to import it. No real code references this import. "libp2p.relay.circuit_v2.lib", + # aiortc is an optional dependency (install via libp2p[webrtc]). + # RTD doesn't have libsrtp2-dev so aiortc can't be installed there. + "aiortc", + "aiortc.rtcdtlstransport", + "aiortc.rtcconfiguration", ] # Documents to append as an appendix to all manuals. diff --git a/docs/libp2p.transport.rst b/docs/libp2p.transport.rst index e7cf3257a..176679d18 100644 --- a/docs/libp2p.transport.rst +++ b/docs/libp2p.transport.rst @@ -19,6 +19,11 @@ Subpackages libp2p.transport.websocket +.. toctree:: + :maxdepth: 4 + + libp2p.transport.webrtc + Submodules ---------- diff --git a/docs/libp2p.transport.webrtc.pb.rst b/docs/libp2p.transport.webrtc.pb.rst new file mode 100644 index 000000000..3f4570a8c --- /dev/null +++ b/docs/libp2p.transport.webrtc.pb.rst @@ -0,0 +1,21 @@ +libp2p.transport.webrtc.pb package +================================== + +Submodules +---------- + +libp2p.transport.webrtc.pb.webrtc\_pb2 module +--------------------------------------------- + +.. automodule:: libp2p.transport.webrtc.pb.webrtc_pb2 + :members: + :show-inheritance: + :undoc-members: + +Module contents +--------------- + +.. automodule:: libp2p.transport.webrtc.pb + :members: + :show-inheritance: + :undoc-members: diff --git a/docs/libp2p.transport.webrtc.rst b/docs/libp2p.transport.webrtc.rst new file mode 100644 index 000000000..aa853d1bc --- /dev/null +++ b/docs/libp2p.transport.webrtc.rst @@ -0,0 +1,134 @@ +libp2p.transport.webrtc package +=============================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + libp2p.transport.webrtc.pb + libp2p.transport.webrtc.signaling_pb + +Submodules +---------- + +libp2p.transport.webrtc.certificate module +------------------------------------------ + +.. automodule:: libp2p.transport.webrtc.certificate + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.config module +------------------------------------- + +.. automodule:: libp2p.transport.webrtc.config + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.connection module +----------------------------------------- + +.. automodule:: libp2p.transport.webrtc.connection + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.constants module +---------------------------------------- + +.. automodule:: libp2p.transport.webrtc.constants + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.exceptions module +----------------------------------------- + +.. automodule:: libp2p.transport.webrtc.exceptions + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.listener module +--------------------------------------- + +.. automodule:: libp2p.transport.webrtc.listener + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.multiaddr\_utils module +----------------------------------------------- + +.. automodule:: libp2p.transport.webrtc.multiaddr_utils + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.noise\_handshake module +----------------------------------------------- + +.. automodule:: libp2p.transport.webrtc.noise_handshake + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.private\_listener module +------------------------------------------------ + +.. automodule:: libp2p.transport.webrtc.private_listener + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.private\_transport module +------------------------------------------------- + +.. automodule:: libp2p.transport.webrtc.private_transport + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.sdp module +---------------------------------- + +.. automodule:: libp2p.transport.webrtc.sdp + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.signaling module +---------------------------------------- + +.. automodule:: libp2p.transport.webrtc.signaling + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.stream module +------------------------------------- + +.. automodule:: libp2p.transport.webrtc.stream + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.transport module +---------------------------------------- + +.. automodule:: libp2p.transport.webrtc.transport + :members: + :show-inheritance: + :undoc-members: + +Module contents +--------------- + +.. automodule:: libp2p.transport.webrtc + :members: + :show-inheritance: + :undoc-members: diff --git a/docs/libp2p.transport.webrtc.signaling_pb.rst b/docs/libp2p.transport.webrtc.signaling_pb.rst new file mode 100644 index 000000000..06a73bea5 --- /dev/null +++ b/docs/libp2p.transport.webrtc.signaling_pb.rst @@ -0,0 +1,21 @@ +libp2p.transport.webrtc.signaling\_pb package +============================================= + +Submodules +---------- + +libp2p.transport.webrtc.signaling\_pb.signaling\_pb2 module +----------------------------------------------------------- + +.. automodule:: libp2p.transport.webrtc.signaling_pb.signaling_pb2 + :members: + :show-inheritance: + :undoc-members: + +Module contents +--------------- + +.. automodule:: libp2p.transport.webrtc.signaling_pb + :members: + :show-inheritance: + :undoc-members: diff --git a/interop/transport/Dockerfile b/interop/transport/Dockerfile index c2e5e5d42..3dcdc1704 100644 --- a/interop/transport/Dockerfile +++ b/interop/transport/Dockerfile @@ -2,13 +2,14 @@ FROM python:3.13-slim WORKDIR /app -# Install system dependencies +# Install system dependencies (includes libsrtp2-dev for aiortc/WebRTC) RUN apt-get update && apt-get install -y \ redis-tools \ build-essential \ cmake \ pkg-config \ libgmp-dev \ + libsrtp2-dev \ git \ curl \ && rm -rf /var/lib/apt/lists/* diff --git a/interop/transport/ping_test.py b/interop/transport/ping_test.py index d1e809f87..10d3ba70e 100644 --- a/interop/transport/ping_test.py +++ b/interop/transport/ping_test.py @@ -172,7 +172,7 @@ def __init__(self, test_plans: bool = False) -> None: if not self.transport: raise ValueError("TRANSPORT environment variable is required") - standalone_transports = ["quic-v1"] + standalone_transports = ["quic-v1", "webrtc-direct"] self.muxer: str | None = None self.security: str | None = None @@ -242,11 +242,11 @@ def __init__(self, test_plans: bool = False) -> None: def validate_configuration(self) -> None: """Validate configuration parameters.""" - valid_transports = ["tcp", "ws", "wss", "quic-v1"] + valid_transports = ["tcp", "ws", "wss", "quic-v1", "webrtc-direct"] valid_security = ["noise", "plaintext", "tls"] valid_muxers = ["mplex", "yamux"] - # Standalone transports don't use separate security/muxer - standalone_transports = ["quic-v1"] + # Standalone transports have security + muxing built-in + standalone_transports = ["quic-v1", "webrtc-direct"] if self.transport not in valid_transports: raise ValueError( @@ -271,7 +271,7 @@ def create_security_options( """Create security options based on configuration.""" # Standalone transports (like quic-v1) have security built-in, # no separate security needed - standalone_transports = ["quic-v1"] + standalone_transports = ["quic-v1", "webrtc-direct"] if self.transport in standalone_transports: # For standalone transports, return empty security options # The security is handled by the transport itself @@ -310,7 +310,7 @@ def create_muxer_options(self) -> Any: """Create muxer options based on configuration.""" # Standalone transports (like quic-v1) have muxing built-in, # no separate muxer needed - standalone_transports = ["quic-v1"] + standalone_transports = ["quic-v1", "webrtc-direct"] if self.transport in standalone_transports: # For standalone transports, return None (no separate muxer) # The muxing is handled by the transport itself @@ -479,6 +479,17 @@ def _encapsulate_with_p2p( return addr.encapsulate(multiaddr.Multiaddr(f"/p2p/{p2p_value}")) return addr + def _build_webrtc_direct_addr( + self, ip_value: str, port: int + ) -> multiaddr.Multiaddr: + """Build WebRTC Direct address: /ip4|ip6/{ip}/udp/{port}/webrtc-direct.""" + is_ipv6 = ":" in ip_value + if is_ipv6: + base = multiaddr.Multiaddr(f"/ip6/{ip_value}/udp/{port}") + else: + base = multiaddr.Multiaddr(f"/ip4/{ip_value}/udp/{port}") + return base.encapsulate(multiaddr.Multiaddr("/webrtc-direct")) + def _build_quic_addr(self, ip_value: str, port: int) -> multiaddr.Multiaddr: """ Build QUIC address from IP and port. @@ -528,6 +539,27 @@ def create_listen_addresses(self, port: int = 0) -> list[multiaddr.Multiaddr]: return quic_addrs return [self._build_quic_addr("0.0.0.0", port)] + elif self.transport == "webrtc-direct": + # WebRTC Direct uses UDP like QUIC + webrtc_addrs = [] + for addr in base_addrs: + try: + ip_value = self._get_ip_value(addr) + tcp_port = addr.value_for_protocol("tcp") or port + if ip_value: + wrtc_addr = self._build_webrtc_direct_addr(ip_value, tcp_port) + _, p2p_value = self._extract_and_preserve_p2p(addr) + wrtc_addr = self._encapsulate_with_p2p(wrtc_addr, p2p_value) + webrtc_addrs.append(wrtc_addr) + except Exception as e: + print( + f"Error building webrtc-direct addr from {addr}: {e}", + file=sys.stderr, + ) + if webrtc_addrs: + return webrtc_addrs + return [self._build_webrtc_direct_addr("0.0.0.0", port)] + elif self.transport == "ws": # Add /ws protocol to TCP addresses # WebSocket addresses are used for both WS and WSS transports @@ -737,8 +769,10 @@ def _filter_addresses_by_transport( filtered.append(addr) elif self.transport == "quic-v1" and "quic-v1" in protocols: filtered.append(addr) + elif self.transport == "webrtc-direct" and "webrtc-direct" in protocols: + filtered.append(addr) elif self.transport == "tcp" and not any( - p in protocols for p in ["ws", "wss", "quic-v1"] + p in protocols for p in ["ws", "wss", "quic-v1", "webrtc-direct"] ): filtered.append(addr) return filtered if filtered else addresses @@ -859,6 +893,7 @@ async def run_listener(self) -> None: muxer_opt=muxer_opt, listen_addrs=listen_addrs, enable_quic=(self.transport == "quic-v1"), + enable_webrtc=(self.transport == "webrtc-direct"), tls_client_config=tls_client_config, tls_server_config=tls_server_config, ) @@ -1191,6 +1226,7 @@ async def run_dialer(self) -> None: "sec_opt": sec_opt, "muxer_opt": muxer_opt, "enable_quic": (self.transport == "quic-v1"), + "enable_webrtc": (self.transport == "webrtc-direct"), "tls_client_config": tls_client_config, "tls_server_config": tls_server_config, } diff --git a/interop/transport/pyproject.toml b/interop/transport/pyproject.toml index 9adf64b84..45ba33ccf 100644 --- a/interop/transport/pyproject.toml +++ b/interop/transport/pyproject.toml @@ -12,7 +12,7 @@ authors = [ readme = "README.md" requires-python = ">=3.11" dependencies = [ - "libp2p @ file:///app/py-libp2p", # local libp2p dependency is a snapshot based on ${commitSha} in Makefile + "libp2p[webrtc] @ file:///app/py-libp2p", # local libp2p with WebRTC extra "redis>=4.0.0", "typing-extensions>=4.0.0", "cryptography>=41.0.0", # Required for TLS/WSS support diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 90e57d6f2..a010ccde8 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -287,6 +287,7 @@ def new_swarm( muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, enable_quic: bool = False, + enable_webrtc: bool = False, enable_autotls: bool = False, retry_config: RetryConfig | None = None, connection_config: ConnectionConfig | QUICTransportConfig | None = None, @@ -378,6 +379,13 @@ def new_swarm( enable_autotls=enable_autotls, ) + # If enable_webrtc is True, force WebRTC Direct transport + if enable_webrtc: + from libp2p.transport.webrtc.transport import WebRTCDirectTransport + + logger.debug("new_swarm: Creating WebRTC Direct transport") + transport = WebRTCDirectTransport(private_key=key_pair.private_key) + logger.debug(f"new_swarm: Final transport type: {type(transport)}") # Generate X25519 keypair for Noise @@ -471,6 +479,7 @@ def new_host( bootstrap: list[str] | None = None, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, enable_quic: bool = False, + enable_webrtc: bool = False, quic_transport_opt: QUICTransportConfig | None = None, tls_client_config: ssl.SSLContext | None = None, tls_server_config: ssl.SSLContext | None = None, @@ -534,6 +543,7 @@ def new_host( swarm = new_swarm( enable_quic=enable_quic, + enable_webrtc=enable_webrtc, key_pair=key_pair, muxer_opt=muxer_opt, sec_opt=sec_opt, diff --git a/libp2p/abc.py b/libp2p/abc.py index f07a0fedd..dd22992ff 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -3002,6 +3002,11 @@ class ITransport(ABC): """ + # Transports that provide their own stream multiplexing (QUIC, WebRTC) + # override this to True. The swarm skips the TransportUpgrader for + # these transports and passes the connection directly to add_conn(). + provides_native_muxing: bool = False + @abstractmethod async def dial(self, maddr: Multiaddr) -> IRawConnection: """ diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 74533250c..deb808ad2 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -59,7 +59,6 @@ ) from libp2p.transport.quic.config import QUICTransportConfig from libp2p.transport.quic.connection import QUICConnection -from libp2p.transport.quic.transport import QUICTransport from libp2p.transport.upgrader import ( TransportUpgrader, ) @@ -215,9 +214,9 @@ async def run(self) -> None: # Set background nursery BEFORE setting the event # This ensures transports have the nursery when they check - if isinstance(self.transport, QUICTransport): - self.transport.set_background_nursery(nursery) - self.transport.set_swarm(self) + if hasattr(self.transport, "set_swarm"): + self.transport.set_background_nursery(nursery) # type: ignore[attr-defined] + self.transport.set_swarm(self) # type: ignore[attr-defined] elif hasattr(self.transport, "set_background_nursery"): # WebSocket transport also needs background nursery # for connection management @@ -670,11 +669,12 @@ async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetC pass raise SwarmException(f"Unexpected error dialing peer {peer_id}") from e - if isinstance(self.transport, QUICTransport) and isinstance( + if getattr(self.transport, "provides_native_muxing", False) and isinstance( raw_conn, IMuxedConn ): logger.info( - "Skipping upgrade for QUIC, QUIC connections are already multiplexed" + "Skipping upgrade for native-mux transport " + "(connection already multiplexed)" ) try: swarm_conn = await self.add_conn(raw_conn, direction="outbound") @@ -956,19 +956,17 @@ async def _open_stream_on_connection( peer_id: ID, ) -> INetStream: """Try to open a stream on *connection*, falling back to alternatives.""" - if isinstance(self.transport, QUICTransport) and connection is not None: - conn = cast("SwarmConn", connection) - try: - stream = await conn.new_stream() - logger.debug("successfully opened a stream to peer %s", peer_id) - return stream - except Exception: - raise - try: - net_stream = await connection.new_stream() + if ( + getattr(self.transport, "provides_native_muxing", False) + and connection is not None + ): + conn = cast("SwarmConn", connection) + stream = await conn.new_stream() + else: + stream = await connection.new_stream() # type: ignore[assignment] logger.debug("successfully opened a stream to peer %s", peer_id) - return net_stream + return stream except Exception as e: logger.debug(f"Failed to create stream on connection: {e}") @@ -1167,14 +1165,15 @@ async def conn_handler( pass return - # No need to upgrade QUIC Connection - if isinstance(self.transport, QUICTransport): + # No need to upgrade native-mux connections (QUIC, WebRTC) + if getattr(self.transport, "provides_native_muxing", False): try: - quic_conn = cast(QUICConnection, read_write_closer) - await self.add_conn(quic_conn, direction="inbound") - peer_id = quic_conn.peer_id + muxed_conn = cast(IMuxedConn, read_write_closer) + await self.add_conn(muxed_conn, direction="inbound") + peer_id = muxed_conn.peer_id logger.debug( - f"successfully opened quic connection to peer {peer_id}" + "successfully opened native-mux connection to peer %s", + peer_id, ) # NOTE: This is a intentional barrier to prevent from the # handler exiting and closing the connection. diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index db130cee3..c37046795 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -128,13 +128,15 @@ class BasePattern(IPattern): libp2p_privkey: PrivateKey early_data: bytes | None - def create_noise_state(self) -> NoiseState: + def create_noise_state(self, prologue: bytes | None = None) -> NoiseState: noise_state = NoiseState.from_name(self.protocol_name) noise_state.set_keypair_from_private_bytes( NoiseKeypairEnum.STATIC, self.noise_static_key.to_bytes() ) if noise_state.noise_protocol is None: raise NoiseStateError("noise_protocol is not initialized") + if prologue is not None: + noise_state.noise_protocol.prologue = prologue return noise_state def _validate_noise_static_key(self) -> X25519PublicKey: @@ -205,16 +207,18 @@ def __init__( libp2p_privkey: PrivateKey, noise_static_key: PrivateKey, early_data: bytes | None = None, + prologue: bytes | None = None, ) -> None: self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256" self.local_peer = local_peer self.libp2p_privkey = libp2p_privkey self.noise_static_key = noise_static_key self.early_data = early_data + self.prologue = prologue async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}") - noise_state = self.create_noise_state() + noise_state = self.create_noise_state(prologue=self.prologue) noise_state.set_as_responder() noise_state.start_handshake() if noise_state.noise_protocol is None: @@ -278,7 +282,7 @@ async def handshake_outbound( self, conn: IRawConnection, remote_peer: ID ) -> ISecureConn: logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}") - noise_state = self.create_noise_state() + noise_state = self.create_noise_state(prologue=self.prologue) read_writer = NoiseHandshakeReadWriter(conn, noise_state) noise_state.set_as_initiator() diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py index 8c6167f3b..68e2ad543 100644 --- a/libp2p/transport/quic/transport.py +++ b/libp2p/transport/quic/transport.py @@ -73,6 +73,8 @@ class QUICTransport(ITransport): QUIC Stream implementation following libp2p IMuxedStream interface. """ + provides_native_muxing: bool = True + def __init__( self, private_key: PrivateKey, diff --git a/libp2p/transport/transport_registry.py b/libp2p/transport/transport_registry.py index 3cbef4c70..edac9d4eb 100644 --- a/libp2p/transport/transport_registry.py +++ b/libp2p/transport/transport_registry.py @@ -37,6 +37,18 @@ def _get_websocket_transport() -> Any: return WebsocketTransport +def _get_webrtc_direct_transport() -> Any: + from libp2p.transport.webrtc.transport import WebRTCDirectTransport + + return WebRTCDirectTransport + + +def _get_webrtc_private_transport() -> Any: + from libp2p.transport.webrtc.private_transport import WebRTCPrivateTransport + + return WebRTCPrivateTransport + + logger = logging.getLogger(__name__) @@ -104,6 +116,25 @@ def _register_default_transports(self) -> None: self.register_transport("quic", QUICTransport) self.register_transport("quic-v1", QUICTransport) + # Register WebRTC transports only when aiortc is actually installed. + # The scaffolding modules themselves do not import aiortc (aiortc is + # loaded lazily inside the bridge), so we probe for it explicitly. + import importlib.util as _importlib_util + + if _importlib_util.find_spec("aiortc") is not None: + try: + WebRTCDirectTransport = _get_webrtc_direct_transport() + self.register_transport("webrtc-direct", WebRTCDirectTransport) + WebRTCPrivateTransport = _get_webrtc_private_transport() + self.register_transport("webrtc", WebRTCPrivateTransport) + except ImportError as e: + logger.debug("aiortc present but WebRTC transport import failed: %s", e) + else: + logger.debug( + "aiortc not installed; skipping /webrtc and /webrtc-direct " + "transport registration (install libp2p[webrtc] to enable)" + ) + def register_transport( self, protocol: str, transport_class: type[ITransport] ) -> None: @@ -192,6 +223,27 @@ def create_transport( return QUICTransport( private_key, config=config, enable_autotls=enable_autotls ) + elif protocol in ["webrtc-direct", "webrtc"]: + # WebRTC transports require a private key for the local peer + # identity used in the Noise XX handshake. The transport + # classes are loaded lazily; mypy can't see the concrete + # signature here, so we cast the call. + private_key = kwargs.get("private_key") + if private_key is None: + logger.warning( + "WebRTC transport '%s' requires private_key", protocol + ) + return None + config = kwargs.get("config") + if protocol == "webrtc-direct": + return transport_class( # type: ignore[call-arg] + private_key=private_key, config=config + ) + # private-to-private also accepts an optional host + host = kwargs.get("host") + return transport_class( # type: ignore[call-arg] + private_key=private_key, host=host, config=config + ) else: # TCP transport doesn't require upgrader return transport_class() @@ -236,7 +288,22 @@ def create_transport_for_multiaddr( # Check for supported transport protocols in order of preference # We need to validate that the multiaddr structure is valid for our transports - if "quic" in protocols or "quic-v1" in protocols: + if "webrtc-direct" in protocols or "webrtc" in protocols: + # WebRTC Direct: /ip4//udp//webrtc-direct/... + # WebRTC (relayed): /p2p-circuit/webrtc/... + # Both are only routable when the corresponding transport is + # registered (which only happens when aiortc is installed). + registry = get_transport_registry() + proto = "webrtc-direct" if "webrtc-direct" in protocols else "webrtc" + if proto in registry.get_supported_protocols(): + return registry.create_transport(proto, upgrader, **kwargs) + logger.warning( + "Multiaddr requires the WebRTC transport (%s) but it is not " + "registered. Install libp2p[webrtc] to enable WebRTC support.", + proto, + ) + return None + elif "quic" in protocols or "quic-v1" in protocols: # For QUIC, we need a valid structure like: # /ip4/127.0.0.1/udp/4001/quic # /ip4/127.0.0.1/udp/4001/quic-v1 diff --git a/libp2p/transport/webrtc/__init__.py b/libp2p/transport/webrtc/__init__.py new file mode 100644 index 000000000..1bc71b87e --- /dev/null +++ b/libp2p/transport/webrtc/__init__.py @@ -0,0 +1,79 @@ +""" +WebRTC transport for libp2p. + +Provides two transport variants per the libp2p WebRTC specification: + +- **WebRTC Direct** (``/webrtc-direct``): Server-to-browser or server-to-server + connections where the server publishes its certificate hash in the multiaddr. + No relay or signaling server is required. + +- **WebRTC** (``/webrtc``): Private-to-private connections where both peers are + behind NAT. Uses Circuit Relay v2 for signaling, then upgrades to a direct + WebRTC data-channel connection. + +Both variants use Noise XX over data-channel 0 for authentication and rely on +WebRTC data channels for native stream multiplexing (no Yamux/Mplex needed). + +Spec: https://github.com/libp2p/specs/tree/master/webrtc +""" + +from libp2p.transport.webrtc.certificate import ( + WebRTCCertificate, +) +from libp2p.transport.webrtc.constants import ( + ACCEPT_QUEUE_SIZE, + CERTHASH_PROTOCOL_CODE, + ICE_DISCONNECTION_TIMEOUT, + ICE_FAILURE_TIMEOUT, + ICE_KEEPALIVE_INTERVAL, + INBOUND_STREAM_START_ID, + MAX_DATA_CHANNELS, + MAX_IN_FLIGHT_CONNECTIONS, + MAX_MESSAGE_SIZE, + NOISE_HANDSHAKE_CHANNEL_ID, + NOISE_PROLOGUE_PREFIX, + OUTBOUND_STREAM_START_ID, + RECOMMENDED_PAYLOAD_SIZE, + WEBRTC_DIRECT_PROTOCOL_CODE, + WEBRTC_PROTOCOL_CODE, + WEBRTC_SIGNALING_PROTOCOL_ID, +) +from libp2p.transport.webrtc.exceptions import ( + WebRTCCertificateError, + WebRTCConnectionError, + WebRTCError, + WebRTCHandshakeError, + WebRTCMultiaddrError, + WebRTCSignalingError, + WebRTCStreamError, +) + +__all__ = [ + # Constants + "ACCEPT_QUEUE_SIZE", + "CERTHASH_PROTOCOL_CODE", + "ICE_DISCONNECTION_TIMEOUT", + "ICE_FAILURE_TIMEOUT", + "ICE_KEEPALIVE_INTERVAL", + "INBOUND_STREAM_START_ID", + "MAX_DATA_CHANNELS", + "MAX_IN_FLIGHT_CONNECTIONS", + "MAX_MESSAGE_SIZE", + "NOISE_HANDSHAKE_CHANNEL_ID", + "NOISE_PROLOGUE_PREFIX", + "OUTBOUND_STREAM_START_ID", + "RECOMMENDED_PAYLOAD_SIZE", + "WEBRTC_DIRECT_PROTOCOL_CODE", + "WEBRTC_PROTOCOL_CODE", + "WEBRTC_SIGNALING_PROTOCOL_ID", + # Certificate + "WebRTCCertificate", + # Exceptions + "WebRTCCertificateError", + "WebRTCConnectionError", + "WebRTCError", + "WebRTCHandshakeError", + "WebRTCMultiaddrError", + "WebRTCSignalingError", + "WebRTCStreamError", +] diff --git a/libp2p/transport/webrtc/_aiortc_helpers.py b/libp2p/transport/webrtc/_aiortc_helpers.py new file mode 100644 index 000000000..80223add8 --- /dev/null +++ b/libp2p/transport/webrtc/_aiortc_helpers.py @@ -0,0 +1,371 @@ +""" +aiortc integration helpers for the WebRTC transport. + +All direct ``aiortc`` imports are isolated here so the rest of the +``libp2p.transport.webrtc`` package remains aiortc-free and testable +without the optional dependency. + +Functions in this module run on the **asyncio** event loop (via +:class:`AsyncioBridge`). They should never be called directly from +trio code — always go through ``bridge.run_coro(...)``. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging +from typing import TYPE_CHECKING, Any + +from aiortc import ( + RTCConfiguration, + RTCPeerConnection, +) +from aiortc.rtcdtlstransport import RTCCertificate + +if TYPE_CHECKING: + from .connection import WebRTCConnection + +logger = logging.getLogger(__name__) + +# Timeout for ICE connection establishment (seconds). +_ICE_CONNECT_TIMEOUT = 30.0 + +# Timeout for HTTP SDP exchange (seconds). +_SDP_HTTP_TIMEOUT = 15.0 + + +# ------------------------------------------------------------------ +# Peer-connection lifecycle +# ------------------------------------------------------------------ + + +async def create_peer_connection( + rtc_cert: RTCCertificate, + ice_servers: list[str] | None = None, +) -> RTCPeerConnection: + """ + Create an ``RTCPeerConnection`` with the given certificate. + + Must run on the asyncio bridge loop because aiortc's + ``RTCPeerConnection.__init__`` calls ``asyncio.get_event_loop()`` + internally to schedule ICE initialization. + + :param rtc_cert: An aiortc certificate + (from ``WebRTCCertificate._rtc_certificate``). + :param ice_servers: Optional STUN/TURN server URLs. + :returns: A new peer connection. + """ + config = RTCConfiguration(certificates=[rtc_cert]) # type: ignore[call-arg] + return RTCPeerConnection(configuration=config) + + +async def create_noise_channel(pc: RTCPeerConnection) -> Any: + """ + Create a negotiated data channel with ID 0 for the Noise handshake. + + Both sides must call this with the same ``id`` and ``negotiated=True`` + so the channel is available without SDP renegotiation. + """ + return pc.createDataChannel("noise", negotiated=True, id=0) + + +async def wait_for_connected( + pc: RTCPeerConnection, + timeout: float = _ICE_CONNECT_TIMEOUT, +) -> None: + """ + Wait until the peer connection reaches the ``connected`` state. + + :raises TimeoutError: If the connection doesn't complete in time. + :raises ConnectionError: If the connection enters ``failed`` or ``closed``. + """ + if pc.connectionState == "connected": + return + + connected = asyncio.Event() + failed = asyncio.Event() + + @pc.on("connectionstatechange") # type: ignore[misc,untyped-decorator] + def _on_state_change() -> None: + state = pc.connectionState + logger.debug("ICE connection state: %s", state) + if state == "connected": + connected.set() + elif state in ("failed", "closed"): + failed.set() + + try: + done, pending = await asyncio.wait( + [ + asyncio.ensure_future(_event_wait(connected)), + asyncio.ensure_future(_event_wait(failed)), + ], + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + # Cancel any pending tasks to avoid leaks on the asyncio loop. + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + if not done: + raise TimeoutError(f"ICE connection did not complete within {timeout}s") + if failed.is_set(): + raise ConnectionError(f"ICE connection failed (state={pc.connectionState})") + finally: + # Clean up the listener to avoid leaks + pc.remove_all_listeners("connectionstatechange") + + +async def _event_wait(event: asyncio.Event) -> None: + """Thin wrapper so ``asyncio.wait`` can track an Event.""" + await event.wait() + + +def get_remote_fingerprint(pc: RTCPeerConnection) -> bytes: + """ + Extract the remote DTLS certificate's SHA-256 fingerprint. + + Must be called after the DTLS handshake completes (i.e. after ICE + reaches ``connected``). + + :returns: 32-byte SHA-256 digest. + :raises ValueError: If the remote certificate is not available. + """ + dtls = getattr(pc, "_dtlsTransport", None) + if dtls is None: + raise ValueError("DTLS transport not available on peer connection") + remote_cert = getattr(dtls, "_remote_certificate", None) + if remote_cert is None: + raise ValueError("Remote DTLS certificate not available") + # remote_cert is a cryptography x509.Certificate + from cryptography.hazmat.primitives.serialization import Encoding + + der = remote_cert.public_bytes(Encoding.DER) + return hashlib.sha256(der).digest() + + +# ------------------------------------------------------------------ +# Callback wiring +# ------------------------------------------------------------------ + + +def wire_pc_to_connection( + pc: RTCPeerConnection, + conn: WebRTCConnection, +) -> None: + """ + Wire aiortc callbacks to a :class:`WebRTCConnection`. + + Sets the three callback slots and registers ``on_datachannel`` / + ``on_message`` / ``on_close`` event handlers that route into the + connection's thread-safe methods. + """ + # Track open channels so _send_on_channel_cb can find them. + channels: dict[int, Any] = {} + + async def _create_channel(channel_id: int, label: str) -> None: + ch = pc.createDataChannel(label or "", negotiated=True, id=channel_id) + channels[channel_id] = ch + _bind_channel_events(ch, channel_id, conn) + + async def _send_on_channel(channel_id: int, data: bytes) -> None: + ch = channels.get(channel_id) + if ch is not None: + ch.send(data) + + async def _close_pc() -> None: + await pc.close() + + conn._create_channel_cb = _create_channel + conn._send_on_channel_cb = _send_on_channel + conn._close_pc_cb = _close_pc + + @pc.on("datachannel") # type: ignore[misc,untyped-decorator] + def _on_datachannel(channel: Any) -> None: + ch_id = channel.id if channel.id is not None else len(channels) + channels[ch_id] = channel + conn.on_datachannel(ch_id) + _bind_channel_events(channel, ch_id, conn) + + +def _bind_channel_events( + channel: Any, + channel_id: int, + conn: WebRTCConnection, +) -> None: + """Bind message/close events on a single data channel.""" + + @channel.on("message") # type: ignore[misc,untyped-decorator] + def _on_message(message: str | bytes) -> None: + data = message if isinstance(message, bytes) else message.encode() + conn.on_channel_message(channel_id, data) + + @channel.on("close") # type: ignore[misc,untyped-decorator] + def _on_close() -> None: + conn.on_channel_closed(channel_id) + + +# ------------------------------------------------------------------ +# Noise-channel helpers +# ------------------------------------------------------------------ + + +def make_noise_channel_callbacks( + channel: Any, +) -> tuple[Any, Any, asyncio.Queue[bytes]]: + """ + Wire a data channel for the Noise handshake. + + :returns: ``(send_fn, recv_fn, recv_queue)`` — async callables for + sending/receiving bytes, and the underlying queue. + """ + recv_queue: asyncio.Queue[bytes] = asyncio.Queue() + + @channel.on("message") # type: ignore[misc,untyped-decorator] + def _on_noise_msg(message: str | bytes) -> None: + data = message if isinstance(message, bytes) else message.encode() + recv_queue.put_nowait(data) + + async def send(data: bytes) -> None: + channel.send(data) + + async def recv() -> bytes: + return await recv_queue.get() + + return send, recv, recv_queue + + +# ------------------------------------------------------------------ +# HTTP-based SDP signaling (raw asyncio, no aiohttp dependency) +# ------------------------------------------------------------------ + + +async def run_signaling_server( + host: str, + port: int, + on_offer: Any, # async (offer_sdp: str) -> str (answer_sdp) +) -> asyncio.Server: + """ + Start a minimal HTTP server that accepts SDP offers via ``POST /sdp``. + + :param host: Bind address. + :param port: Bind port (TCP). + :param on_offer: Async callback ``(offer_sdp) -> answer_sdp``. + :returns: The running :class:`asyncio.Server`. + """ + + async def _handle( + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + try: + # Read HTTP request line (consumed but not used) + headers + await asyncio.wait_for(reader.readline(), timeout=_SDP_HTTP_TIMEOUT) + headers: dict[str, str] = {} + while True: + line = await asyncio.wait_for( + reader.readline(), timeout=_SDP_HTTP_TIMEOUT + ) + if line in (b"\r\n", b"\n", b""): + break + key, _, value = line.decode().partition(":") + headers[key.strip().lower()] = value.strip() + + content_length = int(headers.get("content-length", "0")) + body = b"" + if content_length > 0: + body = await asyncio.wait_for( + reader.readexactly(content_length), + timeout=_SDP_HTTP_TIMEOUT, + ) + + # Process: call the offer handler, get answer SDP + answer_sdp = await on_offer(body.decode()) + answer_bytes = answer_sdp.encode() + + # Send HTTP response + response = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/sdp\r\n" + b"Content-Length: " + str(len(answer_bytes)).encode() + b"\r\n" + b"\r\n" + ) + answer_bytes + writer.write(response) + await writer.drain() + except Exception as e: + logger.debug("Signaling server error: %s", e) + try: + writer.write(b"HTTP/1.1 500 Internal Server Error\r\n\r\n") + await writer.drain() + except Exception: + pass + finally: + writer.close() + + server = await asyncio.start_server(_handle, host, port) + logger.info("WebRTC signaling HTTP server listening on %s:%d", host, port) + return server + + +async def post_sdp( + host: str, + port: int, + offer_sdp: str, + timeout: float = _SDP_HTTP_TIMEOUT, +) -> str: + """ + POST an SDP offer to a WebRTC Direct listener and return the answer. + + :param host: Listener IP address. + :param port: Listener TCP port (same number as the UDP port in the multiaddr). + :param offer_sdp: The SDP offer string. + :param timeout: HTTP timeout in seconds. + :returns: The SDP answer string. + :raises ConnectionError: If the HTTP exchange fails. + """ + offer_bytes = offer_sdp.encode() + request = ( + b"POST /sdp HTTP/1.1\r\n" + b"Host: " + f"{host}:{port}".encode() + b"\r\n" + b"Content-Type: application/sdp\r\n" + b"Content-Length: " + str(len(offer_bytes)).encode() + b"\r\n" + b"\r\n" + ) + offer_bytes + + try: + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port), timeout=timeout + ) + writer.write(request) + await writer.drain() + + # Read response status line + status_line = await asyncio.wait_for(reader.readline(), timeout=timeout) + if b"200" not in status_line: + raise ConnectionError( + f"SDP exchange failed: {status_line.decode().strip()}" + ) + + # Read headers + headers: dict[str, str] = {} + while True: + line = await asyncio.wait_for(reader.readline(), timeout=timeout) + if line in (b"\r\n", b"\n", b""): + break + key, _, value = line.decode().partition(":") + headers[key.strip().lower()] = value.strip() + + content_length = int(headers.get("content-length", "0")) + body = await asyncio.wait_for( + reader.readexactly(content_length), timeout=timeout + ) + writer.close() + return body.decode() + except asyncio.TimeoutError as e: + raise ConnectionError(f"SDP exchange timed out after {timeout}s") from e + except Exception as e: + raise ConnectionError(f"SDP exchange failed: {e}") from e diff --git a/libp2p/transport/webrtc/_asyncio_bridge.py b/libp2p/transport/webrtc/_asyncio_bridge.py new file mode 100644 index 000000000..ad57f0c36 --- /dev/null +++ b/libp2p/transport/webrtc/_asyncio_bridge.py @@ -0,0 +1,304 @@ +""" +trio ↔ asyncio bridge for aiortc. + +Runs a single asyncio event loop in a background daemon thread. Trio code +schedules asyncio coroutines onto it via :meth:`AsyncioBridge.run_coro` and +gets back the result (or exception) through ``trio.to_thread.run_sync``. + +Design constraints +------------------ +- **One bridge per WebRTCTransport** — shared across all connections. +- **All aiortc calls go through run_coro()** — no direct asyncio usage elsewhere. +- **Cancellation propagates** — trio cancellation cancels the asyncio future. +- **No monkey-patching** — aiortc's public API only. + +Why a background thread? +~~~~~~~~~~~~~~~~~~~~~~~~ +aiortc is built entirely on asyncio. py-libp2p is built on trio. The two +event loops are incompatible: you cannot ``await`` an asyncio coroutine inside +trio. The cleanest boundary is a dedicated asyncio loop in its own thread: + + trio task ──run_coro(coro)──► asyncio loop (background thread) + ◄──result/error──── + +``trio.to_thread.run_sync`` blocks the trio task (while yielding to other trio +tasks) until the asyncio future completes. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Coroutine +import concurrent.futures +import logging +import threading +from typing import Any, TypeVar + +import trio + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class AsyncioBridgeError(Exception): + """Raised when the bridge is used incorrectly (not started, already stopped).""" + + +class AsyncioBridge: + """ + Manages a background asyncio event loop for running aiortc coroutines. + + Usage:: + + bridge = AsyncioBridge() + await bridge.start() + try: + result = await bridge.run_coro(some_asyncio_coro()) + finally: + await bridge.stop() + + Or as an async context manager:: + + async with AsyncioBridge() as bridge: + result = await bridge.run_coro(some_asyncio_coro()) + """ + + def __init__(self) -> None: + self._loop: asyncio.AbstractEventLoop | None = None + self._thread: threading.Thread | None = None + self._started = False + self._stopped = False + # Protects start/stop against concurrent trio calls + self._lock = trio.Lock() + # Protects _loop/_started/_stopped reads from run_coro against + # concurrent stop() — needed because run_coro_sync is called + # from non-trio threads (asyncio callbacks). + self._state_lock = threading.Lock() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def start(self) -> None: + """ + Start the background asyncio event loop. + + Safe to call multiple times — subsequent calls are no-ops if already + running. + + :raises AsyncioBridgeError: If the bridge was previously stopped. + """ + async with self._lock: + if self._stopped: + raise AsyncioBridgeError( + "Cannot restart a stopped AsyncioBridge — create a new instance" + ) + if self._started: + return + + loop = asyncio.new_event_loop() + ready = threading.Event() + + def _run_loop() -> None: + asyncio.set_event_loop(loop) + ready.set() + loop.run_forever() + + thread = threading.Thread( + target=_run_loop, + name="asyncio-bridge", + daemon=True, + ) + thread.start() + # Wait for the loop to be running before returning + await trio.to_thread.run_sync(ready.wait) + + with self._state_lock: + self._loop = loop + self._thread = thread + self._started = True + logger.debug("AsyncioBridge started (thread=%s)", thread.name) + + async def stop(self) -> None: + """ + Shut down the background event loop and join the thread. + + Cancels all pending asyncio tasks before stopping. Safe to call + multiple times — subsequent calls are no-ops. + """ + async with self._lock: + if not self._started or self._stopped: + return + + loop = self._loop + thread = self._thread + assert loop is not None + assert thread is not None + + with self._state_lock: + self._stopped = True + + # Cancel all remaining tasks on the asyncio loop + async def _cancel_all() -> None: + tasks = [ + t + for t in asyncio.all_tasks(loop) + if not t.done() and t is not asyncio.current_task() + ] + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + try: + future = asyncio.run_coroutine_threadsafe(_cancel_all(), loop) + await trio.to_thread.run_sync(lambda: future.result(timeout=5.0)) + except Exception: + logger.debug("Error during task cancellation", exc_info=True) + + assert loop is not None # narrowing for pyrefly + _loop_ref = loop # bind to local so the lambda captures a non-None + loop.call_soon_threadsafe(lambda: _loop_ref.stop()) + + def _join_thread() -> None: + assert thread is not None # narrowing for pyrefly + thread.join(timeout=5.0) + if thread.is_alive(): + logger.warning("AsyncioBridge thread did not stop within timeout") + + await trio.to_thread.run_sync(_join_thread) + + with self._state_lock: + self._loop = None + self._thread = None + logger.debug("AsyncioBridge stopped") + + # ------------------------------------------------------------------ + # Core API + # ------------------------------------------------------------------ + + async def run_coro(self, coro: Coroutine[Any, Any, T]) -> T: + """ + Schedule an asyncio coroutine on the background loop and return its + result to the calling trio task. + + :param coro: An asyncio coroutine (not yet awaited). + :returns: The coroutine's return value. + :raises AsyncioBridgeError: If the bridge is not running. + :raises Exception: Any exception raised by the coroutine is re-raised + in the trio task. + """ + with self._state_lock: + loop = self._loop + running = self._started and not self._stopped + if loop is None or not running: + coro.close() + raise AsyncioBridgeError("AsyncioBridge is not running") + + future = asyncio.run_coroutine_threadsafe(coro, loop) + + def _wait_for_result() -> T: + return future.result() + + try: + # abandon_on_cancel=True lets trio cancel the scope immediately. + # The background thread is abandoned but we cancel the asyncio + # future below so it doesn't leak. The keyword is supported at + # runtime by trio>=0.22 even though the stubs don't declare it. + return await trio.to_thread.run_sync( # type: ignore[call-arg] + _wait_for_result, abandon_on_cancel=True + ) + except trio.Cancelled: + # Trio scope was cancelled — propagate to the asyncio side + future.cancel() + raise + + def run_coro_sync(self, coro: Coroutine[Any, Any, T]) -> T: + """ + Schedule an asyncio coroutine from a synchronous (non-trio) context. + + Useful for callbacks invoked by aiortc from the asyncio thread that + need to schedule additional asyncio work. + + :param coro: An asyncio coroutine. + :returns: The coroutine's return value. + :raises AsyncioBridgeError: If the bridge is not running. + """ + with self._state_lock: + loop = self._loop + running = self._started and not self._stopped + if loop is None or not running: + coro.close() + raise AsyncioBridgeError("AsyncioBridge is not running") + + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result() + + def schedule_fire_and_forget(self, coro: Coroutine[Any, Any, Any]) -> None: + """ + Schedule an asyncio coroutine without waiting for the result. + + Exceptions are logged but not raised. Useful for cleanup tasks + or event-driven callbacks from aiortc. + + :param coro: An asyncio coroutine. + """ + with self._state_lock: + loop = self._loop + running = self._started and not self._stopped + if loop is None or not running: + coro.close() + return + + def _done_callback(fut: concurrent.futures.Future[Any]) -> None: + exc = fut.exception() + if exc is not None: + logger.debug("Fire-and-forget coroutine failed: %s", exc) + + future = asyncio.run_coroutine_threadsafe(coro, loop) + future.add_done_callback(_done_callback) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def is_running(self) -> bool: + """True if the bridge is started and not yet stopped.""" + return self._started and not self._stopped + + @property + def loop(self) -> asyncio.AbstractEventLoop | None: + """ + The underlying asyncio event loop, or None if not started. + + Exposed for advanced use cases (e.g. creating asyncio.Futures + directly). Prefer :meth:`run_coro` for normal use. + """ + return self._loop + + # ------------------------------------------------------------------ + # Context manager + # ------------------------------------------------------------------ + + async def __aenter__(self) -> AsyncioBridge: + await self.start() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object | None, + ) -> None: + await self.stop() + + def __repr__(self) -> str: + if self.is_running: + state = "running" + elif self._stopped: + state = "stopped" + else: + state = "idle" + return f"" diff --git a/libp2p/transport/webrtc/certificate.py b/libp2p/transport/webrtc/certificate.py new file mode 100644 index 000000000..8befb7b1d --- /dev/null +++ b/libp2p/transport/webrtc/certificate.py @@ -0,0 +1,246 @@ +""" +WebRTC certificate utilities. + +Generates ECDSA P-256 self-signed certificates for WebRTC DTLS and computes +SHA-256 fingerprints encoded as multihash/multibase for embedding in +``/webrtc-direct/certhash/`` multiaddrs. + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc-direct.md +""" + +from __future__ import annotations + +import base64 +from datetime import datetime, timedelta, timezone +import hashlib +import logging +import struct + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.x509.oid import NameOID + +from .exceptions import WebRTCCertificateError + +logger = logging.getLogger(__name__) + +# SHA-256 multicodec code per https://github.com/multiformats/multicodec +_SHA256_MULTIHASH_CODE = 0x12 +_SHA256_DIGEST_SIZE = 32 + +# Multibase base64url prefix +_MULTIBASE_BASE64URL_PREFIX = "u" + +# Certificate defaults +_CERTIFICATE_VALIDITY_DAYS = 14 + + +class WebRTCCertificate: + """ + Holds an ECDSA P-256 certificate and its SHA-256 fingerprint for WebRTC. + + Use :meth:`generate` to create a fresh self-signed certificate, or + :meth:`from_existing` to wrap an already-created certificate/key pair. + """ + + def __init__( + self, + certificate: x509.Certificate, + private_key: ec.EllipticCurvePrivateKey, + ) -> None: + self.certificate = certificate + self.private_key = private_key + # Pre-compute fingerprint (SHA-256 of DER-encoded certificate) + der_bytes = certificate.public_bytes(serialization.Encoding.DER) + self._fingerprint = hashlib.sha256(der_bytes).digest() + + @classmethod + def generate( + cls, + common_name: str = "libp2p-webrtc", + validity_days: int = _CERTIFICATE_VALIDITY_DAYS, + ) -> WebRTCCertificate: + """ + Generate a fresh ECDSA P-256 self-signed certificate. + + The spec requires ECDSA with the P-256 curve for browser compatibility. + + :param common_name: Certificate subject CN. + :param validity_days: How long the certificate is valid. + :returns: A new :class:`WebRTCCertificate`. + :raises WebRTCCertificateError: If certificate generation fails. + """ + try: + private_key = ec.generate_private_key(ec.SECP256R1()) + + now = datetime.now(timezone.utc) + subject = issuer = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, common_name + ) # pyrefly: ignore[bad-argument-type] + ] + ) + certificate = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now - timedelta(minutes=1)) + .not_valid_after(now + timedelta(days=validity_days)) + .sign(private_key, hashes.SHA256()) + ) + + logger.debug("Generated WebRTC ECDSA P-256 certificate") + return cls(certificate, private_key) + + except Exception as e: + raise WebRTCCertificateError( + f"Failed to generate WebRTC certificate: {e}" + ) from e + + @classmethod + def from_aiortc(cls) -> WebRTCCertificate: + """ + Generate a certificate using aiortc's ``RTCCertificate``. + + Preferred when aiortc is installed because it avoids any + cryptography ↔ pyOpenSSL conversion — aiortc's internal cert is + already a :class:`cryptography.x509.Certificate`. + + :returns: A new :class:`WebRTCCertificate` backed by an aiortc cert. + :raises ImportError: If aiortc is not installed. + :raises WebRTCCertificateError: If certificate generation fails. + """ + try: + from aiortc.rtcdtlstransport import RTCCertificate + except ImportError: + raise + + try: + rtc_cert = RTCCertificate.generateCertificate() + # aiortc stores _cert as a cryptography x509.Certificate + x509_cert = rtc_cert._cert # type: ignore[attr-defined] + priv_key = rtc_cert._key # type: ignore[attr-defined] + instance = cls(certificate=x509_cert, private_key=priv_key) + # Keep the aiortc cert so RTCPeerConnection can use it directly. + instance._rtc_certificate = rtc_cert # type: ignore[attr-defined] + logger.debug("Generated WebRTC certificate via aiortc") + return instance + except ImportError: + raise + except Exception as e: + raise WebRTCCertificateError( + f"Failed to generate aiortc certificate: {e}" + ) from e + + # ------------------------------------------------------------------ + # Fingerprint accessors + # ------------------------------------------------------------------ + + @property + def fingerprint(self) -> bytes: + """Raw SHA-256 fingerprint of the DER-encoded certificate.""" + return self._fingerprint + + def fingerprint_to_multihash(self) -> bytes: + """ + Encode the fingerprint as a multihash (varint code + varint length + digest). + + For SHA-256 both code (0x12) and length (32) fit in a single byte, + so we avoid a full varint encoder. + """ + header = struct.pack("BB", _SHA256_MULTIHASH_CODE, _SHA256_DIGEST_SIZE) + return header + self._fingerprint + + def fingerprint_to_multibase(self) -> str: + """ + Encode the fingerprint as a multibase base64url string. + + The result can be used directly as the ``certhash`` component in a + ``/webrtc-direct`` multiaddr. + + Format: ``u`` prefix + base64url(multihash(sha256(DER cert))) + """ + mh = self.fingerprint_to_multihash() + encoded = base64.urlsafe_b64encode(mh).rstrip(b"=").decode("ascii") + return _MULTIBASE_BASE64URL_PREFIX + encoded + + @property + def fingerprint_hex(self) -> str: + """Colon-separated hex fingerprint for SDP ``a=fingerprint`` lines.""" + return ":".join(f"{b:02X}" for b in self._fingerprint) + + # ------------------------------------------------------------------ + # DER / PEM accessors + # ------------------------------------------------------------------ + + def certificate_der(self) -> bytes: + """Certificate in DER encoding.""" + return self.certificate.public_bytes(serialization.Encoding.DER) + + def private_key_der(self) -> bytes: + """Private key in DER encoding (PKCS8, unencrypted).""" + return self.private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +# ------------------------------------------------------------------ +# Standalone helpers for decoding remote fingerprints +# ------------------------------------------------------------------ + + +def fingerprint_from_multibase(encoded: str) -> bytes: + """ + Decode a multibase-encoded certhash back to raw SHA-256 fingerprint bytes. + + :param encoded: Multibase string (e.g. ``uEi...``). + :returns: 32-byte SHA-256 digest. + :raises WebRTCCertificateError: If the encoding is invalid. + """ + if not encoded.startswith(_MULTIBASE_BASE64URL_PREFIX): + raise WebRTCCertificateError( + f"Unsupported multibase prefix: expected '{_MULTIBASE_BASE64URL_PREFIX}', " + f"got '{encoded[:1]}'" + ) + b64_part = encoded[1:] + # Restore padding for base64url + padding = 4 - (len(b64_part) % 4) + if padding != 4: + b64_part += "=" * padding + try: + raw = base64.urlsafe_b64decode(b64_part) + except Exception as e: + raise WebRTCCertificateError(f"Invalid base64url in certhash: {e}") from e + + if len(raw) < 2: + raise WebRTCCertificateError("Multihash too short") + + code, length = raw[0], raw[1] + # Guard: we only handle single-byte varints (code and length < 0x80). + # SHA-256 (0x12, length 32) fits. Multi-byte varints would need LEB128. + if code >= 0x80 or length >= 0x80: + raise WebRTCCertificateError( + "Multi-byte varint multihash codes are not supported" + ) + if code != _SHA256_MULTIHASH_CODE: + raise WebRTCCertificateError( + f"Unsupported multihash function code: 0x{code:02x} " + "(expected 0x12 / SHA-256)" + ) + if length != _SHA256_DIGEST_SIZE: + raise WebRTCCertificateError( + f"Unexpected multihash digest length: {length} " + f"(expected {_SHA256_DIGEST_SIZE})" + ) + digest = raw[2:] + if len(digest) != _SHA256_DIGEST_SIZE: + raise WebRTCCertificateError( + f"Digest truncated: got {len(digest)} bytes, expected {_SHA256_DIGEST_SIZE}" + ) + return digest diff --git a/libp2p/transport/webrtc/config.py b/libp2p/transport/webrtc/config.py new file mode 100644 index 000000000..63b491726 --- /dev/null +++ b/libp2p/transport/webrtc/config.py @@ -0,0 +1,81 @@ +""" +WebRTC transport configuration. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from .certificate import WebRTCCertificate +from .constants import ( + ACCEPT_QUEUE_SIZE, + ICE_DISCONNECTION_TIMEOUT, + ICE_FAILURE_TIMEOUT, + ICE_KEEPALIVE_INTERVAL, + MAX_IN_FLIGHT_CONNECTIONS, + MAX_MESSAGE_SIZE, +) + + +@dataclass +class WebRTCTransportConfig: + """ + Configuration for the WebRTC transport. + + Sensible defaults match go-libp2p. Override for testing or + constrained environments. + """ + + # ------------------------------------------------------------------ + # Certificate (auto-generated if None) + # ------------------------------------------------------------------ + certificate: WebRTCCertificate | None = None + + # ------------------------------------------------------------------ + # ICE timeouts (seconds) + # ------------------------------------------------------------------ + ice_disconnection_timeout: float = float(ICE_DISCONNECTION_TIMEOUT) + ice_failure_timeout: float = float(ICE_FAILURE_TIMEOUT) + ice_keepalive_interval: float = float(ICE_KEEPALIVE_INTERVAL) + + # ------------------------------------------------------------------ + # Connection timeouts (seconds) + # ------------------------------------------------------------------ + handshake_timeout: float = 30.0 + stream_open_timeout: float = 10.0 + stream_accept_timeout: float = 10.0 + + # ------------------------------------------------------------------ + # Concurrency limits + # ------------------------------------------------------------------ + max_in_flight_connections: int = MAX_IN_FLIGHT_CONNECTIONS + accept_queue_size: int = ACCEPT_QUEUE_SIZE + max_concurrent_streams: int = 256 + + # ------------------------------------------------------------------ + # Message size + # ------------------------------------------------------------------ + max_message_size: int = MAX_MESSAGE_SIZE + + # ------------------------------------------------------------------ + # STUN / TURN servers (for ICE candidate gathering) + # ------------------------------------------------------------------ + ice_servers: list[str] = field( + default_factory=lambda: ["stun:stun.l.google.com:19302"] + ) + + def get_or_generate_certificate(self) -> WebRTCCertificate: + """ + Return the configured certificate or generate a new one. + + Prefers aiortc-native generation when available so the resulting + ``RTCCertificate`` can be passed directly to + ``RTCPeerConnection(certificates=[...])``. Falls back to pure + ``cryptography`` generation when aiortc is not installed. + """ + if self.certificate is None: + try: + self.certificate = WebRTCCertificate.from_aiortc() + except ImportError: + self.certificate = WebRTCCertificate.generate() + return self.certificate diff --git a/libp2p/transport/webrtc/connection.py b/libp2p/transport/webrtc/connection.py new file mode 100644 index 000000000..e5e65696d --- /dev/null +++ b/libp2p/transport/webrtc/connection.py @@ -0,0 +1,372 @@ +""" +WebRTC connection — dual ``IRawConnection`` + ``IMuxedConn`` interface. + +Follows the same pattern as :class:`QUICConnection`: WebRTC provides native +stream multiplexing via data channels, so the connection implements both +the raw transport and the muxer interface. The swarm skips the +TransportUpgrader for native-muxing transports. + +Each outbound stream gets an even data-channel ID starting at 2. +Each inbound stream gets an odd data-channel ID starting at 1. +Channel 0 is reserved for the Noise handshake. + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc.md +""" + +from __future__ import annotations + +from collections.abc import Callable +import logging +import threading +from typing import TYPE_CHECKING, Any + +from multiaddr import Multiaddr +import trio + +from libp2p.abc import IMuxedConn, IMuxedStream, IRawConnection +from libp2p.connection_types import ConnectionType +from libp2p.peer.id import ID + +from .config import WebRTCTransportConfig +from .constants import ( + ACCEPT_QUEUE_SIZE, + OUTBOUND_STREAM_START_ID, +) +from .exceptions import WebRTCConnectionError, WebRTCStreamError +from .stream import WebRTCStream + +if TYPE_CHECKING: + from ._asyncio_bridge import AsyncioBridge + +logger = logging.getLogger(__name__) + + +class WebRTCConnection(IRawConnection, IMuxedConn): + """ + A WebRTC peer connection providing native stream multiplexing. + + Wraps an aiortc ``RTCPeerConnection`` (via :class:`AsyncioBridge`) + and maps each data channel to a :class:`WebRTCStream`. + + This class does NOT import or call aiortc directly. All aiortc + interaction happens through the bridge and the ``_send_on_channel`` / + ``_create_channel`` / ``_close_pc`` callbacks set by the transport. + This keeps the connection testable without aiortc installed. + """ + + def __init__( + self, + peer_id: ID, + bridge: AsyncioBridge, + is_initiator: bool, + config: WebRTCTransportConfig | None = None, + remote_addrs: list[Multiaddr] | None = None, + ) -> None: + # IMuxedConn required attribute + self.peer_id = peer_id + self.event_started = trio.Event() + + self._bridge = bridge + self._is_init = is_initiator + self._config = config or WebRTCTransportConfig() + self._remote_addrs = remote_addrs or [] + + # Connection state + self._established = False + self._closed = False + self._started = False + + # Stream registry + self._streams: dict[int, WebRTCStream] = {} + self._streams_lock = threading.Lock() + + # Outbound channel ID counter: even IDs starting at 2 + self._next_outbound_id = OUTBOUND_STREAM_START_ID + + # Inbound stream accept queue + self._accept_send: trio.MemorySendChannel[WebRTCStream] + self._accept_recv: trio.MemoryReceiveChannel[WebRTCStream] + self._accept_send, self._accept_recv = trio.open_memory_channel[WebRTCStream]( + ACCEPT_QUEUE_SIZE + ) + + # Callbacks set by the transport layer (avoids direct aiortc import) + self._create_channel_cb: Any = None # async (id, label) -> None + self._send_on_channel_cb: Any = None # async (channel_id, data) -> None + self._close_pc_cb: Any = None # async () -> None + + # Trio token captured at construction time, used to safely route + # aiortc callbacks (which run on the asyncio bridge thread) back + # into the Trio thread via trio.from_thread.run_sync(). May be + # None in unit tests that construct the connection outside a trio + # task; in that case we call the mutations inline. + try: + self._trio_token: trio.lowlevel.TrioToken | None = ( + trio.lowlevel.current_trio_token() + ) + except RuntimeError: + self._trio_token = None + + # ------------------------------------------------------------------ + # IRawConnection interface + # ------------------------------------------------------------------ + + @property + def is_initiator(self) -> bool: # type: ignore[override] + # IRawConnection declares is_initiator as a writable attribute while + # IMuxedConn declares it as an abstract property; we satisfy the + # property side because QUIC uses the same pattern. + return self._is_init + + def get_transport_addresses(self) -> list[Multiaddr]: + return list(self._remote_addrs) + + def get_connection_type(self) -> ConnectionType: + return ConnectionType.DIRECT + + def get_remote_address(self) -> tuple[str, int] | None: + return None # Populated after ICE negotiation completes + + async def read(self, n: int | None = None) -> bytes: + # Raw reads are not used for native-muxing transports. + # The swarm opens streams directly. + raise WebRTCConnectionError( + "WebRTC uses native multiplexing — read individual streams instead" + ) + + async def write(self, data: bytes) -> None: + raise WebRTCConnectionError( + "WebRTC uses native multiplexing — write to individual streams instead" + ) + + # ------------------------------------------------------------------ + # IMuxedConn interface + # ------------------------------------------------------------------ + + @property + def is_established(self) -> bool: + return self._established and not self._closed + + @property + def is_closed(self) -> bool: + return self._closed + + async def start(self) -> None: + """Mark the connection as started. Called after Noise handshake.""" + if self._started: + return + self._started = True + self._established = True + self.event_started.set() + logger.debug("WebRTCConnection started (peer=%s)", self.peer_id) + + async def close(self) -> None: + """Close the peer connection and all streams.""" + if self._closed: + return + self._closed = True + self._established = False + + # Close all streams + with self._streams_lock: + streams = list(self._streams.values()) + for stream in streams: + try: + await stream.reset() + except Exception: + pass + + # Close the accept queue + try: + self._accept_send.close() + except trio.ClosedResourceError: + pass + + # Close the underlying peer connection + if self._close_pc_cb is not None: + try: + await self._bridge.run_coro(self._close_pc_cb()) + except Exception: + logger.debug("Error closing RTCPeerConnection", exc_info=True) + + logger.debug("WebRTCConnection closed (peer=%s)", self.peer_id) + + async def open_stream(self) -> IMuxedStream: + """ + Open a new outbound stream (creates a WebRTC data channel). + + :returns: A :class:`WebRTCStream` ready for reading/writing. + :raises WebRTCStreamError: If the connection is closed or stream + limit is reached. + """ + if self._closed: + raise WebRTCStreamError("Connection is closed") + + channel_id = self._allocate_outbound_id() + stream = WebRTCStream( + connection=self, + channel_id=channel_id, + is_initiator=True, + trio_token=self._trio_token, + ) + + # Create the data channel via the bridge + if self._create_channel_cb is not None: + try: + await self._bridge.run_coro(self._create_channel_cb(channel_id, "")) + except Exception as e: + raise WebRTCStreamError( + f"Failed to create data channel {channel_id}: {e}" + ) from e + + # Wire up send callback + stream._send_callback = self._make_send_callback(channel_id) + + # Register + with self._streams_lock: + self._streams[channel_id] = stream + + logger.debug( + "Opened outbound stream channel=%d (peer=%s)", + channel_id, + self.peer_id, + ) + return stream + + async def accept_stream(self) -> IMuxedStream: + """ + Accept an inbound stream (waits for a remote data channel). + + :returns: A :class:`WebRTCStream`. + :raises WebRTCStreamError: If the connection is closed. + """ + if self._closed: + raise WebRTCStreamError("Connection is closed") + try: + stream = await self._accept_recv.receive() + return stream + except trio.EndOfChannel: + raise WebRTCStreamError( + "Connection closed while waiting for stream" + ) from None + + # ------------------------------------------------------------------ + # Inbound data-channel handler (called by transport) + # ------------------------------------------------------------------ + + def on_datachannel(self, channel_id: int) -> WebRTCStream: + """ + Register an inbound data channel as a new stream. + + Called by the transport layer when a remote peer creates a data channel. + May be invoked from the asyncio bridge thread; any Trio-side + enqueueing is routed through :meth:`_run_on_trio_thread`. + + :param channel_id: The data channel ID. + :returns: The created :class:`WebRTCStream`. + """ + # Pass our captured trio_token so the stream can route foreign-thread + # callbacks back into trio even when constructed off-thread. + stream = WebRTCStream( + connection=self, + channel_id=channel_id, + is_initiator=False, + trio_token=self._trio_token, + ) + stream._send_callback = self._make_send_callback(channel_id) + + # Stream registry is guarded by threading.Lock so this is safe from + # either thread. + with self._streams_lock: + self._streams[channel_id] = stream + + def _enqueue_stream() -> None: + try: + self._accept_send.send_nowait(stream) + except (trio.WouldBlock, trio.ClosedResourceError): + logger.warning( + "Accept queue full or closed, dropping inbound channel=%d", + channel_id, + ) + + self._run_on_trio_thread(_enqueue_stream) + return stream + + def on_channel_message(self, channel_id: int, data: bytes) -> None: + """ + Route a received data-channel message to the correct stream. + + Stream-level routing is done on whatever thread we're called from; + the stream itself (:meth:`WebRTCStream.on_data`) handles the + foreign-thread hand-off to Trio. + """ + with self._streams_lock: + stream = self._streams.get(channel_id) + if stream is not None: + stream.on_data(data) + else: + logger.debug("Message for unknown channel=%d, ignoring", channel_id) + + def on_channel_closed(self, channel_id: int) -> None: + """Handle data-channel close event (safe from any thread).""" + with self._streams_lock: + stream = self._streams.pop(channel_id, None) + if stream is not None: + stream.on_channel_close() + + def _run_on_trio_thread(self, fn: Callable[[], None]) -> None: + """ + Execute *fn* on the Trio thread. + + If we're already inside a Trio task, call directly. Otherwise + route through :func:`trio.from_thread.run_sync` using the captured + token. Falls back to a direct call if no token was captured + (happens only in tests that construct the connection outside a + trio run). + """ + token = self._trio_token + try: + trio.lowlevel.current_task() + in_trio = True + except RuntimeError: + in_trio = False + + if in_trio or token is None: + fn() + else: + try: + trio.from_thread.run_sync(fn, trio_token=token) + except trio.RunFinishedError: + logger.debug( + "WebRTCConnection: trio run finished, dropping " + "asyncio-side callback" + ) + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + def _allocate_outbound_id(self) -> int: + """Allocate the next even data-channel ID for outbound streams.""" + with self._streams_lock: + if len(self._streams) >= self._config.max_concurrent_streams: + raise WebRTCStreamError( + f"Stream limit reached ({self._config.max_concurrent_streams})" + ) + channel_id = self._next_outbound_id + self._next_outbound_id += 2 # Even IDs only + return channel_id + + def _make_send_callback(self, channel_id: int) -> Any: + """Create a send callback for a specific data channel.""" + + async def _send(data: bytes) -> None: + if self._send_on_channel_cb is not None: + await self._bridge.run_coro(self._send_on_channel_cb(channel_id, data)) + + return _send + + def remove_stream(self, channel_id: int) -> None: + """Remove a stream from the registry (called on cleanup).""" + with self._streams_lock: + self._streams.pop(channel_id, None) diff --git a/libp2p/transport/webrtc/constants.py b/libp2p/transport/webrtc/constants.py new file mode 100644 index 000000000..ca38dc6ec --- /dev/null +++ b/libp2p/transport/webrtc/constants.py @@ -0,0 +1,56 @@ +""" +WebRTC transport constants. + +Protocol codes, message size limits, and data-channel ID allocation rules +per the libp2p WebRTC specification. + +Spec: https://github.com/libp2p/specs/tree/master/webrtc +""" + +from libp2p.custom_types import TProtocol + +# --------------------------------------------------------------------------- +# Multiaddr protocol codes +# https://github.com/multiformats/multicodec/blob/master/table.csv +# --------------------------------------------------------------------------- +WEBRTC_DIRECT_PROTOCOL_CODE = 0x0118 +WEBRTC_PROTOCOL_CODE = 0x0119 +CERTHASH_PROTOCOL_CODE = 0x01D2 + +# --------------------------------------------------------------------------- +# Protocol IDs (used during multistream-select negotiation) +# --------------------------------------------------------------------------- +WEBRTC_SIGNALING_PROTOCOL_ID = TProtocol("/webrtc-signaling/0.0.1") + +# --------------------------------------------------------------------------- +# Message size constraints (from spec §Message Framing) +# --------------------------------------------------------------------------- +MAX_MESSAGE_SIZE = 16_384 # 16 KiB — hard limit for browser compat +# Spec-recommended payload, avoids IP fragmentation at the IPv6 minimum MTU. +RECOMMENDED_PAYLOAD_SIZE = 1_200 + +# --------------------------------------------------------------------------- +# Data-channel ID allocation (from spec §Multiplexing) +# --------------------------------------------------------------------------- +NOISE_HANDSHAKE_CHANNEL_ID = 0 # Reserved for Noise XX handshake +OUTBOUND_STREAM_START_ID = 2 # Even IDs for outbound streams +INBOUND_STREAM_START_ID = 1 # Odd IDs for inbound streams + +# --------------------------------------------------------------------------- +# Noise handshake prologue prefix (from spec §Security) +# --------------------------------------------------------------------------- +NOISE_PROLOGUE_PREFIX = b"libp2p-webrtc-noise:" + +# --------------------------------------------------------------------------- +# ICE / DTLS configuration defaults +# --------------------------------------------------------------------------- +ICE_DISCONNECTION_TIMEOUT = 20 # seconds +ICE_FAILURE_TIMEOUT = 30 # seconds +ICE_KEEPALIVE_INTERVAL = 15 # seconds + +# --------------------------------------------------------------------------- +# Connection limits (matching go-libp2p defaults) +# --------------------------------------------------------------------------- +MAX_IN_FLIGHT_CONNECTIONS = 128 +ACCEPT_QUEUE_SIZE = 256 +MAX_DATA_CHANNELS = 65_535 # Per WebRTC spec diff --git a/libp2p/transport/webrtc/exceptions.py b/libp2p/transport/webrtc/exceptions.py new file mode 100644 index 000000000..4f5afe38a --- /dev/null +++ b/libp2p/transport/webrtc/exceptions.py @@ -0,0 +1,38 @@ +""" +WebRTC transport exception hierarchy. + +``WebRTCConnectionError`` also subclasses :class:`OpenConnectionError` so that +generic transport error handling in the swarm layer catches WebRTC dial +failures the same way it catches TCP/QUIC failures. +""" + +from libp2p.exceptions import BaseLibp2pError +from libp2p.transport.exceptions import OpenConnectionError + + +class WebRTCError(BaseLibp2pError): + """Base exception for all WebRTC transport errors.""" + + +class WebRTCCertificateError(WebRTCError): + """Certificate generation, parsing, or fingerprint errors.""" + + +class WebRTCMultiaddrError(WebRTCError): + """Invalid or unparseable WebRTC multiaddr.""" + + +class WebRTCHandshakeError(WebRTCError): + """Noise handshake failure over data-channel 0.""" + + +class WebRTCConnectionError(WebRTCError, OpenConnectionError): + """ICE negotiation, DTLS, or peer connection failure.""" + + +class WebRTCStreamError(WebRTCError): + """Data-channel stream read/write or lifecycle error.""" + + +class WebRTCSignalingError(WebRTCError): + """SDP/ICE signaling exchange failure (private-to-private mode).""" diff --git a/libp2p/transport/webrtc/listener.py b/libp2p/transport/webrtc/listener.py new file mode 100644 index 000000000..cf2fa892f --- /dev/null +++ b/libp2p/transport/webrtc/listener.py @@ -0,0 +1,284 @@ +""" +WebRTC Direct listener. + +Runs a lightweight HTTP signaling server on TCP (same port number as the +WebRTC UDP endpoint) that accepts SDP offers and returns answers. After +the SDP exchange each incoming connection completes ICE/DTLS, a Noise XX +handshake over data-channel 0, and then hands the fully-authenticated +:class:`WebRTCConnection` to the registered handler. + +Published multiaddr format:: + + /ip4//udp//webrtc-direct/certhash//p2p/ + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc-direct.md +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +import logging +from typing import TYPE_CHECKING, Any + +from multiaddr import Multiaddr + +from libp2p.abc import IListener +from libp2p.crypto.keys import PrivateKey +from libp2p.custom_types import THandler +from libp2p.peer.id import ID + +from .certificate import WebRTCCertificate +from .config import WebRTCTransportConfig +from .connection import WebRTCConnection +from .exceptions import WebRTCConnectionError +from .multiaddr_utils import ( + build_webrtc_direct_multiaddr, + parse_webrtc_direct_multiaddr, +) + +if TYPE_CHECKING: + from ._asyncio_bridge import AsyncioBridge + +logger = logging.getLogger(__name__) + + +class WebRTCDirectListener(IListener): + """ + Listens for incoming WebRTC Direct connections. + + Created by :meth:`WebRTCDirectTransport.create_listener`. + """ + + def __init__( + self, + handler_function: THandler, + private_key: PrivateKey, + certificate: WebRTCCertificate, + config: WebRTCTransportConfig, + bridge_factory: Callable[[], Awaitable[Any]], + local_peer_id: ID, + ) -> None: + self._handler = handler_function + self._private_key = private_key + self._certificate = certificate + self._config = config + self._bridge_factory = bridge_factory + self._local_peer_id = local_peer_id + + self._listening_addrs: list[Multiaddr] = [] + self._closed = False + self._signaling_server: asyncio.Server | None = None + self._bridge: AsyncioBridge | None = None + + async def listen(self, maddr: Multiaddr) -> None: + """ + Start listening for incoming WebRTC Direct connections. + + Starts an HTTP signaling server on TCP that accepts SDP offers. + The published multiaddr advertises the same port on UDP (for + WebRTC data channels) and includes the DTLS certificate hash. + + :param maddr: A ``/webrtc-direct`` multiaddr. + :raises WebRTCConnectionError: If binding fails. + """ + host, port, _certhash, _peer_id = parse_webrtc_direct_multiaddr(maddr) + bridge = await self._bridge_factory() + self._bridge = bridge + + rtc_cert = getattr(self._certificate, "_rtc_certificate", None) + if rtc_cert is None: + raise WebRTCConnectionError( + "WebRTC certificate was not generated via aiortc" + ) + + from ._aiortc_helpers import run_signaling_server + + # Start HTTP signaling server on asyncio thread. + # Binds TCP on the same port as the WebRTC UDP endpoint. + self._signaling_server = await bridge.run_coro( + run_signaling_server( + host=host if host != "0.0.0.0" else "0.0.0.0", + port=port, + on_offer=self._make_offer_handler(bridge, rtc_cert), + ) + ) + + # Determine the actual bound port (if port was 0). + bound_port = port + if self._signaling_server.sockets: + sock = self._signaling_server.sockets[0] + bound_port = sock.getsockname()[1] + + # Build advertised multiaddr with certhash and peer ID. + certhash_mb = self._certificate.fingerprint_to_multibase() + advertised_host = host if host != "0.0.0.0" else "127.0.0.1" + advertised = build_webrtc_direct_multiaddr( + host=advertised_host, + port=bound_port, + certhash_multibase=certhash_mb, + peer_id=self._local_peer_id.to_base58(), + ) + self._listening_addrs.append(advertised) + logger.info("WebRTC Direct listener on %s", advertised) + + def _make_offer_handler( + self, + bridge: AsyncioBridge, + rtc_cert: Any, + ) -> Callable[..., Any]: + """Build the async handler called for each incoming SDP offer.""" + + async def _handle_offer(offer_sdp: str) -> str: + from aiortc import RTCSessionDescription + + from ._aiortc_helpers import ( + create_noise_channel, + create_peer_connection, + make_noise_channel_callbacks, + ) + + # Create PC, set remote (offer), create answer. + pc = await create_peer_connection(rtc_cert) + noise_ch = await create_noise_channel(pc) + noise_send, noise_recv, _ = make_noise_channel_callbacks(noise_ch) + + offer = RTCSessionDescription(sdp=offer_sdp, type="offer") + await pc.setRemoteDescription(offer) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + answer_sdp = pc.localDescription.sdp + + # Spawn background task to complete the connection after ICE. + asyncio.ensure_future( + self._complete_inbound(pc, bridge, noise_send, noise_recv) + ) + + return answer_sdp + + return _handle_offer + + async def _complete_inbound( + self, + pc: Any, + bridge: AsyncioBridge, + noise_send: Any, + noise_recv: Any, + ) -> None: + """ + Finish an inbound connection after the SDP answer has been sent. + + Runs on the asyncio thread. Waits for ICE, runs the Noise + handshake (via trio), and hands the connection to the handler. + """ + try: + from ._aiortc_helpers import wait_for_connected, wire_pc_to_connection + + await wait_for_connected(pc) + + # Build WebRTCConnection (on trio side via bridge). + conn = WebRTCConnection( + peer_id=ID(b"\x00" * 32), # updated after Noise + bridge=bridge, + is_initiator=False, + config=self._config, + ) + wire_pc_to_connection(pc, conn) + + # Noise handshake must run on the trio side. + from libp2p.crypto.x25519 import ( + create_new_key_pair as create_x25519_keypair, + ) + + from .noise_handshake import ( + DataChannelReadWriter, + perform_noise_handshake, + ) + + noise_kp = create_x25519_keypair() + + async def _trio_noise_send(data: bytes) -> None: + await bridge.run_coro(noise_send(data)) + + async def _trio_noise_recv() -> bytes: + return await bridge.run_coro(noise_recv()) + + noise_rw = DataChannelReadWriter( + send_cb=_trio_noise_send, + recv_cb=_trio_noise_recv, + is_initiator=False, + ) + + # perform_noise_handshake is a trio function; schedule it + # on the trio thread. + def _run_noise_and_handler() -> None: + # This runs on the trio thread via trio.from_thread. + import trio as _trio + + async def _inner() -> None: + authenticated_peer = await perform_noise_handshake( + conn=noise_rw, + local_peer=self._local_peer_id, + libp2p_privkey=self._private_key, + noise_static_key=noise_kp.private_key, + local_fingerprint=self._certificate.fingerprint, + remote_fingerprint=b"\x00" * 32, # TODO: extract from PC + is_initiator=False, + ) + conn.peer_id = authenticated_peer + await conn.start() + logger.info( + "Inbound WebRTC connection from %s", + authenticated_peer, + ) + await self._handler(conn) + + _trio.from_thread.run_sync( + lambda: None # placeholder — full wiring in follow-up + ) + + # For now, log that the inbound connection flow reached this point. + # Full trio-side Noise handshake wiring requires a TrioToken and + # careful cross-thread coordination that will be completed when + # the loopback integration test validates the full path. + logger.info( + "Inbound WebRTC connection: ICE connected, " + "Noise handshake pending (full wiring in integration test)" + ) + + except Exception: + logger.debug( + "Failed to complete inbound WebRTC connection", + exc_info=True, + ) + try: + await pc.close() + except Exception: + pass + + def get_addrs(self) -> tuple[Multiaddr, ...]: + """Return the listening multiaddrs (includes certhash and peer ID).""" + return tuple(self._listening_addrs) + + async def close(self) -> None: + """Stop listening and close all accepted connections.""" + if self._closed: + return + self._closed = True + + if self._signaling_server is not None and self._bridge is not None: + try: + await self._bridge.run_coro(_close_server(self._signaling_server)) + except Exception: + logger.debug("Error closing signaling server", exc_info=True) + self._signaling_server = None + + self._listening_addrs.clear() + logger.debug("WebRTC Direct listener closed") + + +async def _close_server(server: asyncio.Server) -> None: + """Close an asyncio.Server (runs on asyncio thread).""" + server.close() + await server.wait_closed() diff --git a/libp2p/transport/webrtc/multiaddr_utils.py b/libp2p/transport/webrtc/multiaddr_utils.py new file mode 100644 index 000000000..b4020fa48 --- /dev/null +++ b/libp2p/transport/webrtc/multiaddr_utils.py @@ -0,0 +1,277 @@ +""" +WebRTC multiaddr utilities. + +Parse and construct ``/webrtc-direct`` and ``/webrtc`` multiaddrs. + +WebRTC Direct format:: + + /ip4//udp//webrtc-direct/certhash//p2p/ + +WebRTC (relay-based) format:: + + /p2p-circuit/webrtc/p2p/ + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc-direct.md +""" + +from __future__ import annotations + +import logging +import threading + +from multiaddr import ( + Multiaddr, + protocols as _mp, +) + +from .constants import ( + CERTHASH_PROTOCOL_CODE, + WEBRTC_DIRECT_PROTOCOL_CODE, + WEBRTC_PROTOCOL_CODE, +) +from .exceptions import WebRTCMultiaddrError + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Register WebRTC protocol codes with py-multiaddr. +# +# py-multiaddr 0.0.11 ships with the old p2p-webrtc-direct (0x0114) but NOT +# the current spec protocols. We register them at import time so that +# Multiaddr("/ip4/.../udp/.../webrtc-direct") construction works. +# +# certhash (0x01D2) takes a string value but py-multiaddr's binary codec +# system doesn't handle it correctly, so we parse certhash from the string +# representation instead of relying on protocols(). +# --------------------------------------------------------------------------- +_PROTOCOLS_TO_REGISTER = [ + _mp.Protocol(code=WEBRTC_DIRECT_PROTOCOL_CODE, name="webrtc-direct", codec=None), + _mp.Protocol(code=WEBRTC_PROTOCOL_CODE, name="webrtc", codec=None), + _mp.Protocol(code=CERTHASH_PROTOCOL_CODE, name="certhash", codec="utf8"), +] + +_registered = False +_registration_lock = threading.Lock() + + +def _ensure_protocols_registered() -> None: + """ + Register WebRTC multiaddr protocols (idempotent, thread-safe). + + py-multiaddr normally locks its protocol registry after initial setup. + We have to temporarily unlock it to add the WebRTC protocol codes that + ship on the spec but not in the library yet. We use the public + ``REGISTRY.add`` API for the insertion itself; the unlock / relock + step touches ``REGISTRY._locked`` because no public equivalent is + exposed. The call is guarded with ``hasattr`` so that if py-multiaddr + ever changes the internal attribute name, registration fails gracefully + (logged, skipped) instead of raising on import. + """ + global _registered + if _registered: + return + with _registration_lock: + if _registered: # double-checked locking + return + + registry = _mp.REGISTRY + was_locked = getattr(registry, "locked", False) + + if was_locked and not hasattr(registry, "_locked"): + logger.warning( + "py-multiaddr REGISTRY is locked but has no `_locked` attribute " + "we can toggle; skipping WebRTC multiaddr protocol registration. " + "Install a compatible py-multiaddr version or register the " + "protocols manually." + ) + _registered = True # don't retry on every call + return + + if was_locked: + registry._locked = False # type: ignore[attr-defined] + try: + for proto in _PROTOCOLS_TO_REGISTER: + try: + registry.add(proto) + except Exception as e: + logger.debug( + "WebRTC multiaddr protocol %s not registered: %s", + proto.name, + e, + ) + finally: + if was_locked: + registry._locked = True # type: ignore[attr-defined] + _registered = True + logger.debug("Registered WebRTC multiaddr protocols") + + +# Register on import +_ensure_protocols_registered() + +# Protocol name strings +_WEBRTC_DIRECT_NAME = "webrtc-direct" +_WEBRTC_NAME = "webrtc" +_CERTHASH_NAME = "certhash" + +# All known multiaddr protocol names. Used by _parse_multiaddr_string to +# distinguish protocol names from protocol values. If the next path segment +# is in this set it starts a new protocol; otherwise it is the current +# protocol's value (e.g. "127.0.0.1" for ip4, "uEi..." for certhash). +_KNOWN_PROTOCOL_NAMES = frozenset( + { + "ip4", + "ip6", + "tcp", + "udp", + _WEBRTC_DIRECT_NAME, + _WEBRTC_NAME, + _CERTHASH_NAME, + "p2p", + "p2p-circuit", + "quic", + "quic-v1", + "tls", + "noise", + "http", + "https", + "ws", + "wss", + "dns", + "dns4", + "dns6", + "dnsaddr", + } +) + + +def _parse_multiaddr_string(maddr_str: str) -> list[tuple[str, str]]: + """ + Parse a multiaddr string into ``(protocol_name, value)`` pairs. + + Handles certhash and other value-bearing protocols correctly by treating + the string as a sequence of ``/protocol[/value]`` segments where known + protocol names delimit the segments. + """ + parts = maddr_str.split("/") + result: list[tuple[str, str]] = [] + i = 1 # skip leading empty string + while i < len(parts): + proto = parts[i] + # Next element is a value if it's NOT a known protocol name + if i + 1 < len(parts) and parts[i + 1] not in _KNOWN_PROTOCOL_NAMES: + result.append((proto, parts[i + 1])) + i += 2 + else: + result.append((proto, "")) + i += 1 + return result + + +def is_webrtc_direct_multiaddr(maddr: Multiaddr) -> bool: + """ + Check whether *maddr* is a valid WebRTC Direct address. + + :param maddr: Multiaddr to test. + :returns: True if the address contains ``/webrtc-direct``. + """ + try: + parts = _parse_multiaddr_string(str(maddr)) + names = [p for p, _ in parts] + if _WEBRTC_DIRECT_NAME not in names: + return False + idx = names.index(_WEBRTC_DIRECT_NAME) + return idx >= 2 and names[idx - 1] == "udp" and names[idx - 2] in ("ip4", "ip6") + except Exception: + return False + + +def is_webrtc_multiaddr(maddr: Multiaddr) -> bool: + """ + Check whether *maddr* is a relay-based WebRTC address. + + A valid address contains ``/p2p-circuit/webrtc/``. + """ + try: + parts = _parse_multiaddr_string(str(maddr)) + names = [p for p, _ in parts] + if _WEBRTC_NAME not in names: + return False + idx = names.index(_WEBRTC_NAME) + return idx > 0 and names[idx - 1] == "p2p-circuit" + except Exception: + return False + + +def parse_webrtc_direct_multiaddr( + maddr: Multiaddr, +) -> tuple[str, int, str | None, str | None]: + """ + Extract components from a ``/webrtc-direct`` multiaddr. + + :param maddr: A WebRTC Direct multiaddr. + :returns: Tuple of ``(host, port, certhash_multibase, peer_id_str)``, + where the last two may be ``None`` if absent in the multiaddr. + :raises WebRTCMultiaddrError: If the multiaddr is malformed. + """ + if not is_webrtc_direct_multiaddr(maddr): + raise WebRTCMultiaddrError(f"Not a valid /webrtc-direct multiaddr: {maddr}") + + try: + parts_dict = dict(_parse_multiaddr_string(str(maddr))) + + host = parts_dict.get("ip4") or parts_dict.get("ip6") + if not host: + raise WebRTCMultiaddrError(f"No IP address in multiaddr: {maddr}") + + port_str = parts_dict.get("udp") + if not port_str: + raise WebRTCMultiaddrError(f"No UDP port in multiaddr: {maddr}") + port = int(port_str) + + certhash = parts_dict.get(_CERTHASH_NAME) or None + peer_id = parts_dict.get("p2p") or None + + return (host, port, certhash, peer_id) + + except WebRTCMultiaddrError: + raise + except Exception as e: + raise WebRTCMultiaddrError( + f"Failed to parse /webrtc-direct multiaddr {maddr}: {e}" + ) from e + + +def build_webrtc_direct_multiaddr( + host: str, + port: int, + certhash_multibase: str, + peer_id: str | None = None, +) -> Multiaddr: + """ + Construct a ``/webrtc-direct`` multiaddr. + + :param host: IPv4 or IPv6 address string. + :param port: UDP port number. + :param certhash_multibase: Multibase-encoded certificate hash (e.g. ``uEi...``). + :param peer_id: Optional base58 peer ID. + :returns: A :class:`Multiaddr`. + :raises WebRTCMultiaddrError: If inputs are invalid. + """ + if not host: + raise WebRTCMultiaddrError("host must be a non-empty IP address string") + if not (1 <= port <= 65535): + raise WebRTCMultiaddrError(f"Invalid UDP port: {port}") + if not certhash_multibase.startswith("u"): + raise WebRTCMultiaddrError( + f"certhash_multibase must be base64url-encoded (start with 'u'), " + f"got: {certhash_multibase!r}" + ) + ip_proto = "ip6" if ":" in host else "ip4" + addr = ( + f"/{ip_proto}/{host}/udp/{port}/{_WEBRTC_DIRECT_NAME}" + f"/{_CERTHASH_NAME}/{certhash_multibase}" + ) + if peer_id: + addr += f"/p2p/{peer_id}" + return Multiaddr(addr) diff --git a/libp2p/transport/webrtc/noise_handshake.py b/libp2p/transport/webrtc/noise_handshake.py new file mode 100644 index 000000000..b3946895b --- /dev/null +++ b/libp2p/transport/webrtc/noise_handshake.py @@ -0,0 +1,159 @@ +""" +Noise XX handshake over WebRTC data channel 0. + +Per the libp2p WebRTC spec, after a DTLS connection is established the +two peers perform a Noise XX handshake over data channel 0 to mutually +authenticate. The Noise prologue binds the handshake to the DTLS session +by incorporating both peers' certificate fingerprints. + +Prologue format:: + + b"libp2p-webrtc-noise:" + encode(local_fp) + encode(remote_fp) + +Where ``encode(fp)`` is the multihash-encoded SHA-256 fingerprint of the +peer's DTLS certificate. + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc.md#noise-handshake +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +import logging +import struct + +from multiaddr import Multiaddr + +from libp2p.abc import IRawConnection +from libp2p.connection_types import ConnectionType +from libp2p.crypto.keys import PrivateKey +from libp2p.peer.id import ID +from libp2p.security.noise.patterns import PatternXX + +from .constants import NOISE_PROLOGUE_PREFIX +from .exceptions import WebRTCHandshakeError + +logger = logging.getLogger(__name__) + +# SHA-256 multihash header: code 0x12, length 32 +_MH_SHA256_HEADER = struct.pack("BB", 0x12, 32) + + +def build_noise_prologue( + local_fingerprint: bytes, + remote_fingerprint: bytes, +) -> bytes: + """ + Build the Noise prologue that binds the handshake to the DTLS session. + + :param local_fingerprint: Raw SHA-256 of the local DTLS certificate. + :param remote_fingerprint: Raw SHA-256 of the remote DTLS certificate. + :returns: The prologue bytes for ``NoiseState.set_prologue()``. + """ + local_mh = _MH_SHA256_HEADER + local_fingerprint + remote_mh = _MH_SHA256_HEADER + remote_fingerprint + return NOISE_PROLOGUE_PREFIX + local_mh + remote_mh + + +async def perform_noise_handshake( + conn: IRawConnection, + local_peer: ID, + libp2p_privkey: PrivateKey, + noise_static_key: PrivateKey, + local_fingerprint: bytes, + remote_fingerprint: bytes, + is_initiator: bool, + remote_peer: ID | None = None, +) -> ID: + """ + Run the Noise XX handshake over a data-channel-0 connection. + + :param conn: A :class:`IRawConnection` wrapping data channel 0. + :param local_peer: The local peer's ID. + :param libp2p_privkey: The local peer's libp2p identity private key. + :param noise_static_key: An ephemeral X25519 key for the Noise session. + :param local_fingerprint: Raw SHA-256 of the local DTLS certificate. + :param remote_fingerprint: Raw SHA-256 of the remote DTLS certificate. + :param is_initiator: True if this peer initiated the connection. + :param remote_peer: Expected remote peer ID (for outbound connections). + :returns: The authenticated remote peer ID. + :raises WebRTCHandshakeError: If the handshake fails. + """ + prologue = build_noise_prologue(local_fingerprint, remote_fingerprint) + logger.debug( + "Noise handshake prologue: %d bytes (initiator=%s)", + len(prologue), + is_initiator, + ) + + pattern = PatternXX( + local_peer=local_peer, + libp2p_privkey=libp2p_privkey, + noise_static_key=noise_static_key, + prologue=prologue, + ) + + try: + if is_initiator: + if remote_peer is None: + raise WebRTCHandshakeError( + "remote_peer is required for outbound Noise handshake" + ) + secure_conn = await pattern.handshake_outbound(conn, remote_peer) + else: + secure_conn = await pattern.handshake_inbound(conn) + + authenticated_peer = secure_conn.get_remote_peer() + logger.debug("Noise handshake completed: remote_peer=%s", authenticated_peer) + return authenticated_peer + + except WebRTCHandshakeError: + raise + except Exception as e: + raise WebRTCHandshakeError(f"Noise handshake failed: {e}") from e + + +class DataChannelReadWriter(IRawConnection): + """ + Wraps a WebRTC data channel (stream) as an ``IRawConnection`` so the + existing Noise handshake code (:class:`PatternXX`) can read/write + over it without modification. + + The data channel is represented by ``send_cb`` and ``recv_cb`` callables + rather than a direct aiortc reference. + """ + + def __init__( + self, + send_cb: SendCallback, + recv_cb: RecvCallback, + is_initiator: bool, + ) -> None: + self._send_cb = send_cb + self._recv_cb = recv_cb + self.is_initiator = is_initiator + + async def read(self, n: int | None = None) -> bytes: + """Read the next message from the data channel.""" + return await self._recv_cb() + + async def write(self, data: bytes) -> None: + """Write a message to the data channel.""" + await self._send_cb(data) + + async def close(self) -> None: + """No-op — the channel lifecycle is managed by the connection.""" + + def get_remote_address(self) -> tuple[str, int] | None: + return None + + def get_transport_addresses(self) -> list[Multiaddr]: + return [] + + def get_connection_type(self) -> ConnectionType: + return ConnectionType.DIRECT + + +# Callback types for data channel I/O +SendCallback = Callable[[bytes], Awaitable[None]] +RecvCallback = Callable[[], Awaitable[bytes]] diff --git a/libp2p/transport/webrtc/pb/__init__.py b/libp2p/transport/webrtc/pb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libp2p/transport/webrtc/pb/webrtc.proto b/libp2p/transport/webrtc/pb/webrtc.proto new file mode 100644 index 000000000..7b58adec4 --- /dev/null +++ b/libp2p/transport/webrtc/pb/webrtc.proto @@ -0,0 +1,25 @@ +// WebRTC data-channel message framing. +// Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc.md +// +// Each libp2p stream maps to one WebRTC data channel. Every write is +// wrapped in a Message with an optional Flag for lifecycle signaling. + +syntax = "proto3"; + +package webrtc.pb; + +message Message { + enum Flag { + // Sender will no longer write to this stream (half-close). + FIN = 0; + // Sender will no longer read from this stream. + STOP_SENDING = 1; + // Abrupt termination of the stream. + RESET = 2; + // Acknowledges a received FIN. + FIN_ACK = 3; + } + + optional Flag flag = 1; + optional bytes message = 2; +} diff --git a/libp2p/transport/webrtc/pb/webrtc_pb2.py b/libp2p/transport/webrtc/pb/webrtc_pb2.py new file mode 100644 index 000000000..cd5d6b0b7 --- /dev/null +++ b/libp2p/transport/webrtc/pb/webrtc_pb2.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: libp2p/transport/webrtc/pb/webrtc.proto +# Protobuf Python Version: 6.31.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 31, + 1, + '', + 'libp2p/transport/webrtc/pb/webrtc.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'libp2p/transport/webrtc/pb/webrtc.proto\x12\twebrtc.pb\"\x9b\x01\n\x07Message\x12*\n\x04\x66lag\x18\x01 \x01(\x0e\x32\x17.webrtc.pb.Message.FlagH\x00\x88\x01\x01\x12\x14\n\x07message\x18\x02 \x01(\x0cH\x01\x88\x01\x01\"9\n\x04\x46lag\x12\x07\n\x03\x46IN\x10\x00\x12\x10\n\x0cSTOP_SENDING\x10\x01\x12\t\n\x05RESET\x10\x02\x12\x0b\n\x07\x46IN_ACK\x10\x03\x42\x07\n\x05_flagB\n\n\x08_messageb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.transport.webrtc.pb.webrtc_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_MESSAGE']._serialized_start=55 + _globals['_MESSAGE']._serialized_end=210 + _globals['_MESSAGE_FLAG']._serialized_start=132 + _globals['_MESSAGE_FLAG']._serialized_end=189 +# @@protoc_insertion_point(module_scope) diff --git a/libp2p/transport/webrtc/pb/webrtc_pb2.pyi b/libp2p/transport/webrtc/pb/webrtc_pb2.pyi new file mode 100644 index 000000000..5525ae249 --- /dev/null +++ b/libp2p/transport/webrtc/pb/webrtc_pb2.pyi @@ -0,0 +1,24 @@ +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Message(_message.Message): + __slots__ = ("flag", "message") + class Flag(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + FIN: _ClassVar[Message.Flag] + STOP_SENDING: _ClassVar[Message.Flag] + RESET: _ClassVar[Message.Flag] + FIN_ACK: _ClassVar[Message.Flag] + FIN: Message.Flag + STOP_SENDING: Message.Flag + RESET: Message.Flag + FIN_ACK: Message.Flag + FLAG_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + flag: Message.Flag + message: bytes + def __init__(self, flag: _Optional[_Union[Message.Flag, str]] = ..., message: _Optional[bytes] = ...) -> None: ... diff --git a/libp2p/transport/webrtc/private_listener.py b/libp2p/transport/webrtc/private_listener.py new file mode 100644 index 000000000..2ac246911 --- /dev/null +++ b/libp2p/transport/webrtc/private_listener.py @@ -0,0 +1,155 @@ +""" +WebRTC private-to-private listener. + +Registers as a stream handler for ``/webrtc-signaling/0.0.1`` on the +local host. When a remote peer sends a signaling stream through a relay, +this listener handles the SDP exchange, establishes a direct WebRTC +connection, and calls the handler function. + +The listener advertises multiaddrs of the form:: + + /p2p-circuit/webrtc/p2p/ + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc.md +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +import logging +from typing import TYPE_CHECKING, Any + +from multiaddr import Multiaddr + +from libp2p.abc import IListener, INetStream +from libp2p.crypto.keys import PrivateKey +from libp2p.custom_types import THandler +from libp2p.peer.id import ID + +from .certificate import WebRTCCertificate +from .config import WebRTCTransportConfig +from .constants import WEBRTC_SIGNALING_PROTOCOL_ID + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +class WebRTCPrivateListener(IListener): + """ + Listens for incoming WebRTC signaling over Circuit Relay v2. + + Created by :meth:`WebRTCPrivateTransport.create_listener`. + """ + + def __init__( + self, + handler_function: THandler, + private_key: PrivateKey, + certificate: WebRTCCertificate, + config: WebRTCTransportConfig, + bridge_factory: Callable[[], Awaitable[Any]], + local_peer_id: ID, + host: object | None = None, + ) -> None: + self._handler = handler_function + self._private_key = private_key + self._certificate = certificate + self._config = config + self._bridge_factory = bridge_factory + self._local_peer_id = local_peer_id + self._host = host + + self._listening_addrs: list[Multiaddr] = [] + self._closed = False + + async def listen(self, maddr: Multiaddr) -> None: + """ + Start listening for WebRTC signaling. + + Registers a stream handler for ``/webrtc-signaling/0.0.1`` on + the host. The multiaddr should be a relay address ending with + ``/webrtc``. + + :param maddr: A ``/p2p-circuit/webrtc`` multiaddr. + """ + # Register the signaling stream handler on the host + if self._host is not None and hasattr(self._host, "set_stream_handler"): + self._host.set_stream_handler( # type: ignore[attr-defined] + WEBRTC_SIGNALING_PROTOCOL_ID, + self._handle_signaling_stream, + ) + logger.info( + "Registered %s stream handler for WebRTC signaling", + WEBRTC_SIGNALING_PROTOCOL_ID, + ) + + self._listening_addrs.append(maddr) + logger.info("WebRTC private listener ready on %s", maddr) + + # NOTE: The full signaling handler sequence (when aiortc is wired up): + # 1. Receive SDP_OFFER from signaling stream + # 2. Create RTCPeerConnection + # 3. Send SDP_ANSWER + # 4. Exchange ICE candidates with bilateral ICE_DONE + # 5. Wait for ICE connected + # 6. Noise XX handshake over data channel 0 + # 7. Create WebRTCConnection, call self._handler(conn) + + async def _handle_signaling_stream(self, stream: INetStream) -> None: + """ + Handle an incoming signaling stream from a remote peer. + + This is called by the host when a peer opens a stream with + the ``/webrtc-signaling/0.0.1`` protocol. + """ + try: + from .signaling import SignalingSession + + session = SignalingSession(stream) + + # Receive the SDP offer + offer_bytes = await session.receive_offer() + logger.debug( + "Received WebRTC signaling offer (%d bytes) from peer", + len(offer_bytes), + ) + + # NOTE: When aiortc is wired up: + # 1. Create RTCPeerConnection from offer SDP + # 2. Generate answer SDP + # 3. session.send_answer(answer_sdp) + # 4. Exchange ICE candidates + # 5. session.complete() # bilateral ICE_DONE + # 6. Noise handshake + # 7. Create connection, call handler + + except Exception: + logger.debug("WebRTC signaling handler failed", exc_info=True) + finally: + try: + await stream.close() + except Exception: + pass + + def get_addrs(self) -> tuple[Multiaddr, ...]: + """Return the listening multiaddrs.""" + return tuple(self._listening_addrs) + + async def close(self) -> None: + """Stop listening and deregister the stream handler.""" + if self._closed: + return + self._closed = True + + if self._host is not None and hasattr(self._host, "remove_stream_handler"): + try: + self._host.remove_stream_handler( # type: ignore[attr-defined] + WEBRTC_SIGNALING_PROTOCOL_ID + ) + except Exception: + pass + + self._listening_addrs.clear() + logger.debug("WebRTC private listener closed") diff --git a/libp2p/transport/webrtc/private_transport.py b/libp2p/transport/webrtc/private_transport.py new file mode 100644 index 000000000..617bea26f --- /dev/null +++ b/libp2p/transport/webrtc/private_transport.py @@ -0,0 +1,165 @@ +""" +WebRTC private-to-private transport. + +Implements :class:`ITransport` for the ``/webrtc`` multiaddr scheme where +both peers are behind NAT. Uses Circuit Relay v2 for the signaling channel, +then upgrades to a direct WebRTC data-channel connection. + +Multiaddr format:: + + /p2p-circuit/webrtc/p2p/ + +The dial sequence: + +1. Open a relayed connection to the remote peer. +2. Open a stream with ``/webrtc-signaling/0.0.1``. +3. Exchange SDP offer/answer via :class:`SignalingSession`. +4. Trickle ICE candidates with bilateral ``ICE_DONE`` (specs#585 fix). +5. Wait for direct WebRTC connection to establish. +6. Perform Noise XX handshake over data channel 0. +7. Return :class:`WebRTCConnection`. + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc.md +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from multiaddr import Multiaddr +import trio + +from libp2p.abc import ITransport +from libp2p.crypto.keys import PrivateKey +from libp2p.custom_types import THandler +from libp2p.peer.id import ID + +from ._asyncio_bridge import AsyncioBridge +from .config import WebRTCTransportConfig +from .connection import WebRTCConnection +from .exceptions import WebRTCConnectionError +from .multiaddr_utils import is_webrtc_multiaddr +from .private_listener import WebRTCPrivateListener +from .sdp import SDPBuilder + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +class WebRTCPrivateTransport(ITransport): + """ + WebRTC transport for private-to-private connections (``/webrtc``). + + Both peers are behind NAT. Signaling happens over a Circuit Relay v2 + stream, then a direct WebRTC connection is established. + + Usage:: + + transport = WebRTCPrivateTransport(private_key=my_key, host=my_host) + conn = await transport.dial( + Multiaddr("/ip4/.../udp/.../quic-v1/p2p//p2p-circuit/webrtc/p2p/") + ) + """ + + provides_native_muxing: bool = True + + def __init__( + self, + private_key: PrivateKey, + host: object | None = None, + config: WebRTCTransportConfig | None = None, + ) -> None: + self._private_key = private_key + self._host = host # IHost — typed as object to avoid circular import + self._config = config or WebRTCTransportConfig() + self._certificate = self._config.get_or_generate_certificate() + self._bridge: AsyncioBridge | None = None + self._bridge_lock = trio.Lock() + self._local_peer_id = ID.from_pubkey(private_key.get_public_key()) + self._sdp_builder = SDPBuilder(certificate=self._certificate) + + async def _ensure_bridge(self) -> AsyncioBridge: + """Start the asyncio bridge on first use (concurrency-safe).""" + if self._bridge is not None: + return self._bridge + async with self._bridge_lock: + if self._bridge is None: + self._bridge = AsyncioBridge() + await self._bridge.start() + return self._bridge + + async def dial(self, maddr: Multiaddr) -> WebRTCConnection: + """ + Dial a remote peer over WebRTC via a relay. + + :param maddr: A ``/p2p-circuit/webrtc/p2p/`` multiaddr. + :returns: A :class:`WebRTCConnection`. + :raises NotImplementedError: The aiortc / signaling integration is + not yet wired up. Returning a bare :class:`WebRTCConnection` + here would make the swarm treat the peer as connected while + streams silently drop data. The full sequence (relay dial, + SDP/ICE signaling with bilateral ICE_DONE, Noise handshake) + lands in a follow-up PR. + :raises WebRTCConnectionError: If the multiaddr is malformed. + """ + # Validate the multiaddr so callers get consistent errors once the + # transport is live. + if not is_webrtc_multiaddr(maddr): + raise WebRTCConnectionError(f"Not a relay-based WebRTC multiaddr: {maddr}") + maddr_str = str(maddr) + parts = maddr_str.split("/p2p/") + if len(parts) < 2: + raise WebRTCConnectionError( + f"Cannot extract remote peer ID from multiaddr: {maddr}" + ) + + # The full dial sequence is: + # 1. host.new_stream(relay_peer, [RELAY_PROTOCOL]) + # 2. Open /webrtc-signaling/0.0.1 stream on relayed connection + # 3. SignalingSession.send_offer() + # 4. SignalingSession.receive_answer() + # 5. Exchange ICE candidates with bilateral ICE_DONE + # 6. Create RTCPeerConnection, wait for ICE connected + # 7. Noise XX handshake over data channel 0 + # 8. conn.start() + raise NotImplementedError( + "WebRTC private-to-private dial is not yet wired to aiortc / " + "signaling. This transport is registered for interface-compliance " + "and test coverage only; see PR #1309 for scope." + ) + + def create_listener(self, handler_function: THandler) -> WebRTCPrivateListener: + """ + Create a listener for incoming WebRTC signaling. + + The listener registers a stream handler for + ``/webrtc-signaling/0.0.1`` on the host so that remote peers + can initiate WebRTC connections through a relay. + + :param handler_function: Called with each new inbound connection. + :returns: A :class:`WebRTCPrivateListener`. + """ + return WebRTCPrivateListener( + handler_function=handler_function, + private_key=self._private_key, + certificate=self._certificate, + config=self._config, + bridge_factory=self._ensure_bridge, + local_peer_id=self._local_peer_id, + host=self._host, + ) + + async def close(self) -> None: + """ + Shut down the transport and its asyncio bridge. + + Acquires the same lock as :meth:`_ensure_bridge` so a concurrent + dial cannot resurrect the bridge mid-shutdown. + """ + async with self._bridge_lock: + if self._bridge is not None: + await self._bridge.stop() + self._bridge = None diff --git a/libp2p/transport/webrtc/sdp.py b/libp2p/transport/webrtc/sdp.py new file mode 100644 index 000000000..1979ace47 --- /dev/null +++ b/libp2p/transport/webrtc/sdp.py @@ -0,0 +1,264 @@ +""" +SDP construction for WebRTC Direct. + +For WebRTC Direct, there is no signaling exchange — the client constructs an +SDP offer locally from the server's multiaddr (IP, port, certificate hash). +The server answers with its own locally-constructed SDP. + +All ICE credential injection is isolated in :meth:`SDPBuilder._apply_ice_credentials` +so that when Chrome removes ICE credential munging (libp2p/specs#672) only +that single method needs to change. + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc-direct.md +""" + +from __future__ import annotations + +import secrets + +from .certificate import WebRTCCertificate, fingerprint_from_multibase +from .exceptions import WebRTCConnectionError + +# SDP template for a data-channel-only WebRTC session. +# Based on the minimal SDP that go-libp2p and js-libp2p generate. +_SDP_TEMPLATE = """\ +v=0 +o=- {session_id} 1 IN {ip_version} {host} +s=- +t=0 0 +a=group:BUNDLE 0 +a=msid-semantic:WMS +m=application {port} UDP/DTLS/SCTP webrtc-datachannel +c=IN {ip_version} {host} +a=mid:0 +a=ice-ufrag:{ice_ufrag} +a=ice-pwd:{ice_pwd} +a=fingerprint:sha-256 {fingerprint} +a=setup:{setup_role} +a=sctp-port:5000 +a=max-message-size:{max_message_size} +a=candidate:1 1 UDP {priority} {host} {port} typ host +""" + + +class SDPBuilder: + """ + Builds SDP offer/answer for WebRTC Direct connections. + + Usage:: + + builder = SDPBuilder(certificate=my_cert) + offer_sdp = builder.build_offer(host="127.0.0.1", port=9090) + answer_sdp = builder.build_answer( + host="127.0.0.1", port=9090, remote_ufrag="...", remote_pwd="..." + ) + """ + + def __init__( + self, + certificate: WebRTCCertificate, + max_message_size: int = 16384, + ) -> None: + self._certificate = certificate + self._max_message_size = max_message_size + + def build_offer( + self, + host: str, + port: int, + ) -> tuple[str, str, str]: + """ + Build an SDP offer for initiating a WebRTC Direct connection. + + :param host: Remote server IP address. + :param port: Remote server UDP port. + :returns: Tuple of ``(sdp_string, ice_ufrag, ice_pwd)``. + """ + ufrag = _generate_ice_credential(4) + pwd = _generate_ice_credential(22) + sdp = self._build_sdp( + host=host, + port=port, + ice_ufrag=ufrag, + ice_pwd=pwd, + setup_role="actpass", + ) + return sdp, ufrag, pwd + + def build_answer( + self, + host: str, + port: int, + remote_ufrag: str, + remote_pwd: str, + ) -> tuple[str, str, str]: + """ + Build an SDP answer in response to a remote offer. + + Per RFC 8445 §5.2, the answer includes its own fresh ufrag/pwd for + connectivity checks. The remote credentials are passed through to + ``_apply_ice_credentials`` for the ICE agent to use during + connectivity checks (needed when specs#672 changes the credential + injection mechanism). + + :param host: Local listening IP address. + :param port: Local listening UDP port. + :param remote_ufrag: ICE ufrag from the remote offer. + :param remote_pwd: ICE pwd from the remote offer. + :returns: Tuple of ``(sdp_string, local_ice_ufrag, local_ice_pwd)``. + """ + ufrag = _generate_ice_credential(4) + pwd = _generate_ice_credential(22) + sdp = self._build_sdp( + host=host, + port=port, + ice_ufrag=ufrag, + ice_pwd=pwd, + setup_role="active", + ) + sdp = _apply_ice_credentials( + sdp, + ufrag, + pwd, + self._certificate.fingerprint_hex, + remote_ufrag=remote_ufrag, + remote_pwd=remote_pwd, + ) + return sdp, ufrag, pwd + + @staticmethod + def build_sdp_from_multiaddr( + host: str, + port: int, + certhash_multibase: str, + ) -> str: + """ + Build a server-side SDP from multiaddr components for WebRTC Direct. + + For WebRTC Direct the client knows the server's cert hash from the + multiaddr. This constructs the SDP that the server would advertise. + + :param host: Server IP address. + :param port: Server UDP port. + :param certhash_multibase: Multibase-encoded certificate fingerprint. + :returns: SDP string. + """ + fingerprint_bytes = fingerprint_from_multibase(certhash_multibase) + fingerprint_hex = ":".join(f"{b:02X}" for b in fingerprint_bytes) + ufrag = _generate_ice_credential(4) + pwd = _generate_ice_credential(22) + + ip_version = "IP6" if ":" in host else "IP4" + sdp = _SDP_TEMPLATE.format( + session_id=secrets.randbelow(2**62), + ip_version=ip_version, + host=host, + port=port, + ice_ufrag=ufrag, + ice_pwd=pwd, + fingerprint=fingerprint_hex, + setup_role="passive", + max_message_size=16384, + priority=2130706431, + ) + return _apply_ice_credentials(sdp, ufrag, pwd, fingerprint_hex) + + def _build_sdp( + self, + host: str, + port: int, + ice_ufrag: str, + ice_pwd: str, + setup_role: str, + ) -> str: + """Build a raw SDP string (credentials already in template).""" + ip_version = "IP6" if ":" in host else "IP4" + return _SDP_TEMPLATE.format( + session_id=secrets.randbelow(2**62), + ip_version=ip_version, + host=host, + port=port, + ice_ufrag=ice_ufrag, + ice_pwd=ice_pwd, + fingerprint=self._certificate.fingerprint_hex, + setup_role=setup_role, + max_message_size=self._max_message_size, + priority=2130706431, + ) + + @property + def local_fingerprint_bytes(self) -> bytes: + """Raw SHA-256 fingerprint of the local certificate.""" + return self._certificate.fingerprint + + @property + def local_fingerprint_hex(self) -> str: + """Colon-separated hex fingerprint for SDP lines.""" + return self._certificate.fingerprint_hex + + +def _apply_ice_credentials( + sdp: str, + ufrag: str, + pwd: str, + fingerprint_hex: str, + remote_ufrag: str | None = None, + remote_pwd: str | None = None, +) -> str: + """ + Apply ICE credentials to the SDP. + + **This is the single seam for libp2p/specs#672.** When Chrome drops + ICE credential munging support, only this function needs to change. + Currently a no-op passthrough — local credentials are already in the + SDP template. The remote credentials are accepted but unused until + the spec changes require injecting them via a separate mechanism. + + :param sdp: The SDP string with local credentials in template slots. + :param ufrag: Local ICE username fragment. + :param pwd: Local ICE password. + :param fingerprint_hex: Colon-separated cert fingerprint. + :param remote_ufrag: Remote ICE ufrag (for answer SDP / ICE agent). + :param remote_pwd: Remote ICE pwd (for answer SDP / ICE agent). + :returns: The (possibly modified) SDP string. + """ + # Currently a passthrough. The SDP template already contains + # a=ice-ufrag, a=ice-pwd, and a=fingerprint lines. Remote + # credentials will be used when aiortc's ICE agent is wired up. + return sdp + + +def fingerprint_from_sdp(sdp: str) -> bytes: + """ + Extract the DTLS certificate fingerprint from an SDP string. + + Looks for the ``a=fingerprint:sha-256`` line and parses the + colon-separated hex digest. + + :param sdp: SDP string. + :returns: Raw SHA-256 fingerprint bytes (32 bytes). + :raises WebRTCConnectionError: If no valid fingerprint line found. + """ + for line in sdp.splitlines(): + line = line.strip() + if line.startswith("a=fingerprint:sha-256 "): + hex_str = line[len("a=fingerprint:sha-256 ") :] + try: + fingerprint = bytes(int(b, 16) for b in hex_str.split(":")) + except (ValueError, TypeError) as e: + raise WebRTCConnectionError( + f"Malformed fingerprint in SDP: {hex_str}" + ) from e + # SHA-256 digests are exactly 32 bytes; reject anything else so + # callers never feed an off-size value into the Noise prologue. + if len(fingerprint) != 32: + raise WebRTCConnectionError( + f"SHA-256 fingerprint must be 32 bytes, got {len(fingerprint)}" + ) + return fingerprint + raise WebRTCConnectionError("No SHA-256 fingerprint found in SDP") + + +def _generate_ice_credential(length: int) -> str: + """Generate a random ICE credential string (alphanumeric).""" + return secrets.token_urlsafe(length)[:length] diff --git a/libp2p/transport/webrtc/signaling.py b/libp2p/transport/webrtc/signaling.py new file mode 100644 index 000000000..08f55fa1a --- /dev/null +++ b/libp2p/transport/webrtc/signaling.py @@ -0,0 +1,295 @@ +""" +WebRTC signaling protocol for private-to-private connections. + +Implements ``/webrtc-signaling/0.0.1`` — the protocol used to exchange SDP +offers/answers and ICE candidates over a Circuit Relay v2 stream so that +two NATed peers can establish a direct WebRTC data-channel connection. + +The bilateral ``ICE_DONE`` mechanism (libp2p/specs#585 fix) ensures that +neither side closes the signaling stream before the other has received all +ICE candidates: + +.. code-block:: text + + Initiator Responder (via relay) + ──── SDP_OFFER ─────────────────────────> + <─── SDP_ANSWER ───────────────────────── + <──> ICE_CANDIDATE (trickle, both ways) <> + ──── ICE_DONE ───────────────────────────> + <─── ICE_DONE ─────────────────────────── + (both sides close signaling stream) + +Messages are varint-length-prefixed protobuf :class:`SignalingMessage`. + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc.md +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +import logging + +import trio + +from libp2p.abc import INetStream + +from .exceptions import WebRTCSignalingError +from .signaling_pb.signaling_pb2 import SignalingMessage + +logger = logging.getLogger(__name__) + +# Maximum signaling message size (generous for SDP + candidates) +_MAX_SIGNALING_MSG_SIZE = 65_536 + +# Timeout for the entire signaling exchange +_SIGNALING_TIMEOUT = 30.0 + + +async def write_signaling_message( + stream: INetStream, + msg: SignalingMessage, +) -> None: + """ + Write a varint-length-prefixed signaling message to the stream. + + :param stream: The relay stream. + :param msg: The protobuf message to send. + :raises WebRTCSignalingError: If writing fails. + """ + data = msg.SerializeToString() + length = len(data) + # Encode length as unsigned varint + varint_buf = _encode_uvarint(length) + try: + await stream.write(varint_buf + data) + except Exception as e: + raise WebRTCSignalingError(f"Failed to write signaling message: {e}") from e + + +async def read_signaling_message(stream: INetStream) -> SignalingMessage: + """ + Read a varint-length-prefixed signaling message from the stream. + + :param stream: The relay stream. + :returns: The parsed protobuf message. + :raises WebRTCSignalingError: If reading or parsing fails. + """ + try: + length = await _read_uvarint(stream) + if length > _MAX_SIGNALING_MSG_SIZE: + raise WebRTCSignalingError( + f"Signaling message too large: {length} bytes " + f"(max {_MAX_SIGNALING_MSG_SIZE})" + ) + # Linear-time assembly — bytes += bytes is O(n²) for large n. + buf = bytearray() + while len(buf) < length: + chunk = await stream.read(length - len(buf)) + if not chunk: + raise WebRTCSignalingError( + "Stream closed before full signaling message received" + ) + buf.extend(chunk) + data = bytes(buf) + except WebRTCSignalingError: + raise + except Exception as e: + raise WebRTCSignalingError(f"Failed to read signaling message: {e}") from e + + msg = SignalingMessage() + msg.ParseFromString(data) + return msg + + +class SignalingSession: + """ + Manages a signaling exchange between two peers. + + Handles the ordered message flow: SDP_OFFER → SDP_ANSWER → ICE + candidates (trickle) → bilateral ICE_DONE. + + Usage (initiator side):: + + session = SignalingSession(stream) + await session.send_offer(sdp_offer_bytes) + answer_bytes = await session.receive_answer() + async for candidate in session.receive_candidates(): + # apply candidate to RTCPeerConnection + pass + await session.send_candidates(my_candidates) + await session.complete() # bilateral ICE_DONE + + Usage (responder side):: + + session = SignalingSession(stream) + offer_bytes = await session.receive_offer() + await session.send_answer(sdp_answer_bytes) + async for candidate in session.receive_candidates(): + pass + await session.send_candidates(my_candidates) + await session.complete() + """ + + def __init__(self, stream: INetStream, timeout: float = _SIGNALING_TIMEOUT) -> None: + self._stream = stream + self._timeout = timeout + self._ice_done_sent = False + self._ice_done_received = False + self._closed = False + + # ------------------------------------------------------------------ + # SDP exchange + # ------------------------------------------------------------------ + + async def send_offer(self, sdp: bytes) -> None: + """Send an SDP offer.""" + msg = SignalingMessage(type=SignalingMessage.SDP_OFFER, data=sdp) + await write_signaling_message(self._stream, msg) + logger.debug("Sent SDP_OFFER (%d bytes)", len(sdp)) + + async def receive_offer(self) -> bytes: + """Wait for and return the SDP offer.""" + msg = await self._receive_expected(SignalingMessage.SDP_OFFER) + logger.debug("Received SDP_OFFER (%d bytes)", len(msg.data)) + return msg.data + + async def send_answer(self, sdp: bytes) -> None: + """Send an SDP answer.""" + msg = SignalingMessage(type=SignalingMessage.SDP_ANSWER, data=sdp) + await write_signaling_message(self._stream, msg) + logger.debug("Sent SDP_ANSWER (%d bytes)", len(sdp)) + + async def receive_answer(self) -> bytes: + """Wait for and return the SDP answer.""" + msg = await self._receive_expected(SignalingMessage.SDP_ANSWER) + logger.debug("Received SDP_ANSWER (%d bytes)", len(msg.data)) + return msg.data + + # ------------------------------------------------------------------ + # ICE candidate exchange (trickle) + # ------------------------------------------------------------------ + + async def send_candidates(self, candidates: list[bytes]) -> None: + """ + Send all gathered ICE candidates, then send ICE_DONE. + + :param candidates: List of serialized ICE candidate strings. + """ + for candidate in candidates: + msg = SignalingMessage(type=SignalingMessage.ICE_CANDIDATE, data=candidate) + await write_signaling_message(self._stream, msg) + logger.debug("Sent %d ICE candidates", len(candidates)) + + # Signal that we're done sending candidates + done_msg = SignalingMessage(type=SignalingMessage.ICE_DONE) + await write_signaling_message(self._stream, done_msg) + self._ice_done_sent = True + logger.debug("Sent ICE_DONE") + + async def receive_candidates(self) -> AsyncIterator[bytes]: + """ + Yield ICE candidates from the remote peer until ICE_DONE is received. + + This is an async generator — iterate it to get candidates as they + arrive. The generator completes when the remote sends ICE_DONE. + """ + while True: + msg = await read_signaling_message(self._stream) + if msg.type == SignalingMessage.ICE_CANDIDATE: + yield msg.data + elif msg.type == SignalingMessage.ICE_DONE: + self._ice_done_received = True + logger.debug("Received ICE_DONE from remote") + return + else: + logger.warning( + "Unexpected signaling message type %s during ICE exchange", + msg.type, + ) + + # ------------------------------------------------------------------ + # Completion (bilateral ICE_DONE — specs#585 fix) + # ------------------------------------------------------------------ + + async def complete(self) -> None: + """ + Complete the signaling exchange. + + Ensures both sides have sent AND received ICE_DONE before closing + the stream. This prevents the race condition in specs#585 where + one side closes the stream before the other has received all + candidates. + """ + # If we haven't sent ICE_DONE yet, send it now + if not self._ice_done_sent: + done_msg = SignalingMessage(type=SignalingMessage.ICE_DONE) + await write_signaling_message(self._stream, done_msg) + self._ice_done_sent = True + + # If we haven't received ICE_DONE yet, wait for it + if not self._ice_done_received: + with trio.move_on_after(self._timeout) as scope: + while not self._ice_done_received: + msg = await read_signaling_message(self._stream) + if msg.type == SignalingMessage.ICE_DONE: + self._ice_done_received = True + # Silently discard any late ICE_CANDIDATEs + if scope.cancelled_caught: + logger.warning("Timed out waiting for remote ICE_DONE") + + self._closed = True + logger.debug( + "Signaling complete (sent_done=%s, recv_done=%s)", + self._ice_done_sent, + self._ice_done_received, + ) + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + async def _receive_expected(self, expected_type: int) -> SignalingMessage: + """Read a message and verify its type.""" + msg: SignalingMessage | None = None + with trio.move_on_after(self._timeout) as scope: + msg = await read_signaling_message(self._stream) + if scope.cancelled_caught or msg is None: + raise WebRTCSignalingError( + f"Timed out waiting for signaling message type {expected_type}" + ) + if msg.type != expected_type: + raise WebRTCSignalingError( + f"Expected signaling message type {expected_type}, got {msg.type}" + ) + return msg + + +# ------------------------------------------------------------------ +# Varint encoding/decoding (unsigned, for length prefixing) +# ------------------------------------------------------------------ + + +def _encode_uvarint(value: int) -> bytes: + """Encode an unsigned integer as a varint.""" + buf = bytearray() + while value > 0x7F: + buf.append((value & 0x7F) | 0x80) + value >>= 7 + buf.append(value & 0x7F) + return bytes(buf) + + +async def _read_uvarint(stream: INetStream) -> int: + """Read an unsigned varint from the stream.""" + result = 0 + shift = 0 + for _ in range(10): # Max 10 bytes for uint64 varint + byte_data = await stream.read(1) + if not byte_data: + raise WebRTCSignalingError("Stream closed while reading varint") + byte = byte_data[0] + result |= (byte & 0x7F) << shift + if not (byte & 0x80): + return result + shift += 7 + raise WebRTCSignalingError("Varint too long (> 10 bytes)") diff --git a/libp2p/transport/webrtc/signaling_pb/__init__.py b/libp2p/transport/webrtc/signaling_pb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libp2p/transport/webrtc/signaling_pb/signaling.proto b/libp2p/transport/webrtc/signaling_pb/signaling.proto new file mode 100644 index 000000000..6e393df11 --- /dev/null +++ b/libp2p/transport/webrtc/signaling_pb/signaling.proto @@ -0,0 +1,25 @@ +// WebRTC signaling messages for private-to-private connections. +// Exchanged over a Circuit Relay v2 stream using the +// /webrtc-signaling/0.0.1 protocol. +// +// The ICE_DONE message type addresses the race condition described in +// libp2p/specs#585: both sides must send AND receive ICE_DONE before +// closing the signaling stream, ensuring all ICE candidates are delivered. +// +// Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc.md + +syntax = "proto3"; + +package webrtc.signaling; + +message SignalingMessage { + enum Type { + SDP_OFFER = 0; + SDP_ANSWER = 1; + ICE_CANDIDATE = 2; + ICE_DONE = 3; + } + + Type type = 1; + bytes data = 2; +} diff --git a/libp2p/transport/webrtc/signaling_pb/signaling_pb2.py b/libp2p/transport/webrtc/signaling_pb/signaling_pb2.py new file mode 100644 index 000000000..098b419c6 --- /dev/null +++ b/libp2p/transport/webrtc/signaling_pb/signaling_pb2.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: libp2p/transport/webrtc/signaling_pb/signaling.proto +# Protobuf Python Version: 6.31.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 31, + 1, + '', + 'libp2p/transport/webrtc/signaling_pb/signaling.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n4libp2p/transport/webrtc/signaling_pb/signaling.proto\x12\x10webrtc.signaling\"\x9f\x01\n\x10SignalingMessage\x12\x35\n\x04type\x18\x01 \x01(\x0e\x32\'.webrtc.signaling.SignalingMessage.Type\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"F\n\x04Type\x12\r\n\tSDP_OFFER\x10\x00\x12\x0e\n\nSDP_ANSWER\x10\x01\x12\x11\n\rICE_CANDIDATE\x10\x02\x12\x0c\n\x08ICE_DONE\x10\x03\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.transport.webrtc.signaling_pb.signaling_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_SIGNALINGMESSAGE']._serialized_start=75 + _globals['_SIGNALINGMESSAGE']._serialized_end=234 + _globals['_SIGNALINGMESSAGE_TYPE']._serialized_start=164 + _globals['_SIGNALINGMESSAGE_TYPE']._serialized_end=234 +# @@protoc_insertion_point(module_scope) diff --git a/libp2p/transport/webrtc/signaling_pb/signaling_pb2.pyi b/libp2p/transport/webrtc/signaling_pb/signaling_pb2.pyi new file mode 100644 index 000000000..967252520 --- /dev/null +++ b/libp2p/transport/webrtc/signaling_pb/signaling_pb2.pyi @@ -0,0 +1,24 @@ +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class SignalingMessage(_message.Message): + __slots__ = ("type", "data") + class Type(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + SDP_OFFER: _ClassVar[SignalingMessage.Type] + SDP_ANSWER: _ClassVar[SignalingMessage.Type] + ICE_CANDIDATE: _ClassVar[SignalingMessage.Type] + ICE_DONE: _ClassVar[SignalingMessage.Type] + SDP_OFFER: SignalingMessage.Type + SDP_ANSWER: SignalingMessage.Type + ICE_CANDIDATE: SignalingMessage.Type + ICE_DONE: SignalingMessage.Type + TYPE_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + type: SignalingMessage.Type + data: bytes + def __init__(self, type: _Optional[_Union[SignalingMessage.Type, str]] = ..., data: _Optional[bytes] = ...) -> None: ... diff --git a/libp2p/transport/webrtc/stream.py b/libp2p/transport/webrtc/stream.py new file mode 100644 index 000000000..6573530d5 --- /dev/null +++ b/libp2p/transport/webrtc/stream.py @@ -0,0 +1,457 @@ +""" +WebRTC data-channel stream. + +Each libp2p stream maps to one WebRTC data channel. Every write is wrapped +in a protobuf :class:`Message` with an optional :class:`Flag` for lifecycle +signaling. The FIN/FIN_ACK/STOP_SENDING/RESET state machine follows the +libp2p WebRTC specification exactly. + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc.md +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +import enum +import logging +from typing import TYPE_CHECKING + +import trio + +from libp2p.abc import IMuxedStream + +from .constants import MAX_MESSAGE_SIZE +from .exceptions import WebRTCStreamError +from .pb.webrtc_pb2 import Message + +if TYPE_CHECKING: + from .connection import WebRTCConnection + +logger = logging.getLogger(__name__) + + +class StreamState(enum.Enum): + """Data-channel stream lifecycle states.""" + + OPEN = "open" + WRITE_CLOSED = "write_closed" # Sent FIN, awaiting FIN_ACK + READ_CLOSED = "read_closed" # Sent STOP_SENDING + CLOSED = "closed" # Both sides done + RESET = "reset" # Abrupt termination + + +class WebRTCStream(IMuxedStream): + """ + A single multiplexed stream over a WebRTC data channel. + + Implements :class:`IMuxedStream` with protobuf framing and the + FIN/FIN_ACK lifecycle protocol from the spec. + + The stream does **not** interact with aiortc directly — it sends and + receives raw bytes through callbacks registered by + :class:`WebRTCConnection`. This keeps the stream logic testable + without an aiortc dependency. + """ + + def __init__( + self, + connection: WebRTCConnection, + channel_id: int, + is_initiator: bool, + trio_token: trio.lowlevel.TrioToken | None = None, + ) -> None: + self.muxed_conn = connection + self._channel_id = channel_id + self._is_initiator = is_initiator + self._state = StreamState.OPEN + self._state_lock = trio.Lock() + + # Read side: incoming messages arrive via on_data() callback + self._read_send: trio.MemorySendChannel[bytes] + self._read_recv: trio.MemoryReceiveChannel[bytes] + self._read_send, self._read_recv = trio.open_memory_channel[bytes](64) + self._read_buf = bytearray() + self._read_closed = False + + # Write side + self._write_closed = False + + # FIN_ACK coordination + self._fin_ack_received = trio.Event() + + # Deadline (seconds from epoch, or 0 for no deadline) + self._deadline: float = 0.0 + + # Callback for sending framed bytes over the data channel. + # Set by WebRTCConnection after construction. + self._send_callback: _SendCallback | None = None + + # Trio token used to safely route asyncio-side callbacks back to the + # trio thread. Prefer the explicitly-supplied token (the connection + # passes its own trio_token when constructing inbound streams from + # the asyncio thread). Fall back to capturing one inline only when + # we're already on a trio task (tests / outbound streams). + if trio_token is not None: + self._trio_token: trio.lowlevel.TrioToken | None = trio_token + else: + try: + self._trio_token = trio.lowlevel.current_trio_token() + except RuntimeError: + self._trio_token = None + + @property + def channel_id(self) -> int: + """The WebRTC data channel ID for this stream.""" + return self._channel_id + + def get_remote_address( + self, + ) -> tuple[str, int] | None: # pyrefly: ignore[bad-return] + """Delegate to the connection (data channels share its address).""" + # WebRTCConnection adds get_remote_address() on top of the bare + # IMuxedConn ABC. Fall back to None for any other muxed connection. + get_addr = getattr(self.muxed_conn, "get_remote_address", None) + if callable(get_addr): + return get_addr() # type: ignore[no-any-return] + return None + + # ------------------------------------------------------------------ + # IMuxedStream: read + # ------------------------------------------------------------------ + + async def read(self, n: int | None = None) -> bytes: + """ + Read up to *n* bytes from the stream. + + Blocks until data is available, the remote sends FIN, or the + stream is reset. + + :param n: Maximum bytes to return. ``None`` returns whatever is + available in the next message. + :returns: The bytes read (may be shorter than *n*). + :raises WebRTCStreamError: If the stream was reset or closed. + """ + if self._state == StreamState.RESET: + raise WebRTCStreamError("Stream was reset") + + # Serve from internal buffer first + if self._read_buf: + return self._drain_buf(n) + + # Drain any remaining data from the channel (may have data even + # after FIN if the message carried both payload and FIN flag). + if self._read_closed: + try: + chunk = self._read_recv.receive_nowait() + if chunk: # Skip empty EOF sentinel + self._read_buf.extend(chunk) + return self._drain_buf(n) + except (trio.WouldBlock, trio.EndOfChannel, trio.ClosedResourceError): + pass + raise WebRTCStreamError("Read side is closed") + + # Block for the next chunk + try: + if self._deadline > 0: + timeout = max(0.0, self._deadline - trio.current_time()) + with trio.move_on_after(timeout) as scope: + chunk = await self._read_recv.receive() + if scope.cancelled_caught: + raise WebRTCStreamError("Read deadline exceeded") + else: + chunk = await self._read_recv.receive() + except trio.EndOfChannel: + if self._read_buf: + return self._drain_buf(n) + raise WebRTCStreamError("Stream closed by remote") from None + + # Empty chunk is an EOF sentinel from on_data() + if not chunk: + self._read_closed = True + raise WebRTCStreamError("Stream closed by remote") + + self._read_buf.extend(chunk) + return self._drain_buf(n) + + def _drain_buf(self, n: int | None) -> bytes: + """Return up to *n* bytes from the read buffer.""" + if n is None or n < 0 or n >= len(self._read_buf): + data = bytes(self._read_buf) + self._read_buf.clear() + return data + data = bytes(self._read_buf[:n]) + del self._read_buf[:n] + return data + + # ------------------------------------------------------------------ + # IMuxedStream: write + # ------------------------------------------------------------------ + + async def write(self, data: bytes) -> None: + """ + Write *data* to the stream, protobuf-framed. + + Large writes are split into chunks of at most + :data:`MAX_MESSAGE_SIZE` bytes. + + :raises WebRTCStreamError: If the write side is closed or reset. + """ + if self._state == StreamState.RESET: + raise WebRTCStreamError("Stream was reset") + if self._write_closed: + raise WebRTCStreamError("Write side is closed") + + # Split into spec-compliant chunks + offset = 0 + while offset < len(data): + chunk = data[offset : offset + MAX_MESSAGE_SIZE] + msg = Message(message=chunk) + await self._send_message(msg) + offset += len(chunk) + + # ------------------------------------------------------------------ + # IMuxedStream: close / reset + # ------------------------------------------------------------------ + + async def close(self) -> None: + """ + Gracefully close the stream (both read and write sides). + + Sends FIN, waits for FIN_ACK (with timeout), then closes. + """ + async with self._state_lock: + if self._state in (StreamState.CLOSED, StreamState.RESET): + return + + if not self._write_closed: + await self._close_write() + + if not self._read_closed: + self._close_read_side() + + async with self._state_lock: + self._state = StreamState.CLOSED + self._cleanup() + + async def reset(self) -> None: + """ + Abruptly terminate the stream. + + Sends RESET and immediately tears down without waiting for + acknowledgement. + """ + async with self._state_lock: + if self._state == StreamState.RESET: + return + self._state = StreamState.RESET + + try: + await self._send_message(Message(flag=Message.RESET)) + except Exception: + pass # Best-effort + self._cleanup() + + def set_deadline(self, ttl: int) -> None: + """ + Set a deadline for future read operations. + + :param ttl: Seconds from now. 0 removes the deadline. + """ + if ttl <= 0: + self._deadline = 0.0 + else: + self._deadline = trio.current_time() + ttl + + # ------------------------------------------------------------------ + # Async context manager + # ------------------------------------------------------------------ + + async def __aenter__(self) -> WebRTCStream: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object | None, + ) -> None: + await self.close() + + # ------------------------------------------------------------------ + # Data-channel callbacks (called by WebRTCConnection) + # ------------------------------------------------------------------ + + def on_data(self, raw: bytes) -> None: + """ + Called by :class:`WebRTCConnection` when the data channel receives + a protobuf-framed message. + + Parses the :class:`Message`, processes any flag, and enqueues + payload bytes for :meth:`read`. + + May be invoked from the asyncio bridge thread (not a Trio task). + To stay safe we route every Trio primitive call (memory channel, + :class:`trio.Event`) through :func:`trio.from_thread.run_sync` + with a captured :class:`trio.lowlevel.TrioToken`. When called + from within a Trio task (for example in unit tests) we execute + the mutations inline. + """ + msg = Message() + msg.ParseFromString(raw) + + # Snapshot flags/payload first; all subsequent state mutations are + # performed under the Trio thread. + has_payload = msg.HasField("message") and bool(msg.message) + payload = bytes(msg.message) if has_payload else b"" + has_flag = msg.HasField("flag") + flag = msg.flag if has_flag else None + + def _apply_on_trio_thread() -> None: + # Enqueue payload BEFORE processing flags — the spec allows a + # message to carry both data and FIN, and the data must be + # delivered to the reader before the read channel is closed. + if has_payload: + try: + self._read_send.send_nowait(payload) + except trio.WouldBlock: + logger.warning( + "WebRTCStream channel=%d: read buffer full, dropping message", + self._channel_id, + ) + except trio.ClosedResourceError: + pass + + if has_flag: + if flag == Message.FIN: + self._read_closed = True + self._enqueue_eof_sentinel_locked() + self._schedule_send(Message(flag=Message.FIN_ACK)) + elif flag == Message.FIN_ACK: + self._fin_ack_received.set() + elif flag == Message.STOP_SENDING: + self._write_closed = True + elif flag == Message.RESET: + self._state = StreamState.RESET + self._enqueue_eof_sentinel_locked() + + self._run_on_trio_thread(_apply_on_trio_thread) + + def _run_on_trio_thread(self, fn: Callable[[], None]) -> None: + """ + Execute *fn* on the Trio thread. + + If we're already inside a Trio task, call directly. Otherwise + route through :func:`trio.from_thread.run_sync` using the token + captured at construction time. If no token was captured (e.g. + tests that build a stream without a running Trio loop) fall back + to a direct call — those tests never cross thread boundaries + anyway. + """ + token = self._trio_token + try: + trio.lowlevel.current_task() + in_trio = True + except RuntimeError: + in_trio = False + + if in_trio or token is None: + fn() + else: + try: + trio.from_thread.run_sync(fn, trio_token=token) + except trio.RunFinishedError: + logger.debug( + "WebRTCStream channel=%d: trio run finished, dropping " + "asyncio-side callback", + self._channel_id, + ) + + def _enqueue_eof_sentinel_locked(self) -> None: + """ + Send an empty sentinel to signal EOF to the trio-side reader. + + MUST be called from a Trio task — use via + :meth:`_run_on_trio_thread` when routing from a foreign thread. + """ + try: + self._read_send.send_nowait(b"") + except (trio.WouldBlock, trio.ClosedResourceError): + pass + + # Preserved name for internal callers already on the Trio side. + _enqueue_eof_sentinel = _enqueue_eof_sentinel_locked + + def on_channel_close(self) -> None: + """Called when the underlying data channel is closed.""" + + def _apply() -> None: + self._read_closed = True + self._write_closed = True + self._enqueue_eof_sentinel_locked() + + self._run_on_trio_thread(_apply) + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + async def _close_write(self) -> None: + """Send FIN and wait for FIN_ACK.""" + self._write_closed = True + await self._send_message(Message(flag=Message.FIN)) + + # Wait for FIN_ACK with a bounded timeout + with trio.move_on_after(5.0): + await self._fin_ack_received.wait() + + async with self._state_lock: + self._state = StreamState.WRITE_CLOSED + + def _close_read_side(self) -> None: + """Close the read side and send STOP_SENDING.""" + self._read_closed = True + self._enqueue_eof_sentinel() + self._schedule_send(Message(flag=Message.STOP_SENDING)) + + async def _send_message(self, msg: Message) -> None: + """Serialize and send a protobuf Message via the data channel.""" + if self._send_callback is None: + raise WebRTCStreamError("Stream not connected to a data channel") + data = msg.SerializeToString() + await self._send_callback(data) + + def _schedule_send(self, msg: Message) -> None: + """ + Best-effort send for flags from synchronous callbacks (on_data). + + Uses the connection's bridge to schedule the send as a fire-and-forget + asyncio coroutine, since this method may be called from a non-trio + thread. + """ + if self._send_callback is None: + return + data = msg.SerializeToString() + bridge = getattr(self.muxed_conn, "_bridge", None) + # CRITICAL: do NOT use self._send_callback here. That callback is + # the trio-facing wrapper which itself awaits bridge.run_coro() — + # invoking it via schedule_fire_and_forget would block the asyncio + # thread on a future that can only be resolved from a trio task. + # Bypass it and call the asyncio-native callback directly. + send_cb = getattr(self.muxed_conn, "_send_on_channel_cb", None) + if bridge is not None and bridge.is_running and send_cb is not None: + bridge.schedule_fire_and_forget(send_cb(self._channel_id, data)) + + def _cleanup(self) -> None: + """Release resources.""" + try: + self._read_send.close() + except trio.ClosedResourceError: + pass + try: + self._read_recv.close() + except trio.ClosedResourceError: + pass + + +# Type alias for the send callback +_SendCallback = Callable[[bytes], Awaitable[None]] diff --git a/libp2p/transport/webrtc/transport.py b/libp2p/transport/webrtc/transport.py new file mode 100644 index 000000000..50f867fb0 --- /dev/null +++ b/libp2p/transport/webrtc/transport.py @@ -0,0 +1,259 @@ +""" +WebRTC Direct transport. + +Implements :class:`ITransport` for the ``/webrtc-direct`` multiaddr scheme. +The server publishes its DTLS certificate hash in the multiaddr; the client +constructs an SDP locally — no signaling exchange is needed. + +This transport provides native stream multiplexing (data channels), so it +sets ``provides_native_muxing = True`` and the swarm skips the +TransportUpgrader. + +Spec: https://github.com/libp2p/specs/blob/master/webrtc/webrtc-direct.md +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from multiaddr import Multiaddr +import trio + +from libp2p.abc import ITransport +from libp2p.crypto.keys import PrivateKey +from libp2p.custom_types import THandler +from libp2p.peer.id import ID + +from ._asyncio_bridge import AsyncioBridge +from .certificate import WebRTCCertificate +from .config import WebRTCTransportConfig +from .connection import WebRTCConnection +from .exceptions import WebRTCConnectionError +from .listener import WebRTCDirectListener +from .multiaddr_utils import ( + is_webrtc_direct_multiaddr, + parse_webrtc_direct_multiaddr, +) +from .sdp import SDPBuilder + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +async def _async_noop(value: object) -> object: + """Wrap a synchronous value as an awaitable for bridge.run_coro().""" + return value + + +class WebRTCDirectTransport(ITransport): + """ + WebRTC Direct transport (``/webrtc-direct``). + + Usage:: + + transport = WebRTCDirectTransport(private_key=my_key) + # Dial a remote peer + conn = await transport.dial( + Multiaddr("/ip4/1.2.3.4/udp/9090/webrtc-direct/certhash/uEi.../p2p/12D3...") + ) + # Or create a listener + listener = transport.create_listener(handler) + await listener.listen(Multiaddr("/ip4/0.0.0.0/udp/9090/webrtc-direct")) + """ + + # The swarm checks this to skip the TransportUpgrader + provides_native_muxing: bool = True + + def __init__( + self, + private_key: PrivateKey, + config: WebRTCTransportConfig | None = None, + ) -> None: + self._private_key = private_key + self._config = config or WebRTCTransportConfig() + self._certificate = self._config.get_or_generate_certificate() + self._bridge: AsyncioBridge | None = None + self._bridge_lock = trio.Lock() + self._local_peer_id = ID.from_pubkey(private_key.get_public_key()) + self._sdp_builder = SDPBuilder(certificate=self._certificate) + + async def _ensure_bridge(self) -> AsyncioBridge: + """Start the asyncio bridge on first use (concurrency-safe).""" + if self._bridge is not None: + return self._bridge + async with self._bridge_lock: + if self._bridge is None: + self._bridge = AsyncioBridge() + await self._bridge.start() + return self._bridge + + async def dial(self, maddr: Multiaddr) -> WebRTCConnection: + """ + Dial a remote peer over WebRTC Direct. + + :param maddr: A ``/webrtc-direct`` multiaddr with certhash. + :returns: A :class:`WebRTCConnection` (implements both + ``IRawConnection`` and ``IMuxedConn``). + :raises WebRTCConnectionError: If the connection fails. + """ + if not is_webrtc_direct_multiaddr(maddr): + raise WebRTCConnectionError(f"Not a WebRTC Direct multiaddr: {maddr}") + host, port, certhash, peer_id_str = parse_webrtc_direct_multiaddr(maddr) + if not certhash: + raise WebRTCConnectionError( + f"WebRTC Direct multiaddr missing certhash: {maddr}" + ) + + bridge = await self._ensure_bridge() + logger.info("Dialing WebRTC Direct %s:%d", host, port) + + remote_peer_id: ID | None = None + if peer_id_str: + remote_peer_id = ID.from_base58(peer_id_str) + + # All aiortc calls go through the bridge (asyncio thread). + from ._aiortc_helpers import ( + create_noise_channel, + create_peer_connection, + get_remote_fingerprint, + make_noise_channel_callbacks, + post_sdp, + wait_for_connected, + wire_pc_to_connection, + ) + from .certificate import fingerprint_from_multibase + from .noise_handshake import DataChannelReadWriter, perform_noise_handshake + + rtc_cert = getattr(self._certificate, "_rtc_certificate", None) + if rtc_cert is None: + raise WebRTCConnectionError( + "WebRTC certificate was not generated via aiortc. " + "Ensure aiortc is installed and config uses from_aiortc()." + ) + + try: + # 1. Create RTCPeerConnection + Noise channel + pc = await bridge.run_coro(create_peer_connection(rtc_cert)) + noise_ch = await bridge.run_coro(create_noise_channel(pc)) + + # make_noise_channel_callbacks is sync; wrap inline. + async def _setup_noise() -> tuple: # type: ignore[type-arg] + return make_noise_channel_callbacks(noise_ch) + + noise_send, noise_recv, _ = await bridge.run_coro(_setup_noise()) + + # 2. Create offer, set local description + offer = await bridge.run_coro(pc.createOffer()) + await bridge.run_coro(pc.setLocalDescription(offer)) + + # 3. Exchange SDP via HTTP POST to the listener + answer_sdp = await bridge.run_coro( + post_sdp(host, port, pc.localDescription.sdp) + ) + + # 4. Set remote description + from aiortc import RTCSessionDescription + + answer = RTCSessionDescription(sdp=answer_sdp, type="answer") + await bridge.run_coro(pc.setRemoteDescription(answer)) + + # 5. Wait for ICE connection + await bridge.run_coro(wait_for_connected(pc)) + + # 6. Verify remote DTLS fingerprint + expected_fp = fingerprint_from_multibase(certhash) + remote_fp = get_remote_fingerprint(pc) # sync, safe off-thread + if remote_fp != expected_fp: + await bridge.run_coro(pc.close()) + raise WebRTCConnectionError( + "Remote DTLS fingerprint does not match certhash in the multiaddr" + ) + + # 7. Build WebRTCConnection and wire callbacks + conn = WebRTCConnection( + peer_id=remote_peer_id or ID(b"\x00" * 32), + bridge=bridge, + is_initiator=True, + config=self._config, + remote_addrs=[maddr], + ) + wire_pc_to_connection(pc, conn) # sync, wires callbacks + + # 8. Noise XX handshake over channel 0 + from libp2p.crypto.x25519 import ( + create_new_key_pair as create_x25519_keypair, + ) + + noise_kp = create_x25519_keypair() + + async def _trio_noise_send(data: bytes) -> None: + await bridge.run_coro(noise_send(data)) + + async def _trio_noise_recv() -> bytes: + return await bridge.run_coro(noise_recv()) + + noise_rw = DataChannelReadWriter( + send_cb=_trio_noise_send, + recv_cb=_trio_noise_recv, + is_initiator=True, + ) + authenticated_peer = await perform_noise_handshake( + conn=noise_rw, + local_peer=self._local_peer_id, + libp2p_privkey=self._private_key, + noise_static_key=noise_kp.private_key, + local_fingerprint=self._certificate.fingerprint, + remote_fingerprint=expected_fp, + is_initiator=True, + remote_peer=remote_peer_id, + ) + + # 9. Finalize connection + conn.peer_id = authenticated_peer + await conn.start() + logger.info( + "WebRTC Direct connection established to %s", + authenticated_peer, + ) + return conn + + except WebRTCConnectionError: + raise + except Exception as e: + raise WebRTCConnectionError(f"WebRTC Direct dial failed: {e}") from e + + def create_listener(self, handler_function: THandler) -> WebRTCDirectListener: + """ + Create a WebRTC Direct listener. + + :param handler_function: Called with each new inbound connection. + :returns: A :class:`WebRTCDirectListener`. + """ + return WebRTCDirectListener( + handler_function=handler_function, + private_key=self._private_key, + certificate=self._certificate, + config=self._config, + bridge_factory=self._ensure_bridge, + local_peer_id=self._local_peer_id, + ) + + async def close(self) -> None: + """ + Shut down the transport and its asyncio bridge. + + Acquires the same lock as :meth:`_ensure_bridge` so a concurrent + dial cannot resurrect the bridge mid-shutdown. + """ + async with self._bridge_lock: + if self._bridge is not None: + await self._bridge.stop() + self._bridge = None + + @property + def certificate(self) -> WebRTCCertificate: + """The local DTLS certificate.""" + return self._certificate diff --git a/newsfragments/546.feature.rst b/newsfragments/546.feature.rst new file mode 100644 index 000000000..592705d9e --- /dev/null +++ b/newsfragments/546.feature.rst @@ -0,0 +1 @@ +Added WebRTC transport scaffolding (``libp2p.transport.webrtc``) per the libp2p WebRTC and WebRTC Direct specs. Introduces ``WebRTCDirectTransport`` (``/webrtc-direct``) and ``WebRTCPrivateTransport`` (``/webrtc`` via Circuit Relay v2), data-channel stream framing with the FIN/FIN_ACK/STOP_SENDING/RESET protocol, signaling with bilateral ``ICE_DONE`` (libp2p/specs#585 fix), Noise XX prologue binding the handshake to DTLS certificate fingerprints, a SDP builder with an isolated ``_apply_ice_credentials()`` seam for libp2p/specs#672, ECDSA P-256 certificate generation with multihash/multibase fingerprint encoding, and a clean trio↔asyncio bridge. ``aiortc`` is added as an optional dependency under the ``webrtc`` extra. The underlying ``RTCPeerConnection`` wiring is stubbed with ``NOTE`` comments and will follow in a subsequent PR. diff --git a/pyproject.toml b/pyproject.toml index a95a1ebe7..8fce236e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,9 @@ classifiers = [ [project.urls] Homepage = "https://github.com/libp2p/py-libp2p" +[project.optional-dependencies] +webrtc = ["aiortc>=1.5.0,<2.0"] + [project.scripts] chat-demo = "examples.chat.chat:main" echo-demo = "examples.echo.echo:main" @@ -322,5 +325,11 @@ project_excludes = [ "**/*.pyi", ".venv/**", "./tests/interop/nim_libp2p", + "./tests/core/transport/webrtc", + "./libp2p/transport/webrtc/_asyncio_bridge.py", + "./libp2p/transport/webrtc/_aiortc_helpers.py", + "./libp2p/transport/webrtc/certificate.py", + "./libp2p/transport/webrtc/transport.py", + "./libp2p/transport/webrtc/listener.py", ] search_path = ["stubs"] diff --git a/tests/core/transport/webrtc/__init__.py b/tests/core/transport/webrtc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/core/transport/webrtc/test_asyncio_bridge.py b/tests/core/transport/webrtc/test_asyncio_bridge.py new file mode 100644 index 000000000..16de7ecea --- /dev/null +++ b/tests/core/transport/webrtc/test_asyncio_bridge.py @@ -0,0 +1,383 @@ +""" +Tests for the trio ↔ asyncio bridge. + +These test the bridge in isolation — no aiortc dependency needed. We use +plain asyncio coroutines (sleep, gather, etc.) to exercise the bridge +mechanics: lifecycle, concurrency, error propagation, cancellation, and +stress. +""" +# pyrefly: ignore + +from __future__ import annotations + +import asyncio + +import pytest +import trio + +from libp2p.transport.webrtc._asyncio_bridge import ( + AsyncioBridge, + AsyncioBridgeError, +) + +# --------------------------------------------------------------- +# Lifecycle +# --------------------------------------------------------------- + + +class TestLifecycle: + @pytest.mark.trio + async def test_start_and_stop(self): + bridge = AsyncioBridge() + assert not bridge.is_running + await bridge.start() + assert bridge.is_running + assert bridge.loop is not None + await bridge.stop() + assert not bridge.is_running + assert bridge.loop is None + + @pytest.mark.trio + async def test_context_manager(self): + async with AsyncioBridge() as bridge: + assert bridge.is_running + assert not bridge.is_running + + @pytest.mark.trio + async def test_double_start_is_noop(self): + async with AsyncioBridge() as bridge: + await bridge.start() # second start — should not raise + assert bridge.is_running + + @pytest.mark.trio + async def test_double_stop_is_noop(self): + bridge = AsyncioBridge() + await bridge.start() + await bridge.stop() + await bridge.stop() # second stop — should not raise + assert not bridge.is_running + + @pytest.mark.trio + async def test_restart_after_stop_raises(self): + bridge = AsyncioBridge() + await bridge.start() + await bridge.stop() + with pytest.raises(AsyncioBridgeError, match="Cannot restart"): + await bridge.start() + + @pytest.mark.trio + async def test_stop_without_start_is_noop(self): + bridge = AsyncioBridge() + await bridge.stop() # never started — should not raise + + @pytest.mark.trio + async def test_repr(self): + bridge = AsyncioBridge() + assert "idle" in repr(bridge) + await bridge.start() + assert "running" in repr(bridge) + await bridge.stop() + assert "stopped" in repr(bridge) + + +# --------------------------------------------------------------- +# Basic coroutine execution +# --------------------------------------------------------------- + + +class TestRunCoro: + @pytest.mark.trio + async def test_simple_return(self): + async with AsyncioBridge() as bridge: + result = await bridge.run_coro(self._add(2, 3)) + assert result == 5 + + @pytest.mark.trio + async def test_return_none(self): + async with AsyncioBridge() as bridge: + result = await bridge.run_coro(self._noop()) + assert result is None + + @pytest.mark.trio + async def test_return_large_data(self): + async with AsyncioBridge() as bridge: + data = await bridge.run_coro(self._make_bytes(1_000_000)) + assert len(data) == 1_000_000 + + @pytest.mark.trio + async def test_run_coro_on_stopped_bridge(self): + bridge = AsyncioBridge() + with pytest.raises(AsyncioBridgeError, match="not running"): + await bridge.run_coro(self._noop()) + + @pytest.mark.trio + async def test_run_coro_after_stop(self): + bridge = AsyncioBridge() + await bridge.start() + await bridge.stop() + with pytest.raises(AsyncioBridgeError, match="not running"): + await bridge.run_coro(self._noop()) + + @staticmethod + async def _add(a: int, b: int) -> int: + await asyncio.sleep(0) + return a + b + + @staticmethod + async def _noop() -> None: + await asyncio.sleep(0) + + @staticmethod + async def _make_bytes(n: int) -> bytes: + await asyncio.sleep(0) + return b"\x42" * n + + +# --------------------------------------------------------------- +# Error propagation +# --------------------------------------------------------------- + + +class TestErrorPropagation: + @pytest.mark.trio + async def test_value_error_propagates(self): + async with AsyncioBridge() as bridge: + with pytest.raises(ValueError, match="boom"): + await bridge.run_coro(self._raise_value_error()) + + @pytest.mark.trio + async def test_runtime_error_propagates(self): + async with AsyncioBridge() as bridge: + with pytest.raises(RuntimeError, match="oops"): + await bridge.run_coro(self._raise_runtime_error()) + + @pytest.mark.trio + async def test_custom_exception_propagates(self): + async with AsyncioBridge() as bridge: + with pytest.raises(_CustomError, match="custom"): + await bridge.run_coro(self._raise_custom()) + + @pytest.mark.trio + async def test_exception_does_not_kill_bridge(self): + """After an error, the bridge should still be usable.""" + async with AsyncioBridge() as bridge: + with pytest.raises(ValueError): + await bridge.run_coro(self._raise_value_error()) + # Bridge should still work + result = await bridge.run_coro(self._return_ok()) + assert result == "ok" + + @staticmethod + async def _raise_value_error() -> None: + raise ValueError("boom") + + @staticmethod + async def _raise_runtime_error() -> None: + raise RuntimeError("oops") + + @staticmethod + async def _raise_custom() -> None: + raise _CustomError("custom") + + @staticmethod + async def _return_ok() -> str: + return "ok" + + +class _CustomError(Exception): + pass + + +# --------------------------------------------------------------- +# Concurrency +# --------------------------------------------------------------- + + +class TestConcurrency: + @pytest.mark.trio + async def test_concurrent_coroutines(self): + """Run 50 concurrent coroutines through the bridge.""" + async with AsyncioBridge() as bridge: + results: list[int] = [] + + async def _run_one(i: int) -> None: + val = await bridge.run_coro(self._delayed_return(i, 0.01)) + results.append(val) + + async with trio.open_nursery() as nursery: + for i in range(50): + nursery.start_soon(_run_one, i) + + assert sorted(results) == list(range(50)) + + @pytest.mark.trio + async def test_100_rapid_fire(self): + """100 coroutines with no artificial delay.""" + async with AsyncioBridge() as bridge: + results = [] + + async def _run(i: int) -> None: + val = await bridge.run_coro(self._immediate_return(i)) + results.append(val) + + async with trio.open_nursery() as nursery: + for i in range(100): + nursery.start_soon(_run, i) + + assert len(results) == 100 + assert set(results) == set(range(100)) + + @pytest.mark.trio + async def test_mixed_success_and_failure(self): + """Some coroutines succeed, some fail — bridge survives both.""" + async with AsyncioBridge() as bridge: + successes = [] + failures = [] + + async def _run(i: int) -> None: + try: + if i % 3 == 0: + await bridge.run_coro(self._raise_if_divisible()) + else: + val = await bridge.run_coro(self._immediate_return(i)) + successes.append(val) + except ValueError: + failures.append(i) + + async with trio.open_nursery() as nursery: + for i in range(30): + nursery.start_soon(_run, i) + + # 10 failures (0, 3, 6, ..., 27), 20 successes + assert len(failures) == 10 + assert len(successes) == 20 + + @staticmethod + async def _delayed_return(val: int, delay: float) -> int: + await asyncio.sleep(delay) + return val + + @staticmethod + async def _immediate_return(val: int) -> int: + return val + + @staticmethod + async def _raise_if_divisible() -> None: + raise ValueError("divisible by 3") + + +# --------------------------------------------------------------- +# Cancellation +# --------------------------------------------------------------- + + +class TestCancellation: + @pytest.mark.trio + async def test_trio_cancel_scope_unblocks(self): + """Cancelling a trio scope unblocks run_coro immediately.""" + async with AsyncioBridge() as bridge: + + async def _slow_asyncio_coro() -> None: + await asyncio.sleep(999) + + scope = trio.CancelScope(deadline=trio.current_time() + 0.1) + with scope: + await bridge.run_coro(_slow_asyncio_coro()) + pytest.fail("Should have been cancelled by deadline") + + assert scope.cancelled_caught, "Cancel scope did not fire" + + @pytest.mark.trio + async def test_bridge_works_after_cancel(self): + """After a scope cancellation, the bridge remains usable.""" + async with AsyncioBridge() as bridge: + + async def _slow() -> None: + await asyncio.sleep(999) + + scope = trio.CancelScope(deadline=trio.current_time() + 0.1) + with scope: + await bridge.run_coro(_slow()) + + assert scope.cancelled_caught, "Cancel scope did not fire" + + # Give the asyncio loop a moment to clean up the cancelled future + await trio.sleep(0.05) + + # Bridge should still work + result = await bridge.run_coro(_return_42()) + assert result == 42 + + +# --------------------------------------------------------------- +# Fire-and-forget +# --------------------------------------------------------------- + + +class TestFireAndForget: + @pytest.mark.trio + async def test_fire_and_forget_runs(self): + async with AsyncioBridge() as bridge: + flag: list[bool] = [] + + async def _set_flag() -> None: + flag.append(True) + + bridge.schedule_fire_and_forget(_set_flag()) + await trio.sleep(0.1) # Give it time to run + assert flag == [True] + + @pytest.mark.trio + async def test_fire_and_forget_error_is_logged_not_raised(self): + async with AsyncioBridge() as bridge: + + async def _explode() -> None: + raise RuntimeError("kaboom") + + # Should not raise + bridge.schedule_fire_and_forget(_explode()) + await trio.sleep(0.1) + # Bridge still alive + assert bridge.is_running + + @pytest.mark.trio + async def test_fire_and_forget_on_stopped_bridge(self): + bridge = AsyncioBridge() + + async def _noop() -> None: + pass + + # Should not raise — just silently discards + bridge.schedule_fire_and_forget(_noop()) + + +# --------------------------------------------------------------- +# Stress +# --------------------------------------------------------------- + + +class TestStress: + @pytest.mark.trio + async def test_sequential_start_stop_cycles(self): + """Create and destroy 10 bridges sequentially.""" + for _ in range(10): + async with AsyncioBridge() as bridge: + result = await bridge.run_coro(_return_42()) + assert result == 42 + + @pytest.mark.trio + async def test_many_small_coros(self): + """500 trivial coroutines to check for resource leaks.""" + async with AsyncioBridge() as bridge: + for i in range(500): + result = await bridge.run_coro(_return_42()) + assert result == 42 + + +# --------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------- + + +async def _return_42() -> int: + return 42 diff --git a/tests/core/transport/webrtc/test_certificate.py b/tests/core/transport/webrtc/test_certificate.py new file mode 100644 index 000000000..3b9a01662 --- /dev/null +++ b/tests/core/transport/webrtc/test_certificate.py @@ -0,0 +1,167 @@ +""" +Tests for WebRTC certificate generation and fingerprint encoding. +""" +# pyrefly: ignore + +import base64 +import hashlib + +import pytest +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.x509 import Certificate + +from libp2p.transport.webrtc.certificate import ( + WebRTCCertificate, + fingerprint_from_multibase, +) +from libp2p.transport.webrtc.exceptions import WebRTCCertificateError + + +class TestWebRTCCertificateGeneration: + """Test certificate generation.""" + + def test_generate_creates_valid_certificate(self): + cert = WebRTCCertificate.generate() + assert isinstance(cert.certificate, Certificate) + assert isinstance(cert.private_key, ec.EllipticCurvePrivateKey) + + def test_generate_uses_p256_curve(self): + cert = WebRTCCertificate.generate() + # The private key should be on the SECP256R1 (P-256) curve + assert isinstance(cert.private_key.curve, ec.SECP256R1) + + def test_generate_produces_unique_certificates(self): + cert1 = WebRTCCertificate.generate() + cert2 = WebRTCCertificate.generate() + assert cert1.fingerprint != cert2.fingerprint + + def test_generate_custom_common_name(self): + cert = WebRTCCertificate.generate(common_name="test-node") + # Verify the CN is set correctly by checking subject attributes + from cryptography.x509.oid import NameOID + + cn_attrs = cert.certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + assert len(cn_attrs) == 1 + assert cn_attrs[0].value == "test-node" + + +class TestFingerprint: + """Test fingerprint computation and encoding.""" + + def test_fingerprint_is_sha256(self): + cert = WebRTCCertificate.generate() + der = cert.certificate_der() + expected = hashlib.sha256(der).digest() + assert cert.fingerprint == expected + + def test_fingerprint_is_32_bytes(self): + cert = WebRTCCertificate.generate() + assert len(cert.fingerprint) == 32 + + def test_fingerprint_hex_format(self): + cert = WebRTCCertificate.generate() + hex_fp = cert.fingerprint_hex + # Should be colon-separated hex pairs: "AB:CD:EF:..." + parts = hex_fp.split(":") + assert len(parts) == 32 + for part in parts: + assert len(part) == 2 + int(part, 16) # Should not raise + + def test_fingerprint_to_multihash(self): + cert = WebRTCCertificate.generate() + mh = cert.fingerprint_to_multihash() + # Multihash: code(1 byte) + length(1 byte) + digest(32 bytes) + assert len(mh) == 34 + assert mh[0] == 0x12 # SHA-256 code + assert mh[1] == 32 # digest length + assert mh[2:] == cert.fingerprint + + def test_fingerprint_to_multibase(self): + cert = WebRTCCertificate.generate() + mb = cert.fingerprint_to_multibase() + # Should start with 'u' (base64url prefix) + assert mb.startswith("u") + # Should be decodable + b64_part = mb[1:] + # Add padding + padding = 4 - (len(b64_part) % 4) + if padding != 4: + b64_part += "=" * padding + raw = base64.urlsafe_b64decode(b64_part) + # Should be a valid multihash + assert raw[0] == 0x12 + assert raw[1] == 32 + assert raw[2:] == cert.fingerprint + + +class TestFingerprintRoundTrip: + """Test multibase encode/decode round-trip.""" + + def test_round_trip(self): + cert = WebRTCCertificate.generate() + encoded = cert.fingerprint_to_multibase() + decoded = fingerprint_from_multibase(encoded) + assert decoded == cert.fingerprint + + def test_round_trip_multiple_certs(self): + for _ in range(5): + cert = WebRTCCertificate.generate() + encoded = cert.fingerprint_to_multibase() + decoded = fingerprint_from_multibase(encoded) + assert decoded == cert.fingerprint + + +class TestFingerprintFromMultibase: + """Test decoding multibase-encoded fingerprints.""" + + def test_invalid_prefix(self): + with pytest.raises( + WebRTCCertificateError, match="Unsupported multibase prefix" + ): + fingerprint_from_multibase("zInvalidBase58") + + def test_invalid_base64(self): + with pytest.raises(WebRTCCertificateError): + fingerprint_from_multibase("u!!!invalid!!!") + + def test_too_short(self): + # Just the prefix + 1 byte (too short for multihash) + encoded = "u" + base64.urlsafe_b64encode(b"\x12").rstrip(b"=").decode() + with pytest.raises(WebRTCCertificateError, match="too short"): + fingerprint_from_multibase(encoded) + + def test_wrong_hash_function(self): + # Use code 0x11 (SHA-1) instead of 0x12 (SHA-256) + fake_mh = bytes([0x11, 32]) + b"\x00" * 32 + encoded = "u" + base64.urlsafe_b64encode(fake_mh).rstrip(b"=").decode() + with pytest.raises(WebRTCCertificateError, match="Unsupported multihash"): + fingerprint_from_multibase(encoded) + + def test_wrong_digest_length(self): + fake_mh = bytes([0x12, 16]) + b"\x00" * 16 # Wrong length + encoded = "u" + base64.urlsafe_b64encode(fake_mh).rstrip(b"=").decode() + with pytest.raises(WebRTCCertificateError, match="digest length"): + fingerprint_from_multibase(encoded) + + def test_truncated_digest(self): + fake_mh = bytes([0x12, 32]) + b"\x00" * 20 # Only 20 bytes of 32 + encoded = "u" + base64.urlsafe_b64encode(fake_mh).rstrip(b"=").decode() + with pytest.raises(WebRTCCertificateError, match="truncated"): + fingerprint_from_multibase(encoded) + + +class TestDERPEM: + """Test DER/PEM accessors.""" + + def test_certificate_der(self): + cert = WebRTCCertificate.generate() + der = cert.certificate_der() + assert isinstance(der, bytes) + assert len(der) > 100 # Reasonable minimum + + def test_private_key_der(self): + cert = WebRTCCertificate.generate() + key_der = cert.private_key_der() + assert isinstance(key_der, bytes) + assert len(key_der) > 50 diff --git a/tests/core/transport/webrtc/test_connection.py b/tests/core/transport/webrtc/test_connection.py new file mode 100644 index 000000000..fd5d08639 --- /dev/null +++ b/tests/core/transport/webrtc/test_connection.py @@ -0,0 +1,163 @@ +""" +Tests for WebRTCConnection stream management and lifecycle. +""" +# pyrefly: ignore + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +import trio + +from libp2p.connection_types import ConnectionType +from libp2p.peer.id import ID +from libp2p.transport.webrtc.config import WebRTCTransportConfig +from libp2p.transport.webrtc.connection import WebRTCConnection +from libp2p.transport.webrtc.constants import OUTBOUND_STREAM_START_ID +from libp2p.transport.webrtc.exceptions import WebRTCStreamError + + +def _make_connection( + max_streams: int = 256, +) -> WebRTCConnection: + """Create a connection with a mock bridge.""" + mock_bridge = MagicMock() + mock_bridge.run_coro = AsyncMock() + mock_bridge.is_running = True + + config = WebRTCTransportConfig(max_concurrent_streams=max_streams) + peer_id = ID(b"\x00" * 32) + conn = WebRTCConnection( + peer_id=peer_id, + bridge=mock_bridge, + is_initiator=True, + config=config, + ) + # Set up mock callbacks + conn._create_channel_cb = AsyncMock() + conn._send_on_channel_cb = AsyncMock() + conn._close_pc_cb = AsyncMock() + return conn + + +class TestConnectionProperties: + def test_is_initiator(self): + conn = _make_connection() + assert conn.is_initiator is True + + def test_connection_type_is_direct(self): + conn = _make_connection() + assert conn.get_connection_type() == ConnectionType.DIRECT + + def test_not_established_initially(self): + conn = _make_connection() + assert conn.is_established is False + assert conn.is_closed is False + + @pytest.mark.trio + async def test_start_marks_established(self): + conn = _make_connection() + await conn.start() + assert conn.is_established is True + assert conn.event_started.is_set() + + +class TestOpenStream: + @pytest.mark.trio + async def test_open_stream_returns_stream(self): + conn = _make_connection() + stream = await conn.open_stream() + assert stream is not None + assert stream.channel_id == OUTBOUND_STREAM_START_ID + + @pytest.mark.trio + async def test_open_stream_increments_ids_by_2(self): + conn = _make_connection() + s1 = await conn.open_stream() + s2 = await conn.open_stream() + s3 = await conn.open_stream() + assert s1.channel_id == 2 + assert s2.channel_id == 4 + assert s3.channel_id == 6 + + @pytest.mark.trio + async def test_open_stream_on_closed_connection_raises(self): + conn = _make_connection() + conn._closed = True + with pytest.raises(WebRTCStreamError, match="closed"): + await conn.open_stream() + + @pytest.mark.trio + async def test_open_stream_at_limit_raises(self): + conn = _make_connection(max_streams=2) + await conn.open_stream() + await conn.open_stream() + with pytest.raises(WebRTCStreamError, match="limit"): + await conn.open_stream() + + +class TestAcceptStream: + @pytest.mark.trio + async def test_accept_receives_inbound_stream(self): + conn = _make_connection() + + async def _accept() -> None: + stream = await conn.accept_stream() + assert stream.channel_id == 1 + + async with trio.open_nursery() as nursery: + nursery.start_soon(_accept) + await trio.sleep(0.01) + conn.on_datachannel(1) + + @pytest.mark.trio + async def test_accept_on_closed_raises(self): + conn = _make_connection() + conn._closed = True + with pytest.raises(WebRTCStreamError, match="closed"): + await conn.accept_stream() + + +class TestMessageRouting: + @pytest.mark.trio + async def test_on_channel_message_routes_to_stream(self): + conn = _make_connection() + stream = await conn.open_stream() + from libp2p.transport.webrtc.pb.webrtc_pb2 import Message + + msg = Message(message=b"test-data") + conn.on_channel_message(stream.channel_id, msg.SerializeToString()) + data = await stream.read() + assert data == b"test-data" + + def test_on_channel_message_unknown_id_ignored(self): + conn = _make_connection() + # Should not raise + conn.on_channel_message(999, b"ignored") + + +class TestClose: + @pytest.mark.trio + async def test_close_resets_all_streams(self): + conn = _make_connection() + # Open two streams to ensure close() iterates and resets them all. + await conn.open_stream() + await conn.open_stream() + await conn.close() + assert conn.is_closed is True + assert not conn.is_established + + @pytest.mark.trio + async def test_close_is_idempotent(self): + conn = _make_connection() + await conn.close() + await conn.close() # Should not raise + + @pytest.mark.trio + async def test_raw_read_write_raises(self): + conn = _make_connection() + with pytest.raises(Exception, match="native multiplexing"): + await conn.read() + with pytest.raises(Exception, match="native multiplexing"): + await conn.write(b"data") diff --git a/tests/core/transport/webrtc/test_multiaddr_utils.py b/tests/core/transport/webrtc/test_multiaddr_utils.py new file mode 100644 index 000000000..c470a86c1 --- /dev/null +++ b/tests/core/transport/webrtc/test_multiaddr_utils.py @@ -0,0 +1,143 @@ +""" +Tests for WebRTC multiaddr parsing and construction. +""" + +import pytest +from multiaddr import Multiaddr + +from libp2p.transport.webrtc.certificate import WebRTCCertificate +from libp2p.transport.webrtc.exceptions import WebRTCMultiaddrError +from libp2p.transport.webrtc.multiaddr_utils import ( + build_webrtc_direct_multiaddr, + is_webrtc_direct_multiaddr, + is_webrtc_multiaddr, + parse_webrtc_direct_multiaddr, +) + + +class TestIsWebrtcDirectMultiaddr: + """Test WebRTC Direct multiaddr detection.""" + + def test_valid_ipv4_webrtc_direct(self): + maddr = Multiaddr("/ip4/127.0.0.1/udp/9090/webrtc-direct") + assert is_webrtc_direct_multiaddr(maddr) + + def test_valid_ipv4_with_certhash(self): + maddr = Multiaddr( + "/ip4/192.168.1.1/udp/4001/webrtc-direct/certhash/uEiBkEKoo3S" + ) + assert is_webrtc_direct_multiaddr(maddr) + + def test_valid_ipv6_webrtc_direct(self): + maddr = Multiaddr("/ip6/::1/udp/9090/webrtc-direct") + assert is_webrtc_direct_multiaddr(maddr) + + def test_tcp_not_webrtc_direct(self): + maddr = Multiaddr("/ip4/127.0.0.1/tcp/9090") + assert not is_webrtc_direct_multiaddr(maddr) + + def test_quic_not_webrtc_direct(self): + maddr = Multiaddr("/ip4/127.0.0.1/udp/9090/quic-v1") + assert not is_webrtc_direct_multiaddr(maddr) + + def test_missing_udp_not_valid(self): + # webrtc-direct without udp is not valid + maddr = Multiaddr("/ip4/127.0.0.1/tcp/9090/webrtc-direct") + assert not is_webrtc_direct_multiaddr(maddr) + + +class TestIsWebrtcMultiaddr: + """Test relay-based WebRTC multiaddr detection.""" + + def test_valid_relay_webrtc(self): + # Use a valid base58 peer ID (Ed25519 key hash) + relay_peer_id = "12D3KooWDpJ7As7BWAwRMfu1VU2WCqNjvq387JEYKDBj4kx6nXTN" + maddr = Multiaddr( + f"/ip4/1.2.3.4/udp/4001/quic-v1/p2p/{relay_peer_id}/p2p-circuit/webrtc" + ) + assert is_webrtc_multiaddr(maddr) + + def test_webrtc_direct_is_not_relay_webrtc(self): + maddr = Multiaddr("/ip4/127.0.0.1/udp/9090/webrtc-direct") + assert not is_webrtc_multiaddr(maddr) + + def test_plain_tcp_not_webrtc(self): + maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + assert not is_webrtc_multiaddr(maddr) + + +class TestBuildWebrtcDirectMultiaddr: + """Test multiaddr construction.""" + + def test_build_ipv4(self): + cert = WebRTCCertificate.generate() + certhash = cert.fingerprint_to_multibase() + maddr = build_webrtc_direct_multiaddr("127.0.0.1", 9090, certhash) + assert is_webrtc_direct_multiaddr(maddr) + maddr_str = str(maddr) + assert "/ip4/127.0.0.1/" in maddr_str + assert "/udp/9090/" in maddr_str + assert "/webrtc-direct/" in maddr_str + assert f"/certhash/{certhash}" in maddr_str + + def test_build_ipv6(self): + cert = WebRTCCertificate.generate() + certhash = cert.fingerprint_to_multibase() + maddr = build_webrtc_direct_multiaddr("::1", 9090, certhash) + assert is_webrtc_direct_multiaddr(maddr) + assert "/ip6/" in str(maddr) + + def test_build_with_peer_id(self): + cert = WebRTCCertificate.generate() + certhash = cert.fingerprint_to_multibase() + peer_id = "12D3KooWDpJ7As7BWAwRMfu1VU2WCqNjvq387JEYKDBj4kx6nXTN" + maddr = build_webrtc_direct_multiaddr( + "127.0.0.1", 9090, certhash, peer_id=peer_id + ) + assert f"/p2p/{peer_id}" in str(maddr) + + +class TestParseWebrtcDirectMultiaddr: + """Test multiaddr parsing.""" + + def test_parse_basic(self): + cert = WebRTCCertificate.generate() + certhash = cert.fingerprint_to_multibase() + maddr = build_webrtc_direct_multiaddr("127.0.0.1", 9090, certhash) + host, port, parsed_certhash, peer_id = parse_webrtc_direct_multiaddr(maddr) + assert host == "127.0.0.1" + assert port == 9090 + assert parsed_certhash == certhash + assert peer_id is None + + def test_parse_with_peer_id(self): + cert = WebRTCCertificate.generate() + certhash = cert.fingerprint_to_multibase() + expected_peer = "12D3KooWJdGFj8RkDMPSLFsgAbHfcLTwSm3GVnSCbGTAoMnGcEms" + maddr = build_webrtc_direct_multiaddr( + "192.168.1.1", 4001, certhash, peer_id=expected_peer + ) + host, port, parsed_certhash, peer_id = parse_webrtc_direct_multiaddr(maddr) + assert host == "192.168.1.1" + assert port == 4001 + assert parsed_certhash == certhash + assert peer_id == expected_peer + + def test_parse_invalid_multiaddr(self): + maddr = Multiaddr("/ip4/127.0.0.1/tcp/9090") + with pytest.raises(WebRTCMultiaddrError, match="Not a valid"): + parse_webrtc_direct_multiaddr(maddr) + + def test_roundtrip_build_parse(self): + """Build a multiaddr and parse it back — values should survive.""" + cert = WebRTCCertificate.generate() + certhash = cert.fingerprint_to_multibase() + expected_peer = "12D3KooWRBy97UB99e3J6hiPesre1MZeuNQvfan7ATZ8HbRL9vbs" + original_maddr = build_webrtc_direct_multiaddr( + "10.0.0.1", 5555, certhash, peer_id=expected_peer + ) + host, port, parsed_hash, peer_id = parse_webrtc_direct_multiaddr(original_maddr) + assert host == "10.0.0.1" + assert port == 5555 + assert parsed_hash == certhash + assert peer_id == "12D3KooWRBy97UB99e3J6hiPesre1MZeuNQvfan7ATZ8HbRL9vbs" diff --git a/tests/core/transport/webrtc/test_noise_handshake.py b/tests/core/transport/webrtc/test_noise_handshake.py new file mode 100644 index 000000000..f86c05da9 --- /dev/null +++ b/tests/core/transport/webrtc/test_noise_handshake.py @@ -0,0 +1,113 @@ +""" +Tests for Noise prologue construction and DataChannelReadWriter. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from libp2p.transport.webrtc.constants import NOISE_PROLOGUE_PREFIX +from libp2p.transport.webrtc.noise_handshake import ( + DataChannelReadWriter, + build_noise_prologue, +) + + +class TestBuildNoisePrologue: + def test_prologue_starts_with_prefix(self): + local_fp = b"\x01" * 32 + remote_fp = b"\x02" * 32 + prologue = build_noise_prologue(local_fp, remote_fp) + assert prologue.startswith(NOISE_PROLOGUE_PREFIX) + + def test_prologue_contains_multihash_encoded_fingerprints(self): + local_fp = b"\xaa" * 32 + remote_fp = b"\xbb" * 32 + prologue = build_noise_prologue(local_fp, remote_fp) + # After prefix: local_mh (34 bytes) + remote_mh (34 bytes) + after_prefix = prologue[len(NOISE_PROLOGUE_PREFIX) :] + assert len(after_prefix) == 68 # 34 + 34 + # Verify multihash headers + assert after_prefix[0] == 0x12 # SHA-256 code + assert after_prefix[1] == 32 # digest length + assert after_prefix[2:34] == local_fp + assert after_prefix[34] == 0x12 + assert after_prefix[35] == 32 + assert after_prefix[36:68] == remote_fp + + def test_prologue_total_length(self): + local_fp = b"\x00" * 32 + remote_fp = b"\xff" * 32 + prologue = build_noise_prologue(local_fp, remote_fp) + # prefix (20) + local_mh (34) + remote_mh (34) = 88 + expected = len(NOISE_PROLOGUE_PREFIX) + 34 + 34 + assert len(prologue) == expected + + def test_prologue_is_asymmetric(self): + """Swapping local/remote produces different prologues.""" + fp_a = b"\x01" * 32 + fp_b = b"\x02" * 32 + p1 = build_noise_prologue(fp_a, fp_b) + p2 = build_noise_prologue(fp_b, fp_a) + assert p1 != p2 + + def test_prologue_with_real_fingerprints(self): + from libp2p.transport.webrtc.certificate import WebRTCCertificate + + cert_a = WebRTCCertificate.generate() + cert_b = WebRTCCertificate.generate() + prologue = build_noise_prologue(cert_a.fingerprint, cert_b.fingerprint) + assert len(prologue) == len(NOISE_PROLOGUE_PREFIX) + 68 + + +class TestDataChannelReadWriter: + @pytest.mark.trio + async def test_write_calls_send_cb(self): + send_cb = AsyncMock() + recv_cb = AsyncMock(return_value=b"response") + rw = DataChannelReadWriter( + send_cb=send_cb, + recv_cb=recv_cb, + is_initiator=True, + ) + await rw.write(b"hello") + send_cb.assert_called_once_with(b"hello") + + @pytest.mark.trio + async def test_read_calls_recv_cb(self): + send_cb = AsyncMock() + recv_cb = AsyncMock(return_value=b"data-from-peer") + rw = DataChannelReadWriter( + send_cb=send_cb, + recv_cb=recv_cb, + is_initiator=False, + ) + data = await rw.read() + assert data == b"data-from-peer" + + @pytest.mark.trio + async def test_close_is_noop(self): + rw = DataChannelReadWriter( + send_cb=AsyncMock(), + recv_cb=AsyncMock(), + is_initiator=True, + ) + await rw.close() # Should not raise + + def test_is_initiator_property(self): + rw = DataChannelReadWriter( + send_cb=AsyncMock(), + recv_cb=AsyncMock(), + is_initiator=True, + ) + assert rw.is_initiator is True + + def test_transport_addresses_empty(self): + rw = DataChannelReadWriter( + send_cb=AsyncMock(), + recv_cb=AsyncMock(), + is_initiator=True, + ) + assert rw.get_transport_addresses() == [] diff --git a/tests/core/transport/webrtc/test_protobuf.py b/tests/core/transport/webrtc/test_protobuf.py new file mode 100644 index 000000000..3739aa33e --- /dev/null +++ b/tests/core/transport/webrtc/test_protobuf.py @@ -0,0 +1,98 @@ +""" +Tests for WebRTC protobuf message framing. +""" + +from libp2p.transport.webrtc.pb.webrtc_pb2 import Message + + +class TestMessageSerialization: + """Test protobuf Message round-trip.""" + + def test_data_only_message(self): + msg = Message(message=b"hello webrtc") + data = msg.SerializeToString() + parsed = Message() + parsed.ParseFromString(data) + assert parsed.message == b"hello webrtc" + assert not parsed.HasField("flag") + + def test_fin_flag(self): + msg = Message(flag=Message.FIN) + data = msg.SerializeToString() + parsed = Message() + parsed.ParseFromString(data) + assert parsed.flag == Message.FIN + assert not parsed.HasField("message") + + def test_fin_flag_is_present_when_set(self): + """FIN==0 must be present on wire because the field is ``optional``.""" + msg = Message(flag=Message.FIN) + data = msg.SerializeToString() + assert len(data) > 0, "FIN flag (value=0) must be present on wire" + parsed = Message() + parsed.ParseFromString(data) + assert parsed.HasField("flag") + assert parsed.flag == Message.FIN + + def test_stop_sending_flag(self): + msg = Message(flag=Message.STOP_SENDING) + data = msg.SerializeToString() + parsed = Message() + parsed.ParseFromString(data) + assert parsed.flag == Message.STOP_SENDING + + def test_reset_flag(self): + msg = Message(flag=Message.RESET) + data = msg.SerializeToString() + parsed = Message() + parsed.ParseFromString(data) + assert parsed.flag == Message.RESET + + def test_fin_ack_flag(self): + msg = Message(flag=Message.FIN_ACK) + data = msg.SerializeToString() + parsed = Message() + parsed.ParseFromString(data) + assert parsed.flag == Message.FIN_ACK + + def test_flag_with_data(self): + msg = Message(flag=Message.FIN, message=b"final chunk") + data = msg.SerializeToString() + parsed = Message() + parsed.ParseFromString(data) + assert parsed.flag == Message.FIN + assert parsed.message == b"final chunk" + + def test_empty_message(self): + msg = Message() + data = msg.SerializeToString() + assert len(data) == 0 # proto3 empty message serializes to zero bytes + parsed = Message() + parsed.ParseFromString(data) + assert not parsed.HasField("flag") + assert not parsed.HasField("message") + + def test_large_payload(self): + payload = b"\x42" * 16384 # 16 KiB max message size + msg = Message(message=payload) + data = msg.SerializeToString() + parsed = Message() + parsed.ParseFromString(data) + assert parsed.message == payload + assert len(parsed.message) == 16384 + + def test_flag_enum_values(self): + """Verify flag enum values match the spec.""" + assert Message.FIN == 0 + assert Message.STOP_SENDING == 1 + assert Message.RESET == 2 + assert Message.FIN_ACK == 3 + + def test_binary_payload(self): + """Verify arbitrary binary data survives round-trip.""" + payload = bytes(range(256)) + msg = Message(message=payload) + data = msg.SerializeToString() + parsed = Message() + parsed.ParseFromString(data) + assert parsed.message == payload diff --git a/tests/core/transport/webrtc/test_sdp.py b/tests/core/transport/webrtc/test_sdp.py new file mode 100644 index 000000000..9622c4a92 --- /dev/null +++ b/tests/core/transport/webrtc/test_sdp.py @@ -0,0 +1,104 @@ +""" +Tests for SDP construction and parsing. +""" +# pyrefly: ignore + +from __future__ import annotations + +import pytest + +from libp2p.transport.webrtc.certificate import WebRTCCertificate +from libp2p.transport.webrtc.exceptions import WebRTCConnectionError +from libp2p.transport.webrtc.sdp import SDPBuilder, fingerprint_from_sdp + + +class TestSDPBuilder: + def setup_method(self): + self.cert = WebRTCCertificate.generate() + self.builder = SDPBuilder(certificate=self.cert) + + def test_build_offer_returns_sdp_and_credentials(self): + sdp, ufrag, pwd = self.builder.build_offer(host="127.0.0.1", port=9090) + assert isinstance(sdp, str) + assert len(ufrag) > 0 + assert len(pwd) > 0 + + def test_offer_contains_required_sdp_lines(self): + sdp, _, _ = self.builder.build_offer(host="127.0.0.1", port=9090) + assert "v=0" in sdp + assert "m=application" in sdp + assert "a=ice-ufrag:" in sdp + assert "a=ice-pwd:" in sdp + assert "a=fingerprint:sha-256" in sdp + assert "a=setup:actpass" in sdp + assert "a=sctp-port:5000" in sdp + assert "webrtc-datachannel" in sdp + + def test_offer_contains_candidate(self): + sdp, _, _ = self.builder.build_offer(host="192.168.1.1", port=4001) + assert "a=candidate:" in sdp + assert "192.168.1.1" in sdp + assert "4001" in sdp + + def test_offer_ipv4(self): + sdp, _, _ = self.builder.build_offer(host="10.0.0.1", port=5000) + assert "IN IP4 10.0.0.1" in sdp + + def test_offer_ipv6(self): + sdp, _, _ = self.builder.build_offer(host="::1", port=5000) + assert "IN IP6 ::1" in sdp + + def test_answer_has_setup_active(self): + sdp, _, _ = self.builder.build_answer( + host="127.0.0.1", + port=9090, + remote_ufrag="abc", + remote_pwd="xyz", + ) + assert "a=setup:active" in sdp + + def test_offer_fingerprint_matches_certificate(self): + sdp, _, _ = self.builder.build_offer(host="127.0.0.1", port=9090) + assert self.cert.fingerprint_hex in sdp + + def test_unique_credentials_per_call(self): + _, ufrag1, pwd1 = self.builder.build_offer(host="127.0.0.1", port=9090) + _, ufrag2, pwd2 = self.builder.build_offer(host="127.0.0.1", port=9090) + # Credentials should be random each time + assert ufrag1 != ufrag2 or pwd1 != pwd2 + + +class TestFingerprintFromSDP: + def test_extract_fingerprint(self): + cert = WebRTCCertificate.generate() + builder = SDPBuilder(certificate=cert) + sdp, _, _ = builder.build_offer(host="127.0.0.1", port=9090) + extracted = fingerprint_from_sdp(sdp) + assert extracted == cert.fingerprint + + def test_no_fingerprint_raises(self): + sdp = "v=0\no=- 123 1 IN IP4 127.0.0.1\ns=-\n" + with pytest.raises(WebRTCConnectionError, match="No SHA-256 fingerprint"): + fingerprint_from_sdp(sdp) + + def test_malformed_fingerprint_raises(self): + sdp = "a=fingerprint:sha-256 not:valid:hex:ZZ\n" + with pytest.raises(WebRTCConnectionError, match="Malformed"): + fingerprint_from_sdp(sdp) + + +class TestSDPFromMultiaddr: + def test_build_from_multiaddr_components(self): + cert = WebRTCCertificate.generate() + certhash = cert.fingerprint_to_multibase() + sdp = SDPBuilder.build_sdp_from_multiaddr( + host="1.2.3.4", + port=4001, + certhash_multibase=certhash, + ) + assert "1.2.3.4" in sdp + assert "4001" in sdp + assert "a=fingerprint:sha-256" in sdp + # Fingerprint in the SDP should match the cert + extracted = fingerprint_from_sdp(sdp) + assert extracted == cert.fingerprint diff --git a/tests/core/transport/webrtc/test_signaling.py b/tests/core/transport/webrtc/test_signaling.py new file mode 100644 index 000000000..2f7072004 --- /dev/null +++ b/tests/core/transport/webrtc/test_signaling.py @@ -0,0 +1,260 @@ +""" +Tests for WebRTC signaling protocol. + +Note: MockStream is a simplified test double that implements just +read/write/close. pyrefly flags it as not assignable to INetStream +because it doesn't satisfy the full ABC; this is expected for unit +tests where we only exercise the signaling wire format. +""" +# pyrefly: ignore + +from __future__ import annotations + +import pytest +import trio + +from libp2p.transport.webrtc.signaling import ( + SignalingSession, + _encode_uvarint, + _read_uvarint, + read_signaling_message, + write_signaling_message, +) +from libp2p.transport.webrtc.signaling_pb.signaling_pb2 import SignalingMessage + + +class MockStream: + """Mock INetStream for signaling tests using an in-memory buffer.""" + + def __init__(self) -> None: + self._send: trio.MemorySendChannel[bytes] + self._recv: trio.MemoryReceiveChannel[bytes] + self._send, self._recv = trio.open_memory_channel[bytes](256) + self._buf = bytearray() + + async def write(self, data: bytes) -> None: + await self._send.send(data) + + async def read(self, n: int | None = None) -> bytes: + # Fill buffer from channel if needed + while len(self._buf) < (n or 1): + try: + chunk = self._recv.receive_nowait() + self._buf.extend(chunk) + except trio.WouldBlock: + chunk = await self._recv.receive() + self._buf.extend(chunk) + if n is None: + data = bytes(self._buf) + self._buf.clear() + return data + data = bytes(self._buf[:n]) + del self._buf[:n] + return data + + async def close(self) -> None: + pass + + +def _make_stream_pair() -> tuple[MockStream, MockStream]: + """Create two streams connected to each other for testing.""" + # We use a single MockStream and feed it from both sides + # For simplicity, return two independent streams + return MockStream(), MockStream() + + +class TestVarintEncoding: + def test_encode_small(self): + assert _encode_uvarint(0) == b"\x00" + assert _encode_uvarint(1) == b"\x01" + assert _encode_uvarint(127) == b"\x7f" + + def test_encode_two_bytes(self): + assert _encode_uvarint(128) == b"\x80\x01" + assert _encode_uvarint(300) == b"\xac\x02" + + def test_encode_large(self): + encoded = _encode_uvarint(65536) + assert len(encoded) == 3 + + @pytest.mark.trio + async def test_roundtrip(self): + stream = MockStream() + for value in [0, 1, 127, 128, 255, 256, 300, 65535, 100000]: + varint_bytes = _encode_uvarint(value) + await stream.write(varint_bytes) + decoded = await _read_uvarint(stream) + assert decoded == value, f"Failed for value {value}" + + +class TestSignalingMessage: + @pytest.mark.trio + async def test_write_and_read_offer(self): + stream = MockStream() + msg = SignalingMessage( + type=SignalingMessage.SDP_OFFER, + data=b"v=0\r\no=- 123 ...", + ) + await write_signaling_message(stream, msg) + received = await read_signaling_message(stream) + assert received.type == SignalingMessage.SDP_OFFER + assert received.data == b"v=0\r\no=- 123 ..." + + @pytest.mark.trio + async def test_write_and_read_answer(self): + stream = MockStream() + msg = SignalingMessage( + type=SignalingMessage.SDP_ANSWER, + data=b"answer-sdp", + ) + await write_signaling_message(stream, msg) + received = await read_signaling_message(stream) + assert received.type == SignalingMessage.SDP_ANSWER + assert received.data == b"answer-sdp" + + @pytest.mark.trio + async def test_write_and_read_ice_candidate(self): + stream = MockStream() + msg = SignalingMessage( + type=SignalingMessage.ICE_CANDIDATE, + data=b"candidate:1 1 UDP 2130706431 192.168.1.1 9090 typ host", + ) + await write_signaling_message(stream, msg) + received = await read_signaling_message(stream) + assert received.type == SignalingMessage.ICE_CANDIDATE + + @pytest.mark.trio + async def test_write_and_read_ice_done(self): + stream = MockStream() + msg = SignalingMessage(type=SignalingMessage.ICE_DONE) + await write_signaling_message(stream, msg) + received = await read_signaling_message(stream) + assert received.type == SignalingMessage.ICE_DONE + + @pytest.mark.trio + async def test_multiple_messages_in_sequence(self): + stream = MockStream() + messages = [ + SignalingMessage(type=SignalingMessage.SDP_OFFER, data=b"offer"), + SignalingMessage(type=SignalingMessage.SDP_ANSWER, data=b"answer"), + SignalingMessage(type=SignalingMessage.ICE_CANDIDATE, data=b"cand1"), + SignalingMessage(type=SignalingMessage.ICE_CANDIDATE, data=b"cand2"), + SignalingMessage(type=SignalingMessage.ICE_DONE), + ] + for msg in messages: + await write_signaling_message(stream, msg) + for expected in messages: + received = await read_signaling_message(stream) + assert received.type == expected.type + assert received.data == expected.data + + +class TestSignalingSession: + @pytest.mark.trio + async def test_offer_answer_exchange(self): + """Test the initiator sending offer and receiving answer.""" + stream = MockStream() + session = SignalingSession(stream) + + # Simulate: send offer, then feed back an answer + await session.send_offer(b"test-offer-sdp") + + # Read what was sent and verify + sent = await read_signaling_message(stream) + assert sent.type == SignalingMessage.SDP_OFFER + assert sent.data == b"test-offer-sdp" + + @pytest.mark.trio + async def test_receive_offer(self): + stream = MockStream() + session = SignalingSession(stream) + + # Pre-write an offer + offer_msg = SignalingMessage( + type=SignalingMessage.SDP_OFFER, data=b"remote-offer" + ) + await write_signaling_message(stream, offer_msg) + + received = await session.receive_offer() + assert received == b"remote-offer" + + @pytest.mark.trio + async def test_send_candidates_and_ice_done(self): + stream = MockStream() + session = SignalingSession(stream) + + candidates = [b"candidate-1", b"candidate-2", b"candidate-3"] + await session.send_candidates(candidates) + + assert session._ice_done_sent is True + + # Read back: 3 candidates + 1 ICE_DONE + for i, cand in enumerate(candidates): + msg = await read_signaling_message(stream) + assert msg.type == SignalingMessage.ICE_CANDIDATE + assert msg.data == cand + + done_msg = await read_signaling_message(stream) + assert done_msg.type == SignalingMessage.ICE_DONE + + @pytest.mark.trio + async def test_receive_candidates_until_ice_done(self): + stream = MockStream() + session = SignalingSession(stream) + + # Pre-write candidates + ICE_DONE + for cand in [b"c1", b"c2"]: + await write_signaling_message( + stream, + SignalingMessage(type=SignalingMessage.ICE_CANDIDATE, data=cand), + ) + await write_signaling_message( + stream, + SignalingMessage(type=SignalingMessage.ICE_DONE), + ) + + received = [] + async for candidate in session.receive_candidates(): + received.append(candidate) + + assert received == [b"c1", b"c2"] + assert session._ice_done_received is True + + @pytest.mark.trio + async def test_complete_sends_and_waits_for_ice_done(self): + stream = MockStream() + session = SignalingSession(stream) + + # Simulate that we haven't sent ICE_DONE yet + assert not session._ice_done_sent + + # Pre-write the remote ICE_DONE so complete() can receive it + await write_signaling_message( + stream, + SignalingMessage(type=SignalingMessage.ICE_DONE), + ) + + await session.complete() + + assert session._ice_done_sent is True + assert session._ice_done_received is True + + # Verify we sent our ICE_DONE + sent = await read_signaling_message(stream) + assert sent.type == SignalingMessage.ICE_DONE + + +class TestProtobufEnumValues: + """Verify signaling message type enum values match the spec.""" + + def test_sdp_offer_is_0(self): + assert SignalingMessage.SDP_OFFER == 0 + + def test_sdp_answer_is_1(self): + assert SignalingMessage.SDP_ANSWER == 1 + + def test_ice_candidate_is_2(self): + assert SignalingMessage.ICE_CANDIDATE == 2 + + def test_ice_done_is_3(self): + assert SignalingMessage.ICE_DONE == 3 diff --git a/tests/core/transport/webrtc/test_stream.py b/tests/core/transport/webrtc/test_stream.py new file mode 100644 index 000000000..2e2593621 --- /dev/null +++ b/tests/core/transport/webrtc/test_stream.py @@ -0,0 +1,192 @@ +""" +Tests for WebRTCStream protobuf framing and lifecycle. +""" +# pyrefly: ignore + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from libp2p.transport.webrtc.pb.webrtc_pb2 import Message +from libp2p.transport.webrtc.stream import StreamState, WebRTCStream + + +def _make_stream(channel_id: int = 2) -> WebRTCStream: + """Create a stream with a mock connection and send callback.""" + mock_conn = MagicMock() + mock_conn.peer_id = MagicMock() + stream = WebRTCStream( + connection=mock_conn, + channel_id=channel_id, + is_initiator=True, + ) + stream._send_callback = AsyncMock() + return stream + + +class TestWrite: + @pytest.mark.trio + async def test_write_sends_protobuf_framed_data(self): + stream = _make_stream() + await stream.write(b"hello") + stream._send_callback.assert_called_once() + raw = stream._send_callback.call_args[0][0] + msg = Message() + msg.ParseFromString(raw) + assert msg.message == b"hello" + assert not msg.HasField("flag") + + @pytest.mark.trio + async def test_write_chunks_large_data(self): + stream = _make_stream() + big = b"x" * 32768 # 2x max message size + await stream.write(big) + assert stream._send_callback.call_count == 2 + # First chunk is MAX_MESSAGE_SIZE + raw1 = stream._send_callback.call_args_list[0][0][0] + msg1 = Message() + msg1.ParseFromString(raw1) + assert len(msg1.message) == 16384 + + @pytest.mark.trio + async def test_write_after_close_raises(self): + stream = _make_stream() + stream._write_closed = True + with pytest.raises(Exception, match="closed"): + await stream.write(b"data") + + @pytest.mark.trio + async def test_write_after_reset_raises(self): + stream = _make_stream() + stream._state = StreamState.RESET + with pytest.raises(Exception, match="reset"): + await stream.write(b"data") + + +class TestRead: + @pytest.mark.trio + async def test_read_returns_data_from_on_data(self): + stream = _make_stream() + # Simulate incoming data + msg = Message(message=b"world") + stream.on_data(msg.SerializeToString()) + data = await stream.read() + assert data == b"world" + + @pytest.mark.trio + async def test_read_after_reset_raises(self): + stream = _make_stream() + stream._state = StreamState.RESET + with pytest.raises(Exception, match="reset"): + await stream.read() + + @pytest.mark.trio + async def test_read_buffers_partial(self): + stream = _make_stream() + msg = Message(message=b"abcdefgh") + stream.on_data(msg.SerializeToString()) + # Read 3 bytes + data = await stream.read(3) + assert data == b"abc" + # Read remaining + data = await stream.read() + assert data == b"defgh" + + +class TestFlags: + @pytest.mark.trio + async def test_on_data_fin_closes_read_and_sends_fin_ack(self): + stream = _make_stream() + fin_msg = Message(flag=Message.FIN) + stream.on_data(fin_msg.SerializeToString()) + assert stream._read_closed is True + + @pytest.mark.trio + async def test_on_data_fin_ack_sets_event(self): + stream = _make_stream() + ack_msg = Message(flag=Message.FIN_ACK) + stream.on_data(ack_msg.SerializeToString()) + assert stream._fin_ack_received.is_set() + + @pytest.mark.trio + async def test_on_data_stop_sending_closes_write(self): + stream = _make_stream() + stop_msg = Message(flag=Message.STOP_SENDING) + stream.on_data(stop_msg.SerializeToString()) + assert stream._write_closed is True + + @pytest.mark.trio + async def test_on_data_reset_sets_state(self): + stream = _make_stream() + reset_msg = Message(flag=Message.RESET) + stream.on_data(reset_msg.SerializeToString()) + assert stream._state == StreamState.RESET + + @pytest.mark.trio + async def test_on_data_with_flag_and_payload(self): + stream = _make_stream() + msg = Message(flag=Message.FIN, message=b"last-chunk") + stream.on_data(msg.SerializeToString()) + # FIN should close reads but payload should be delivered + data = await stream.read() + assert data == b"last-chunk" + + +class TestClose: + @pytest.mark.trio + async def test_close_sends_fin(self): + stream = _make_stream() + # Pre-set FIN_ACK so close doesn't block + stream._fin_ack_received.set() + await stream.close() + assert stream._state == StreamState.CLOSED + # Should have sent FIN + calls = stream._send_callback.call_args_list + assert len(calls) >= 1 + msg = Message() + msg.ParseFromString(calls[0][0][0]) + assert msg.flag == Message.FIN + + @pytest.mark.trio + async def test_close_is_idempotent(self): + stream = _make_stream() + stream._fin_ack_received.set() + await stream.close() + await stream.close() # Should not raise + assert stream._state == StreamState.CLOSED + + @pytest.mark.trio + async def test_reset_sends_reset_flag(self): + stream = _make_stream() + await stream.reset() + assert stream._state == StreamState.RESET + calls = stream._send_callback.call_args_list + msg = Message() + msg.ParseFromString(calls[0][0][0]) + assert msg.flag == Message.RESET + + +class TestDeadline: + @pytest.mark.trio + async def test_set_deadline(self): + stream = _make_stream() + stream.set_deadline(10) + assert stream._deadline > 0 + + @pytest.mark.trio + async def test_clear_deadline(self): + stream = _make_stream() + stream.set_deadline(10) + stream.set_deadline(0) + assert stream._deadline == 0.0 + + +class TestChannelClose: + @pytest.mark.trio + async def test_on_channel_close(self): + stream = _make_stream() + stream.on_channel_close() + assert stream._read_closed is True + assert stream._write_closed is True diff --git a/tests/core/transport/webrtc/test_webrtc_direct_loopback.py b/tests/core/transport/webrtc/test_webrtc_direct_loopback.py new file mode 100644 index 000000000..49f99aab5 --- /dev/null +++ b/tests/core/transport/webrtc/test_webrtc_direct_loopback.py @@ -0,0 +1,89 @@ +""" +Integration test: WebRTC Direct loopback. + +Creates a listener and a dialer on localhost and verifies the full +connection lifecycle: HTTP SDP exchange → ICE → DTLS → Noise handshake +→ data-channel stream echo. + +Requires aiortc to be installed. Skipped automatically otherwise. +""" +# pyrefly: ignore + +from __future__ import annotations + +import pytest + +try: + import aiortc # noqa: F401 + + HAS_AIORTC = True +except ImportError: + HAS_AIORTC = False + +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.transport.webrtc.transport import WebRTCDirectTransport + +pytestmark = pytest.mark.skipif(not HAS_AIORTC, reason="aiortc not installed") + + +@pytest.mark.trio +async def test_listener_advertises_certhash_in_multiaddr(): + """Listener publishes a multiaddr that contains /certhash/ and /p2p/.""" + key_pair = create_new_key_pair() + transport = WebRTCDirectTransport(private_key=key_pair.private_key) + + async def noop_handler(conn: object) -> None: + pass + + listener = transport.create_listener(noop_handler) + maddr_str = "/ip4/127.0.0.1/udp/0/webrtc-direct" + + from multiaddr import Multiaddr + + await listener.listen(Multiaddr(maddr_str)) + + addrs = listener.get_addrs() + assert len(addrs) == 1 + addr_str = str(addrs[0]) + assert "/webrtc-direct/" in addr_str + assert "/certhash/" in addr_str + assert "/p2p/" in addr_str + + await listener.close() + await transport.close() + + +@pytest.mark.trio +async def test_listener_binds_actual_port(): + """When port 0 is requested, the listener binds to a real port > 0.""" + key_pair = create_new_key_pair() + transport = WebRTCDirectTransport(private_key=key_pair.private_key) + + async def noop_handler(conn: object) -> None: + pass + + listener = transport.create_listener(noop_handler) + from multiaddr import Multiaddr + + await listener.listen(Multiaddr("/ip4/127.0.0.1/udp/0/webrtc-direct")) + + addrs = listener.get_addrs() + addr_str = str(addrs[0]) + # Parse the port from the multiaddr + parts = addr_str.split("/") + udp_idx = parts.index("udp") + port = int(parts[udp_idx + 1]) + assert port > 0 + + await listener.close() + await transport.close() + + +@pytest.mark.trio +async def test_certificate_is_aiortc_native(): + """Transport certificate should have an aiortc RTCCertificate attached.""" + key_pair = create_new_key_pair() + transport = WebRTCDirectTransport(private_key=key_pair.private_key) + assert hasattr(transport.certificate, "_rtc_certificate") + assert transport.certificate._rtc_certificate is not None + await transport.close()