From 26057471e4698f4ee4162ccfbeb73d3a67a169b7 Mon Sep 17 00:00:00 2001 From: Thomas Schaller Date: Thu, 24 Aug 2023 16:08:04 +0200 Subject: [PATCH] Abstraction for synchronous esp-tls (#288) * Start wrapping esp-tls * Improve EspTls * Allow splitting into EspTlsRead and EspTlsWrite * Don't use std * Use str and do alloc-free conversion * Pass Config by reference * Import from sys reexport * Implement crt_bundle_attach, add docs * Fix build, move stuff to private, use str * Make clippy happy * Add AsRawFd * Add example for tls * Make EspTls work with v5 * Remove unnecessary cast * Use core for no-std support * Guard EspTls with appropriate cfg * Change cstr imports to core --- examples/tls.rs | 113 +++++++++++++++ src/private/cstr.rs | 106 ++++++++++++++ src/tls.rs | 329 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 546 insertions(+), 2 deletions(-) create mode 100644 examples/tls.rs diff --git a/examples/tls.rs b/examples/tls.rs new file mode 100644 index 00000000000..29f1219f503 --- /dev/null +++ b/examples/tls.rs @@ -0,0 +1,113 @@ +//! Example of using blocking TLS/TCP. +//! +//! Add your own ssid and password + +use std::ffi::CStr; +use std::io::{BufRead, BufReader}; + +use embedded_svc::io::adapters::ToStd; +use embedded_svc::io::Write; +use embedded_svc::wifi::{AuthMethod, ClientConfiguration, Configuration}; +use esp_idf_hal::prelude::Peripherals; +use esp_idf_svc::log::EspLogger; +use esp_idf_svc::tls::{self, EspTls, X509}; +use esp_idf_svc::wifi::{BlockingWifi, EspWifi}; +use esp_idf_svc::{eventloop::EspSystemEventLoop, nvs::EspDefaultNvsPartition}; +use esp_idf_sys::{self as _}; // If using the `binstart` feature of `esp-idf-sys`, always keep this module imported +use log::info; + +const SSID: &str = env!("WIFI_SSID"); +const PASSWORD: &str = env!("WIFI_PASS"); + +// cannot use include_str because we need a \0 at the end +const CA_CERT: &str = "-----BEGIN CERTIFICATE----- +MIIEvjCCA6agAwIBAgIQBtjZBNVYQ0b2ii+nVCJ+xDANBgkqhkiG9w0BAQsFADBh +MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3 +d3cuZGlnaWNlcnQuY29tMSAwHgYDVQQDExdEaWdpQ2VydCBHbG9iYWwgUm9vdCBD +QTAeFw0yMTA0MTQwMDAwMDBaFw0zMTA0MTMyMzU5NTlaME8xCzAJBgNVBAYTAlVT +MRUwEwYDVQQKEwxEaWdpQ2VydCBJbmMxKTAnBgNVBAMTIERpZ2lDZXJ0IFRMUyBS +U0EgU0hBMjU2IDIwMjAgQ0ExMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAwUuzZUdwvN1PWNvsnO3DZuUfMRNUrUpmRh8sCuxkB+Uu3Ny5CiDt3+PE0J6a +qXodgojlEVbbHp9YwlHnLDQNLtKS4VbL8Xlfs7uHyiUDe5pSQWYQYE9XE0nw6Ddn +g9/n00tnTCJRpt8OmRDtV1F0JuJ9x8piLhMbfyOIJVNvwTRYAIuE//i+p1hJInuW +raKImxW8oHzf6VGo1bDtN+I2tIJLYrVJmuzHZ9bjPvXj1hJeRPG/cUJ9WIQDgLGB +Afr5yjK7tI4nhyfFK3TUqNaX3sNk+crOU6JWvHgXjkkDKa77SU+kFbnO8lwZV21r +eacroicgE7XQPUDTITAHk+qZ9QIDAQABo4IBgjCCAX4wEgYDVR0TAQH/BAgwBgEB +/wIBADAdBgNVHQ4EFgQUt2ui6qiqhIx56rTaD5iyxZV2ufQwHwYDVR0jBBgwFoAU +A95QNVbRTLtm8KPiGxvDl7I90VUwDgYDVR0PAQH/BAQDAgGGMB0GA1UdJQQWMBQG +CCsGAQUFBwMBBggrBgEFBQcDAjB2BggrBgEFBQcBAQRqMGgwJAYIKwYBBQUHMAGG +GGh0dHA6Ly9vY3NwLmRpZ2ljZXJ0LmNvbTBABggrBgEFBQcwAoY0aHR0cDovL2Nh +Y2VydHMuZGlnaWNlcnQuY29tL0RpZ2lDZXJ0R2xvYmFsUm9vdENBLmNydDBCBgNV +HR8EOzA5MDegNaAzhjFodHRwOi8vY3JsMy5kaWdpY2VydC5jb20vRGlnaUNlcnRH +bG9iYWxSb290Q0EuY3JsMD0GA1UdIAQ2MDQwCwYJYIZIAYb9bAIBMAcGBWeBDAEB +MAgGBmeBDAECATAIBgZngQwBAgIwCAYGZ4EMAQIDMA0GCSqGSIb3DQEBCwUAA4IB +AQCAMs5eC91uWg0Kr+HWhMvAjvqFcO3aXbMM9yt1QP6FCvrzMXi3cEsaiVi6gL3z +ax3pfs8LulicWdSQ0/1s/dCYbbdxglvPbQtaCdB73sRD2Cqk3p5BJl+7j5nL3a7h +qG+fh/50tx8bIKuxT8b1Z11dmzzp/2n3YWzW2fP9NsarA4h20ksudYbj/NhVfSbC +EXffPgK2fPOre3qGNm+499iTcc+G33Mw+nur7SpZyEKEOxEXGlLzyQ4UfaJbcme6 +ce1XR2bFuAJKZTRei9AqPCCcUZlM51Ke92sRKw2Sfh3oius2FkOH6ipjv3U/697E +A7sKPPcw7+uvTPyLNhBzPvOk +-----END CERTIFICATE-----\0"; + +fn main() -> anyhow::Result<()> { + EspLogger::initialize_default(); + + let peripherals = Peripherals::take().unwrap(); + let sys_loop = EspSystemEventLoop::take()?; + let nvs = EspDefaultNvsPartition::take()?; + + let mut wifi = BlockingWifi::wrap( + EspWifi::new(peripherals.modem, sys_loop.clone(), Some(nvs))?, + sys_loop, + )?; + + connect_wifi(&mut wifi)?; + + let ip_info = wifi.wifi().sta_netif().get_ip_info()?; + + info!("Wifi DHCP info: {:?}", ip_info); + + let mut tls = EspTls::new( + "example.com", + 1234, + &tls::Config { + common_name: Some("example.com"), + ca_cert: Some(X509::pem( + CStr::from_bytes_with_nul(CA_CERT.as_bytes()).unwrap(), + )), + ..Default::default() + }, + )?; + let (reader, writer) = tls.split(); + let mut reader = BufReader::with_capacity(512, ToStd::new(reader)); + let mut line = String::new(); + // receive line by line from server and echo them back + loop { + reader.read_line(&mut line)?; + writer.write_all(line.as_bytes())?; + line.clear(); + } +} + +fn connect_wifi(wifi: &mut BlockingWifi>) -> anyhow::Result<()> { + let wifi_configuration: Configuration = Configuration::Client(ClientConfiguration { + ssid: SSID.into(), + bssid: None, + auth_method: AuthMethod::WPA2Personal, + password: PASSWORD.into(), + channel: None, + }); + + wifi.set_configuration(&wifi_configuration)?; + + wifi.start()?; + info!("Wifi started"); + + wifi.connect()?; + info!("Wifi connected"); + + wifi.wait_netif_up()?; + info!("Wifi netif up"); + + Ok(()) +} diff --git a/src/private/cstr.rs b/src/private/cstr.rs index f6c6ba43d41..89fcae01902 100644 --- a/src/private/cstr.rs +++ b/src/private/cstr.rs @@ -6,6 +6,8 @@ pub use alloc::ffi::CString; pub use core::ffi::{c_char, CStr}; +use crate::sys::{EspError, ESP_ERR_INVALID_SIZE}; + #[cfg(feature = "alloc")] pub fn set_str(buf: &mut [u8], s: &str) -> Result<(), crate::sys::EspError> { assert!(s.len() < buf.len()); @@ -78,3 +80,107 @@ pub fn nul_to_invalid_arg(_err: alloc::ffi::NulError) -> crate::sys::EspError { pub fn to_cstring_arg(value: &str) -> Result { CString::new(value).map_err(nul_to_invalid_arg) } + +/// str to cstr, will be truncated if str is larger than buf.len() - 1 +/// +/// # Panics +/// +/// * Panics if buffer is empty. +pub fn cstr_from_str_truncating<'a>(rust_str: &str, buf: &'a mut [u8]) -> &'a CStr { + assert!(!buf.is_empty()); + + let max_str_size = buf.len() - 1; // account for NUL + let truncated_str = &rust_str[..max_str_size.min(rust_str.len())]; + buf[..truncated_str.len()].copy_from_slice(truncated_str.as_bytes()); + buf[truncated_str.len()] = b'\0'; + + CStr::from_bytes_with_nul(&buf[..truncated_str.len() + 1]).unwrap() +} + +/// Convert slice of rust strs to NULL-terminated fixed size array of c string pointers +/// +/// # Panics +/// +/// * Panics if cbuf is empty. +/// * Panics if N is <= 1 +pub fn cstr_arr_from_str_slice( + rust_strs: &[&str], + mut cbuf: &mut [u8], +) -> Result<[*const i8; N], EspError> { + assert!(N > 1); + assert!(!cbuf.is_empty()); + + // ensure last element stays NULL + if rust_strs.len() > N - 1 { + return Err(EspError::from_infallible::()); + } + + let mut cstrs = [core::ptr::null(); N]; + + for (i, s) in rust_strs.iter().enumerate() { + let max_str_size = cbuf.len() - 1; // account for NUL + if s.len() > max_str_size { + return Err(EspError::from_infallible::()); + } + cbuf[..s.len()].copy_from_slice(s.as_bytes()); + cbuf[s.len()] = b'\0'; + let cstr = CStr::from_bytes_with_nul(&cbuf[..s.len() + 1]).unwrap(); + cstrs[i] = cstr.as_ptr(); + + cbuf = &mut cbuf[s.len() + 1..]; + } + + Ok(cstrs) +} + +#[cfg(test)] +mod tests { + use super::{cstr_arr_from_str_slice, cstr_from_str_truncating, CStr}; + + #[test] + fn cstr_from_str_happy() { + let mut same_size = [0u8; 6]; + let hello = cstr_from_str_truncating("Hello", &mut same_size); + assert_eq!(hello.to_bytes(), b"Hello"); + + let mut larger = [0u8; 42]; + let hello = cstr_from_str_truncating("Hello", &mut larger); + assert_eq!(hello.to_bytes(), b"Hello"); + } + + #[test] + fn cstr_from_str_unhappy() { + let mut smaller = [0u8; 6]; + let hello = cstr_from_str_truncating("Hello World", &mut smaller); + assert_eq!(hello.to_bytes(), b"Hello"); + } + + #[test] + fn cstr_arr_happy() { + let mut same_size = [0u8; 13]; + let hello = cstr_arr_from_str_slice::<3>(&["Hello", "World"], &mut same_size).unwrap(); + assert_eq!(unsafe { CStr::from_ptr(hello[0]) }.to_bytes(), b"Hello"); + assert_eq!(unsafe { CStr::from_ptr(hello[1]) }.to_bytes(), b"World"); + assert_eq!(hello[2], core::ptr::null()); + } + + #[test] + #[should_panic] + fn cstr_arr_unhappy_n1() { + let mut cbuf = [0u8; 25]; + let _ = cstr_arr_from_str_slice::<1>(&["Hello"], &mut cbuf); + } + + #[test] + fn cstr_arr_unhappy_n_too_small() { + let mut cbuf = [0u8; 25]; + assert!(cstr_arr_from_str_slice::<2>(&["Hello", "World"], &mut cbuf).is_err()); + } + + #[test] + #[should_panic] + fn cstr_arr_unhappy_cbuf_too_small() { + let mut cbuf = [0u8; 12]; + assert!(cstr_arr_from_str_slice::<3>(&["Hello", "World"], &mut cbuf).is_err()); + } +} diff --git a/src/tls.rs b/src/tls.rs index 147c58cc502..c797fa846ae 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,7 +1,12 @@ -//! TLS-related helper types -use core::ffi::{c_char, CStr}; +//! Type safe abstraction for esp-tls + use core::fmt::Debug; +use crate::private::cstr::{c_char, CStr}; + +#[cfg(esp_idf_comp_esp_tls_enabled)] +pub use self::esptls::*; + #[derive(Copy, Clone, Eq, PartialEq)] pub struct X509<'a>(&'a [u8]); @@ -48,3 +53,323 @@ impl<'a> Debug for X509<'a> { write!(f, "X509(...)") } } + +#[cfg(esp_idf_comp_esp_tls_enabled)] +mod esptls { + use core::time::Duration; + + use embedded_svc::io; + + use super::X509; + use crate::{ + errors::EspIOError, + private::cstr::{cstr_arr_from_str_slice, cstr_from_str_truncating, CStr}, + sys::{self, EspError, ESP_ERR_NO_MEM, ESP_FAIL}, + }; + + /// see https://www.ietf.org/rfc/rfc3280.txt ub-common-name-length + const MAX_COMMON_NAME_LENGTH: usize = 64; + + /// Wrapper for `esp-tls` module. Only supports synchronous operation for now. + pub struct EspTls { + reader: EspTlsRead, + writer: EspTlsWrite, + } + + impl EspTls { + /// Create a new blocking TLS/SSL connection. + /// + /// This function establishes a TLS/SSL connection with the specified host in a blocking manner. + /// + /// # Errors + /// + /// * `ESP_ERR_INVALID_SIZE` if `cfg.alpn_protos` exceeds 9 elements or avg 10 bytes/ALPN + /// * `ESP_ERR_NO_MEM` if TLS context could not be allocated + /// * `ESP_FAIL` if connection could not be established + pub fn new(host: &str, port: u16, cfg: &Config) -> Result { + let mut rcfg: sys::esp_tls_cfg = unsafe { core::mem::zeroed() }; + + if let Some(ca_cert) = cfg.ca_cert { + rcfg.__bindgen_anon_1.cacert_buf = ca_cert.data().as_ptr(); + rcfg.__bindgen_anon_2.cacert_bytes = ca_cert.data().len() as u32; + } + + if let Some(client_cert) = cfg.client_cert { + rcfg.__bindgen_anon_3.clientcert_buf = client_cert.data().as_ptr(); + rcfg.__bindgen_anon_4.clientcert_bytes = client_cert.data().len() as u32; + } + + if let Some(client_key) = cfg.client_key { + rcfg.__bindgen_anon_5.clientkey_buf = client_key.data().as_ptr(); + rcfg.__bindgen_anon_6.clientkey_bytes = client_key.data().len() as u32; + } + + if let Some(ckp) = cfg.client_key_password { + rcfg.clientkey_password = ckp.as_ptr(); + rcfg.clientkey_password_len = ckp.len() as u32; + } + + // allow up to 9 protocols + let mut alpn_protos: [*const i8; 10]; + let mut alpn_protos_cbuf = [0u8; 99]; + if let Some(protos) = cfg.alpn_protos { + alpn_protos = cstr_arr_from_str_slice(protos, &mut alpn_protos_cbuf)?; + rcfg.alpn_protos = alpn_protos.as_mut_ptr(); + } + + rcfg.non_block = cfg.non_block; + rcfg.use_secure_element = cfg.use_secure_element; + rcfg.timeout_ms = cfg.timeout_ms as i32; + rcfg.use_global_ca_store = cfg.use_global_ca_store; + + if let Some(common_name) = cfg.common_name { + let mut common_name_buf = [0; MAX_COMMON_NAME_LENGTH + 1]; + rcfg.common_name = + cstr_from_str_truncating(common_name, &mut common_name_buf).as_ptr(); + } + + rcfg.skip_common_name = cfg.skip_common_name; + + let mut raw_kac: sys::tls_keep_alive_cfg; + if let Some(kac) = &cfg.keep_alive_cfg { + raw_kac = sys::tls_keep_alive_cfg { + keep_alive_enable: kac.enable, + keep_alive_idle: kac.idle.as_secs() as i32, + keep_alive_interval: kac.interval.as_secs() as i32, + keep_alive_count: kac.count as i32, + }; + rcfg.keep_alive_cfg = &mut raw_kac as *mut _; + } + + let mut raw_psk: sys::psk_key_hint; + if let Some(psk) = &cfg.psk_hint_key { + raw_psk = sys::psk_key_hint { + key: psk.key.as_ptr(), + key_size: psk.key.len(), + hint: psk.hint.as_ptr(), + }; + rcfg.psk_hint_key = &mut raw_psk as *mut _; + } + + #[cfg(esp_idf_mbedtls_certificate_bundle)] + if cfg.use_crt_bundle_attach { + rcfg.crt_bundle_attach = Some(sys::esp_crt_bundle_attach); + } + + rcfg.is_plain_tcp = cfg.is_plain_tcp; + rcfg.if_name = core::ptr::null_mut(); + + let tls = unsafe { sys::esp_tls_init() }; + if tls.is_null() { + return Err(EspError::from_infallible::()); + } + let ret = unsafe { + sys::esp_tls_conn_new_sync( + host.as_bytes().as_ptr() as *const i8, + host.len() as i32, + port as i32, + &rcfg, + tls, + ) + }; + + if ret == 1 { + Ok(EspTls { + reader: EspTlsRead { raw: tls }, + writer: EspTlsWrite { raw: tls }, + }) + } else { + unsafe { + sys::esp_tls_conn_destroy(tls); + } + + Err(EspError::from_infallible::()) + } + } + + pub fn read(&mut self, buf: &mut [u8]) -> Result { + self.reader.read(buf) + } + + pub fn write(&mut self, buf: &[u8]) -> Result { + self.writer.write(buf) + } + + pub fn split(&mut self) -> (&mut EspTlsRead, &mut EspTlsWrite) { + (&mut self.reader, &mut self.writer) + } + } + + #[cfg(feature = "std")] + impl std::os::fd::AsRawFd for EspTls { + fn as_raw_fd(&self) -> std::os::fd::RawFd { + let mut fd = -1; + let _ = unsafe { sys::esp_tls_get_conn_sockfd(self.reader.raw, &mut fd) }; + + fd + } + } + + impl io::Io for EspTls { + type Error = EspIOError; + } + + impl io::Read for EspTls { + fn read(&mut self, buf: &mut [u8]) -> Result { + self.read(buf) + } + } + + impl io::Write for EspTls { + fn write(&mut self, buf: &[u8]) -> Result { + self.write(buf) + } + + fn flush(&mut self) -> Result<(), EspIOError> { + Ok(()) + } + } + + pub struct EspTlsRead { + raw: *mut sys::esp_tls, + } + + impl EspTlsRead { + pub fn read(&mut self, buf: &mut [u8]) -> Result { + if buf.is_empty() { + return Ok(0); + } + + let ret = self.read_raw(buf); + // ESP docs treat 0 as error, but in Rust it's common to return 0 from `Read::read` to indicate eof + if ret >= 0 { + Ok(ret as usize) + } else { + Err(EspIOError(EspError::from(ret as i32).unwrap())) + } + } + + #[cfg(esp_idf_version_major = "4")] + fn read_raw(&mut self, buf: &mut [u8]) -> isize { + // cannot call esp_tls_conn_read bc it's inline in v4 + let esp_tls = unsafe { core::ptr::read_unaligned(self.raw) }; + let read_func = esp_tls.read.unwrap(); + unsafe { read_func(self.raw, buf.as_mut_ptr() as *mut i8, buf.len()) } + } + + #[cfg(not(esp_idf_version_major = "4"))] + fn read_raw(&mut self, buf: &mut [u8]) -> isize { + use core::ffi::c_void; + + unsafe { sys::esp_tls_conn_read(self.raw, buf.as_mut_ptr() as *mut c_void, buf.len()) } + } + } + + impl io::Io for EspTlsRead { + type Error = EspIOError; + } + + impl io::Read for EspTlsRead { + fn read(&mut self, buf: &mut [u8]) -> Result { + self.read(buf) + } + } + + pub struct EspTlsWrite { + raw: *mut sys::esp_tls, + } + + impl EspTlsWrite { + pub fn write(&mut self, buf: &[u8]) -> Result { + if buf.is_empty() { + return Ok(0); + } + + let ret = self.write_raw(buf); + if ret >= 0 { + Ok(ret as usize) + } else { + Err(EspIOError(EspError::from(ret as i32).unwrap())) + } + } + + #[cfg(esp_idf_version_major = "4")] + fn write_raw(&mut self, buf: &[u8]) -> isize { + // cannot call esp_tls_conn_write bc it's inline + let esp_tls = unsafe { core::ptr::read_unaligned(self.raw) }; + let write_func = esp_tls.write.unwrap(); + unsafe { write_func(self.raw, buf.as_ptr() as *const i8, buf.len()) } + } + + #[cfg(not(esp_idf_version_major = "4"))] + fn write_raw(&mut self, buf: &[u8]) -> isize { + use core::ffi::c_void; + + unsafe { sys::esp_tls_conn_write(self.raw, buf.as_ptr() as *const c_void, buf.len()) } + } + } + + impl io::Io for EspTlsWrite { + type Error = EspIOError; + } + + impl io::Write for EspTlsWrite { + fn write(&mut self, buf: &[u8]) -> Result { + self.write(buf) + } + + fn flush(&mut self) -> Result<(), EspIOError> { + Ok(()) + } + } + + impl Drop for EspTls { + fn drop(&mut self) { + unsafe { + sys::esp_tls_conn_destroy(self.reader.raw); + } + } + } + + #[derive(Default)] + pub struct Config<'a> { + /// up to 9 ALPNs allowed, with avg 10 bytes for each name + pub alpn_protos: Option<&'a [&'a str]>, + pub ca_cert: Option>, + pub client_cert: Option>, + pub client_key: Option>, + pub client_key_password: Option<&'a str>, + pub non_block: bool, + pub use_secure_element: bool, + pub timeout_ms: u32, + pub use_global_ca_store: bool, + pub common_name: Option<&'a str>, + pub skip_common_name: bool, + pub keep_alive_cfg: Option, + pub psk_hint_key: Option>, + /// whether to use esp_crt_bundle_attach, see https://docs.espressif.com/projects/esp-idf/en/latest/esp32s2/api-reference/protocols/esp_crt_bundle.html + #[cfg(esp_idf_mbedtls_certificate_bundle)] + pub use_crt_bundle_attach: bool, + // TODO ds_data not implemented + pub is_plain_tcp: bool, + #[cfg(esp_idf_comp_lwip_enabled)] + pub if_name: sys::ifreq, + } + + #[derive(Clone, Debug)] + pub struct KeepAliveConfig { + /// Enable keep-alive timeout + pub enable: bool, + /// Keep-alive idle time (second) + pub idle: Duration, + /// Keep-alive interval time (second) + pub interval: Duration, + /// Keep-alive packet retry send count + pub count: u32, + } + + pub struct PskHintKey<'a> { + pub key: &'a [u8], + pub hint: &'a CStr, + } +}