Skip to content

Commit

Permalink
RUST-802 Support Unix Domain Sockets (#908)
Browse files Browse the repository at this point in the history
  • Loading branch information
PureWhiteWu authored Jul 20, 2023
1 parent c242539 commit 3ea6bc1
Show file tree
Hide file tree
Showing 15 changed files with 443 additions and 135 deletions.
23 changes: 0 additions & 23 deletions src/client/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,29 +397,6 @@ pub struct Credential {
}

impl Credential {
#[cfg(all(test, not(feature = "sync"), not(feature = "tokio-sync")))]
pub(crate) fn into_document(mut self) -> Document {
use crate::bson::Bson;

let mut doc = Document::new();

if let Some(s) = self.username.take() {
doc.insert("username", s);
}

if let Some(s) = self.password.take() {
doc.insert("password", s);
} else {
doc.insert("password", Bson::Null);
}

if let Some(s) = self.source.take() {
doc.insert("db", s);
}

doc
}

pub(crate) fn resolved_source(&self) -> &str {
self.mechanism
.as_ref()
Expand Down
125 changes: 102 additions & 23 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod test;
mod resolver_config;

use std::{
borrow::Cow,
cmp::Ordering,
collections::HashSet,
convert::TryFrom,
Expand Down Expand Up @@ -91,14 +92,11 @@ lazy_static! {
};

static ref ILLEGAL_DATABASE_CHARACTERS: HashSet<&'static char> = {
['/', '\\', ' ', '"', '$', '.'].iter().collect()
['/', '\\', ' ', '"', '$'].iter().collect()
};
}

/// An enum representing the address of a MongoDB server.
///
/// Currently this just supports addresses that can be connected to over TCP, but alternative
/// address types may be supported in the future (e.g. Unix Domain Socket paths).
#[derive(Clone, Debug, Eq, Serialize)]
#[non_exhaustive]
pub enum ServerAddress {
Expand All @@ -112,6 +110,12 @@ pub enum ServerAddress {
/// The default is 27017.
port: Option<u16>,
},
/// A Unix Domain Socket path.
#[cfg(unix)]
Unix {
/// The path to the Unix Domain Socket.
path: PathBuf,
},
}

impl<'de> Deserialize<'de> for ServerAddress {
Expand Down Expand Up @@ -144,6 +148,10 @@ impl PartialEq for ServerAddress {
port: other_port,
},
) => host == other_host && port.unwrap_or(27017) == other_port.unwrap_or(27017),
#[cfg(unix)]
(Self::Unix { path }, Self::Unix { path: other_path }) => path == other_path,
#[cfg(unix)]
_ => false,
}
}
}
Expand All @@ -158,6 +166,8 @@ impl Hash for ServerAddress {
host.hash(state);
port.unwrap_or(27017).hash(state);
}
#[cfg(unix)]
Self::Unix { path } => path.hash(state),
}
}
}
Expand All @@ -173,6 +183,15 @@ impl ServerAddress {
/// Parses an address string into a `ServerAddress`.
pub fn parse(address: impl AsRef<str>) -> Result<Self> {
let address = address.as_ref();
// checks if the address is a unix domain socket
#[cfg(unix)]
{
if address.ends_with(".sock") {
return Ok(ServerAddress::Unix {
path: PathBuf::from(address),
});
}
}
let mut parts = address.split(':');
let hostname = match parts.next() {
Some(part) => {
Expand Down Expand Up @@ -243,18 +262,29 @@ impl ServerAddress {
"port": port.map(|i| Bson::Int32(i.into())).unwrap_or(Bson::Null)
}
}
#[cfg(unix)]
Self::Unix { path } => {
doc! {
"host": path.to_string_lossy().as_ref(),
"port": Bson::Null,
}
}
}
}

pub(crate) fn host(&self) -> &str {
pub(crate) fn host(&self) -> Cow<'_, str> {
match self {
Self::Tcp { host, .. } => host.as_str(),
Self::Tcp { host, .. } => Cow::Borrowed(host.as_str()),
#[cfg(unix)]
Self::Unix { path } => path.to_string_lossy(),
}
}

pub(crate) fn port(&self) -> Option<u16> {
match self {
Self::Tcp { port, .. } => *port,
#[cfg(unix)]
Self::Unix { .. } => None,
}
}
}
Expand All @@ -265,6 +295,8 @@ impl fmt::Display for ServerAddress {
Self::Tcp { host, port } => {
write!(fmt, "{}:{}", host, port.unwrap_or(DEFAULT_PORT))
}
#[cfg(unix)]
Self::Unix { path } => write!(fmt, "{}", path.display()),
}
}
}
Expand Down Expand Up @@ -1580,10 +1612,26 @@ impl ConnectionString {
None => (None, None),
};

let host_list: Result<Vec<_>> =
hosts_section.split(',').map(ServerAddress::parse).collect();

let host_list = host_list?;
let mut host_list = Vec::with_capacity(hosts_section.len());
for host in hosts_section.split(',') {
let address = if host.ends_with(".sock") {
#[cfg(unix)]
{
ServerAddress::parse(percent_decode(
host,
"Unix domain sockets must be URL-encoded",
)?)
}
#[cfg(not(unix))]
return Err(ErrorKind::InvalidArgument {
message: "Unix domain sockets are not supported on this platform".to_string(),
}
.into());
} else {
ServerAddress::parse(host)
}?;
host_list.push(address);
}

let hosts = if srv {
if host_list.len() != 1 {
Expand All @@ -1592,16 +1640,26 @@ impl ConnectionString {
}
.into());
}
// Unwrap safety: the `len` check above guarantees this can't fail.
let ServerAddress::Tcp { host, port } = host_list.into_iter().next().unwrap();

if port.is_some() {
return Err(ErrorKind::InvalidArgument {
message: "a port cannot be specified with 'mongodb+srv'".into(),
// Unwrap safety: the `len` check above guarantees this can't fail.
match host_list.into_iter().next().unwrap() {
ServerAddress::Tcp { host, port } => {
if port.is_some() {
return Err(ErrorKind::InvalidArgument {
message: "a port cannot be specified with 'mongodb+srv'".into(),
}
.into());
}
HostInfo::DnsRecord(host)
}
#[cfg(unix)]
ServerAddress::Unix { .. } => {
return Err(ErrorKind::InvalidArgument {
message: "unix sockets cannot be used with 'mongodb+srv'".into(),
}
.into());
}
.into());
}
HostInfo::DnsRecord(host)
} else {
HostInfo::HostIdentifiers(host_list)
};
Expand Down Expand Up @@ -2299,18 +2357,39 @@ mod tests {
#[test]
fn test_parse_address_with_from_str() {
let x = "localhost:27017".parse::<ServerAddress>().unwrap();
let ServerAddress::Tcp { host, port } = x;
assert_eq!(host, "localhost");
assert_eq!(port, Some(27017));
match x {
ServerAddress::Tcp { host, port } => {
assert_eq!(host, "localhost");
assert_eq!(port, Some(27017));
}
#[cfg(unix)]
_ => panic!("expected ServerAddress::Tcp"),
}

// Port defaults to 27017 (so this doesn't fail)
let x = "localhost".parse::<ServerAddress>().unwrap();
let ServerAddress::Tcp { host, port } = x;
assert_eq!(host, "localhost");
assert_eq!(port, None);
match x {
ServerAddress::Tcp { host, port } => {
assert_eq!(host, "localhost");
assert_eq!(port, None);
}
#[cfg(unix)]
_ => panic!("expected ServerAddress::Tcp"),
}

let x = "localhost:not a number".parse::<ServerAddress>();
assert!(x.is_err());

#[cfg(unix)]
{
let x = "/path/to/socket.sock".parse::<ServerAddress>().unwrap();
match x {
ServerAddress::Unix { path } => {
assert_eq!(path.to_str().unwrap(), "/path/to/socket.sock");
}
_ => panic!("expected ServerAddress::Unix"),
}
}
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
Expand Down
Loading

0 comments on commit 3ea6bc1

Please sign in to comment.