Skip to content

Commit

Permalink
Add connect helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed May 21, 2020
1 parent d4d6b0e commit 5a7b285
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 99 deletions.
4 changes: 4 additions & 0 deletions ntex/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.1.18] - 2020-05-xx

* ntex::connect: Add `connect` helper function

## [0.1.17] - 2020-05-18

* ntex::util: Add Variant service
Expand Down
16 changes: 15 additions & 1 deletion ntex/src/connect/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
//! Tcp connector service
use std::future::Future;

mod error;
mod message;
mod resolve;
Expand All @@ -15,7 +17,7 @@ pub use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
pub use trust_dns_resolver::error::ResolveError;
use trust_dns_resolver::system_conf::read_system_conf;

use crate::rt::Arbiter;
use crate::rt::{net::TcpStream, Arbiter};

pub use self::error::ConnectError;
pub use self::message::{Address, Connect};
Expand Down Expand Up @@ -46,3 +48,15 @@ pub fn default_resolver() -> AsyncResolver {
resolver
}
}

/// Resolve and connect to remote host
pub fn connect<T: Address, U>(
message: U,
) -> impl Future<Output = Result<TcpStream, ConnectError>>
where
Connect<T>: From<U>,
{
service::ConnectServiceResponse::new(
Resolver::new(default_resolver()).lookup(message.into()),
)
}
62 changes: 38 additions & 24 deletions ntex/src/connect/openssl.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::future::Future;
use std::io;
use std::task::{Context, Poll};

Expand Down Expand Up @@ -33,6 +34,42 @@ impl<T> OpensslConnector<T> {
}
}

impl<T: Address + 'static> OpensslConnector<T> {
/// Resolve and connect to remote host
pub fn connect<U>(
&self,
message: U,
) -> impl Future<Output = Result<SslStream<TcpStream>, ConnectError>>
where
Connect<T>: From<U>,
{
let message = Connect::from(message);
let host = message.host().to_string();
let conn = self.connector.call(message);
let openssl = self.openssl.clone();

async move {
let io = conn.await?;
trace!("SSL Handshake start for: {:?}", host);

match openssl.configure() {
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()),
Ok(config) => match tokio_openssl::connect(config, &host, io).await {
Ok(io) => {
trace!("SSL Handshake success: {:?}", host);
Ok(io)
}
Err(e) => {
trace!("SSL Handshake error: {:?}", e);
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))
.into())
}
},
}
}
}
}

impl<T> Clone for OpensslConnector<T> {
fn clone(&self) -> Self {
OpensslConnector {
Expand Down Expand Up @@ -68,30 +105,7 @@ impl<T: Address + 'static> Service for OpensslConnector<T> {
}

fn call(&self, req: Connect<T>) -> Self::Future {
let host = req.host().to_string();
let conn = self.connector.call(req);
let openssl = self.openssl.clone();

async move {
let io = conn.await?;
trace!("SSL Handshake start for: {:?}", host);

match openssl.configure() {
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()),
Ok(config) => match tokio_openssl::connect(config, &host, io).await {
Ok(io) => {
trace!("SSL Handshake success: {:?}", host);
Ok(io)
}
Err(e) => {
trace!("SSL Handshake error: {:?}", e);
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))
.into())
}
},
}
}
.boxed_local()
self.connect(req).boxed_local()
}
}

Expand Down
61 changes: 37 additions & 24 deletions ntex/src/connect/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::task::{Context, Poll};
pub use rust_tls::Session;
pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};

use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
use futures::future::{ok, Future, FutureExt, LocalBoxFuture, Ready};
use tokio_rustls::{self, TlsConnector};
use webpki::DNSNameRef;

Expand Down Expand Up @@ -37,6 +37,41 @@ impl<T> RustlsConnector<T> {
}
}

impl<T: Address + 'static> RustlsConnector<T> {
/// Resolve and connect to remote host
pub fn connect<U>(
&self,
message: U,
) -> impl Future<Output = Result<TlsStream<TcpStream>, ConnectError>>
where
Connect<T>: From<U>,
{
let req = Connect::from(message);
let host = req.host().to_string();
let conn = self.connector.call(req);
let config = self.config.clone();

async move {
let io = conn.await?;
trace!("SSL Handshake start for: {:?}", host);

let host = DNSNameRef::try_from_ascii_str(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;

match TlsConnector::from(config).connect(host, io).await {
Ok(io) => {
trace!("SSL Handshake success: {:?}", host);
Ok(io)
}
Err(e) => {
trace!("SSL Handshake error: {:?}", e);
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)).into())
}
}
}
}
}

impl<T> Clone for RustlsConnector<T> {
fn clone(&self) -> Self {
Self {
Expand Down Expand Up @@ -72,29 +107,7 @@ impl<T: Address + 'static> Service for RustlsConnector<T> {
}

fn call(&self, req: Connect<T>) -> Self::Future {
let host = req.host().to_string();
let conn = self.connector.call(req);
let config = self.config.clone();

async move {
let io = conn.await?;
trace!("SSL Handshake start for: {:?}", host);

let host = DNSNameRef::try_from_ascii_str(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;

match TlsConnector::from(config).connect(host, io).await {
Ok(io) => {
trace!("SSL Handshake success: {:?}", host);
Ok(io)
}
Err(e) => {
trace!("SSL Handshake error: {:?}", e);
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)).into())
}
}
}
.boxed_local()
self.connect(req).boxed_local()
}
}

Expand Down
125 changes: 75 additions & 50 deletions ntex/src/connect/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};

use either::Either;
use futures::future::{self, err, ok, FutureExt, LocalBoxFuture, Ready};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};

use crate::rt::net::TcpStream;
use crate::service::{Service, ServiceFactory};
Expand All @@ -26,6 +26,19 @@ impl<T> Connector<T> {
}
}

impl<T: Address> Connector<T> {
/// Resolve and connect to remote host
pub fn connect<U>(
&self,
message: U,
) -> impl Future<Output = Result<TcpStream, ConnectError>>
where
Connect<T>: From<U>,
{
ConnectServiceResponse::new(self.resolver.lookup(message.into()))
}
}

impl<T> Default for Connector<T> {
fn default() -> Self {
Connector {
Expand Down Expand Up @@ -70,74 +83,55 @@ impl<T: Address> Service for Connector<T> {

#[inline]
fn call(&self, req: Connect<T>) -> Self::Future {
ConnectServiceResponse {
state: ConnectState::Resolve(self.resolver.lookup(req)),
}
ConnectServiceResponse::new(self.resolver.lookup(req))
}
}

enum ConnectState<T: Address> {
Resolve(<Resolver<T> as Service>::Future),
Connect(ConnectFut<T>),
}

impl<T: Address> ConnectState<T> {
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Either<Poll<Result<TcpStream, ConnectError>>, Connect<T>> {
match self {
ConnectState::Resolve(ref mut fut) => match Pin::new(fut).poll(cx) {
Poll::Pending => Either::Left(Poll::Pending),
Poll::Ready(Ok(res)) => Either::Right(res),
Poll::Ready(Err(err)) => Either::Left(Poll::Ready(Err(err))),
},
ConnectState::Connect(ref mut fut) => Either::Left(Pin::new(fut).poll(cx)),
}
}
Connect(TcpConnectorResponse<T>),
}

#[doc(hidden)]
pub struct ConnectServiceResponse<T: Address> {
state: ConnectState<T>,
}

impl<T: Address> ConnectServiceResponse<T> {
pub(super) fn new(fut: <Resolver<T> as Service>::Future) -> Self {
ConnectServiceResponse {
state: ConnectState::Resolve(fut),
}
}
}

impl<T: Address> Future for ConnectServiceResponse<T> {
type Output = Result<TcpStream, ConnectError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = match self.state.poll(cx) {
Either::Right(res) => {
self.state = ConnectState::Connect(connect(res));
self.state.poll(cx)
}
Either::Left(res) => return res,
};

match res {
Either::Left(res) => res,
Either::Right(_) => panic!(),
match self.state {
ConnectState::Resolve(ref mut fut) => match Pin::new(fut).poll(cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(address) => {
let port = address.port();
let Connect { req, addr, .. } = address;

if let Some(addr) = addr {
self.state = ConnectState::Connect(TcpConnectorResponse::new(
req, port, addr,
));
self.poll(cx)
} else {
error!("TCP connector: got unresolved address");
Poll::Ready(Err(ConnectError::Unresolved))
}
}
},
ConnectState::Connect(ref mut fut) => Pin::new(fut).poll(cx),
}
}
}

type ConnectFut<T> =
future::Either<TcpConnectorResponse<T>, Ready<Result<TcpStream, ConnectError>>>;

/// Connect to remote host.
///
/// Ip address must be resolved.
fn connect<T: Address>(address: Connect<T>) -> ConnectFut<T> {
let port = address.port();
let Connect { req, addr, .. } = address;

if let Some(addr) = addr {
future::Either::Left(TcpConnectorResponse::new(req, port, addr))
} else {
error!("TCP connector: got unresolved address");
future::Either::Right(err(ConnectError::Unresolved))
}
}

/// Tcp stream connector response future
struct TcpConnectorResponse<T> {
req: Option<T>,
Expand Down Expand Up @@ -215,3 +209,34 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[ntex_rt::test]
async fn test_connect() {
let server = crate::server::test_server(|| {
crate::fn_service(|_| async { Ok::<_, ()>(()) })
});

let srv = Connector::default();
let result = srv.connect("").await;
assert!(result.is_err());
let result = srv.connect("localhost-111").await;
assert!(result.is_err());

let srv = Connector::default();
let result = srv.connect(format!("{}", server.addr())).await;
assert!(result.is_ok());

let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![
format!("127.0.0.1:{}", server.addr().port() - 1)
.parse()
.unwrap(),
server.addr(),
]);
let result = crate::connect::connect(msg).await;
assert!(result.is_ok());
}
}

0 comments on commit 5a7b285

Please sign in to comment.