diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 243cffc..016c7e2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -12,7 +12,7 @@ jobs: runs-on: "ubuntu-latest" strategy: matrix: - rust-version: ["1.70", "stable"] + rust-version: [ "1.70", "stable" ] steps: - name: Pull Neo4j Docker Image run: docker pull neo4j:5.6-enterprise @@ -57,14 +57,24 @@ jobs: TEST_NEO4J_PASS: pass TEST_NEO4J_EDITION: enterprise TEST_NEO4J_VERSION: 5.6 - run: cargo test --all + run: cargo test --workspace --all-targets + - name: doc tests + env: + TEST_NEO4J_SCHEME: neo4j + TEST_NEO4J_HOST: localhost + TEST_NEO4J_PORT: 7687 + TEST_NEO4J_USER: neo4j + TEST_NEO4J_PASS: pass + TEST_NEO4J_EDITION: enterprise + TEST_NEO4J_VERSION: 5.6 + run: cargo test --workspace --doc testkit: name: TestKit - needs: [tests] + needs: [ tests ] runs-on: ubuntu-latest strategy: matrix: - tests: [TESTKIT_TESTS] + tests: [ TESTKIT_TESTS ] config: - 4.4-community-bolt - 4.4-community-neo4j diff --git a/doc_test_utils/src/lib.rs b/doc_test_utils/src/lib.rs index 05a71d4..fce52c4 100644 --- a/doc_test_utils/src/lib.rs +++ b/doc_test_utils/src/lib.rs @@ -63,7 +63,7 @@ pub fn get_driver() -> Driver { .is_test(true) .try_init(); - let driver = neo4j::driver::Driver::new( + let driver = Driver::new( ConnectionConfig::new(get_address()), DriverConfig::new().with_auth(Arc::new(get_auth_token())), ); diff --git a/examples/basic.rs b/neo4j/examples/basic.rs similarity index 96% rename from examples/basic.rs rename to neo4j/examples/basic.rs index 06fcadd..51f5b61 100644 --- a/examples/basic.rs +++ b/neo4j/examples/basic.rs @@ -65,3 +65,13 @@ fn main() { assert_eq!(record.take_value("x"), Some(ValueReceive::Integer(123))); } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_main() { + main(); + } +} diff --git a/neo4j/src/address_.rs b/neo4j/src/address_.rs index 7b2e8df..ecf18d5 100644 --- a/neo4j/src/address_.rs +++ b/neo4j/src/address_.rs @@ -25,6 +25,10 @@ use std::vec::IntoIter; use crate::error_::Result; use resolution::{AddressResolver, CustomResolution, DnsResolution}; +// imports for docs +#[allow(unused)] +use crate::driver::DriverConfig; + pub(crate) const DEFAULT_PORT: u16 = 7687; const COLON_BYTES: usize = ':'.len_utf8(); @@ -72,6 +76,10 @@ pub struct Address { pub(crate) is_dns_resolved: bool, } +/// Note that equality of addresses is defined as equality of its [`Address::unresolved_host()`] +/// and [`Address::port()`] only. +/// Therefore, resolved to different IP addresses coming from the same host are considered equal +/// if their port is equal as well. impl PartialEq for Address { #[inline] fn eq(&self, other: &Self) -> bool { @@ -128,6 +136,41 @@ impl Address { } /// Return the host name or IP address. + /// + /// For addresses that have been resolved by the driver, this will be the final IP address after + /// all resolutions. + /// This includes: + /// * potential custom address resolver, see [`DriverConfig::with_resolver`] + /// * DNS resolution, see [`ToSocketAddrs`]. + /// + /// # Example + /// ``` + /// use neo4j::address::Address; + /// + /// let addr = Address::from(("localhost", 1234)); + /// assert_eq!(addr.host(), "localhost"); + /// ``` + /// + /// # Example (after resolution) + /// ``` + /// use neo4j::address::Address; + /// use neo4j::driver::Driver; + /// + /// use std::net::ToSocketAddrs; + /// # use doc_test_utils::get_address; + /// + /// let address: Address = get_address(); + /// # fn get_driver(_: &Address) -> Driver { + /// # doc_test_utils::get_driver() + /// # } + /// let driver: Driver = get_driver(&address); + /// let resolved_address = driver.get_server_info().unwrap().address; + /// + /// assert!(address + /// .to_socket_addrs() + /// .unwrap() + /// .any(|sock_addr| { &sock_addr.ip().to_string() == resolved_address.host() })); + /// ``` pub fn host(&self) -> &str { self.host.as_str() } @@ -137,7 +180,31 @@ impl Address { self.port } - pub(crate) fn unresolved_host(&self) -> &str { + /// Return the host name (before a potential DNS resolution). + /// + /// When a custom address resolver is registered with the driver (see + /// [`DriverConfig::with_resolver`]), `unresolved_host` will return the host name + /// from after the custom address resolver. + /// + /// # Example + /// ``` + /// use neo4j::address::Address; + /// use neo4j::driver::Driver; + /// # use doc_test_utils::get_address; + /// + /// let address: Address = get_address(); + /// # fn get_driver(_: &Address) -> Driver { + /// # doc_test_utils::get_driver() + /// # } + /// let driver: Driver = get_driver(&address); + /// let resolved_address = driver.get_server_info().unwrap().address; + /// + /// assert_eq!(address.host(), resolved_address.unresolved_host()); + /// // but not necessarily + /// // assert_eq!(address.host(), resolved_address.host()); + /// // because resolved_address.host() will be DNS resolved. + /// ``` + pub fn unresolved_host(&self) -> &str { &self.key } } diff --git a/neo4j/src/driver/io/bolt.rs b/neo4j/src/driver/io/bolt.rs index 6bba9ef..8fb350d 100644 --- a/neo4j/src/driver/io/bolt.rs +++ b/neo4j/src/driver/io/bolt.rs @@ -21,6 +21,7 @@ mod bolt5x2; mod bolt5x3; mod bolt_state; mod chunk; +mod handshake; mod message; pub(crate) mod message_parameters; mod packstream; @@ -31,8 +32,8 @@ use std::borrow::Borrow; use std::collections::{HashMap, VecDeque}; use std::fmt::{Debug, Formatter}; use std::hash::{Hash, Hasher}; -use std::io::{self, Read, Write}; -use std::net::{Shutdown, SocketAddr, TcpStream, ToSocketAddrs}; +use std::io::{Read, Write}; +use std::net::{Shutdown, TcpStream}; use std::ops::Deref; use std::result; use std::sync::atomic::{AtomicBool, Ordering}; @@ -41,16 +42,12 @@ use std::time::Duration; use atomic_refcell::AtomicRefCell; use enum_dispatch::enum_dispatch; -use log::Level::Trace; -use log::{debug, log_enabled, trace}; -use rustls::ClientConfig; -use socket2::{Socket as Socket2, TcpKeepalive}; +use log::debug; use usize_cast::FromUsize; use super::deadline::DeadlineIO; use crate::address_::Address; use crate::driver::auth::AuthToken; -use crate::driver::config::KeepAliveConfig; use crate::driver::io::bolt::message_parameters::ResetParameters; use crate::error_::{Neo4jError, Result, ServerError}; use crate::time::Instant; @@ -62,6 +59,7 @@ use bolt5x2::{Bolt5x2, Bolt5x2StructTranslator}; use bolt5x3::{Bolt5x3, Bolt5x3StructTranslator}; use bolt_state::{BoltState, BoltStateTracker}; use chunk::{Chunker, Dechunker}; +pub(crate) use handshake::{open, TcpConnector}; use message::BoltMessage; use message_parameters::{ BeginParameters, CommitParameters, DiscardParameters, GoodbyeParameters, HelloParameters, @@ -574,11 +572,11 @@ impl BoltData { qid == -1 || Some(qid) == *(self.last_qid.deref().borrow()) } - fn serialize_dict>( + fn serialize_dict( &self, serializer: &mut S, - translator: &T, - map: &HashMap, + translator: &impl BoltStructTranslator, + map: &HashMap, ValueSend>, ) -> result::Result<(), S::Error> { serializer.write_dict_header(u64::from_usize(map.len()))?; for (k, v) in map { @@ -588,10 +586,10 @@ impl BoltData { Ok(()) } - fn serialize_str_slice>( + fn serialize_str_slice( &self, serializer: &mut S, - slice: &[V], + slice: &[impl Borrow], ) -> result::Result<(), S::Error> { serializer.write_list_header(u64::from_usize(slice.len()))?; for v in slice { @@ -600,35 +598,23 @@ impl BoltData { Ok(()) } - fn serialize_str_iter, I: Iterator>( + #[inline] + fn serialize_str_iter( &self, serializer: &mut S, - iter: I, + iter: impl Iterator>, ) -> result::Result<(), S::Error> { self.serialize_str_slice(serializer, &iter.collect::>()) } - fn serialize_value( + #[inline] + fn serialize_value( &self, serializer: &mut S, - translator: &T, + translator: &impl BoltStructTranslator, v: &ValueSend, ) -> result::Result<(), S::Error> { - translator.serialize(serializer, v).map_err(Into::into) - } - - fn serialize_routing_context( - &self, - serializer: &mut S, - translator: &T, - routing_context: &HashMap, - ) -> result::Result<(), S::Error> { - serializer.write_dict_header(u64::from_usize(routing_context.len()))?; - for (k, v) in routing_context { - serializer.write_string(k.borrow())?; - self.serialize_value(serializer, translator, v)?; - } - Ok(()) + translator.serialize(serializer, v) } fn write_all(&mut self, deadline: Option) -> Result<()> { @@ -807,176 +793,3 @@ fn assert_response_field_count(name: &str, fields: &[T], expected_count: usiz ))) } } - -const BOLT_MAGIC_PREAMBLE: [u8; 4] = [0x60, 0x60, 0xB0, 0x17]; -// [bolt-version-bump] search tag when changing bolt version support -const BOLT_VERSION_OFFER: [u8; 16] = [ - 0, 3, 3, 5, // BOLT 5.3 - 5.0 - 0, 0, 4, 4, // BOLT 4.4 - 0, 0, 0, 0, // - - 0, 0, 0, 0, // - -]; - -pub(crate) fn open( - address: Arc
, - deadline: Option, - mut connect_timeout: Option, - keep_alive: Option, - tls_config: Option>, -) -> Result { - if log_enabled!(Trace) { - trace!( - "{}{}", - dbg_extra(None, None), - format!("C: {address:?}") - ); - } else { - debug!( - "{}{}", - dbg_extra(None, None), - format!("C: {address}") - ); - } - if let Some(deadline) = deadline { - let mut time_left = deadline.saturating_duration_since(Instant::now()); - if time_left == Duration::from_secs(0) { - time_left = Duration::from_nanos(1); - } - match connect_timeout { - None => connect_timeout = Some(time_left), - Some(timeout) => connect_timeout = Some(timeout.min(time_left)), - } - } - let raw_socket = Neo4jError::wrap_connect(match connect_timeout { - None => TcpStream::connect(&*address), - Some(timeout) => each_addr(&*address, |addr| TcpStream::connect_timeout(addr?, timeout)), - })?; - let raw_socket = set_tcp_keepalive(raw_socket, keep_alive)?; - let local_port = raw_socket - .local_addr() - .map(|addr| addr.port()) - .unwrap_or_default(); - - let buffered_socket = BufTcpStream::new(&raw_socket, true)?; - let mut socket = Socket::new(buffered_socket, address.unresolved_host(), tls_config)?; - - let mut deadline_io = DeadlineIO::new(&mut socket, deadline, Some(&raw_socket)); - - socket_debug!(local_port, "C: {:02X?}", BOLT_MAGIC_PREAMBLE); - wrap_write_socket( - &raw_socket, - local_port, - deadline_io.write_all(&BOLT_MAGIC_PREAMBLE), - )?; - socket_debug!(local_port, "C: {:02X?}", BOLT_VERSION_OFFER); - wrap_write_socket( - &raw_socket, - local_port, - deadline_io.write_all(&BOLT_VERSION_OFFER), - )?; - wrap_write_socket(&raw_socket, local_port, deadline_io.flush())?; - - let mut negotiated_version = [0u8; 4]; - wrap_read_socket( - &raw_socket, - local_port, - deadline_io.read_exact(&mut negotiated_version), - )?; - socket_debug!(local_port, "S: {:02X?}", negotiated_version); - - // [bolt-version-bump] search tag when changing bolt version support - let version = match negotiated_version { - [0, 0, 0, 0] => Err(Neo4jError::InvalidConfig { - message: String::from("server version not supported"), - }), - [0, 0, 3, 5] => Ok((5, 3)), - [0, 0, 2, 5] => Ok((5, 2)), - [0, 0, 1, 5] => Ok((5, 1)), - [0, 0, 0, 5] => Ok((5, 0)), - [0, 0, 4, 4] => Ok((4, 4)), - [72, 84, 84, 80] => { - // "HTTP" - Err(Neo4jError::InvalidConfig { - message: format!( - "unexpected server handshake response {:?} (looks like HTTP)", - &negotiated_version - ), - }) - } - _ => Err(Neo4jError::InvalidConfig { - message: format!( - "unexpected server handshake response {:?}", - &negotiated_version - ), - }), - }?; - - Ok(Bolt::new( - version, - socket, - Arc::new(Some(raw_socket)), - Some(local_port), - address, - )) -} - -// copied from std::net -fn each_addr(addr: A, mut f: F) -> io::Result -where - F: FnMut(io::Result<&SocketAddr>) -> io::Result, -{ - let addrs = match addr.to_socket_addrs() { - Ok(addrs) => addrs, - Err(e) => return f(Err(e)), - }; - let mut last_err = None; - for addr in addrs { - match f(Ok(&addr)) { - Ok(l) => return Ok(l), - Err(e) => last_err = Some(e), - } - } - Err(last_err.unwrap_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve to any addresses", - ) - })) -} - -fn set_tcp_keepalive(socket: TcpStream, keep_alive: Option) -> Result { - let keep_alive = match keep_alive { - None => return Ok(socket), - Some(KeepAliveConfig::Default) => TcpKeepalive::new(), - Some(KeepAliveConfig::CustomTime(time)) => TcpKeepalive::new().with_time(time), - }; - let socket = Socket2::from(socket); - socket - .set_tcp_keepalive(&keep_alive) - .map_err(|err| Neo4jError::InvalidConfig { - message: format!("failed to set tcp keepalive: {}", err), - })?; - Ok(socket.into()) -} - -fn wrap_write_socket(stream: &TcpStream, local_port: u16, res: io::Result) -> Result { - match res { - Ok(res) => Ok(res), - Err(err) => { - socket_debug!(local_port, " write error: {}", err); - let _ = stream.shutdown(Shutdown::Both); - Neo4jError::wrap_write(Err(err)) - } - } -} - -fn wrap_read_socket(stream: &TcpStream, local_port: u16, res: io::Result) -> Result { - match res { - Ok(res) => Ok(res), - Err(err) => { - socket_debug!(local_port, " read error: {}", err); - let _ = stream.shutdown(Shutdown::Both); - Neo4jError::wrap_read(Err(err)) - } - } -} diff --git a/neo4j/src/driver/io/bolt/bolt4x4/protocol.rs b/neo4j/src/driver/io/bolt/bolt4x4/protocol.rs index 4c19d47..c698348 100644 --- a/neo4j/src/driver/io/bolt/bolt4x4/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt4x4/protocol.rs @@ -13,12 +13,12 @@ // limitations under the License. use std::borrow::Borrow; +use std::collections::HashMap; use std::fmt::Debug; -use std::io::{Error as IoError, Read, Write}; +use std::io::{Read, Write}; use std::mem; use std::ops::Deref; use std::sync::Arc; -use std::time::Duration; use atomic_refcell::AtomicRefCell; use log::{debug, log_enabled, warn, Level}; @@ -65,6 +65,66 @@ impl Bolt4x4 { protocol_version, } } + + pub(in super::super) fn write_utc_patch_entry( + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + data: &BoltData, + ) -> Result<()> { + serializer.write_string("patch_bolt")?; + data.serialize_str_slice(serializer, &["utc"])?; + debug_buf!(log_buf, "{}", { + dbg_serializer.write_string("patch_bolt").unwrap(); + data.serialize_str_slice(dbg_serializer, &["utc"]).unwrap(); + dbg_serializer.flush() + }); + Ok(()) + } + + pub(in super::super) fn hello_response_handle_utc_patch( + hints: &HashMap, + translator: &AtomicRefCell, + ) { + if let Some(value) = hints.get(PATCH_BOLT_KEY) { + match value { + ValueReceive::List(value) => { + for entry in value { + match entry { + ValueReceive::String(s) if s == "utc" => { + translator.borrow_mut().enable_utc_patch(); + } + _ => {} + } + } + } + _ => { + warn!("Server sent unexpected {PATCH_BOLT_KEY} type {:?}", value); + } + } + } + } + + pub(in super::super) fn enqueue_hello_response(&self, data: &mut BoltData) { + let bolt_meta = Arc::clone(&data.meta); + let bolt_server_agent = Arc::clone(&data.server_agent); + let socket = Arc::clone(&data.socket); + let translator = Arc::clone(&self.translator); + + data.responses.push_back(BoltResponse::new( + ResponseMessage::Hello, + ResponseCallbacks::new().with_on_success(move |mut meta| { + Bolt5x0::::hello_response_handle_agent(&mut meta, &bolt_server_agent); + Self::hello_response_handle_utc_patch(&meta, &translator); + Bolt5x0::::hello_response_handle_connection_hints( + &meta, + socket.deref().as_ref(), + ); + mem::swap(&mut *bolt_meta.borrow_mut(), &mut meta); + Ok(()) + }), + )); + } } impl Default for Bolt4x4 { @@ -89,7 +149,6 @@ impl BoltProtocol f .check_no_notification_filter(Some(notification_filter))?; debug_buf_start!(log_buf); debug_buf!(log_buf, "C: HELLO"); - let translator = &*(*self.translator).borrow(); let mut dbg_serializer = PackStreamSerializerDebugImpl::new(); let mut message_buff = Vec::new(); let mut serializer = PackStreamSerializerImpl::new(&mut message_buff); @@ -99,133 +158,41 @@ impl BoltProtocol f + >::into(routing_context.is_some()) + u64::from_usize(auth.data.len()); serializer.write_dict_header(extra_size)?; - serializer.write_string("user_agent")?; - serializer.write_string(user_agent)?; debug_buf!(log_buf, " {}", { dbg_serializer.write_dict_header(extra_size).unwrap(); - dbg_serializer.write_string("user_agent").unwrap(); - dbg_serializer.write_string(user_agent).unwrap(); dbg_serializer.flush() }); - serializer.write_string("patch_bolt")?; - data.serialize_str_slice(&mut serializer, &["utc"])?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("patch_bolt").unwrap(); - data.serialize_str_slice(&mut dbg_serializer, &["utc"]) - .unwrap(); - dbg_serializer.flush() - }); + Bolt5x0::::write_user_agent_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + user_agent, + )?; - if let Some(routing_context) = routing_context { - serializer.write_string("routing")?; - data.serialize_routing_context(&mut serializer, translator, routing_context)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("routing").unwrap(); - data.serialize_routing_context(&mut dbg_serializer, translator, routing_context) - .unwrap(); - dbg_serializer.flush() - }); - } + Self::write_utc_patch_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, data)?; - for (k, v) in &auth.data { - serializer.write_string(k)?; - data.serialize_value(&mut serializer, translator, v)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string(k).unwrap(); - if k == "credentials" { - dbg_serializer.write_string("**********").unwrap(); - } else { - data.serialize_value(&mut dbg_serializer, translator, v) - .unwrap(); - } - dbg_serializer.flush() - }); - } + self.bolt5x0.write_routing_context_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + routing_context, + )?; + + self.bolt5x0.write_auth_entries( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + auth, + )?; data.auth = Some(Arc::clone(auth)); data.message_buff.push_back(vec![message_buff]); debug_buf_end!(data, log_buf); - let bolt_meta = Arc::clone(&data.meta); - let translator = Arc::clone(&self.translator); - let bolt_server_agent = Arc::clone(&data.server_agent); - let socket = Arc::clone(&data.socket); - data.responses.push_back(BoltResponse::new( - ResponseMessage::Hello, - ResponseCallbacks::new().with_on_success(move |mut meta| { - if let Some((key, value)) = meta.remove_entry(SERVER_AGENT_KEY) { - match value { - ValueReceive::String(value) => { - mem::swap(&mut *bolt_server_agent.borrow_mut(), &mut Arc::new(value)); - } - _ => { - warn!( - "Server sent unexpected {SERVER_AGENT_KEY} type {:?}", - &value - ); - meta.insert(key, value); - } - } - } - if let Some(value) = meta.get(PATCH_BOLT_KEY) { - match value { - ValueReceive::List(value) => { - for entry in value { - match entry { - ValueReceive::String(s) if s == "utc" => { - translator.borrow_mut().enable_utc_patch(); - } - _ => {} - } - } - } - _ => { - warn!("Server sent unexpected {PATCH_BOLT_KEY} type {:?}", value); - - } - } - } - if let Some(value) = meta.get(HINTS_KEY) { - match value { - ValueReceive::Map(value) => { - if let Some(timeout) = value.get(RECV_TIMEOUT_KEY) { - match timeout { - ValueReceive::Integer(timeout) if timeout > &0 => { - socket.deref().as_ref().map(|socket| { - let timeout = Some(Duration::from_secs(*timeout as u64)); - socket.set_read_timeout(timeout)?; - socket.set_write_timeout(timeout)?; - Ok(()) - }).transpose().unwrap_or_else(|err: IoError| { - warn!("Failed to set socket timeout as hinted by the server: {err}"); - None - }); - } - ValueReceive::Integer(_) => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} value {:?}", - timeout - ); - } - _ => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} type {:?}", - timeout - ); - } - } - } - } - _ => { - warn!("Server sent unexpected {HINTS_KEY} type {:?}", value); - } - } - } - mem::swap(&mut *bolt_meta.borrow_mut(), &mut meta); - Ok(()) - }), - )); + self.enqueue_hello_response(data); Ok(()) } diff --git a/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs index db6c1f5..f402717 100644 --- a/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs @@ -13,13 +13,16 @@ // limitations under the License. use std::borrow::Borrow; +use std::collections::HashMap; use std::fmt::Debug; use std::io::{Error as IoError, Read, Write}; use std::mem; +use std::net::TcpStream; use std::ops::Deref; use std::sync::Arc; use std::time::Duration; +use atomic_refcell::AtomicRefCell; use log::{debug, log_enabled, warn, Level}; use usize_cast::FromUsize; @@ -36,12 +39,14 @@ use super::super::packstream::{ }; use super::super::{ assert_response_field_count, bolt_debug, bolt_debug_extra, dbg_extra, debug_buf, debug_buf_end, - debug_buf_start, BoltData, BoltProtocol, BoltResponse, BoltStructTranslator, ConnectionState, - OnServerErrorCb, ResponseCallbacks, ResponseMessage, + debug_buf_start, BoltData, BoltMeta, BoltProtocol, BoltResponse, BoltStructTranslator, + ConnectionState, OnServerErrorCb, ResponseCallbacks, ResponseMessage, }; +use crate::driver::config::auth::AuthToken; use crate::driver::config::notification::NotificationFilter; +use crate::driver::session::bookmarks::Bookmarks; use crate::error_::{Neo4jError, Result, ServerError}; -use crate::value::ValueReceive; +use crate::value::{ValueReceive, ValueSend}; const SERVER_AGENT_KEY: &str = "server"; const HINTS_KEY: &str = "hints"; @@ -130,6 +135,384 @@ impl Bolt5x0 { } Ok(()) } + + pub(in super::super) fn write_str_entry( + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + key: &str, + value: &str, + ) -> Result<()> { + serializer.write_string(key)?; + serializer.write_string(value)?; + debug_buf!(log_buf, " {}", { + dbg_serializer.write_string(key).unwrap(); + dbg_serializer.write_string(value).unwrap(); + dbg_serializer.flush() + }); + Ok(()) + } + + pub(in super::super) fn write_int_entry( + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + key: &str, + value: i64, + ) -> Result<()> { + serializer.write_string(key)?; + serializer.write_int(value)?; + debug_buf!(log_buf, " {}", { + dbg_serializer.write_string(key).unwrap(); + dbg_serializer.write_int(value).unwrap(); + dbg_serializer.flush() + }); + Ok(()) + } + + pub(in super::super) fn write_dict_entry( + &self, + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + data: &BoltData, + key: &str, + value: &HashMap + Debug, ValueSend>, + ) -> Result<()> { + serializer.write_string(key)?; + data.serialize_dict(serializer, &self.translator, value)?; + debug_buf!(log_buf, " {}", { + dbg_serializer.write_string(key).unwrap(); + data.serialize_dict(dbg_serializer, &self.translator, value) + .unwrap(); + dbg_serializer.flush() + }); + Ok(()) + } + + pub(in super::super) fn write_user_agent_entry( + log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + user_agent: &str, + ) -> Result<()> { + Self::write_str_entry( + log_buf, + serializer, + dbg_serializer, + "user_agent", + user_agent, + ) + } + + pub(in super::super) fn write_routing_context_entry( + &self, + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + data: &BoltData, + routing_context: Option<&HashMap>, + ) -> Result<()> { + if let Some(routing_context) = routing_context { + serializer.write_string("routing")?; + data.serialize_dict(serializer, &self.translator, routing_context)?; + debug_buf!(log_buf, "{}", { + dbg_serializer.write_string("routing").unwrap(); + data.serialize_dict(dbg_serializer, &self.translator, routing_context) + .unwrap(); + dbg_serializer.flush() + }); + } + Ok(()) + } + + pub(in super::super) fn write_auth_entries( + &self, + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + data: &BoltData, + auth: &Arc, + ) -> Result<()> { + for (k, v) in &auth.data { + serializer.write_string(k)?; + data.serialize_value(serializer, &self.translator, v)?; + debug_buf!(log_buf, "{}", { + dbg_serializer.write_string(k).unwrap(); + if k == "credentials" { + dbg_serializer.write_string("**********").unwrap(); + } else { + data.serialize_value(dbg_serializer, &self.translator, v) + .unwrap(); + } + dbg_serializer.flush() + }); + } + Ok(()) + } + + pub(in super::super) fn hello_response_handle_agent( + meta: &mut BoltMeta, + bolt_server_agent: &AtomicRefCell>, + ) { + if let Some((key, value)) = meta.remove_entry(SERVER_AGENT_KEY) { + match value { + ValueReceive::String(value) => { + mem::swap(&mut *bolt_server_agent.borrow_mut(), &mut Arc::new(value)); + } + _ => { + warn!("Server sent unexpected server_agent type {:?}", &value); + meta.insert(key, value); + } + } + } + } + + pub(in super::super) fn hello_response_handle_timeout_hint( + hints: &HashMap, + socket: Option<&TcpStream>, + ) { + if let Some(timeout) = hints.get(RECV_TIMEOUT_KEY) { + match timeout { + ValueReceive::Integer(timeout) if timeout > &0 => { + socket + .map(|socket| { + let timeout = Some(Duration::from_secs(*timeout as u64)); + socket.set_read_timeout(timeout)?; + socket.set_write_timeout(timeout)?; + Ok(()) + }) + .transpose() + .unwrap_or_else(|err: IoError| { + warn!("Failed to set socket timeout as hinted by the server: {err}"); + None + }); + } + ValueReceive::Integer(_) => { + warn!( + "Server sent unexpected {RECV_TIMEOUT_KEY} value {:?}", + timeout + ); + } + _ => { + warn!( + "Server sent unexpected {RECV_TIMEOUT_KEY} type {:?}", + timeout + ); + } + } + } + } + + pub(in super::super) fn hello_response_handle_connection_hints( + meta: &BoltMeta, + socket: Option<&TcpStream>, + ) { + if let Some(value) = meta.get(HINTS_KEY) { + match value { + ValueReceive::Map(value) => { + Self::hello_response_handle_timeout_hint(value, socket); + } + _ => { + warn!("Server sent unexpected {HINTS_KEY} type {:?}", value); + } + } + } + } + + pub(in super::super) fn enqueue_hello_response(data: &mut BoltData) { + let bolt_meta = Arc::clone(&data.meta); + let bolt_server_agent = Arc::clone(&data.server_agent); + let socket = Arc::clone(&data.socket); + + data.responses.push_back(BoltResponse::new( + ResponseMessage::Hello, + ResponseCallbacks::new().with_on_success(move |mut meta| { + Self::hello_response_handle_agent(&mut meta, &bolt_server_agent); + Self::hello_response_handle_connection_hints(&meta, socket.deref().as_ref()); + mem::swap(&mut *bolt_meta.borrow_mut(), &mut meta); + Ok(()) + }), + )); + } + + pub(in super::super) fn write_parameter_dict( + &self, + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + data: &BoltData, + parameters: Option<&HashMap + Debug, ValueSend>>, + ) -> Result<()> { + match parameters { + Some(parameters) => { + data.serialize_dict(serializer, &self.translator, parameters)?; + debug_buf!(log_buf, " {}", { + data.serialize_dict(dbg_serializer, &self.translator, parameters) + .unwrap(); + dbg_serializer.flush() + }); + } + None => { + serializer.write_dict_header(0)?; + debug_buf!(log_buf, " {}", { + dbg_serializer.write_dict_header(0).unwrap(); + dbg_serializer.flush() + }); + } + } + Ok(()) + } + + pub(in super::super) fn write_bookmarks_entry( + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + data: &BoltData, + bookmarks: Option<&Bookmarks>, + ) -> Result<()> { + if let Some(bookmarks) = bookmarks { + if !bookmarks.is_empty() { + serializer.write_string("bookmarks")?; + data.serialize_str_iter(serializer, bookmarks.raw())?; + debug_buf!(log_buf, "{}", { + dbg_serializer.write_string("bookmarks").unwrap(); + data.serialize_str_iter(dbg_serializer, bookmarks.raw()) + .unwrap(); + dbg_serializer.flush() + }); + } + } + Ok(()) + } + + pub(in super::super) fn write_bookmarks_list( + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + data: &BoltData, + bookmarks: Option<&Bookmarks>, + ) -> Result<()> { + match bookmarks { + None => { + debug_buf!(log_buf, " {}", { + dbg_serializer.write_list_header(0).unwrap(); + dbg_serializer.flush() + }); + serializer.write_list_header(0)?; + } + Some(bms) => { + debug_buf!(log_buf, " {}", { + data.serialize_str_iter(dbg_serializer, bms.raw()).unwrap(); + dbg_serializer.flush() + }); + data.serialize_str_iter(serializer, bms.raw())?; + } + } + Ok(()) + } + + pub(in super::super) fn write_tx_timeout_entry( + log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + tx_timeout: Option, + ) -> Result<()> { + if let Some(tx_timeout) = tx_timeout { + Self::write_int_entry( + log_buf, + serializer, + dbg_serializer, + "tx_timeout", + tx_timeout, + )?; + } + Ok(()) + } + + pub(in super::super) fn write_tx_metadata_entry( + &self, + log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + data: &BoltData, + tx_metadata: Option<&HashMap + Debug, ValueSend>>, + ) -> Result<()> { + if let Some(tx_metadata) = tx_metadata { + if !tx_metadata.is_empty() { + self.write_dict_entry( + log_buf, + serializer, + dbg_serializer, + data, + "tx_metadata", + tx_metadata, + )?; + } + } + Ok(()) + } + + pub(in super::super) fn write_mode_entry( + log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + mode: Option<&str>, + ) -> Result<()> { + if let Some(mode) = mode { + if mode != "w" { + Self::write_str_entry(log_buf, serializer, dbg_serializer, "mode", mode)?; + } + } + Ok(()) + } + + pub(in super::super) fn write_db_entry( + log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + db: Option<&str>, + ) -> Result<()> { + if let Some(db) = db { + Self::write_str_entry(log_buf, serializer, dbg_serializer, "db", db)?; + } + Ok(()) + } + + pub(in super::super) fn write_imp_user_entry( + log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + imp_user: Option<&str>, + ) -> Result<()> { + if let Some(imp_user) = imp_user { + Self::write_str_entry(log_buf, serializer, dbg_serializer, "imp_user", imp_user)?; + } + Ok(()) + } + + pub(in super::super) fn install_qid_hook( + data: &BoltData, + callbacks: ResponseCallbacks, + ) -> ResponseCallbacks { + callbacks.with_on_success_pre_hook({ + let last_qid = Arc::clone(&data.last_qid); + move |meta| match meta.get("qid") { + Some(ValueReceive::Integer(qid)) => { + *last_qid.borrow_mut() = Some(*qid); + Ok(()) + } + None => { + *last_qid.borrow_mut() = None; + Ok(()) + } + Some(v) => Err(Neo4jError::protocol_error(format!( + "server send non-int qid: {:?}", + v + ))), + } + }) + } } pub(in super::super) struct PullOrDiscardMessageSpec { @@ -162,106 +545,39 @@ impl BoltProtocol for Bolt5x0 { + >::into(routing_context.is_some()) + u64::from_usize(auth.data.len()); serializer.write_dict_header(extra_size)?; - serializer.write_string("user_agent")?; - serializer.write_string(user_agent)?; debug_buf!(log_buf, " {}", { dbg_serializer.write_dict_header(extra_size).unwrap(); - dbg_serializer.write_string("user_agent").unwrap(); - dbg_serializer.write_string(user_agent).unwrap(); dbg_serializer.flush() }); - if let Some(routing_context) = routing_context { - serializer.write_string("routing")?; - data.serialize_routing_context(&mut serializer, &self.translator, routing_context)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("routing").unwrap(); - data.serialize_routing_context( - &mut dbg_serializer, - &self.translator, - routing_context, - ) - .unwrap(); - dbg_serializer.flush() - }); - } + Self::write_user_agent_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + user_agent, + )?; - for (k, v) in &auth.data { - serializer.write_string(k)?; - data.serialize_value(&mut serializer, &self.translator, v)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string(k).unwrap(); - if k == "credentials" { - dbg_serializer.write_string("**********").unwrap(); - } else { - data.serialize_value(&mut dbg_serializer, &self.translator, v) - .unwrap(); - } - dbg_serializer.flush() - }); - } + self.write_routing_context_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + routing_context, + )?; + + self.write_auth_entries( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + auth, + )?; data.auth = Some(Arc::clone(auth)); data.message_buff.push_back(vec![message_buff]); debug_buf_end!(data, log_buf); - let bolt_meta = Arc::clone(&data.meta); - let bolt_server_agent = Arc::clone(&data.server_agent); - let socket = Arc::clone(&data.socket); - data.responses.push_back(BoltResponse::new( - ResponseMessage::Hello, - ResponseCallbacks::new().with_on_success(move |mut meta| { - if let Some((key, value)) = meta.remove_entry(SERVER_AGENT_KEY) { - match value { - ValueReceive::String(value) => { - mem::swap(&mut *bolt_server_agent.borrow_mut(), &mut Arc::new(value)); - } - _ => { - warn!("Server sent unexpected server_agent type {:?}", &value); - meta.insert(key, value); - } - } - } - if let Some(value) = meta.get(HINTS_KEY) { - match value { - ValueReceive::Map(value) => { - if let Some(timeout) = value.get(RECV_TIMEOUT_KEY) { - match timeout { - ValueReceive::Integer(timeout) if timeout > &0 => { - socket.deref().as_ref().map(|socket| { - let timeout = Some(Duration::from_secs(*timeout as u64)); - socket.set_read_timeout(timeout)?; - socket.set_write_timeout(timeout)?; - Ok(()) - }).transpose().unwrap_or_else(|err: IoError| { - warn!("Failed to set socket timeout as hinted by the server: {err}"); - None - }); - } - ValueReceive::Integer(_) => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} value {:?}", - timeout - ); - } - _ => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} type {:?}", - timeout - ); - } - } - } - } - _ => { - warn!("Server sent unexpected {HINTS_KEY} type {:?}", value); - } - } - } - mem::swap(&mut *bolt_meta.borrow_mut(), &mut meta); - Ok(()) - }), - )); + Self::enqueue_hello_response(data); Ok(()) } @@ -343,23 +659,13 @@ impl BoltProtocol for Bolt5x0 { dbg_serializer.flush() }); - match parameters { - Some(parameters) => { - data.serialize_dict(&mut serializer, &self.translator, parameters)?; - debug_buf!(log_buf, " {}", { - data.serialize_dict(&mut dbg_serializer, &self.translator, parameters) - .unwrap(); - dbg_serializer.flush() - }); - } - None => { - serializer.write_dict_header(0)?; - debug_buf!(log_buf, " {}", { - dbg_serializer.write_dict_header(0).unwrap(); - dbg_serializer.flush() - }); - } - } + self.write_parameter_dict( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + parameters, + )?; let extra_size = [ bookmarks.is_some() && !bookmarks.unwrap().is_empty(), @@ -379,91 +685,41 @@ impl BoltProtocol for Bolt5x0 { dbg_serializer.flush() }); - if let Some(bookmarks) = bookmarks { - if !bookmarks.is_empty() { - serializer.write_string("bookmarks")?; - data.serialize_str_iter(&mut serializer, bookmarks.raw())?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("bookmarks").unwrap(); - data.serialize_str_iter(&mut dbg_serializer, bookmarks.raw()) - .unwrap(); - dbg_serializer.flush() - }); - } - } + Self::write_bookmarks_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + bookmarks, + )?; - if let Some(tx_timeout) = tx_timeout { - serializer.write_string("tx_timeout")?; - serializer.write_int(tx_timeout)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("tx_timeout").unwrap(); - dbg_serializer.write_int(tx_timeout).unwrap(); - dbg_serializer.flush() - }); - } + Self::write_tx_timeout_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + tx_timeout, + )?; - if let Some(tx_metadata) = tx_metadata { - if !tx_metadata.is_empty() { - serializer.write_string("tx_metadata")?; - data.serialize_dict(&mut serializer, &self.translator, tx_metadata)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("tx_metadata").unwrap(); - data.serialize_dict(&mut dbg_serializer, &self.translator, tx_metadata) - .unwrap(); - dbg_serializer.flush() - }); - } - } + self.write_tx_metadata_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + tx_metadata, + )?; - if let Some(mode) = mode { - if mode != "w" { - serializer.write_string("mode")?; - serializer.write_string(mode)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("mode").unwrap(); - dbg_serializer.write_string(mode).unwrap(); - dbg_serializer.flush() - }); - } - } + Self::write_mode_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, mode)?; - if let Some(db) = db { - serializer.write_string("db")?; - serializer.write_string(db)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("db").unwrap(); - dbg_serializer.write_string(db).unwrap(); - dbg_serializer.flush() - }); - } + Self::write_db_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, db)?; - if let Some(imp_user) = imp_user { - serializer.write_string("imp_user")?; - serializer.write_string(imp_user)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("imp_user").unwrap(); - dbg_serializer.write_string(imp_user).unwrap(); - dbg_serializer.flush() - }); - } + Self::write_imp_user_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + imp_user, + )?; - callbacks = callbacks.with_on_success_pre_hook({ - let last_qid = Arc::clone(&data.last_qid); - move |meta| match meta.get("qid") { - Some(ValueReceive::Integer(qid)) => { - *last_qid.borrow_mut() = Some(*qid); - Ok(()) - } - None => { - *last_qid.borrow_mut() = None; - Ok(()) - } - Some(v) => Err(Neo4jError::protocol_error(format!( - "server send non-int qid: {:?}", - v - ))), - } - }); + callbacks = Self::install_qid_hook(data, callbacks); data.message_buff.push_back(vec![message_buff]); data.responses @@ -552,71 +808,39 @@ impl BoltProtocol for Bolt5x0 { }); serializer.write_dict_header(extra_size)?; - if let Some(bookmarks) = bookmarks { - if !bookmarks.is_empty() { - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("bookmarks").unwrap(); - data.serialize_str_iter(&mut dbg_serializer, bookmarks.raw()) - .unwrap(); - dbg_serializer.flush() - }); - serializer.write_string("bookmarks").unwrap(); - data.serialize_str_iter(&mut serializer, bookmarks.raw())?; - } - } + Self::write_bookmarks_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + bookmarks, + )?; - if let Some(tx_timeout) = tx_timeout { - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("tx_timeout").unwrap(); - dbg_serializer.write_int(tx_timeout).unwrap(); - dbg_serializer.flush() - }); - serializer.write_string("tx_timeout")?; - serializer.write_int(tx_timeout)?; - } + Self::write_tx_timeout_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + tx_timeout, + )?; - if let Some(tx_metadata) = tx_metadata { - if !tx_metadata.is_empty() { - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("tx_metadata").unwrap(); - data.serialize_dict(&mut dbg_serializer, &self.translator, tx_metadata) - .unwrap(); - dbg_serializer.flush() - }); - serializer.write_string("tx_metadata")?; - data.serialize_dict(&mut serializer, &self.translator, tx_metadata)?; - } - } + self.write_tx_metadata_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + tx_metadata, + )?; - if let Some(mode) = mode { - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("mode").unwrap(); - dbg_serializer.write_string(mode).unwrap(); - dbg_serializer.flush() - }); - serializer.write_string("mode")?; - serializer.write_string(mode)?; - } + Self::write_mode_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, mode)?; - if let Some(db) = db { - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("db").unwrap(); - dbg_serializer.write_string(db).unwrap(); - dbg_serializer.flush() - }); - serializer.write_string("db")?; - serializer.write_string(db)?; - } + Self::write_db_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, db)?; - if let Some(imp_user) = imp_user { - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("imp_user").unwrap(); - dbg_serializer.write_string(imp_user).unwrap(); - dbg_serializer.flush() - }); - serializer.write_string("imp_user")?; - serializer.write_string(imp_user)?; - } + Self::write_imp_user_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + imp_user, + )?; data.message_buff.push_back(vec![message_buff]); data.responses @@ -677,30 +901,20 @@ impl BoltProtocol for Bolt5x0 { let mut serializer = PackStreamSerializerImpl::new(&mut message_buff); serializer.write_struct_header(0x66, 3)?; - data.serialize_routing_context(&mut serializer, &self.translator, routing_context)?; + data.serialize_dict(&mut serializer, &self.translator, routing_context)?; debug_buf!(log_buf, " {}", { - data.serialize_routing_context(&mut dbg_serializer, &self.translator, routing_context) + data.serialize_dict(&mut dbg_serializer, &self.translator, routing_context) .unwrap(); dbg_serializer.flush() }); - match bookmarks { - None => { - debug_buf!(log_buf, " {}", { - dbg_serializer.write_list_header(0).unwrap(); - dbg_serializer.flush() - }); - serializer.write_list_header(0)?; - } - Some(bms) => { - debug_buf!(log_buf, " {}", { - data.serialize_str_iter(&mut dbg_serializer, bms.raw()) - .unwrap(); - dbg_serializer.flush() - }); - data.serialize_str_iter(&mut serializer, bms.raw())?; - } - } + Self::write_bookmarks_list( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + bookmarks, + )?; let extra_size = >::into(db.is_some()) + >::into(imp_user.is_some()); @@ -712,25 +926,14 @@ impl BoltProtocol for Bolt5x0 { dbg_serializer.flush() }); - if let Some(db) = db { - serializer.write_string("db")?; - serializer.write_string(db)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("db").unwrap(); - dbg_serializer.write_string(db).unwrap(); - dbg_serializer.flush() - }); - } + Self::write_db_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, db)?; - if let Some(imp_user) = imp_user { - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("imp_user").unwrap(); - dbg_serializer.write_string(imp_user).unwrap(); - dbg_serializer.flush() - }); - serializer.write_string("imp_user")?; - serializer.write_string(imp_user)?; - } + Self::write_imp_user_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + imp_user, + )?; data.message_buff.push_back(vec![message_buff]); data.responses diff --git a/neo4j/src/driver/io/bolt/bolt5x1/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x1/protocol.rs index cc12fda..e4b1f40 100644 --- a/neo4j/src/driver/io/bolt/bolt5x1/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x1/protocol.rs @@ -14,13 +14,10 @@ use std::borrow::Borrow; use std::fmt::Debug; -use std::io::{Error as IoError, Read, Write}; -use std::mem; -use std::ops::Deref; +use std::io::{Read, Write}; use std::sync::Arc; -use std::time::Duration; -use log::{debug, log_enabled, warn, Level}; +use log::{debug, log_enabled, Level}; use usize_cast::FromUsize; use super::super::bolt5x0::Bolt5x0; @@ -39,6 +36,7 @@ use super::super::{ BoltProtocol, BoltResponse, BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, ResponseMessage, }; +use crate::driver::config::auth::AuthToken; use crate::error_::Result; use crate::value::ValueReceive; @@ -49,7 +47,7 @@ const RECV_TIMEOUT_KEY: &str = "connection.recv_timeout_seconds"; #[derive(Debug)] pub(crate) struct Bolt5x1 { translator: T, - bolt5x0: Bolt5x0, + pub(in super::super) bolt5x0: Bolt5x0, protocol_version: ServerAwareBoltVersion, } @@ -93,27 +91,13 @@ impl Bolt5x1 { let mut serializer = PackStreamSerializerImpl::new(&mut message_buff); serializer.write_struct_header(0x6A, 1)?; - let auth_size = u64::from_usize(auth.data.len()); - serializer.write_dict_header(auth_size)?; - debug_buf!(log_buf, " {}", { - dbg_serializer.write_dict_header(auth_size).unwrap(); - dbg_serializer.flush() - }); - - for (k, v) in &auth.data { - serializer.write_string(k)?; - data.serialize_value(&mut serializer, &self.translator, v)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string(k).unwrap(); - if k == "credentials" { - dbg_serializer.write_string("**********").unwrap(); - } else { - data.serialize_value(&mut dbg_serializer, &self.translator, v) - .unwrap(); - } - dbg_serializer.flush() - }); - } + self.write_auth_dict( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + auth, + )?; data.message_buff.push_back(vec![message_buff]); data.responses @@ -135,6 +119,25 @@ impl Bolt5x1 { bolt_debug!(data, "C: LOGOFF"); Ok(()) } + + pub(in super::super) fn write_auth_dict( + &self, + mut log_buf: Option<&mut String>, + serializer: &mut PackStreamSerializerImpl, + dbg_serializer: &mut PackStreamSerializerDebugImpl, + data: &BoltData, + auth: &Arc, + ) -> Result<()> { + let auth_size = u64::from_usize(auth.data.len()); + serializer.write_dict_header(auth_size)?; + debug_buf!(log_buf, " {}", { + dbg_serializer.write_dict_header(auth_size).unwrap(); + dbg_serializer.flush() + }); + self.bolt5x0 + .write_auth_entries(log_buf, serializer, dbg_serializer, data, auth)?; + Ok(()) + } } impl BoltProtocol for Bolt5x1 { @@ -165,90 +168,30 @@ impl BoltProtocol for Bolt5x1 { let extra_size = 1 + >::into(routing_context.is_some()); serializer.write_dict_header(extra_size)?; - serializer.write_string("user_agent")?; - serializer.write_string(user_agent)?; debug_buf!(log_buf, " {}", { dbg_serializer.write_dict_header(extra_size).unwrap(); - dbg_serializer.write_string("user_agent").unwrap(); - dbg_serializer.write_string(user_agent).unwrap(); dbg_serializer.flush() }); - if let Some(routing_context) = routing_context { - serializer.write_string("routing")?; - data.serialize_routing_context(&mut serializer, &self.translator, routing_context)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("routing").unwrap(); - data.serialize_routing_context( - &mut dbg_serializer, - &self.translator, - routing_context, - ) - .unwrap(); - dbg_serializer.flush() - }); - } + Bolt5x0::::write_user_agent_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + user_agent, + )?; + + self.bolt5x0.write_routing_context_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + routing_context, + )?; data.message_buff.push_back(vec![message_buff]); debug_buf_end!(data, log_buf); - let bolt_meta = Arc::clone(&data.meta); - let bolt_server_agent = Arc::clone(&data.server_agent); - let socket = Arc::clone(&data.socket); - data.responses.push_back(BoltResponse::new( - ResponseMessage::Hello, - ResponseCallbacks::new().with_on_success(move |mut meta| { - if let Some((key, value)) = meta.remove_entry(SERVER_AGENT_KEY) { - match value { - ValueReceive::String(value) => { - mem::swap(&mut *bolt_server_agent.borrow_mut(), &mut Arc::new(value)); - } - _ => { - warn!("Server sent unexpected server_agent type {:?}", &value); - meta.insert(key, value); - } - } - } - if let Some(value) = meta.get(HINTS_KEY) { - match value { - ValueReceive::Map(value) => { - if let Some(timeout) = value.get(RECV_TIMEOUT_KEY) { - match timeout { - ValueReceive::Integer(timeout) if timeout > &0 => { - socket.deref().as_ref().map(|socket| { - let timeout = Some(Duration::from_secs(*timeout as u64)); - socket.set_read_timeout(timeout)?; - socket.set_write_timeout(timeout)?; - Ok(()) - }).transpose().unwrap_or_else(|err: IoError| { - warn!("Failed to set socket timeout as hinted by the server: {err}"); - None - }); - } - ValueReceive::Integer(_) => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} value {:?}", - timeout - ); - } - _ => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} type {:?}", - timeout - ); - } - } - } - } - _ => { - warn!("Server sent unexpected {HINTS_KEY} type {:?}", value); - } - } - } - mem::swap(&mut *bolt_meta.borrow_mut(), &mut meta); - Ok(()) - }), - )); + Bolt5x0::::enqueue_hello_response(data); Ok(()) } diff --git a/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs index 69cd288..d337596 100644 --- a/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs @@ -14,16 +14,13 @@ use std::borrow::Borrow; use std::fmt::Debug; -use std::io::{Error as IoError, Read, Write}; -use std::mem; -use std::ops::Deref; -use std::sync::Arc; -use std::time::Duration; +use std::io::{Read, Write}; use crate::driver::notification::NotificationFilter; -use log::{debug, log_enabled, warn, Level}; +use log::{debug, log_enabled, Level}; use usize_cast::FromUsize; +use super::super::bolt5x0::Bolt5x0; use super::super::bolt5x1::Bolt5x1; use super::super::bolt_common::ServerAwareBoltVersion; use super::super::message::BoltMessage; @@ -39,7 +36,7 @@ use super::super::{ bolt_debug_extra, dbg_extra, debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, BoltResponse, BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, ResponseMessage, }; -use crate::error_::{Neo4jError, Result}; +use crate::error_::Result; use crate::value::ValueReceive; const SERVER_AGENT_KEY: &str = "server"; @@ -49,7 +46,7 @@ const RECV_TIMEOUT_KEY: &str = "connection.recv_timeout_seconds"; #[derive(Debug)] pub(crate) struct Bolt5x2 { translator: T, - bolt5x1: Bolt5x1, + pub(in super::super) bolt5x1: Bolt5x1, protocol_version: ServerAwareBoltVersion, } @@ -62,7 +59,7 @@ impl Bolt5x2 { } } - pub(in super::super) fn notification_filter_size( + pub(in super::super) fn notification_filter_entries_count( notification_filter: Option<&NotificationFilter>, ) -> u64 { match notification_filter { @@ -77,7 +74,7 @@ impl Bolt5x2 { } } - pub(in super::super) fn write_notification_filter( + pub(in super::super) fn write_notification_filter_entries( mut log_buf: Option<&mut String>, serializer: &mut PackStreamSerializerImpl, dbg_serializer: &mut PackStreamSerializerDebugImpl, @@ -154,34 +151,30 @@ impl BoltProtocol for Bolt5x2 { serializer.write_struct_header(0x01, 1)?; let extra_size = 1 - + Self::notification_filter_size(Some(notification_filter)) + + Self::notification_filter_entries_count(Some(notification_filter)) + >::into(routing_context.is_some()); serializer.write_dict_header(extra_size)?; - serializer.write_string("user_agent")?; - serializer.write_string(user_agent)?; debug_buf!(log_buf, " {}", { dbg_serializer.write_dict_header(extra_size).unwrap(); - dbg_serializer.write_string("user_agent").unwrap(); - dbg_serializer.write_string(user_agent).unwrap(); dbg_serializer.flush() }); - if let Some(routing_context) = routing_context { - serializer.write_string("routing")?; - data.serialize_routing_context(&mut serializer, &self.translator, routing_context)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("routing").unwrap(); - data.serialize_routing_context( - &mut dbg_serializer, - &self.translator, - routing_context, - ) - .unwrap(); - dbg_serializer.flush() - }); - } + Bolt5x0::::write_user_agent_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + user_agent, + )?; + + self.bolt5x1.bolt5x0.write_routing_context_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + routing_context, + )?; - Self::write_notification_filter( + Self::write_notification_filter_entries( log_buf.as_mut(), &mut serializer, &mut dbg_serializer, @@ -191,63 +184,7 @@ impl BoltProtocol for Bolt5x2 { data.message_buff.push_back(vec![message_buff]); debug_buf_end!(data, log_buf); - let bolt_meta = Arc::clone(&data.meta); - let bolt_server_agent = Arc::clone(&data.server_agent); - let socket = Arc::clone(&data.socket); - data.responses.push_back(BoltResponse::new( - ResponseMessage::Hello, - ResponseCallbacks::new().with_on_success(move |mut meta| { - if let Some((key, value)) = meta.remove_entry(SERVER_AGENT_KEY) { - match value { - ValueReceive::String(value) => { - mem::swap(&mut *bolt_server_agent.borrow_mut(), &mut Arc::new(value)); - } - _ => { - warn!("Server sent unexpected server_agent type {:?}", &value); - meta.insert(key, value); - } - } - } - if let Some(value) = meta.get(HINTS_KEY) { - match value { - ValueReceive::Map(value) => { - if let Some(timeout) = value.get(RECV_TIMEOUT_KEY) { - match timeout { - ValueReceive::Integer(timeout) if timeout > &0 => { - socket.deref().as_ref().map(|socket| { - let timeout = Some(Duration::from_secs(*timeout as u64)); - socket.set_read_timeout(timeout)?; - socket.set_write_timeout(timeout)?; - Ok(()) - }).transpose().unwrap_or_else(|err: IoError| { - warn!("Failed to set socket timeout as hinted by the server: {err}"); - None - }); - } - ValueReceive::Integer(_) => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} value {:?}", - timeout - ); - } - _ => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} type {:?}", - timeout - ); - } - } - } - } - _ => { - warn!("Server sent unexpected {HINTS_KEY} type {:?}", value); - } - } - } - mem::swap(&mut *bolt_meta.borrow_mut(), &mut meta); - Ok(()) - }), - )); + Bolt5x0::::enqueue_hello_response(data); Ok(()) } @@ -313,25 +250,15 @@ impl BoltProtocol for Bolt5x2 { dbg_serializer.flush() }); - match parameters { - Some(parameters) => { - data.serialize_dict(&mut serializer, &self.translator, parameters)?; - debug_buf!(log_buf, " {}", { - data.serialize_dict(&mut dbg_serializer, &self.translator, parameters) - .unwrap(); - dbg_serializer.flush() - }); - } - None => { - serializer.write_dict_header(0)?; - debug_buf!(log_buf, " {}", { - dbg_serializer.write_dict_header(0).unwrap(); - dbg_serializer.flush() - }); - } - } + self.bolt5x1.bolt5x0.write_parameter_dict( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + parameters, + )?; - let extra_size = Self::notification_filter_size(notification_filter) + let extra_size = Self::notification_filter_entries_count(notification_filter) + [ bookmarks.is_some() && !bookmarks.unwrap().is_empty(), tx_timeout.is_some(), @@ -350,98 +277,53 @@ impl BoltProtocol for Bolt5x2 { dbg_serializer.flush() }); - if let Some(bookmarks) = bookmarks { - if !bookmarks.is_empty() { - serializer.write_string("bookmarks")?; - data.serialize_str_iter(&mut serializer, bookmarks.raw())?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("bookmarks").unwrap(); - data.serialize_str_iter(&mut dbg_serializer, bookmarks.raw()) - .unwrap(); - dbg_serializer.flush() - }); - } - } + Bolt5x0::::write_bookmarks_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + bookmarks, + )?; - if let Some(tx_timeout) = tx_timeout { - serializer.write_string("tx_timeout")?; - serializer.write_int(tx_timeout)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("tx_timeout").unwrap(); - dbg_serializer.write_int(tx_timeout).unwrap(); - dbg_serializer.flush() - }); - } + Bolt5x0::::write_tx_timeout_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + tx_timeout, + )?; - if let Some(tx_metadata) = tx_metadata { - if !tx_metadata.is_empty() { - serializer.write_string("tx_metadata")?; - data.serialize_dict(&mut serializer, &self.translator, tx_metadata)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("tx_metadata").unwrap(); - data.serialize_dict(&mut dbg_serializer, &self.translator, tx_metadata) - .unwrap(); - dbg_serializer.flush() - }); - } - } + self.bolt5x1.bolt5x0.write_tx_metadata_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + tx_metadata, + )?; - if let Some(mode) = mode { - if mode != "w" { - serializer.write_string("mode")?; - serializer.write_string(mode)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("mode").unwrap(); - dbg_serializer.write_string(mode).unwrap(); - dbg_serializer.flush() - }); - } - } + Bolt5x0::::write_mode_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + mode, + )?; - if let Some(db) = db { - serializer.write_string("db")?; - serializer.write_string(db)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("db").unwrap(); - dbg_serializer.write_string(db).unwrap(); - dbg_serializer.flush() - }); - } + Bolt5x0::::write_db_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, db)?; - if let Some(imp_user) = imp_user { - serializer.write_string("imp_user")?; - serializer.write_string(imp_user)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("imp_user").unwrap(); - dbg_serializer.write_string(imp_user).unwrap(); - dbg_serializer.flush() - }); - } + Bolt5x0::::write_imp_user_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + imp_user, + )?; - Self::write_notification_filter( + Self::write_notification_filter_entries( log_buf.as_mut(), &mut serializer, &mut dbg_serializer, notification_filter, )?; - callbacks = callbacks.with_on_success_pre_hook({ - let last_qid = Arc::clone(&data.last_qid); - move |meta| match meta.get("qid") { - Some(ValueReceive::Integer(qid)) => { - *last_qid.borrow_mut() = Some(*qid); - Ok(()) - } - None => { - *last_qid.borrow_mut() = None; - Ok(()) - } - Some(v) => Err(Neo4jError::protocol_error(format!( - "server send non-int qid: {:?}", - v - ))), - } - }); + callbacks = Bolt5x0::::install_qid_hook(data, callbacks); data.message_buff.push_back(vec![message_buff]); data.responses @@ -491,7 +373,7 @@ impl BoltProtocol for Bolt5x2 { let mut serializer = PackStreamSerializerImpl::new(&mut message_buff); serializer.write_struct_header(0x11, 1)?; - let extra_size = Self::notification_filter_size(Some(notification_filter)) + let extra_size = Self::notification_filter_entries_count(Some(notification_filter)) + [ bookmarks.is_some() && !bookmarks.unwrap().is_empty(), tx_timeout.is_some(), @@ -576,7 +458,7 @@ impl BoltProtocol for Bolt5x2 { serializer.write_string(imp_user)?; } - Self::write_notification_filter( + Self::write_notification_filter_entries( log_buf.as_mut(), &mut serializer, &mut dbg_serializer, diff --git a/neo4j/src/driver/io/bolt/bolt5x3/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x3/protocol.rs index 6d0ed2d..e0a9dfe 100644 --- a/neo4j/src/driver/io/bolt/bolt5x3/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x3/protocol.rs @@ -14,14 +14,11 @@ use std::borrow::Borrow; use std::fmt::Debug; -use std::io::{Error as IoError, Read, Write}; -use std::mem; -use std::ops::Deref; -use std::sync::Arc; -use std::time::Duration; +use std::io::{Read, Write}; -use log::{debug, log_enabled, warn, Level}; +use log::{debug, log_enabled, Level}; +use super::super::bolt5x0::Bolt5x0; use super::super::bolt5x2::Bolt5x2; use super::super::bolt_common::{ ServerAwareBoltVersion, BOLT_AGENT_LANGUAGE, BOLT_AGENT_LANGUAGE_DETAILS, BOLT_AGENT_PLATFORM, @@ -38,7 +35,7 @@ use super::super::packstream::{ }; use super::super::{ bolt_debug_extra, dbg_extra, debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, - BoltResponse, BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, ResponseMessage, + BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, }; use crate::error_::Result; use crate::value::ValueReceive; @@ -50,7 +47,7 @@ const RECV_TIMEOUT_KEY: &str = "connection.recv_timeout_seconds"; #[derive(Debug)] pub(crate) struct Bolt5x3 { translator: T, - bolt5x2: Bolt5x2, + pub(in super::super) bolt5x2: Bolt5x2, protocol_version: ServerAwareBoltVersion, } @@ -63,7 +60,7 @@ impl Bolt5x3 { } } - pub(in super::super) fn write_bolt_agent( + pub(in super::super) fn write_bolt_agent_entry( mut log_buf: Option<&mut String>, serializer: &mut PackStreamSerializerImpl, dbg_serializer: &mut PackStreamSerializerDebugImpl, @@ -123,36 +120,33 @@ impl BoltProtocol for Bolt5x3 { serializer.write_struct_header(0x01, 1)?; let extra_size = 2 - + Bolt5x2::::notification_filter_size(Some(notification_filter)) + + Bolt5x2::::notification_filter_entries_count(Some(notification_filter)) + >::into(routing_context.is_some()); + serializer.write_dict_header(extra_size)?; - serializer.write_string("user_agent")?; - serializer.write_string(user_agent)?; debug_buf!(log_buf, " {}", { dbg_serializer.write_dict_header(extra_size).unwrap(); - dbg_serializer.write_string("user_agent").unwrap(); - dbg_serializer.write_string(user_agent).unwrap(); dbg_serializer.flush() }); - Self::write_bolt_agent(log_buf.as_mut(), &mut serializer, &mut dbg_serializer)?; + Bolt5x0::::write_user_agent_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + user_agent, + )?; - if let Some(routing_context) = routing_context { - serializer.write_string("routing")?; - data.serialize_routing_context(&mut serializer, &self.translator, routing_context)?; - debug_buf!(log_buf, "{}", { - dbg_serializer.write_string("routing").unwrap(); - data.serialize_routing_context( - &mut dbg_serializer, - &self.translator, - routing_context, - ) - .unwrap(); - dbg_serializer.flush() - }); - } + Self::write_bolt_agent_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer)?; + + self.bolt5x2.bolt5x1.bolt5x0.write_routing_context_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + routing_context, + )?; - Bolt5x2::::write_notification_filter( + Bolt5x2::::write_notification_filter_entries( log_buf.as_mut(), &mut serializer, &mut dbg_serializer, @@ -162,63 +156,7 @@ impl BoltProtocol for Bolt5x3 { data.message_buff.push_back(vec![message_buff]); debug_buf_end!(data, log_buf); - let bolt_meta = Arc::clone(&data.meta); - let bolt_server_agent = Arc::clone(&data.server_agent); - let socket = Arc::clone(&data.socket); - data.responses.push_back(BoltResponse::new( - ResponseMessage::Hello, - ResponseCallbacks::new().with_on_success(move |mut meta| { - if let Some((key, value)) = meta.remove_entry(SERVER_AGENT_KEY) { - match value { - ValueReceive::String(value) => { - mem::swap(&mut *bolt_server_agent.borrow_mut(), &mut Arc::new(value)); - } - _ => { - warn!("Server sent unexpected server_agent type {:?}", &value); - meta.insert(key, value); - } - } - } - if let Some(value) = meta.get(HINTS_KEY) { - match value { - ValueReceive::Map(value) => { - if let Some(timeout) = value.get(RECV_TIMEOUT_KEY) { - match timeout { - ValueReceive::Integer(timeout) if timeout > &0 => { - socket.deref().as_ref().map(|socket| { - let timeout = Some(Duration::from_secs(*timeout as u64)); - socket.set_read_timeout(timeout)?; - socket.set_write_timeout(timeout)?; - Ok(()) - }).transpose().unwrap_or_else(|err: IoError| { - warn!("Failed to set socket timeout as hinted by the server: {err}"); - None - }); - } - ValueReceive::Integer(_) => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} value {:?}", - timeout - ); - } - _ => { - warn!( - "Server sent unexpected {RECV_TIMEOUT_KEY} type {:?}", - timeout - ); - } - } - } - } - _ => { - warn!("Server sent unexpected {HINTS_KEY} type {:?}", value); - } - } - } - mem::swap(&mut *bolt_meta.borrow_mut(), &mut meta); - Ok(()) - }), - )); + Bolt5x0::::enqueue_hello_response(data); Ok(()) } diff --git a/neo4j/src/driver/io/bolt/handshake.rs b/neo4j/src/driver/io/bolt/handshake.rs new file mode 100644 index 0000000..5018908 --- /dev/null +++ b/neo4j/src/driver/io/bolt/handshake.rs @@ -0,0 +1,752 @@ +// Copyright Rouven Bauer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::{Debug, Display}; +use std::io::{self, Read, Write}; +use std::net::{Shutdown, SocketAddr, TcpStream, ToSocketAddrs}; +use std::sync::Arc; +use std::time::Duration; + +use log::Level::Trace; +use log::{debug, log_enabled, trace}; +use rustls::ClientConfig; +use socket2::{Socket as Socket2, TcpKeepalive}; + +use super::super::deadline::DeadlineIO; +pub(crate) use super::socket::{BufTcpStream, Socket}; +use super::{dbg_extra, socket_debug, Bolt}; +use crate::address_::Address; +use crate::driver::config::KeepAliveConfig; +use crate::error_::{Neo4jError, Result}; +use crate::time::Instant; + +const BOLT_MAGIC_PREAMBLE: [u8; 4] = [0x60, 0x60, 0xB0, 0x17]; +// [bolt-version-bump] search tag when changing bolt version support +const BOLT_VERSION_OFFER: [u8; 16] = [ + 0, 3, 3, 5, // BOLT 5.3 - 5.0 + 0, 0, 4, 4, // BOLT 4.4 + 0, 0, 0, 0, // - + 0, 0, 0, 0, // - +]; + +pub(crate) trait AddressProvider: Debug + Display + ToSocketAddrs + Sized + 'static { + fn unresolved_host(&self) -> &str; + fn into_address(self: Arc) -> Arc
; +} + +impl AddressProvider for Address { + fn unresolved_host(&self) -> &str { + self.unresolved_host() + } + + fn into_address(self: Arc) -> Arc
{ + self + } +} + +pub(crate) trait SocketProvider { + type RW: Read + Write + Debug; + type BuffRW: Read + Write + Debug; + + fn connect(&mut self, addr: &Arc) -> io::Result; + fn get_local_port(&mut self, sock: &Self::RW) -> u16; + fn connect_timeout(&mut self, addr: &SocketAddr, timeout: Duration) -> io::Result; + fn set_tcp_keepalive( + &mut self, + socket: Self::RW, + keepalive: Option, + ) -> io::Result; + fn shutdown(&mut self, sock: &Self::RW, how: Shutdown) -> io::Result<()>; + fn new_buffered(&mut self, sock: &Self::RW) -> Result; + fn new_deadline_io<'rw, S: Read + Write>( + &'_ mut self, + stream: S, + sock: &'rw Self::RW, + deadline: Option, + ) -> DeadlineIO<'rw, S>; + fn new_socket(&mut self, _sock: Self::RW) -> Option { + None + } +} + +pub(crate) struct TcpConnector; + +impl SocketProvider for TcpConnector { + type RW = TcpStream; + type BuffRW = BufTcpStream; + + #[inline] + fn connect(&mut self, addr: &Arc) -> io::Result { + TcpStream::connect(&**addr) + } + + #[inline] + fn get_local_port(&mut self, sock: &Self::RW) -> u16 { + sock.local_addr() + .map(|addr| addr.port()) + .unwrap_or_default() + } + + #[inline] + fn connect_timeout(&mut self, addr: &SocketAddr, timeout: Duration) -> io::Result { + TcpStream::connect_timeout(addr, timeout) + } + + #[inline] + fn set_tcp_keepalive( + &mut self, + socket: Self::RW, + keepalive: Option, + ) -> io::Result { + let keep_alive = match keepalive { + None => return Ok(socket), + Some(KeepAliveConfig::Default) => TcpKeepalive::new(), + Some(KeepAliveConfig::CustomTime(time)) => TcpKeepalive::new().with_time(time), + }; + let socket = Socket2::from(socket); + socket.set_tcp_keepalive(&keep_alive)?; + Ok(socket.into()) + } + + #[inline] + fn shutdown(&mut self, sock: &Self::RW, how: Shutdown) -> io::Result<()> { + sock.shutdown(how) + } + + #[inline] + fn new_buffered(&mut self, sock: &Self::RW) -> Result { + BufTcpStream::new(sock, true) + } + + #[inline] + fn new_deadline_io<'rw, S: Read + Write>( + &'_ mut self, + stream: S, + sock: &'rw Self::RW, + deadline: Option, + ) -> DeadlineIO<'rw, S> { + DeadlineIO::new(stream, deadline, Some(sock)) + } + + #[inline] + fn new_socket(&mut self, sock: Self::RW) -> Option { + Some(sock) + } +} + +pub(crate) fn open( + mut socket_provider: S, + address: Arc, + deadline: Option, + connect_timeout: Option, + keep_alive: Option, + tls_config: Option>, +) -> Result>> { + if log_enabled!(Trace) { + trace!( + "{}{}", + dbg_extra(None, None), + format!("C: {address:?}") + ); + } else { + debug!( + "{}{}", + dbg_extra(None, None), + format!("C: {address}") + ); + } + + let timeout = combine_connection_timout(connect_timeout, deadline); + let raw_socket = Neo4jError::wrap_connect(match timeout { + None => socket_provider.connect(&address), + Some(timeout) => { + let mut timeout = timeout; + each_addr(&*address, |addr| { + match socket_provider.connect_timeout(addr?, timeout) { + Ok(connection) => Ok(connection), + Err(e) => { + timeout = combine_connection_timout(connect_timeout, deadline) + .expect("timeout cannot disappear"); + Err(e) + } + } + }) + } + })?; + let raw_socket = socket_provider + .set_tcp_keepalive(raw_socket, keep_alive) + .map_err(|err| Neo4jError::InvalidConfig { + message: format!("failed to set tcp keepalive: {}", err), + })?; + let local_port = socket_provider.get_local_port(&raw_socket); + + let buffered_socket = socket_provider.new_buffered(&raw_socket)?; + let mut socket = Socket::new(buffered_socket, address.unresolved_host(), tls_config)?; + + let mut deadline_io = socket_provider.new_deadline_io(&mut socket, &raw_socket, deadline); + + socket_debug!(local_port, "C: {:02X?}", BOLT_MAGIC_PREAMBLE); + wrap_socket_write( + &mut socket_provider, + &raw_socket, + local_port, + deadline_io.write_all(&BOLT_MAGIC_PREAMBLE), + )?; + socket_debug!(local_port, "C: {:02X?}", BOLT_VERSION_OFFER); + wrap_socket_write( + &mut socket_provider, + &raw_socket, + local_port, + deadline_io.write_all(&BOLT_VERSION_OFFER), + )?; + wrap_socket_write( + &mut socket_provider, + &raw_socket, + local_port, + deadline_io.flush(), + )?; + + let mut negotiated_version = [0u8; 4]; + wrap_socket_read( + &mut socket_provider, + &raw_socket, + local_port, + deadline_io.read_exact(&mut negotiated_version), + )?; + socket_debug!(local_port, "S: {:02X?}", negotiated_version); + + let version = decode_version_offer(&negotiated_version)?; + + Ok(Bolt::new( + version, + socket, + Arc::new(socket_provider.new_socket(raw_socket)), + Some(local_port), + address.into_address(), + )) +} + +// [bolt-version-bump] search tag when changing bolt version support +fn decode_version_offer(offer: &[u8; 4]) -> Result<(u8, u8)> { + match offer { + [0, 0, 0, 0] => Err(Neo4jError::InvalidConfig { + message: String::from("server version not supported"), + }), + [_, _, 3, 5] => Ok((5, 3)), + [_, _, 2, 5] => Ok((5, 2)), + [_, _, 1, 5] => Ok((5, 1)), + [_, _, 0, 5] => Ok((5, 0)), + [_, _, 4, 4] => Ok((4, 4)), + [72, 84, 84, 80] => { + // "HTTP" + Err(Neo4jError::InvalidConfig { + message: format!( + "unexpected server handshake response {:?} (looks like HTTP)", + offer + ), + }) + } + _ => Err(Neo4jError::InvalidConfig { + message: format!("unexpected server handshake response {:?}", offer), + }), + } +} + +fn combine_connection_timout( + connect_timeout: Option, + deadline: Option, +) -> Option { + deadline.map(|deadline| { + let mut time_left = deadline.saturating_duration_since(Instant::now()); + if time_left == Duration::from_secs(0) { + time_left = Duration::from_nanos(1); + } + match connect_timeout { + None => time_left, + Some(timeout) => timeout.min(time_left), + } + }) +} + +// copied from std::net +fn each_addr(addr: A, mut f: F) -> io::Result +where + F: FnMut(io::Result<&SocketAddr>) -> io::Result, +{ + let addrs = match addr.to_socket_addrs() { + Ok(addrs) => addrs, + Err(e) => return f(Err(e)), + }; + let mut last_err = None; + for addr in addrs { + match f(Ok(&addr)) { + Ok(l) => return Ok(l), + Err(e) => last_err = Some(e), + } + } + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any addresses", + ) + })) +} + +fn wrap_socket_write( + socket_provider: &mut S, + stream: &S::RW, + local_port: u16, + res: io::Result, +) -> Result { + match res { + Ok(res) => Ok(res), + Err(err) => { + socket_debug!(local_port, " write error: {}", err); + let _ = socket_provider.shutdown(stream, Shutdown::Both); + Neo4jError::wrap_write(Err(err)) + } + } +} + +fn wrap_socket_read( + socket_provider: &mut S, + stream: &S::RW, + local_port: u16, + res: io::Result, +) -> Result { + match res { + Ok(res) => Ok(res), + Err(err) => { + socket_debug!(local_port, " read error: {}", err); + let _ = socket_provider.shutdown(stream, Shutdown::Both); + Neo4jError::wrap_read(Err(err)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::any::Any; + use std::cell::RefCell; + use std::collections::VecDeque; + use std::fmt::Formatter; + use std::rc::Rc; + use std::str::FromStr; + use std::vec; + + use rstest::*; + + // [bolt-version-bump] search tag when changing bolt version support + #[rstest] + #[case([0, 0, 4, 4], (4, 4))] + #[case([0, 0, 0, 5], (5, 0))] + #[case([0, 0, 1, 5], (5, 1))] + #[case([0, 0, 2, 5], (5, 2))] + #[case([0, 0, 3, 5], (5, 3))] + fn test_decode_version_offer( + #[case] mut offer: [u8; 4], + #[case] expected: (u8, u8), + #[values([0, 0], [1, 2], [255, 254])] garbage: [u8; 2], + ) { + offer[0..2].copy_from_slice(&garbage); + assert_eq!(decode_version_offer(dbg!(&offer)).unwrap(), expected); + } + + #[test] + fn test_unsupported_server_version() { + let res = decode_version_offer(&[0, 0, 0, 0]); + let Err(Neo4jError::InvalidConfig { message }) = res else { + panic!("Expected InvalidConfig error, got {:?}", res); + }; + let message = message.to_lowercase(); + assert!(message.contains("server version not supported")); + } + + #[test] + fn test_server_version_looks_like_http() { + let res = decode_version_offer(&[72, 84, 84, 80]); + let Err(Neo4jError::InvalidConfig { message }) = res else { + panic!("Expected InvalidConfig error, got {:?}", res); + }; + let message = message.to_lowercase(); + assert!(message.contains("unexpected server handshake response")); + assert!(message.contains("looks like http")); + } + + // [bolt-version-bump] search tag when changing bolt version support + #[rstest] + #[case([0, 0, 0, 1])] // driver didn't offer version 1 + #[case([0, 0, 0, 2])] // driver didn't offer version 2 + #[case([0, 0, 0, 3])] // driver didn't offer version 3 + #[case([0, 0, 0, 4])] // driver didn't offer version 4.0 + #[case([0, 0, 1, 4])] // driver didn't offer version 4.1 + #[case([0, 0, 2, 4])] // driver didn't offer version 4.2 + #[case([0, 0, 3, 4])] // driver didn't offer version 4.3 + #[case([0, 0, 4, 5])] // driver didn't offer version 5.4 + #[case([0, 0, 0, 6])] // driver didn't offer version 6.0 + fn test_garbage_server_version( + #[case] mut offer: [u8; 4], + #[values([0, 0], [1, 2], [255, 254])] garbage: [u8; 2], + ) { + offer[0..2].copy_from_slice(&garbage); + let res = decode_version_offer(&offer); + let Err(Neo4jError::InvalidConfig { message }) = res else { + panic!("Expected InvalidConfig error, got {:?}", res); + }; + let message = message.to_lowercase(); + assert!(message.contains("unexpected server handshake response")); + } + + #[derive(Debug)] + struct MockSocket { + received_data: VecDeque, + response: VecDeque, + } + + #[derive(Debug, Clone)] + struct RRMockSocket(Rc>); + + impl RRMockSocket { + fn new(sock: MockSocket) -> Self { + Self(Rc::new(RefCell::new(sock))) + } + } + + impl Read for RRMockSocket { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.borrow_mut().response.read(buf) + } + } + + impl Write for RRMockSocket { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.borrow_mut().received_data.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.borrow_mut().received_data.flush() + } + } + + #[derive(Debug)] + struct MockAddress { + host: String, + resolves_to: Vec, + } + + impl MockAddress { + fn new_localhost() -> Self { + Self { + host: "localhost".to_string(), + resolves_to: vec![SocketAddr::from(([127, 0, 0, 1], 7687))], + } + } + + fn get_single_socket_addr(&self) -> SocketAddr { + assert_eq!(self.resolves_to.len(), 1); + self.resolves_to[0] + } + } + + impl Display for MockAddress { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Debug::fmt(self, f) + } + } + + impl ToSocketAddrs for MockAddress { + type Iter = vec::IntoIter; + + fn to_socket_addrs(&self) -> io::Result { + Ok(self.resolves_to.clone().into_iter()) + } + } + + impl AddressProvider for MockAddress { + fn unresolved_host(&self) -> &str { + &self.host + } + + fn into_address(self: Arc) -> Arc
{ + Arc::new(Address::from((self.host.as_str(), 1234))) + } + } + + struct MockSocketProvider { + spawn: Box io::Result>, + spawned: VecDeque, + calls: Vec, + } + + impl MockSocketProvider { + fn new io::Result + 'static>(spawn: F) -> Self { + Self { + spawn: Box::new(spawn), + spawned: VecDeque::new(), + calls: Vec::new(), + } + } + + fn new_handshake_5_0() -> Self { + Self::new(|| { + Ok(MockSocket { + received_data: VecDeque::new(), + response: [0, 0, 0, 5].into_iter().collect(), + }) + }) + } + + fn new_handshake_5_0_failing_first_n(mut n: usize) -> Self { + Self::new(move || { + if n == 0 { + Ok(MockSocket { + received_data: VecDeque::new(), + response: [0, 0, 0, 5].into_iter().collect(), + }) + } else { + n -= 1; + Err(io::Error::new(io::ErrorKind::Other, "scripted failure")) + } + }) + } + } + + #[derive(Debug, Clone)] + enum MockSocketProviderCall { + Connect(Arc), + ConnectTimeout { + timeout: Duration, + address: SocketAddr, + }, + SetTcpKeepalive { + keepalive: Option, + }, + Shutdown { + how: Shutdown, + }, + NewBuffered, + NewDeadlineIO { + deadline: Option, + }, + } + + impl SocketProvider for &mut MockSocketProvider { + type RW = RRMockSocket; + type BuffRW = RRMockSocket; + + fn connect(&mut self, addr: &Arc) -> io::Result { + let address_any: Box = Box::new(Arc::clone(addr)); + let mock_address = address_any + .downcast::>() + .expect("MockSocketProvider can only be used with MockAddress"); + + self.calls + .push(MockSocketProviderCall::Connect(*mock_address)); + let sock = RRMockSocket::new((self.spawn)()?); + self.spawned.push_back(sock.clone()); + Ok(sock) + } + + fn get_local_port(&mut self, _sock: &Self::RW) -> u16 { + 1234 + } + + fn connect_timeout( + &mut self, + addr: &SocketAddr, + timeout: Duration, + ) -> io::Result { + self.calls.push(MockSocketProviderCall::ConnectTimeout { + timeout, + address: *addr, + }); + let sock = RRMockSocket::new((self.spawn)()?); + self.spawned.push_back(sock.clone()); + Ok(sock) + } + + fn set_tcp_keepalive( + &mut self, + socket: Self::RW, + keepalive: Option, + ) -> io::Result { + self.calls + .push(MockSocketProviderCall::SetTcpKeepalive { keepalive }); + Ok(socket) + } + + fn shutdown(&mut self, _sock: &Self::RW, how: Shutdown) -> io::Result<()> { + self.calls.push(MockSocketProviderCall::Shutdown { how }); + Ok(()) + } + + fn new_buffered(&mut self, sock: &Self::RW) -> Result { + self.calls.push(MockSocketProviderCall::NewBuffered); + Ok(sock.clone()) + } + + fn new_deadline_io<'rw, S: Read + Write>( + &'_ mut self, + stream: S, + _sock: &'rw Self::RW, + deadline: Option, + ) -> DeadlineIO<'rw, S> { + self.calls + .push(MockSocketProviderCall::NewDeadlineIO { deadline }); + DeadlineIO::new(stream, deadline, None) + } + } + + #[test] + fn test_open() { + let mut provider = MockSocketProvider::new_handshake_5_0(); + let bolt = open( + &mut provider, + Arc::new(MockAddress::new_localhost()), + None, + None, + None, + None, + ) + .unwrap(); + + assert_eq!(bolt.protocol_version(), (5, 0)); + assert_eq!(provider.spawned.len(), 1); + assert_eq!(provider.spawned[0].0.borrow().received_data, { + let mut data = VecDeque::new(); + data.extend(&BOLT_MAGIC_PREAMBLE); + data.extend(&BOLT_VERSION_OFFER); + data + }); + } + + fn get_connect_calls(provider: &MockSocketProvider) -> Vec { + provider + .calls + .iter() + .filter(|call| { + matches!( + call, + MockSocketProviderCall::Connect(_) + | MockSocketProviderCall::ConnectTimeout { .. } + ) + }) + .map(Clone::clone) + .collect::>() + } + + fn get_connect_call(provider: &MockSocketProvider) -> MockSocketProviderCall { + let mut connects = get_connect_calls(provider); + assert_eq!(dbg!(&connects).len(), 1); + connects.pop().unwrap() + } + + #[test] + fn test_open_address() { + let mut provider = MockSocketProvider::new_handshake_5_0(); + let address = Arc::new(MockAddress::new_localhost()); + open(&mut provider, Arc::clone(&address), None, None, None, None).unwrap(); + + let connect = get_connect_call(&provider); + let MockSocketProviderCall::Connect(connect_address) = connect else { + panic!("expected Connect call") + }; + assert!(Arc::ptr_eq(&connect_address, &address)); + } + + #[test] + fn test_open_timeout_address() { + let mut provider = MockSocketProvider::new_handshake_5_0(); + let address = Arc::new(MockAddress::new_localhost()); + let resolved_address = address.get_single_socket_addr(); + const DEADLINE_IN: u64 = 1000; + let deadline = Instant::now() + .checked_add(Duration::from_secs(DEADLINE_IN)) + .unwrap(); + let timeout = Duration::from_secs(500); + + open( + &mut provider, + Arc::clone(&address), + Some(deadline), + Some(timeout), + None, + None, + ) + .unwrap(); + + let connect = get_connect_call(&provider); + let MockSocketProviderCall::ConnectTimeout { + address: connect_address, + timeout: connect_timeout, + } = connect + else { + panic!("expected Connect call") + }; + assert_eq!(connect_address, resolved_address); + assert!(dbg!(connect_timeout) < Duration::from_secs(DEADLINE_IN)) + } + + #[test] + fn test_open_timeout_retry() { + let sock_addr1 = SocketAddr::from_str("127.0.0.1:7687").unwrap(); + let sock_addr2 = SocketAddr::from_str("[::1]:7687").unwrap(); + let address = Arc::new(MockAddress { + host: "localhost".to_string(), + resolves_to: vec![sock_addr1, sock_addr2], + }); + let mut provider = MockSocketProvider::new_handshake_5_0_failing_first_n(1); + const DEADLINE_IN: u64 = 200; + let deadline = Instant::now() + .checked_add(Duration::from_secs(DEADLINE_IN)) + .unwrap(); + let timeout = Duration::from_secs(400); + + open( + &mut provider, + Arc::clone(&address), + Some(deadline), + Some(timeout), + None, + None, + ) + .unwrap(); + + let mut connect_calls: VecDeque<_> = get_connect_calls(&provider).into(); + assert_eq!(dbg!(&connect_calls).len(), 2); + let connect = connect_calls.pop_front().unwrap(); + let MockSocketProviderCall::ConnectTimeout { + address: connect_address, + timeout: connect_timeout1, + } = connect + else { + panic!("expected Connect call") + }; + assert_eq!(dbg!(connect_address), dbg!(sock_addr1)); + assert!(dbg!(connect_timeout1) < Duration::from_secs(DEADLINE_IN)); + + let connect = connect_calls.pop_front().unwrap(); + let MockSocketProviderCall::ConnectTimeout { + address: connect_address, + timeout: connect_timeout2, + } = connect + else { + panic!("expected Connect call") + }; + assert_eq!(connect_address, sock_addr2); + assert!(dbg!(connect_timeout2) < connect_timeout1); + } +} diff --git a/neo4j/src/driver/io/pool/single_pool.rs b/neo4j/src/driver/io/pool/single_pool.rs index cc15b92..9e48a60 100644 --- a/neo4j/src/driver/io/pool/single_pool.rs +++ b/neo4j/src/driver/io/pool/single_pool.rs @@ -128,6 +128,7 @@ impl InnerPool { last_err = match address { Ok(address) => { match bolt::open( + bolt::TcpConnector, address, deadline, self.config.connection_timeout, diff --git a/neo4j/src/driver/session.rs b/neo4j/src/driver/session.rs index 137ece1..438f652 100644 --- a/neo4j/src/driver/session.rs +++ b/neo4j/src/driver/session.rs @@ -920,7 +920,7 @@ impl<'driver, 'session, KM: Borrow + Debug, M: Borrow