diff --git a/Cargo.lock b/Cargo.lock index 8a12fba05..f9460fd52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -340,11 +340,19 @@ dependencies = [ name = "chia-client" version = "0.10.0" dependencies = [ + "anyhow", "chia-protocol", + "chia-ssl", "chia-traits 0.10.0", + "dns-lookup", + "env_logger", "futures-util", + "hex", + "hex-literal", "log", "native-tls", + "rand", + "semver", "sha2", "thiserror", "tokio", @@ -899,6 +907,18 @@ dependencies = [ "syn 2.0.70", ] +[[package]] +name = "dns-lookup" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5766087c2235fec47fafa4cfecc81e494ee679d0fd4a59887ea0919bfb0e4fc" +dependencies = [ + "cfg-if", + "libc", + "socket2", + "windows-sys 0.48.0", +] + [[package]] name = "ecdsa" version = "0.16.9" @@ -939,6 +959,29 @@ dependencies = [ "zeroize", ] +[[package]] +name = "env_filter" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -1256,6 +1299,12 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "idna" version = "0.5.0" @@ -2206,9 +2255,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.22" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" [[package]] name = "serde" @@ -2263,6 +2312,15 @@ dependencies = [ "digest", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "signature" version = "2.2.0" @@ -2517,11 +2575,26 @@ dependencies = [ "bytes", "libc", "mio", + "num_cpus", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", + "tokio-macros", "windows-sys 0.48.0", ] +[[package]] +name = "tokio-macros" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.70", +] + [[package]] name = "tokio-native-tls" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index bdd25731a..82aa25ff5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ unused_imports = "warn" unused_import_braces = "deny" unreachable_code = "deny" unreachable_patterns = "deny" -dead_code = "deny" +dead_code = "warn" deprecated = "deny" deprecated_in_future = "deny" trivial_casts = "deny" @@ -137,3 +137,6 @@ libfuzzer-sys = "0.4" wasm-bindgen = "0.2.92" log = "0.4.22" native-tls = "0.2.12" +dns-lookup = "2.0.4" +semver = "1.0.23" +env_logger = "0.11.3" diff --git a/crates/chia-client/Cargo.toml b/crates/chia-client/Cargo.toml index 89f1cbd74..aed4bc764 100644 --- a/crates/chia-client/Cargo.toml +++ b/crates/chia-client/Cargo.toml @@ -22,3 +22,14 @@ thiserror = { workspace = true } sha2 = { workspace = true } log = { workspace = true } native-tls = { workspace = true } +dns-lookup = { workspace = true } +hex-literal = { workspace = true } +rand = { workspace = true } +semver = { workspace = true } +hex = { workspace = true } + +[dev-dependencies] +chia-ssl = { path = "../chia-ssl" } +tokio = { workspace = true, features = ["full"] } +anyhow = { workspace = true } +env_logger = { workspace = true } diff --git a/crates/chia-client/examples/client.rs b/crates/chia-client/examples/client.rs new file mode 100644 index 000000000..8b9e4d348 --- /dev/null +++ b/crates/chia-client/examples/client.rs @@ -0,0 +1,54 @@ +use std::time::Duration; + +use chia_client::{create_tls_connector, Client, ClientOptions, Event}; +use chia_ssl::ChiaCertificate; +use tokio::time::sleep; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + env_logger::init(); + + log::info!("Generating certificate"); + let cert = ChiaCertificate::generate()?; + let tls_connector = create_tls_connector(cert.cert_pem.as_bytes(), cert.key_pem.as_bytes())?; + + log::info!("Creating client"); + let (client, mut receiver) = Client::with_options( + tls_connector, + ClientOptions { + target_peers: 20, + ..Default::default() + }, + ); + + log::info!("Connecting to DNS introducers"); + client.find_peers().await; + + let client_clone = client.clone(); + + tokio::spawn(async move { + loop { + sleep(Duration::from_secs(10)).await; + let count = client_clone.peer_count().await; + log::info!("Currently connected to {} peers", count); + client.find_peers().await; + } + }); + + while let Some(event) = receiver.recv().await { + match event { + Event::Message(peer_id, message) => { + log::info!( + "Received message from peer {}: {:?}", + peer_id, + message.msg_type + ); + } + Event::ConnectionClosed(peer_id) => { + log::info!("Peer {} disconnected", peer_id); + } + } + } + + Ok(()) +} diff --git a/crates/chia-client/examples/peer_discovery.rs b/crates/chia-client/examples/peer_discovery.rs new file mode 100644 index 000000000..3791ef49c --- /dev/null +++ b/crates/chia-client/examples/peer_discovery.rs @@ -0,0 +1,68 @@ +use std::time::Duration; + +use chia_client::{create_tls_connector, Peer}; +use chia_protocol::{Handshake, NodeType, ProtocolMessageTypes}; +use chia_ssl::ChiaCertificate; +use chia_traits::Streamable; +use dns_lookup::lookup_host; +use tokio::time::timeout; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + env_logger::init(); + + let cert = ChiaCertificate::generate()?; + let tls = create_tls_connector(cert.cert_pem.as_bytes(), cert.key_pem.as_bytes())?; + + for ip in lookup_host("dns-introducer.chia.net")? { + let Ok(response) = + timeout(Duration::from_secs(3), Peer::connect(ip, 8444, tls.clone())).await + else { + log::info!("{ip} exceeded connection timeout of 3 seconds"); + continue; + }; + + let (peer, mut receiver) = response?; + + peer.send(Handshake { + network_id: "mainnet".to_string(), + protocol_version: "0.0.37".to_string(), + software_version: "0.0.0".to_string(), + server_port: 0, + node_type: NodeType::Wallet, + capabilities: vec![ + (1, "1".to_string()), + (2, "1".to_string()), + (3, "1".to_string()), + ], + }) + .await?; + + let Ok(message) = timeout(Duration::from_secs(1), receiver.recv()).await else { + log::info!("{ip} exceeded timeout of 1 second"); + continue; + }; + + let Some(message) = message else { + log::info!("{ip} did not send any messages"); + continue; + }; + + if message.msg_type != ProtocolMessageTypes::Handshake { + log::info!("{ip} sent an unexpected message {:?}", message.msg_type); + continue; + } + + let Ok(handshake) = Handshake::from_bytes(&message.data) else { + log::info!("{ip} sent an invalid handshake"); + continue; + }; + + log::info!( + "{ip} handshake sent with protocol version {}", + handshake.protocol_version + ); + } + + Ok(()) +} diff --git a/crates/chia-client/src/client.rs b/crates/chia-client/src/client.rs new file mode 100644 index 000000000..7644068e0 --- /dev/null +++ b/crates/chia-client/src/client.rs @@ -0,0 +1,346 @@ +use std::{ + collections::{HashMap, HashSet}, + net::IpAddr, + str::FromStr, + sync::Arc, + time::Duration, +}; + +use chia_protocol::{ + Handshake, Message, NodeType, ProtocolMessageTypes, RequestPeers, RespondPeers, +}; +use chia_traits::Streamable; +use dns_lookup::lookup_host; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use native_tls::TlsConnector; +use rand::{seq::SliceRandom, thread_rng}; +use semver::Version; +use tokio::{ + sync::{mpsc, Mutex, RwLock, RwLockWriteGuard}, + time::timeout, +}; + +use crate::{Error, Event, Network, Peer, PeerId, Result}; + +/// A client that can connect to many different peers on the network. +#[derive(Debug, Clone)] +pub struct Client(Arc); + +#[derive(Debug, Clone)] +pub struct ClientOptions { + /// The network to connect to. By default, this is mainnet. + pub network: Network, + + /// The type of service that this client represents. + /// This defaults to [`NodeType::Wallet`], since that is the most common use case for this library. + pub node_type: NodeType, + + /// The capabilities that this client supports. + /// This defaults to the standard capabilities all Chia services connect with. + pub capabilities: Vec<(u16, String)>, + + /// The minimum protocol version that this client supports. + /// Currently defaults to `0.0.37`, which is supported by a majority of the network. + /// If the protocol version of the peer is lower than this, the connection will be rejected. + pub protocol_version: Version, + + /// The software version of this client. + /// This is not important for the handshake, but is sent to the peer for informational purposes. + /// Defaults to `0.0.0`, since this isn't a Chia full node. + pub software_version: String, + + /// The ideal number of peers that should be connected at any given time. + /// This defaults to `5`. + pub target_peers: usize, + + /// How long to wait when trying to connect to a peer. + pub connection_timeout: Duration, + + /// How long to wait for a handshake response from a peer before disconnecting. + pub handshake_timeout: Duration, + + /// How long to wait for a response to a request for peers. + pub request_peers_timeout: Duration, +} + +impl Default for ClientOptions { + fn default() -> Self { + Self { + network: Network::mainnet(), + node_type: NodeType::Wallet, + capabilities: vec![ + (1, "1".to_string()), + (2, "1".to_string()), + (3, "1".to_string()), + ], + protocol_version: Version::parse("0.0.37").expect("invalid version"), + software_version: "0.0.0".to_string(), + target_peers: 5, + connection_timeout: Duration::from_secs(3), + handshake_timeout: Duration::from_secs(1), + request_peers_timeout: Duration::from_secs(3), + } + } +} + +#[derive(Debug)] +struct ClientInner { + peers: Arc>>, + message_sender: Arc>>, + options: ClientOptions, + tls_connector: TlsConnector, +} + +impl Client { + pub fn new(tls_connector: TlsConnector) -> (Self, mpsc::Receiver) { + Self::with_options(tls_connector, ClientOptions::default()) + } + + pub fn with_options( + tls_connector: TlsConnector, + options: ClientOptions, + ) -> (Self, mpsc::Receiver) { + let (sender, receiver) = mpsc::channel(32); + + let client = Self(Arc::new(ClientInner { + peers: Arc::new(RwLock::new(HashMap::new())), + message_sender: Arc::new(Mutex::new(sender)), + options, + tls_connector, + })); + + (client, receiver) + } + + pub async fn peer_count(&self) -> usize { + self.0.peers.read().await.len() + } + + pub async fn find_peers(&self) { + // If we don't have any peers, try to connect to DNS introducers. + if self.peer_count().await == 0 && self.connect_dns().await { + return; + } + + let mut peers = self.0.peers.write().await; + + // If we still don't have any peers, we can't do anything. + if peers.len() >= self.0.options.target_peers { + return; + } + + if peers.is_empty() { + log::error!("No peers connected after DNS lookups"); + return; + } + + for (peer_id, peer) in peers.clone() { + if peers.len() >= self.0.options.target_peers { + break; + } + + // Request new peers from the peer. + let Ok(Ok(response)): std::result::Result, _> = timeout( + self.0.options.request_peers_timeout, + peer.request_infallible(RequestPeers::new()), + ) + .await + else { + log::info!("Failed to request peers from peer {peer_id}"); + peers.remove(&peer_id); + continue; + }; + + log::info!("Requested peers from peer {peer_id}"); + + let mut ips = HashSet::new(); + + for item in response.peer_list { + // If we can't parse the IP address, skip it. + let Ok(ip_addr) = IpAddr::from_str(&item.host) else { + log::debug!("Failed to parse IP address {}", item.host); + continue; + }; + + ips.insert((ip_addr, item.port)); + } + + // Keep connecting peers until the peer list is exhausted, + // then move on to the next peer to request from. + let mut iter = ips.into_iter(); + + loop { + let required_peers = self.0.options.target_peers - peers.len(); + let next_peers: Vec<_> = iter.by_ref().take(required_peers).collect(); + if next_peers.is_empty() { + break; + } + self.connect_peers(&mut peers, next_peers).await; + } + } + } + + async fn connect_dns(&self) -> bool { + log::info!("Requesting peers from DNS introducer"); + + // Lock the peer map early to prevent adding too many connections. + let mut peers = self.0.peers.write().await; + + let mut ips = Vec::new(); + + for dns_introducer in &self.0.options.network.dns_introducers { + // If a DNS introducer lookup fails, we just skip it. + let Ok(result) = lookup_host(dns_introducer) else { + log::warn!("Failed to lookup DNS introducer `{dns_introducer}`"); + continue; + }; + ips.extend(result); + } + + // Shuffle the list of IPs so that we don't always connect to the same ones. + // This also prevents bias towards IPv4 or IPv6. + ips.as_mut_slice().shuffle(&mut thread_rng()); + + // Keep track of where we are in the peer list. + let mut cursor = 0; + + while peers.len() < self.0.options.target_peers { + // If we've reached the end of the list of IPs, stop early. + if cursor >= ips.len() { + break; + } + + // Calculate how many peers we still need to connect to. + let required_peers = self.0.options.target_peers - peers.len(); + + // Get the remaining peers we can and need to connect to. + let peers_to_try = &ips[cursor..ips.len().min(cursor + required_peers)]; + + // Increment the cursor by the number of peers we're trying to connect to. + cursor += required_peers; + + self.connect_peers( + &mut peers, + peers_to_try + .iter() + .map(|ip| (*ip, self.0.options.network.default_port)) + .collect(), + ) + .await; + } + + peers.len() >= self.0.options.target_peers + } + + async fn connect_peers( + &self, + peers: &mut RwLockWriteGuard<'_, HashMap>, + potential_ips: Vec<(IpAddr, u16)>, + ) { + let ips: Vec<(IpAddr, u16)> = potential_ips + .into_iter() + .filter(|&(ip, _port)| !peers.values().any(|peer| peer.ip_addr() == ip)) + .collect(); + + // Add the connections and wait for them to complete. + let mut connections = FuturesUnordered::new(); + + for (ip, port) in ips { + connections.push(self.connect_peer(ip, port)); + } + + while let Some(result) = connections.next().await { + let (ip, peer, mut receiver) = match result { + Ok(result) => result, + Err(error) => { + log::debug!("Failed to connect to peer: {error}"); + continue; + } + }; + + let peer_id = peer.peer_id(); + peers.insert(peer_id, peer); + + let message_sender = self.0.message_sender.clone(); + let peer_map = self.0.peers.clone(); + + // Spawn a task to propagate messages from the peer. + tokio::spawn(async move { + while let Some(message) = receiver.recv().await { + if let Err(error) = message_sender + .lock() + .await + .send(Event::Message(peer_id, message)) + .await + { + log::debug!("Failed to send client message event: {error}"); + break; + } + } + peer_map.write().await.remove(&peer_id); + + if let Err(error) = message_sender + .lock() + .await + .send(Event::ConnectionClosed(peer_id)) + .await + { + log::debug!("Failed to send client connection closed event: {error}"); + } + + log::info!("Peer {ip} disconnected"); + }); + + log::info!("Connected to peer {ip}"); + } + } + + /// Does not lock the peer map or add the peer automatically. + /// This prevents deadlocks when called from within a lock. + async fn connect_peer( + &self, + ip: IpAddr, + port: u16, + ) -> Result<(IpAddr, Peer, mpsc::Receiver)> { + let (peer, mut receiver) = timeout( + self.0.options.connection_timeout, + Peer::connect(ip, port, self.0.tls_connector.clone()), + ) + .await??; + + let options = &self.0.options; + + peer.send(Handshake { + network_id: options.network.network_id.clone(), + protocol_version: options.protocol_version.to_string(), + software_version: options.software_version.clone(), + server_port: 0, + node_type: options.node_type, + capabilities: options.capabilities.clone(), + }) + .await?; + + let Some(message) = timeout(options.handshake_timeout, receiver.recv()).await? else { + return Err(Error::ExpectedHandshake); + }; + + if message.msg_type != ProtocolMessageTypes::Handshake { + return Err(Error::ExpectedHandshake); + }; + + let handshake = Handshake::from_bytes(&message.data)?; + + let Ok(protocol_version) = Version::parse(&handshake.protocol_version) else { + return Err(Error::InvalidProtocolVersion(handshake.protocol_version)); + }; + + if protocol_version < options.protocol_version { + return Err(Error::OutdatedProtocolVersion( + protocol_version, + options.protocol_version.clone(), + )); + } + + Ok((ip, peer, receiver)) + } +} diff --git a/crates/chia-client/src/error.rs b/crates/chia-client/src/error.rs index fcb99bcd9..0bf9abc50 100644 --- a/crates/chia-client/src/error.rs +++ b/crates/chia-client/src/error.rs @@ -1,12 +1,22 @@ use chia_protocol::ProtocolMessageTypes; +use semver::Version; use thiserror::Error; -use tokio::sync::oneshot::error::RecvError; +use tokio::{sync::oneshot::error::RecvError, time::error::Elapsed}; #[derive(Debug, Error)] pub enum Error { #[error("Peer is missing certificate")] MissingCertificate, + #[error("Handshake not received")] + ExpectedHandshake, + + #[error("Invalid protocol version {0}")] + InvalidProtocolVersion(String), + + #[error("Outdated protocol version {0}, expected {1}")] + OutdatedProtocolVersion(Version, Version), + #[error("Streamable error: {0}")] Streamable(#[from] chia_traits::Error), @@ -27,6 +37,12 @@ pub enum Error { #[error("Failed to receive message")] Recv(#[from] RecvError), + + #[error("Timeout error: {0}")] + Timeout(#[from] Elapsed), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), } pub type Result = std::result::Result; diff --git a/crates/chia-client/src/event.rs b/crates/chia-client/src/event.rs index ec08b4e2c..81c198101 100644 --- a/crates/chia-client/src/event.rs +++ b/crates/chia-client/src/event.rs @@ -1,8 +1,9 @@ -use chia_protocol::{CoinStateUpdate, Handshake, NewPeakWallet}; +use chia_protocol::Message; + +use crate::PeerId; #[derive(Debug, Clone)] pub enum Event { - Handshake(Handshake), - NewPeakWallet(NewPeakWallet), - CoinStateUpdate(CoinStateUpdate), + Message(PeerId, Message), + ConnectionClosed(PeerId), } diff --git a/crates/chia-client/src/lib.rs b/crates/chia-client/src/lib.rs index 71d9b32c6..3c990dc94 100644 --- a/crates/chia-client/src/lib.rs +++ b/crates/chia-client/src/lib.rs @@ -1,10 +1,16 @@ +mod client; mod error; mod event; +mod network; mod peer; mod request_map; mod response; +mod tls; +pub use client::*; pub use error::*; pub use event::*; +pub use network::*; pub use peer::*; pub use response::*; +pub use tls::*; diff --git a/crates/chia-client/src/network.rs b/crates/chia-client/src/network.rs new file mode 100644 index 000000000..af229e184 --- /dev/null +++ b/crates/chia-client/src/network.rs @@ -0,0 +1,39 @@ +use chia_protocol::Bytes32; +use hex_literal::hex; + +#[derive(Debug, Clone)] +pub struct Network { + pub network_id: String, + pub default_port: u16, + pub genesis_challenge: Bytes32, + pub dns_introducers: Vec, +} + +impl Network { + pub fn mainnet() -> Self { + Self { + network_id: "mainnet".to_string(), + default_port: 8444, + genesis_challenge: Bytes32::new(hex!( + "ccd5bb71183532bff220ba46c268991a3ff07eb358e8255a65c30a2dce0e5fbb" + )), + dns_introducers: vec![ + "dns-introducer.chia.net".to_string(), + "chia.ctrlaltdel.ch".to_string(), + "seeder.dexie.space".to_string(), + "chia.hoffmang.com".to_string(), + ], + } + } + + pub fn testnet11() -> Self { + Self { + network_id: "testnet11".to_string(), + default_port: 58444, + genesis_challenge: Bytes32::new(hex!( + "37a90eb5185a9c4439a91ddc98bbadce7b4feba060d50116a067de66bf236615" + )), + dns_introducers: vec!["dns-introducer-testnet11.chia.net".to_string()], + } + } +} diff --git a/crates/chia-client/src/peer.rs b/crates/chia-client/src/peer.rs index 16eb1355e..05ad574aa 100644 --- a/crates/chia-client/src/peer.rs +++ b/crates/chia-client/src/peer.rs @@ -1,22 +1,21 @@ -use std::sync::Arc; +use std::{fmt, net::IpAddr, sync::Arc}; -use chia_protocol::{ - ChiaProtocolMessage, CoinStateUpdate, Handshake, Message, NewPeakWallet, ProtocolMessageTypes, -}; +use chia_protocol::{ChiaProtocolMessage, Message}; use chia_traits::Streamable; use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, StreamExt, }; +use native_tls::TlsConnector; use sha2::{digest::FixedOutput, Digest, Sha256}; use tokio::{ net::TcpStream, sync::{mpsc, oneshot, Mutex}, task::JoinHandle, }; -use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; +use tokio_tungstenite::{Connector, MaybeTlsStream, WebSocketStream}; -use crate::{request_map::RequestMap, Error, Event, Response, Result}; +use crate::{request_map::RequestMap, Error, Response, Result}; type WebSocket = WebSocketStream>; type Sink = SplitSink; @@ -25,6 +24,12 @@ type Stream = SplitStream; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct PeerId([u8; 32]); +impl fmt::Display for PeerId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(self.0)) + } +} + #[derive(Debug, Clone)] pub struct Peer(Arc); @@ -34,13 +39,45 @@ struct PeerInner { inbound_handle: JoinHandle>, requests: Arc, peer_id: PeerId, + ip_addr: IpAddr, } impl Peer { - pub fn new(ws: WebSocket) -> Result<(Self, mpsc::Receiver)> { - let cert = match ws.get_ref() { - MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_certificate()?, - _ => None, + pub async fn connect( + ip: IpAddr, + port: u16, + tls_connector: TlsConnector, + ) -> Result<(Self, mpsc::Receiver)> { + let uri = if ip.is_ipv4() { + format!("wss://{ip}:{port}/ws") + } else { + format!("wss://[{ip}]:{port}/ws") + }; + Self::connect_addr(&uri, tls_connector).await + } + + pub async fn connect_addr( + uri: &str, + tls_connector: TlsConnector, + ) -> Result<(Self, mpsc::Receiver)> { + let (ws, _) = tokio_tungstenite::connect_async_tls_with_config( + uri, + None, + false, + Some(Connector::NativeTls(tls_connector)), + ) + .await?; + Self::from_websocket(ws) + } + + pub fn from_websocket(ws: WebSocket) -> Result<(Self, mpsc::Receiver)> { + let (addr, cert) = match ws.get_ref() { + MaybeTlsStream::NativeTls(tls) => { + let tls_stream = tls.get_ref(); + let tcp_stream = tls_stream.get_ref().get_ref(); + (tcp_stream.peer_addr()?, tls_stream.peer_certificate()?) + } + _ => return Err(Error::MissingCertificate), }; let Some(cert) = cert else { @@ -63,6 +100,7 @@ impl Peer { inbound_handle, requests, peer_id, + ip_addr: addr.ip(), })); Ok((peer, receiver)) @@ -72,6 +110,10 @@ impl Peer { self.0.peer_id } + pub fn ip_addr(&self) -> IpAddr { + self.0.ip_addr + } + pub async fn send(&self, body: T) -> Result<()> where T: Streamable + ChiaProtocolMessage, @@ -147,46 +189,29 @@ impl Drop for PeerInner { async fn handle_inbound_messages( mut stream: Stream, - sender: mpsc::Sender, + sender: mpsc::Sender, requests: Arc, ) -> Result<()> { while let Some(message) = stream.next().await { let message = Message::from_bytes(&message?.into_data())?; - match message.msg_type { - ProtocolMessageTypes::CoinStateUpdate => { - let event = Event::CoinStateUpdate(CoinStateUpdate::from_bytes(&message.data)?); - sender.send(event).await.map_err(|error| { - log::error!("Failed to send `CoinStateUpdate` event: {error}"); - Error::EventNotSent - })?; - } - ProtocolMessageTypes::NewPeakWallet => { - let event = Event::NewPeakWallet(NewPeakWallet::from_bytes(&message.data)?); - sender.send(event).await.map_err(|error| { - log::error!("Failed to send `NewPeakWallet` event: {error}"); - Error::EventNotSent - })?; - } - ProtocolMessageTypes::Handshake => { - let event = Event::Handshake(Handshake::from_bytes(&message.data)?); - sender.send(event).await.map_err(|error| { - log::error!("Failed to send `Handshake` event: {error}"); - Error::EventNotSent - })?; - } - kind => { - let Some(id) = message.id else { - log::error!("Received unknown message without an id."); - return Err(Error::UnexpectedMessage(kind)); - }; - let Some(request) = requests.remove(id).await else { - log::error!("Received message with untracked id {id}."); - return Err(Error::UnexpectedMessage(kind)); - }; - request.send(message); - } - } + let Some(id) = message.id else { + sender.send(message).await.map_err(|error| { + log::debug!("Failed to send peer message event: {error}"); + Error::EventNotSent + })?; + continue; + }; + + let Some(request) = requests.remove(id).await else { + log::warn!( + "Received {:?} message with untracked id {id}", + message.msg_type + ); + return Err(Error::UnexpectedMessage(message.msg_type)); + }; + + request.send(message); } Ok(()) } diff --git a/crates/chia-client/src/tls.rs b/crates/chia-client/src/tls.rs new file mode 100644 index 000000000..efbd9fc10 --- /dev/null +++ b/crates/chia-client/src/tls.rs @@ -0,0 +1,11 @@ +use native_tls::{Identity, TlsConnector}; + +pub fn create_tls_connector( + cert_pem: &[u8], + key_pem: &[u8], +) -> Result { + TlsConnector::builder() + .identity(Identity::from_pkcs8(cert_pem, key_pem)?) + .danger_accept_invalid_certs(true) + .build() +}