From 3ea6bc12536fca5718f16bcf21b67612b106de08 Mon Sep 17 00:00:00 2001 From: Pure White Date: Thu, 20 Jul 2023 22:19:10 +0800 Subject: [PATCH] RUST-802 Support Unix Domain Sockets (#908) --- src/client/auth.rs | 23 -- src/client/options.rs | 125 +++++++-- src/client/options/test.rs | 102 ++++--- src/runtime.rs | 5 +- src/runtime/stream.rs | 258 +++++++++++++++++- src/sdam/description/server.rs | 10 +- src/sdam/topology.rs | 6 + src/srv.rs | 3 +- .../json/connection-string/valid-auth.json | 5 +- .../json/connection-string/valid-auth.yml | 5 +- .../connection-string/valid-warnings.json | 2 +- src/trace.rs | 4 +- src/trace/command.rs | 6 +- src/trace/connection.rs | 22 +- src/trace/server_selection.rs | 2 +- 15 files changed, 443 insertions(+), 135 deletions(-) diff --git a/src/client/auth.rs b/src/client/auth.rs index 1a0418ff9..2476fced3 100644 --- a/src/client/auth.rs +++ b/src/client/auth.rs @@ -397,29 +397,6 @@ pub struct Credential { } impl Credential { - #[cfg(all(test, not(feature = "sync"), not(feature = "tokio-sync")))] - pub(crate) fn into_document(mut self) -> Document { - use crate::bson::Bson; - - let mut doc = Document::new(); - - if let Some(s) = self.username.take() { - doc.insert("username", s); - } - - if let Some(s) = self.password.take() { - doc.insert("password", s); - } else { - doc.insert("password", Bson::Null); - } - - if let Some(s) = self.source.take() { - doc.insert("db", s); - } - - doc - } - pub(crate) fn resolved_source(&self) -> &str { self.mechanism .as_ref() diff --git a/src/client/options.rs b/src/client/options.rs index e432080f1..ebcdce90a 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -4,6 +4,7 @@ mod test; mod resolver_config; use std::{ + borrow::Cow, cmp::Ordering, collections::HashSet, convert::TryFrom, @@ -91,14 +92,11 @@ lazy_static! { }; static ref ILLEGAL_DATABASE_CHARACTERS: HashSet<&'static char> = { - ['/', '\\', ' ', '"', '$', '.'].iter().collect() + ['/', '\\', ' ', '"', '$'].iter().collect() }; } /// An enum representing the address of a MongoDB server. -/// -/// Currently this just supports addresses that can be connected to over TCP, but alternative -/// address types may be supported in the future (e.g. Unix Domain Socket paths). #[derive(Clone, Debug, Eq, Serialize)] #[non_exhaustive] pub enum ServerAddress { @@ -112,6 +110,12 @@ pub enum ServerAddress { /// The default is 27017. port: Option, }, + /// A Unix Domain Socket path. + #[cfg(unix)] + Unix { + /// The path to the Unix Domain Socket. + path: PathBuf, + }, } impl<'de> Deserialize<'de> for ServerAddress { @@ -144,6 +148,10 @@ impl PartialEq for ServerAddress { port: other_port, }, ) => host == other_host && port.unwrap_or(27017) == other_port.unwrap_or(27017), + #[cfg(unix)] + (Self::Unix { path }, Self::Unix { path: other_path }) => path == other_path, + #[cfg(unix)] + _ => false, } } } @@ -158,6 +166,8 @@ impl Hash for ServerAddress { host.hash(state); port.unwrap_or(27017).hash(state); } + #[cfg(unix)] + Self::Unix { path } => path.hash(state), } } } @@ -173,6 +183,15 @@ impl ServerAddress { /// Parses an address string into a `ServerAddress`. pub fn parse(address: impl AsRef) -> Result { let address = address.as_ref(); + // checks if the address is a unix domain socket + #[cfg(unix)] + { + if address.ends_with(".sock") { + return Ok(ServerAddress::Unix { + path: PathBuf::from(address), + }); + } + } let mut parts = address.split(':'); let hostname = match parts.next() { Some(part) => { @@ -243,18 +262,29 @@ impl ServerAddress { "port": port.map(|i| Bson::Int32(i.into())).unwrap_or(Bson::Null) } } + #[cfg(unix)] + Self::Unix { path } => { + doc! { + "host": path.to_string_lossy().as_ref(), + "port": Bson::Null, + } + } } } - pub(crate) fn host(&self) -> &str { + pub(crate) fn host(&self) -> Cow<'_, str> { match self { - Self::Tcp { host, .. } => host.as_str(), + Self::Tcp { host, .. } => Cow::Borrowed(host.as_str()), + #[cfg(unix)] + Self::Unix { path } => path.to_string_lossy(), } } pub(crate) fn port(&self) -> Option { match self { Self::Tcp { port, .. } => *port, + #[cfg(unix)] + Self::Unix { .. } => None, } } } @@ -265,6 +295,8 @@ impl fmt::Display for ServerAddress { Self::Tcp { host, port } => { write!(fmt, "{}:{}", host, port.unwrap_or(DEFAULT_PORT)) } + #[cfg(unix)] + Self::Unix { path } => write!(fmt, "{}", path.display()), } } } @@ -1580,10 +1612,26 @@ impl ConnectionString { None => (None, None), }; - let host_list: Result> = - hosts_section.split(',').map(ServerAddress::parse).collect(); - - let host_list = host_list?; + let mut host_list = Vec::with_capacity(hosts_section.len()); + for host in hosts_section.split(',') { + let address = if host.ends_with(".sock") { + #[cfg(unix)] + { + ServerAddress::parse(percent_decode( + host, + "Unix domain sockets must be URL-encoded", + )?) + } + #[cfg(not(unix))] + return Err(ErrorKind::InvalidArgument { + message: "Unix domain sockets are not supported on this platform".to_string(), + } + .into()); + } else { + ServerAddress::parse(host) + }?; + host_list.push(address); + } let hosts = if srv { if host_list.len() != 1 { @@ -1592,16 +1640,26 @@ impl ConnectionString { } .into()); } - // Unwrap safety: the `len` check above guarantees this can't fail. - let ServerAddress::Tcp { host, port } = host_list.into_iter().next().unwrap(); - if port.is_some() { - return Err(ErrorKind::InvalidArgument { - message: "a port cannot be specified with 'mongodb+srv'".into(), + // Unwrap safety: the `len` check above guarantees this can't fail. + match host_list.into_iter().next().unwrap() { + ServerAddress::Tcp { host, port } => { + if port.is_some() { + return Err(ErrorKind::InvalidArgument { + message: "a port cannot be specified with 'mongodb+srv'".into(), + } + .into()); + } + HostInfo::DnsRecord(host) + } + #[cfg(unix)] + ServerAddress::Unix { .. } => { + return Err(ErrorKind::InvalidArgument { + message: "unix sockets cannot be used with 'mongodb+srv'".into(), + } + .into()); } - .into()); } - HostInfo::DnsRecord(host) } else { HostInfo::HostIdentifiers(host_list) }; @@ -2299,18 +2357,39 @@ mod tests { #[test] fn test_parse_address_with_from_str() { let x = "localhost:27017".parse::().unwrap(); - let ServerAddress::Tcp { host, port } = x; - assert_eq!(host, "localhost"); - assert_eq!(port, Some(27017)); + match x { + ServerAddress::Tcp { host, port } => { + assert_eq!(host, "localhost"); + assert_eq!(port, Some(27017)); + } + #[cfg(unix)] + _ => panic!("expected ServerAddress::Tcp"), + } // Port defaults to 27017 (so this doesn't fail) let x = "localhost".parse::().unwrap(); - let ServerAddress::Tcp { host, port } = x; - assert_eq!(host, "localhost"); - assert_eq!(port, None); + match x { + ServerAddress::Tcp { host, port } => { + assert_eq!(host, "localhost"); + assert_eq!(port, None); + } + #[cfg(unix)] + _ => panic!("expected ServerAddress::Tcp"), + } let x = "localhost:not a number".parse::(); assert!(x.is_err()); + + #[cfg(unix)] + { + let x = "/path/to/socket.sock".parse::().unwrap(); + match x { + ServerAddress::Unix { path } => { + assert_eq!(path.to_str().unwrap(), "/path/to/socket.sock"); + } + _ => panic!("expected ServerAddress::Unix"), + } + } } #[cfg_attr(feature = "tokio-runtime", tokio::test)] diff --git a/src/client/options/test.rs b/src/client/options/test.rs index e2dee4361..6dad20c35 100644 --- a/src/client/options/test.rs +++ b/src/client/options/test.rs @@ -20,13 +20,30 @@ struct TestFile { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct TestCase { - pub description: String, - pub uri: String, - pub valid: bool, - pub warning: Option, - pub hosts: Option>, - pub auth: Option, - pub options: Option, + description: String, + uri: String, + valid: bool, + warning: Option, + hosts: Option>, + auth: Option, + options: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +struct TestAuth { + username: Option, + password: Option, + db: Option, +} + +impl TestAuth { + fn matches_client_options(&self, options: &ClientOptions) -> bool { + let credential = options.credential.as_ref(); + self.username.as_ref() == credential.and_then(|cred| cred.username.as_ref()) + && self.password.as_ref() == credential.and_then(|cred| cred.password.as_ref()) + && self.db.as_ref() == options.default_database.as_ref() + } } async fn run_test(test_file: TestFile) { @@ -43,7 +60,6 @@ async fn run_test(test_file: TestFile) { || test_case.description.contains("tlsAllowInvalidHostnames") || test_case.description.contains("single-threaded") || test_case.description.contains("serverSelectionTryOnce") - || test_case.description.contains("Unix") || test_case.description.contains("relative path") // Compression is implemented but will only pass the tests if all // the appropriate feature flags are set. That is because @@ -63,6 +79,11 @@ async fn run_test(test_file: TestFile) { continue; } + #[cfg(not(unix))] + if test_case.description.contains("Unix") { + continue; + } + let warning = test_case.warning.take().unwrap_or(false); if test_case.valid && !warning { @@ -70,12 +91,25 @@ async fn run_test(test_file: TestFile) { // hosts if let Some(mut json_hosts) = test_case.hosts.take() { // skip over unsupported host types - is_unsupported_host_type = json_hosts.iter_mut().any(|h_json| { - matches!( - h_json.remove("type").as_ref().and_then(Bson::as_str), - Some("ip_literal") | Some("unix") - ) - }); + #[cfg(not(unix))] + { + is_unsupported_host_type = json_hosts.iter_mut().any(|h_json| { + matches!( + h_json.remove("type").as_ref().and_then(Bson::as_str), + Some("ip_literal") | Some("unix") + ) + }); + } + + #[cfg(unix)] + { + is_unsupported_host_type = json_hosts.iter_mut().any(|h_json| { + matches!( + h_json.remove("type").as_ref().and_then(Bson::as_str), + Some("ip_literal") + ) + }); + } if !is_unsupported_host_type { let options = ClientOptions::parse(&test_case.uri).await.unwrap(); @@ -154,27 +188,10 @@ async fn run_test(test_file: TestFile) { assert_eq!(options_doc, json_options, "{}", test_case.description) } - // auth - if let Some(json_auth) = test_case.auth { - let json_auth: Document = json_auth - .into_iter() - .filter_map(|(k, v)| { - if let Bson::Null = v { - None - } else { - Some((k.to_lowercase(), v)) - } - }) - .collect(); + if let Some(test_auth) = test_case.auth { let options = ClientOptions::parse(&test_case.uri).await.unwrap(); - let mut expected_auth = options.credential.unwrap_or_default().into_document(); - expected_auth = expected_auth - .into_iter() - .filter(|(ref key, _)| json_auth.contains_key(key)) - .collect(); - - assert_eq!(expected_auth, json_auth); + assert!(test_auth.matches_client_options(&options)); } } } else { @@ -278,25 +295,6 @@ async fn parse_unknown_options() { parse_uri("maxstalenessms", Some("maxstalenessseconds")).await; } -#[cfg_attr(feature = "tokio-runtime", tokio::test)] -#[cfg_attr(feature = "async-std-runtime", async_std::test)] -async fn parse_with_default_database() { - let uri = "mongodb://localhost/abc"; - - assert_eq!( - ClientOptions::parse(uri).await.unwrap(), - ClientOptions { - hosts: vec![ServerAddress::Tcp { - host: "localhost".to_string(), - port: None - }], - original_uri: Some(uri.into()), - default_database: Some("abc".to_string()), - ..Default::default() - } - ); -} - #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn parse_with_no_default_database() { diff --git a/src/runtime.rs b/src/runtime.rs index b625a5101..c7c8b343c 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -170,8 +170,9 @@ pub(crate) async fn resolve_address( #[cfg(feature = "async-std-runtime")] { - let host = (address.host(), address.port().unwrap_or(27017)); - let socket_addrs = async_std::net::ToSocketAddrs::to_socket_addrs(&host).await?; + use async_std::net::ToSocketAddrs; + + let socket_addrs = format!("{}", address).to_socket_addrs().await?; Ok(socket_addrs) } } diff --git a/src/runtime/stream.rs b/src/runtime/stream.rs index 9ca6a0c55..5c4223035 100644 --- a/src/runtime/stream.rs +++ b/src/runtime/stream.rs @@ -30,6 +30,10 @@ pub(crate) enum AsyncStream { /// A TLS connection over TCP. Tls(AsyncTlsStream), + + /// A Unix domain socket connection. + #[cfg(unix)] + Unix(unix::AsyncUnixStream), } impl AsyncStream { @@ -37,17 +41,191 @@ impl AsyncStream { address: ServerAddress, tls_cfg: Option<&TlsConfig>, ) -> Result { - let inner = AsyncTcpStream::connect(&address).await?; - - // If there are TLS options, wrap the inner stream in an AsyncTlsStream. - match tls_cfg { - Some(cfg) => { - let host = address.host(); - Ok(AsyncStream::Tls( - AsyncTlsStream::connect(host, inner, cfg).await?, - )) + match &address { + ServerAddress::Tcp { host, .. } => { + let inner = AsyncTcpStream::connect(&address).await?; + + // If there are TLS options, wrap the inner stream in an AsyncTlsStream. + match tls_cfg { + Some(cfg) => Ok(AsyncStream::Tls( + AsyncTlsStream::connect(host, inner, cfg).await?, + )), + None => Ok(AsyncStream::Tcp(inner)), + } + } + #[cfg(unix)] + ServerAddress::Unix { .. } => Ok(AsyncStream::Unix( + unix::AsyncUnixStream::connect(&address).await?, + )), + } + } +} + +/// A runtime-agnostic async unix domain socket stream. +#[cfg(unix)] +mod unix { + use std::{ + ops::DerefMut, + path::Path, + pin::Pin, + task::{Context, Poll}, + }; + + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + + use crate::{client::options::ServerAddress, error::Result}; + + #[derive(Debug)] + pub(crate) enum AsyncUnixStream { + /// Wrapper around `tokio::net:UnixStream`. + #[cfg(feature = "tokio-runtime")] + Tokio(tokio::net::UnixStream), + + /// Wrapper around `async_std::net::UnixStream`. + #[cfg(feature = "async-std-runtime")] + AsyncStd(async_std::os::unix::net::UnixStream), + } + + #[cfg(feature = "tokio-runtime")] + impl From for AsyncUnixStream { + fn from(stream: tokio::net::UnixStream) -> Self { + Self::Tokio(stream) + } + } + + #[cfg(feature = "async-std-runtime")] + impl From for AsyncUnixStream { + fn from(stream: async_std::os::unix::net::UnixStream) -> Self { + Self::AsyncStd(stream) + } + } + + impl AsyncUnixStream { + #[cfg(feature = "tokio-runtime")] + async fn try_connect(address: &Path) -> Result { + use tokio::net::UnixStream; + + let stream = UnixStream::connect(address).await?; + Ok(stream.into()) + } + + #[cfg(feature = "async-std-runtime")] + async fn try_connect(address: &Path) -> Result { + use async_std::os::unix::net::UnixStream; + + let stream = UnixStream::connect(address).await?; + Ok(stream.into()) + } + + pub(crate) async fn connect(address: &ServerAddress) -> Result { + debug_assert!( + matches!(address, ServerAddress::Unix { .. }), + "address must be unix" + ); + + match address { + ServerAddress::Unix { ref path } => Self::try_connect(path.as_path()).await, + _ => unreachable!(), + } + } + } + + impl AsyncRead for AsyncUnixStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll> { + match self.deref_mut() { + #[cfg(feature = "tokio-runtime")] + Self::Tokio(ref mut inner) => Pin::new(inner).poll_read(cx, buf), + + #[cfg(feature = "async-std-runtime")] + Self::AsyncStd(ref mut inner) => { + use tokio_util::compat::FuturesAsyncReadCompatExt; + + Pin::new(&mut inner.compat()).poll_read(cx, buf) + } + } + } + } + + impl AsyncWrite for AsyncUnixStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.deref_mut() { + #[cfg(feature = "tokio-runtime")] + Self::Tokio(ref mut inner) => Pin::new(inner).poll_write(cx, buf), + + #[cfg(feature = "async-std-runtime")] + Self::AsyncStd(ref mut inner) => { + use tokio_util::compat::FuturesAsyncReadCompatExt; + + Pin::new(&mut inner.compat()).poll_write(cx, buf) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.deref_mut() { + #[cfg(feature = "tokio-runtime")] + Self::Tokio(ref mut inner) => Pin::new(inner).poll_flush(cx), + + #[cfg(feature = "async-std-runtime")] + Self::AsyncStd(ref mut inner) => { + use tokio_util::compat::FuturesAsyncReadCompatExt; + + Pin::new(&mut inner.compat()).poll_flush(cx) + } + } + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + match self.deref_mut() { + #[cfg(feature = "tokio-runtime")] + Self::Tokio(ref mut inner) => Pin::new(inner).poll_shutdown(cx), + + #[cfg(feature = "async-std-runtime")] + Self::AsyncStd(ref mut inner) => { + use tokio_util::compat::FuturesAsyncReadCompatExt; + + Pin::new(&mut inner.compat()).poll_shutdown(cx) + } + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[futures_io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tokio-runtime")] + Self::Tokio(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs), + + #[cfg(feature = "async-std-runtime")] + Self::AsyncStd(ref mut inner) => { + use tokio_util::compat::FuturesAsyncReadCompatExt; + + Pin::new(&mut inner.compat()).poll_write_vectored(cx, bufs) + } + } + } + + fn is_write_vectored(&self) -> bool { + match self { + #[cfg(feature = "tokio-runtime")] + Self::Tokio(ref inner) => inner.is_write_vectored(), + + #[cfg(feature = "async-std-runtime")] + Self::AsyncStd(_) => false, } - None => Ok(AsyncStream::Tcp(inner)), } } } @@ -154,6 +332,8 @@ impl tokio::io::AsyncRead for AsyncStream { Self::Null => Poll::Ready(Ok(())), Self::Tcp(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf), Self::Tls(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf), + #[cfg(unix)] + Self::Unix(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf), } } } @@ -168,6 +348,8 @@ impl AsyncWrite for AsyncStream { Self::Null => Poll::Ready(Ok(0)), Self::Tcp(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf), Self::Tls(ref mut inner) => Pin::new(inner).poll_write(cx, buf), + #[cfg(unix)] + Self::Unix(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf), } } @@ -176,6 +358,8 @@ impl AsyncWrite for AsyncStream { Self::Null => Poll::Ready(Ok(())), Self::Tcp(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx), Self::Tls(ref mut inner) => Pin::new(inner).poll_flush(cx), + #[cfg(unix)] + Self::Unix(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx), } } @@ -184,6 +368,32 @@ impl AsyncWrite for AsyncStream { Self::Null => Poll::Ready(Ok(())), Self::Tcp(ref mut inner) => Pin::new(inner).poll_shutdown(cx), Self::Tls(ref mut inner) => Pin::new(inner).poll_shutdown(cx), + #[cfg(unix)] + Self::Unix(ref mut inner) => Pin::new(inner).poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[futures_io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Self::Null => Poll::Ready(Ok(0)), + Self::Tcp(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs), + Self::Tls(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs), + #[cfg(unix)] + Self::Unix(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Self::Null => false, + Self::Tcp(ref inner) => inner.is_write_vectored(), + Self::Tls(ref inner) => inner.is_write_vectored(), + #[cfg(unix)] + Self::Unix(ref inner) => inner.is_write_vectored(), } } } @@ -254,4 +464,32 @@ impl AsyncWrite for AsyncTcpStream { } } } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[futures_io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + #[cfg(feature = "tokio-runtime")] + Self::Tokio(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs), + + #[cfg(feature = "async-std-runtime")] + Self::AsyncStd(ref mut inner) => { + use tokio_util::compat::FuturesAsyncReadCompatExt; + + Pin::new(&mut inner.compat()).poll_write_vectored(cx, bufs) + } + } + } + + fn is_write_vectored(&self) -> bool { + match self { + #[cfg(feature = "tokio-runtime")] + Self::Tokio(ref inner) => inner.is_write_vectored(), + + #[cfg(feature = "async-std-runtime")] + Self::AsyncStd(_) => false, + } + } } diff --git a/src/sdam/description/server.rs b/src/sdam/description/server.rs index ebfa91512..76c5144e3 100644 --- a/src/sdam/description/server.rs +++ b/src/sdam/description/server.rs @@ -190,9 +190,13 @@ impl PartialEq for ServerDescription { impl ServerDescription { pub(crate) fn new(address: ServerAddress) -> Self { Self { - address: ServerAddress::Tcp { - host: address.host().to_lowercase(), - port: address.port(), + address: match address { + ServerAddress::Tcp { host, port } => ServerAddress::Tcp { + host: host.to_lowercase(), + port, + }, + #[cfg(unix)] + ServerAddress::Unix { path } => ServerAddress::Unix { path }, }, server_type: Default::default(), last_update_time: None, diff --git a/src/sdam/topology.rs b/src/sdam/topology.rs index 1f061da26..476a88615 100644 --- a/src/sdam/topology.rs +++ b/src/sdam/topology.rs @@ -446,6 +446,8 @@ impl TopologyWorker { let mut servers = diff.changed_servers.into_iter().collect::>(); servers.sort_by_key(|(addr, _)| match addr { ServerAddress::Tcp { host, port } => (host, port), + #[cfg(unix)] + ServerAddress::Unix { .. } => unreachable!(), }); servers }; @@ -477,6 +479,8 @@ impl TopologyWorker { let mut addresses = diff.removed_addresses.into_iter().collect::>(); addresses.sort_by_key(|addr| match addr { ServerAddress::Tcp { host, port } => (host, port), + #[cfg(unix)] + ServerAddress::Unix { .. } => unreachable!(), }); addresses }; @@ -513,6 +517,8 @@ impl TopologyWorker { let mut addresses = diff.added_addresses.into_iter().collect::>(); addresses.sort_by_key(|addr| match addr { ServerAddress::Tcp { host, port } => (host, port), + #[cfg(unix)] + ServerAddress::Unix { .. } => unreachable!(), }); addresses }; diff --git a/src/srv.rs b/src/srv.rs index 6789b2486..ce4b145ab 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -94,7 +94,8 @@ impl SrvResolver { let domain_name = &hostname_parts[1..]; - let mut hostname_parts: Vec<_> = address.host().split('.').collect(); + let host = address.host(); + let mut hostname_parts: Vec<_> = host.split('.').collect(); // Remove empty final section, which indicates a trailing dot. if hostname_parts.last().map(|s| s.is_empty()).unwrap_or(false) { diff --git a/src/test/spec/json/connection-string/valid-auth.json b/src/test/spec/json/connection-string/valid-auth.json index 4f684ff18..176a54a09 100644 --- a/src/test/spec/json/connection-string/valid-auth.json +++ b/src/test/spec/json/connection-string/valid-auth.json @@ -284,7 +284,7 @@ }, { "description": "Escaped username (GSSAPI)", - "uri": "mongodb://user%40EXAMPLE.COM:secret@localhost/?authMechanismProperties=SERVICE_NAME:other,CANONICALIZE_HOST_NAME:true&authMechanism=GSSAPI", + "uri": "mongodb://user%40EXAMPLE.COM:secret@localhost/?authMechanismProperties=SERVICE_NAME:other,CANONICALIZE_HOST_NAME:forward,SERVICE_HOST:example.com&authMechanism=GSSAPI", "valid": true, "warning": false, "hosts": [ @@ -303,7 +303,8 @@ "authmechanism": "GSSAPI", "authmechanismproperties": { "SERVICE_NAME": "other", - "CANONICALIZE_HOST_NAME": true + "SERVICE_HOST": "example.com", + "CANONICALIZE_HOST_NAME": "forward" } } }, diff --git a/src/test/spec/json/connection-string/valid-auth.yml b/src/test/spec/json/connection-string/valid-auth.yml index 01c866ee9..f40c748fa 100644 --- a/src/test/spec/json/connection-string/valid-auth.yml +++ b/src/test/spec/json/connection-string/valid-auth.yml @@ -222,7 +222,7 @@ tests: authmechanism: "MONGODB-X509" - description: "Escaped username (GSSAPI)" - uri: "mongodb://user%40EXAMPLE.COM:secret@localhost/?authMechanismProperties=SERVICE_NAME:other,CANONICALIZE_HOST_NAME:true&authMechanism=GSSAPI" + uri: "mongodb://user%40EXAMPLE.COM:secret@localhost/?authMechanismProperties=SERVICE_NAME:other,CANONICALIZE_HOST_NAME:forward,SERVICE_HOST:example.com&authMechanism=GSSAPI" valid: true warning: false hosts: @@ -238,7 +238,8 @@ tests: authmechanism: "GSSAPI" authmechanismproperties: SERVICE_NAME: "other" - CANONICALIZE_HOST_NAME: true + SERVICE_HOST: "example.com" + CANONICALIZE_HOST_NAME: "forward" - description: "At-signs in options aren't part of the userinfo" uri: "mongodb://alice:secret@example.com/admin?replicaset=my@replicaset" diff --git a/src/test/spec/json/connection-string/valid-warnings.json b/src/test/spec/json/connection-string/valid-warnings.json index 7ede9bdd5..1eacbf8fc 100644 --- a/src/test/spec/json/connection-string/valid-warnings.json +++ b/src/test/spec/json/connection-string/valid-warnings.json @@ -95,4 +95,4 @@ "options": null } ] -} \ No newline at end of file +} diff --git a/src/trace.rs b/src/trace.rs index c32c17a6f..acce9bb59 100644 --- a/src/trace.rs +++ b/src/trace.rs @@ -35,8 +35,10 @@ impl ServerAddress { pub(crate) fn port_tracing_representation(&self) -> Option { match self { Self::Tcp { port, .. } => Some(port.unwrap_or(DEFAULT_PORT)), - // TODO: RUST-802 For Unix domain sockets we should return None here, as ports + // For Unix domain sockets we should return None here, as ports // are not meaningful for those. + #[cfg(unix)] + Self::Unix { .. } => None, } } } diff --git a/src/trace/command.rs b/src/trace/command.rs index 659e779e8..836a59252 100644 --- a/src/trace/command.rs +++ b/src/trace/command.rs @@ -43,7 +43,7 @@ impl CommandEventHandler for CommandTracingEventEmitter { requestId = event.request_id, driverConnectionId = event.connection.id, serverConnectionId = event.connection.server_id, - serverHost = event.connection.address.host(), + serverHost = event.connection.address.host().as_ref(), serverPort = event.connection.address.port_tracing_representation(), serviceId = event.service_id.map(|id| id.tracing_representation()), "Command started" @@ -59,7 +59,7 @@ impl CommandEventHandler for CommandTracingEventEmitter { requestId = event.request_id, driverConnectionId = event.connection.id, serverConnectionId = event.connection.server_id, - serverHost = event.connection.address.host(), + serverHost = event.connection.address.host().as_ref(), serverPort = event.connection.address.port_tracing_representation(), serviceId = event.service_id.map(|id| id.tracing_representation()), durationMS = event.duration.as_millis(), @@ -76,7 +76,7 @@ impl CommandEventHandler for CommandTracingEventEmitter { requestId = event.request_id, driverConnectionId = event.connection.id, serverConnectionId = event.connection.server_id, - serverHost = event.connection.address.host(), + serverHost = event.connection.address.host().as_ref(), serverPort = event.connection.address.port_tracing_representation(), serviceId = event.service_id.map(|id| id.tracing_representation()), durationMS = event.duration.as_millis(), diff --git a/src/trace/connection.rs b/src/trace/connection.rs index 0d93a1157..0e596fa66 100644 --- a/src/trace/connection.rs +++ b/src/trace/connection.rs @@ -37,7 +37,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), maxIdleTimeMS = options_ref.and_then(|o| o.max_idle_time.map(|m| m.as_millis())), maxPoolSize = options_ref.and_then(|o| o.max_pool_size), @@ -50,7 +50,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), "Connection pool ready", ); @@ -60,7 +60,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), serviceId = event.service_id.map(|id| id.tracing_representation()), "Connection pool cleared", @@ -71,7 +71,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), "Connection pool closed", ); @@ -81,7 +81,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), driverConnectionId = event.connection_id, "Connection created", @@ -92,7 +92,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), driverConnectionId = event.connection_id, "Connection ready", @@ -103,7 +103,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), driverConnectionId = event.connection_id, reason = event.reason.tracing_representation(), @@ -116,7 +116,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), "Connection checkout started", ); @@ -126,7 +126,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), reason = event.reason.tracing_representation(), error = event.error.map(|e| e.tracing_representation()), @@ -138,7 +138,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), driverConnectionId = event.connection_id, "Connection checked out", @@ -149,7 +149,7 @@ impl CmapEventHandler for ConnectionTracingEventEmitter { tracing::debug!( target: CONNECTION_TRACING_EVENT_TARGET, topologyId = self.topology_id.tracing_representation(), - serverHost = event.address.host(), + serverHost = event.address.host().as_ref(), serverPort = event.address.port_tracing_representation(), driverConnectionId = event.connection_id, "Connection checked in", diff --git a/src/trace/server_selection.rs b/src/trace/server_selection.rs index bd4c47669..742c25e82 100644 --- a/src/trace/server_selection.rs +++ b/src/trace/server_selection.rs @@ -106,7 +106,7 @@ impl ServerSelectionTracingEventEmitter<'_> { operation = self.operation_name, selector = self.criteria.tracing_representation(), topologyDescription = topology_description.tracing_representation(), - serverHost = server.address().host(), + serverHost = server.address().host().as_ref(), serverPort = server.address().port_tracing_representation(), "Server selection succeeded" );