Skip to content

Commit

Permalink
Fix tls config unsoundness (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
torkleyy authored Aug 31, 2023
1 parent 2964c5c commit a909b7a
Showing 1 changed file with 71 additions and 42 deletions.
113 changes: 71 additions & 42 deletions src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ impl<'a> Debug for X509<'a> {
any(esp_idf_esp_tls_using_mbedtls, esp_idf_esp_tls_using_wolfssl)
))]
mod esptls {
use core::convert::{TryFrom, TryInto};
use core::task::{Context, Poll};
use core::time::Duration;

Expand Down Expand Up @@ -125,63 +124,50 @@ mod esptls {
is_plain_tcp: false,
}
}
}

impl<'a> Default for Config<'a> {
fn default() -> Self {
Self::new()
}
}

impl<'a> TryFrom<&'a Config<'a>> for sys::esp_tls_cfg {
type Error = EspError;

fn try_from(cfg: &'a Config<'a>) -> Result<Self, EspError> {
let mut rcfg: Self = Default::default();
fn try_into_raw(&self, bufs: &mut RawConfigBufs) -> Result<sys::esp_tls_cfg, EspError> {
let mut rcfg: sys::esp_tls_cfg = Default::default();

if let Some(ca_cert) = cfg.ca_cert {
if let Some(ca_cert) = self.ca_cert {
rcfg.__bindgen_anon_1.cacert_buf = ca_cert.data().as_ptr();
rcfg.__bindgen_anon_2.cacert_bytes = ca_cert.data().len() as u32;
}

if let Some(client_cert) = cfg.client_cert {
if let Some(client_cert) = self.client_cert {
rcfg.__bindgen_anon_3.clientcert_buf = client_cert.data().as_ptr();
rcfg.__bindgen_anon_4.clientcert_bytes = client_cert.data().len() as u32;
}

if let Some(client_key) = cfg.client_key {
if let Some(client_key) = self.client_key {
rcfg.__bindgen_anon_5.clientkey_buf = client_key.data().as_ptr();
rcfg.__bindgen_anon_6.clientkey_bytes = client_key.data().len() as u32;
}

if let Some(ckp) = cfg.client_key_password {
if let Some(ckp) = self.client_key_password {
rcfg.clientkey_password = ckp.as_ptr();
rcfg.clientkey_password_len = ckp.len() as u32;
}

// allow up to 9 protocols
let mut alpn_protos: [*const i8; 10];
let mut alpn_protos_cbuf = [0u8; 99];
if let Some(protos) = cfg.alpn_protos {
alpn_protos = cstr_arr_from_str_slice(protos, &mut alpn_protos_cbuf)?;
rcfg.alpn_protos = alpn_protos.as_mut_ptr();
if let Some(protos) = self.alpn_protos {
bufs.alpn_protos = cstr_arr_from_str_slice(protos, &mut bufs.alpn_protos_cbuf)?;
rcfg.alpn_protos = bufs.alpn_protos.as_mut_ptr();
}

rcfg.non_block = cfg.non_block;
rcfg.use_secure_element = cfg.use_secure_element;
rcfg.timeout_ms = cfg.timeout_ms as i32;
rcfg.use_global_ca_store = cfg.use_global_ca_store;
rcfg.non_block = self.non_block;
rcfg.use_secure_element = self.use_secure_element;
rcfg.timeout_ms = self.timeout_ms as i32;
rcfg.use_global_ca_store = self.use_global_ca_store;

if let Some(common_name) = cfg.common_name {
let mut common_name_buf = [0; MAX_COMMON_NAME_LENGTH + 1];
if let Some(common_name) = self.common_name {
rcfg.common_name =
cstr_from_str_truncating(common_name, &mut common_name_buf).as_ptr();
cstr_from_str_truncating(common_name, &mut bufs.common_name_buf).as_ptr();
}

rcfg.skip_common_name = cfg.skip_common_name;
rcfg.skip_common_name = self.skip_common_name;

let mut raw_kac: sys::tls_keep_alive_cfg;
if let Some(kac) = &cfg.keep_alive_cfg {
if let Some(kac) = &self.keep_alive_cfg {
raw_kac = sys::tls_keep_alive_cfg {
keep_alive_enable: kac.enable,
keep_alive_idle: kac.idle.as_secs() as i32,
Expand All @@ -192,7 +178,7 @@ mod esptls {
}

let mut raw_psk: sys::psk_key_hint;
if let Some(psk) = &cfg.psk_hint_key {
if let Some(psk) = &self.psk_hint_key {
raw_psk = sys::psk_key_hint {
key: psk.key.as_ptr(),
key_size: psk.key.len(),
Expand All @@ -202,11 +188,11 @@ mod esptls {
}

#[cfg(esp_idf_mbedtls_certificate_bundle)]
if cfg.use_crt_bundle_attach {
if self.use_crt_bundle_attach {
rcfg.crt_bundle_attach = Some(sys::esp_crt_bundle_attach);
}

rcfg.is_plain_tcp = cfg.is_plain_tcp;
rcfg.is_plain_tcp = self.is_plain_tcp;

#[cfg(esp_idf_comp_lwip_enabled)]
{
Expand All @@ -217,6 +203,28 @@ mod esptls {
}
}

impl<'a> Default for Config<'a> {
fn default() -> Self {
Self::new()
}
}

struct RawConfigBufs {
alpn_protos: [*const i8; 10],
alpn_protos_cbuf: [u8; 99],
common_name_buf: [u8; MAX_COMMON_NAME_LENGTH + 1],
}

impl Default for RawConfigBufs {
fn default() -> Self {
RawConfigBufs {
alpn_protos: [core::ptr::null(); 10],
alpn_protos_cbuf: [0; 99],
common_name_buf: [0; MAX_COMMON_NAME_LENGTH + 1],
}
}
}

#[derive(Clone, Debug)]
pub struct KeepAliveConfig {
/// Enable keep-alive timeout
Expand Down Expand Up @@ -293,9 +301,16 @@ mod esptls {
/// * `ESP_TLS_ERR_SSL_WANT_WRITE` if the socket is in non-blocking mode and it is not ready for writing
/// * `EWOULDBLOCK` if the socket is in non-blocking mode and it is not ready either for reading or writing (a peculiarity/bug of the `esp-tls` C module)
pub fn connect(&mut self, host: &str, port: u16, cfg: &Config) -> Result<(), EspError> {
let rcfg = cfg.try_into()?;
let mut bufs = RawConfigBufs::default();
let rcfg = cfg.try_into_raw(&mut bufs)?;

let res = self.internal_connect(host, port, cfg.non_block, &rcfg);

// Make sure buffers are held long enough
#[allow(clippy::drop_non_drop)]
drop(bufs);

self.internal_connect(host, port, cfg.non_block, &rcfg)
res
}
}

Expand Down Expand Up @@ -342,9 +357,16 @@ mod esptls {
any(not(esp_idf_version_major = "5"), not(esp_idf_version_minor = "0"))
))]
pub fn negotiate(&mut self, host: &str, cfg: &Config) -> Result<(), EspError> {
let rcfg = cfg.try_into()?;
let mut bufs = RawConfigBufs::default();
let rcfg = cfg.try_into_raw(&mut bufs)?;

self.internal_connect(host, 0, cfg.non_block, &rcfg)
let res = self.internal_connect(host, 0, cfg.non_block, &rcfg);

// Make sure buffers are held long enough
#[allow(clippy::drop_non_drop)]
drop(bufs);

res
}

#[allow(clippy::unnecessary_cast)]
Expand Down Expand Up @@ -551,7 +573,8 @@ mod esptls {
hostname: &str,
cfg: &Config<'_>,
) -> Result<(), EspError> {
let mut rcfg: sys::esp_tls_cfg = cfg.try_into()?;
let mut bufs = RawConfigBufs::default();
let mut rcfg: sys::esp_tls_cfg = cfg.try_into_raw(&mut bufs)?;

// It is a bit unintuitive, but when an async socket is being adopted, `non_block` should be set to false.
//
Expand All @@ -565,7 +588,7 @@ mod esptls {
// must be already connected anyway (API requirement).
rcfg.non_block = false;

loop {
let res = loop {
let res = self
.0
.borrow_mut()
Expand All @@ -575,7 +598,13 @@ mod esptls {
Err(e) => self.wait(e).await?,
other => break other,
}
}
};

// Make sure buffers are held long enough
#[allow(clippy::drop_non_drop)]
drop(bufs);

res
}

/// Read in the supplied buffer. Returns the number of bytes read.
Expand Down Expand Up @@ -619,7 +648,7 @@ mod esptls {
async fn wait(&self, error: EspError) -> Result<(), EspError> {
const EWOULDBLOCK_I32: i32 = EWOULDBLOCK as i32;

match error.code() as i32 {
match error.code() {
// EWOULDBLOCK models the "0" return code of esp_mbedtls_handshake() which does not allow us
// to figure out whether we need the socket to become readable or writable
// The code below is therefore a hack which just waits with a timeout for the socket to (eventually)
Expand Down

0 comments on commit a909b7a

Please sign in to comment.