diff --git a/rustls-libssl/MATRIX.md b/rustls-libssl/MATRIX.md index 353c573..e473773 100644 --- a/rustls-libssl/MATRIX.md +++ b/rustls-libssl/MATRIX.md @@ -104,7 +104,7 @@ | `SSL_CTX_get_security_callback` | | | | | `SSL_CTX_get_security_level` | | | | | `SSL_CTX_get_ssl_method` | | | | -| `SSL_CTX_get_timeout` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_CTX_get_timeout` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_get_verify_callback` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_get_verify_depth` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_get_verify_mode` | | :white_check_mark: | :white_check_mark: | @@ -119,9 +119,9 @@ | `SSL_CTX_sess_get_get_cb` | | | | | `SSL_CTX_sess_get_new_cb` | | | | | `SSL_CTX_sess_get_remove_cb` | | | | -| `SSL_CTX_sess_set_get_cb` | | :white_check_mark: | :exclamation: [^stub] | -| `SSL_CTX_sess_set_new_cb` | :white_check_mark: | :white_check_mark: | :exclamation: [^stub] | -| `SSL_CTX_sess_set_remove_cb` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_CTX_sess_set_get_cb` | | :white_check_mark: | :white_check_mark: | +| `SSL_CTX_sess_set_new_cb` | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| `SSL_CTX_sess_set_remove_cb` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_sessions` | | | | | `SSL_CTX_set0_CA_list` | | | | | `SSL_CTX_set0_ctlog_store` [^ct] | | | | @@ -191,7 +191,7 @@ | `SSL_CTX_set_ssl_version` [^deprecatedin_3_0] | | | | | `SSL_CTX_set_stateless_cookie_generate_cb` | | | | | `SSL_CTX_set_stateless_cookie_verify_cb` | | | | -| `SSL_CTX_set_timeout` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_CTX_set_timeout` | | :white_check_mark: | :white_check_mark: | | `SSL_CTX_set_tlsext_max_fragment_length` | | | | | `SSL_CTX_set_tlsext_ticket_key_evp_cb` | | | | | `SSL_CTX_set_tlsext_use_srtp` [^srtp] | | | | @@ -216,7 +216,7 @@ | `SSL_CTX_use_serverinfo_ex` | | | | | `SSL_CTX_use_serverinfo_file` | | | | | `SSL_SESSION_dup` | | | | -| `SSL_SESSION_free` | :white_check_mark: | :white_check_mark: | :exclamation: [^stub] | +| `SSL_SESSION_free` | :white_check_mark: | :white_check_mark: | :white_check_mark: | | `SSL_SESSION_get0_alpn_selected` | | | | | `SSL_SESSION_get0_cipher` | | | | | `SSL_SESSION_get0_hostname` | | | | @@ -226,7 +226,7 @@ | `SSL_SESSION_get0_ticket_appdata` | | | | | `SSL_SESSION_get_compress_id` | | | | | `SSL_SESSION_get_ex_data` | | | | -| `SSL_SESSION_get_id` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_SESSION_get_id` | | :white_check_mark: | :white_check_mark: | | `SSL_SESSION_get_master_key` | | | | | `SSL_SESSION_get_max_early_data` | | | | | `SSL_SESSION_get_max_fragment_length` | | | | @@ -252,7 +252,7 @@ | `SSL_SESSION_set_protocol_version` | | | | | `SSL_SESSION_set_time` | | | | | `SSL_SESSION_set_timeout` | | | | -| `SSL_SESSION_up_ref` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_SESSION_up_ref` | | :white_check_mark: | :white_check_mark: | | `SSL_SRP_CTX_free` [^deprecatedin_3_0] [^srp] | | | | | `SSL_SRP_CTX_init` [^deprecatedin_3_0] [^srp] | | | | | `SSL_accept` | | | :white_check_mark: | @@ -316,7 +316,7 @@ | `SSL_get0_security_ex_data` | | | | | `SSL_get0_verified_chain` | | | :white_check_mark: | | `SSL_get1_peer_certificate` | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| `SSL_get1_session` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_get1_session` | | :white_check_mark: | :white_check_mark: | | `SSL_get1_supported_ciphers` | | | | | `SSL_get_SSL_CTX` | | | | | `SSL_get_all_async_fds` | | | | @@ -364,7 +364,7 @@ | `SSL_get_server_random` | | | | | `SSL_get_servername` | | :white_check_mark: | :white_check_mark: | | `SSL_get_servername_type` | | | :white_check_mark: | -| `SSL_get_session` | | :white_check_mark: | :exclamation: [^stub] | +| `SSL_get_session` | | :white_check_mark: | :white_check_mark: | | `SSL_get_shared_ciphers` | | | | | `SSL_get_shared_sigalgs` | | | | | `SSL_get_shutdown` | :white_check_mark: | :white_check_mark: | :white_check_mark: | @@ -461,7 +461,7 @@ | `SSL_set_security_callback` | | | | | `SSL_set_security_level` | | | | | `SSL_set_session` | :white_check_mark: | :white_check_mark: | :exclamation: [^stub] | -| `SSL_set_session_id_context` | | | :exclamation: [^stub] | +| `SSL_set_session_id_context` | | | | | `SSL_set_session_secret_cb` | | | | | `SSL_set_session_ticket_ext` | | | | | `SSL_set_session_ticket_ext_cb` | | | | @@ -519,8 +519,8 @@ | `TLSv1_client_method` [^deprecatedin_1_1_0] [^tls1_method] | | | | | `TLSv1_method` [^deprecatedin_1_1_0] [^tls1_method] | | | | | `TLSv1_server_method` [^deprecatedin_1_1_0] [^tls1_method] | | | | -| `d2i_SSL_SESSION` | | :white_check_mark: | :exclamation: [^stub] | -| `i2d_SSL_SESSION` | | :white_check_mark: | :exclamation: [^stub] | +| `d2i_SSL_SESSION` | | :white_check_mark: | :white_check_mark: | +| `i2d_SSL_SESSION` | | :white_check_mark: | :white_check_mark: | [^stub]: symbol exists, but just returns an error. [^deprecatedin_1_1_0]: deprecated in openssl 1.1.0 diff --git a/rustls-libssl/build.rs b/rustls-libssl/build.rs index 34fb6d2..d676daf 100644 --- a/rustls-libssl/build.rs +++ b/rustls-libssl/build.rs @@ -174,7 +174,6 @@ const ENTRYPOINTS: &[&str] = &[ "SSL_set_post_handshake_auth", "SSL_set_quiet_shutdown", "SSL_set_session", - "SSL_set_session_id_context", "SSL_set_shutdown", "SSL_set_SSL_CTX", "SSL_set_verify", diff --git a/rustls-libssl/src/cache.rs b/rustls-libssl/src/cache.rs new file mode 100644 index 0000000..8371117 --- /dev/null +++ b/rustls-libssl/src/cache.rs @@ -0,0 +1,651 @@ +use core::ptr; +use std::collections::BTreeSet; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::SystemTime; + +use rustls::client::ClientSessionMemoryCache; +use rustls::client::ClientSessionStore; +use rustls::server::StoresServerSessions; + +use crate::entry::{ + SSL_CTX_new_session_cb, SSL_CTX_sess_get_cb, SSL_CTX_sess_remove_cb, SSL_CTX, SSL_SESSION, +}; +use crate::{callbacks, SslSession, SslSessionLookup}; + +/// A container for session caches that can live inside +/// an `SSL_CTX` but outlive a rustls `ServerConfig`/`ClientConfig` +pub struct SessionCaches { + max_size: usize, + + /// the underlying client store. This outlives any given connection. + client: Option>, + + /// the underlying server store. This outlives any given connection. + server: Arc, +} + +impl SessionCaches { + pub fn with_size(max_size: usize) -> Self { + // a user who has one `SSL_CTX` for both clients and servers will end + // up with twice as many sessions as this, since rustls caches + // client and server sessions separately. + // + // the common case is to have those separate (it is, for example, + // impossible to configure certs/keys separately for client and + // servers in a given `SSL_CTX`) so this should be ok. + Self { + max_size, + client: None, + server: Arc::new(ServerSessionStorage::new(max_size)), + } + } + + pub fn set_pointer_to_owning_ssl_ctx(&mut self, ptr: *mut SSL_CTX) { + self.server.set_ssl_ctx(ptr); + } + + /// Get a cache that can be used for an in-construction `ClientConnection` + pub fn get_client(&mut self) -> Arc { + Arc::clone(self.client.get_or_insert_with(|| { + Arc::new(ClientSessionMemoryCache::new(if self.max_size == 0 { + usize::MAX + } else { + self.max_size + })) + })) + } + + /// Get a cache that can be used for a single `ServerConnection` + pub fn get_server(&mut self) -> Arc { + Arc::new(SingleServerCache::new(self.server.clone())) + } + + pub fn set_mode(&mut self, mode: u32) -> u32 { + self.server.set_mode(mode) + } + + pub fn get_timeout(&self) -> u64 { + self.server.get_timeout() + } + + pub fn set_timeout(&mut self, timeout: u64) -> u64 { + self.server.set_timeout(timeout) + } + + pub fn size(&self) -> usize { + self.max_size + } + + pub fn set_size(&mut self, size: usize) -> usize { + let old_size = self.max_size; + self.max_size = size; + self.server + .set_size(if size == 0 { usize::MAX } else { size }); + // divergence: openssl can change the size without emptying the (client) cache + self.client.take(); + old_size + } + + pub fn set_new_callback(&mut self, callback: SSL_CTX_new_session_cb) { + self.server.set_new_callback(callback) + } + + pub fn set_remove_callback(&mut self, callback: SSL_CTX_sess_remove_cb) { + self.server.set_remove_callback(callback) + } + + pub fn set_get_callback(&mut self, callback: SSL_CTX_sess_get_cb) { + self.server.set_get_callback(callback); + } + + pub fn set_context(&mut self, context: &[u8]) { + self.server.set_context(context); + } + + pub fn flush_all(&mut self) { + self.server.flush_all(); + self.client.take(); + } +} + +impl Default for SessionCaches { + fn default() -> Self { + // this is SSL_SESSION_CACHE_MAX_SIZE_DEFAULT + Self::with_size(1024 * 20) + } +} + +#[derive(Debug)] +pub struct ServerSessionStorage { + items: Mutex>>, + parameters: Mutex, + op_count: AtomicUsize, +} + +impl ServerSessionStorage { + fn new(max_size: usize) -> Self { + Self { + items: Mutex::new(BTreeSet::new()), + parameters: Mutex::new(CacheParameters::new(max_size)), + op_count: AtomicUsize::new(0), + } + } + + fn set_mode(&self, mode: u32) -> u32 { + if let Ok(mut inner) = self.parameters.lock() { + let old = inner.mode; + inner.mode = mode; + old + } else { + 0 + } + } + + fn get_timeout(&self) -> u64 { + self.parameters + .lock() + .map(|inner| inner.time_out) + .unwrap_or_default() + } + + fn set_timeout(&self, time_out: u64) -> u64 { + self.parameters + .lock() + .map(|mut inner| { + let old = inner.time_out; + inner.time_out = time_out; + old + }) + .unwrap_or_default() + } + + fn set_size(&self, size: usize) { + if let Ok(mut inner) = self.parameters.lock() { + inner.max_size = size; + } + } + + fn set_new_callback(&self, callback: SSL_CTX_new_session_cb) { + if let Ok(mut inner) = self.parameters.lock() { + inner.callbacks.new_callback = callback; + } + } + + fn set_remove_callback(&self, callback: SSL_CTX_sess_remove_cb) { + if let Ok(mut inner) = self.parameters.lock() { + inner.callbacks.remove_callback = callback; + } + } + + fn set_get_callback(&self, callback: SSL_CTX_sess_get_cb) { + if let Ok(mut inner) = self.parameters.lock() { + inner.callbacks.get_callback = callback; + } + } + + fn set_ssl_ctx(&self, ssl_ctx: *mut SSL_CTX) { + if let Ok(mut inner) = self.parameters.lock() { + inner.callbacks.ssl_ctx = ssl_ctx; + } + } + + fn set_context(&self, context: &[u8]) { + if let Ok(mut inner) = self.parameters.lock() { + context.clone_into(&mut inner.context); + } + } + + fn get_context(&self) -> Vec { + self.parameters + .lock() + .ok() + .map(|inner| inner.context.clone()) + .unwrap_or_default() + } + + fn mode(&self) -> u32 { + self.parameters + .lock() + .map(|inner| inner.mode) + .unwrap_or_default() + } + + fn callbacks(&self) -> CacheCallbacks { + self.parameters + .lock() + .map(|inner| inner.callbacks) + .unwrap_or_default() + } + + fn invoke_new_callback(&self, sess: Arc) -> bool { + callbacks::invoke_session_new_callback(self.callbacks().new_callback, sess) + } + + fn invoke_remove_callback(&self, sess: Arc) { + let callbacks = self.callbacks(); + callbacks::invoke_session_remove_callback( + callbacks.remove_callback, + callbacks.ssl_ctx, + sess, + ); + } + + fn invoke_get_callback(&self, id: &[u8]) -> Option> { + callbacks::invoke_session_get_callback(self.callbacks().get_callback, id) + } + + fn build_new_session(&self, id: Vec, value: Vec) -> Arc { + let context = self.get_context(); + let time_out = ExpiryTime::calculate(TimeBase::now(), self.get_timeout()); + Arc::new(SslSession::new(id, value, context, time_out)) + } + + /// Return `None` if `sess` has the wrong context value. + fn filter_session_context(&self, sess: Arc) -> Option> { + if self.get_context() == sess.context { + Some(sess) + } else { + None + } + } + + fn insert(&self, new: Arc) -> bool { + self.tick(); + + let max_size = self + .parameters + .lock() + .map(|inner| inner.max_size) + .unwrap_or_default(); + + if let Ok(mut items) = self.items.lock() { + let inserted = items.insert(new); + + while items.len() > max_size { + Self::flush_oldest(&mut items); + } + + inserted + } else { + false + } + } + + fn take(&self, id: &[u8]) -> Option> { + self.tick(); + + if let Ok(mut items) = self.items.lock() { + items.take(&SslSessionLookup::for_id(id)) + } else { + None + } + } + + fn find_by_id(&self, id: &[u8]) -> Option> { + self.tick(); + + if let Ok(items) = self.items.lock() { + items.get(&SslSessionLookup::for_id(id)).cloned() + } else { + None + } + } + + fn flush_all(&self) { + if let Ok(mut items) = self.items.lock() { + let callbacks = self.callbacks(); + if let Some(callback) = callbacks.remove_callback { + // if we have a callback to invoke, do it the slow way + while let Some(sess) = items.pop_first() { + callbacks::invoke_session_remove_callback( + Some(callback), + callbacks.ssl_ctx, + sess, + ); + } + } else { + // otherwise, this is quicker. + items.clear(); + } + } + } + + fn flush_expired(&self, at_time: TimeBase) { + if let Ok(mut items) = self.items.lock() { + let callbacks = self.callbacks(); + if let Some(callback) = callbacks.remove_callback { + // if we have a callback to invoke, do it the slow way + let mut removal_list: BTreeSet<_> = items + .iter() + .filter(|item| item.expired(at_time)) + .cloned() + .collect(); + + while let Some(sess) = removal_list.pop_first() { + items.remove(&sess); + callbacks::invoke_session_remove_callback( + Some(callback), + callbacks.ssl_ctx, + sess, + ); + } + } else { + items.retain(|item| !item.expired(at_time)); + } + } + } + + fn tick(&self) { + // Called every cache operation. Every 255 operations, expire + // sessions (unless application opts out with CACHE_MODE_NO_AUTO_CLEAR). + let op_count = self.op_count.fetch_add(1, Ordering::SeqCst); + if self.mode() & CACHE_MODE_NO_AUTO_CLEAR == 0 && op_count & 0xff == 0xff { + self.flush_expired(TimeBase::now()); + } + } + + fn flush_oldest(items: &mut BTreeSet>) { + let oldest = items.iter().min_by_key(|item| item.expiry_time.0); + if let Some(oldest) = oldest.cloned() { + items.take(&oldest); + } + } +} + +#[derive(Debug)] +struct CacheParameters { + callbacks: CacheCallbacks, + mode: u32, + context: Vec, + max_size: usize, + time_out: u64, +} + +impl CacheParameters { + fn new(max_size: usize) -> Self { + Self { + callbacks: CacheCallbacks::default(), + mode: CACHE_MODE_SERVER, + context: vec![], + max_size, + // See + time_out: 300, + } + } +} + +/// A `StoresServerSessions` implementor that is bound to a single `SSL`, +/// and tracks which `SSL_SESSION` was most recently used, to allow +/// `SSL_get_session` to work. +#[derive(Debug)] +pub struct SingleServerCache { + parent: Arc, + most_recent_session: Mutex>>, +} + +impl SingleServerCache { + fn new(parent: Arc) -> Self { + Self { + parent, + most_recent_session: Mutex::new(None), + } + } + + fn is_enabled(&self) -> bool { + self.parent.mode() & CACHE_MODE_SERVER == CACHE_MODE_SERVER + } + + fn save_most_recent_session(&self, sess: Arc) { + if let Ok(mut old) = self.most_recent_session.lock() { + *old = Some(sess); + } + } + + pub fn get_most_recent_session(&self) -> Option> { + self.most_recent_session + .lock() + .ok() + .and_then(|inner| inner.clone()) + } + + pub fn borrow_most_recent_session(&self) -> *mut SSL_SESSION { + if let Ok(inner) = self.most_recent_session.lock() { + inner + .as_ref() + .map(|sess| Arc::as_ptr(sess) as *mut SSL_SESSION) + .unwrap_or_else(ptr::null_mut) + } else { + ptr::null_mut() + } + } +} + +impl StoresServerSessions for SingleServerCache { + fn put(&self, id: Vec, value: Vec) -> bool { + if !self.is_enabled() { + return false; + } + + let sess = self.parent.build_new_session(id, value); + + self.save_most_recent_session(sess.clone()); + + let possibly_stored_elsewhere = self.parent.invoke_new_callback(sess.clone()); + + if self.parent.mode() & CACHE_MODE_NO_INTERNAL_STORE == 0 { + self.parent.insert(sess) || possibly_stored_elsewhere + } else { + possibly_stored_elsewhere + } + } + + fn get(&self, id: &[u8]) -> Option> { + if !self.is_enabled() { + return None; + } + + if self.parent.mode() & CACHE_MODE_NO_INTERNAL_LOOKUP == 0 { + let sess = self + .parent + .find_by_id(id) + .and_then(|sess| self.parent.filter_session_context(sess)); + if let Some(sess) = sess { + self.save_most_recent_session(sess.clone()); + return Some(sess.value.clone()); + } + } + + if let Some(sess) = self + .parent + .invoke_get_callback(id) + .and_then(|sess| self.parent.filter_session_context(sess)) + { + return Some(sess.value.clone()); + } + + None + } + + fn take(&self, id: &[u8]) -> Option> { + if !self.is_enabled() { + return None; + } + + if self.parent.mode() & CACHE_MODE_NO_INTERNAL_LOOKUP == 0 { + let sess = self + .parent + .take(id) + .and_then(|sess| self.parent.filter_session_context(sess)); + + if let Some(sess) = sess { + // inform external cache that this session is being consumed + self.parent.invoke_remove_callback(sess.clone()); + + self.save_most_recent_session(sess.clone()); + return Some(sess.value.clone()); + } + } + + // look up in external cache + if let Some(sess) = self + .parent + .invoke_get_callback(id) + .and_then(|sess| self.parent.filter_session_context(sess)) + { + self.save_most_recent_session(sess.clone()); + self.parent.invoke_remove_callback(sess.clone()); + return Some(sess.value.clone()); + } + + None + } + + fn can_cache(&self) -> bool { + self.is_enabled() + } +} + +const CACHE_MODE_SERVER: u32 = 0x02; +const CACHE_MODE_NO_AUTO_CLEAR: u32 = 0x080; +const CACHE_MODE_NO_INTERNAL_LOOKUP: u32 = 0x100; +const CACHE_MODE_NO_INTERNAL_STORE: u32 = 0x200; + +#[derive(Clone, Copy, Debug)] +struct CacheCallbacks { + new_callback: SSL_CTX_new_session_cb, + remove_callback: SSL_CTX_sess_remove_cb, + get_callback: SSL_CTX_sess_get_cb, + ssl_ctx: *mut SSL_CTX, +} + +impl Default for CacheCallbacks { + fn default() -> Self { + Self { + new_callback: None, + remove_callback: None, + get_callback: None, + ssl_ctx: ptr::null_mut(), + } + } +} + +// `ssl_ctx` is not Send, but we don't dereference it (could +// equally be an integer). +unsafe impl Send for CacheCallbacks {} + +#[derive(Debug, Clone, Copy)] +pub struct ExpiryTime(pub u64); + +impl ExpiryTime { + fn calculate(now: TimeBase, life_time_secs: u64) -> ExpiryTime { + ExpiryTime(now.0.saturating_add(life_time_secs)) + } + + pub fn in_past(&self, time: TimeBase) -> bool { + self.0 < time.0 + } +} + +#[derive(Debug, Clone, Copy)] +pub struct TimeBase(u64); + +impl TimeBase { + pub fn now() -> Self { + Self( + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map(|n| n.as_secs()) + .unwrap_or_default(), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn flush_expired() { + let cache = ServerSessionStorage::new(10); + + for i in 1..=5 { + assert!(cache.insert( + SslSession::new(vec![i], vec![], vec![], ExpiryTime(10 + i as u64)).into() + )); + } + + // expires items 1, 2 + cache.flush_expired(TimeBase(10 + 3)); + + assert!(cache.find_by_id(&[1]).is_none()); + assert!(cache.find_by_id(&[2]).is_none()); + assert!(cache.find_by_id(&[3]).is_some()); + assert!(cache.find_by_id(&[4]).is_some()); + assert!(cache.find_by_id(&[5]).is_some()); + } + + #[test] + fn respects_max_size() { + let cache = ServerSessionStorage::new(4); + + for i in 1..=5 { + assert!(cache.insert( + SslSession::new(vec![i], vec![], vec![], ExpiryTime(10 + i as u64)).into() + )); + } + + assert!(cache.find_by_id(&[1]).is_none()); + assert!(cache.find_by_id(&[2]).is_some()); + assert!(cache.find_by_id(&[3]).is_some()); + assert!(cache.find_by_id(&[4]).is_some()); + assert!(cache.find_by_id(&[5]).is_some()); + } + + #[test] + fn respects_change_in_max_size() { + let cache = ServerSessionStorage::new(5); + + for i in 1..=5 { + assert!(cache.insert( + SslSession::new(vec![i], vec![], vec![], ExpiryTime(10 + i as u64)).into() + )); + } + + assert!(cache.find_by_id(&[1]).is_some()); + assert!(cache.find_by_id(&[2]).is_some()); + assert!(cache.find_by_id(&[3]).is_some()); + assert!(cache.find_by_id(&[4]).is_some()); + assert!(cache.find_by_id(&[5]).is_some()); + + cache.set_size(4); + assert!(cache.insert(SslSession::new(vec![6], vec![], vec![], ExpiryTime(16)).into())); + + assert!(cache.find_by_id(&[1]).is_none()); + assert!(cache.find_by_id(&[2]).is_none()); + assert!(cache.find_by_id(&[3]).is_some()); + assert!(cache.find_by_id(&[4]).is_some()); + assert!(cache.find_by_id(&[5]).is_some()); + assert!(cache.find_by_id(&[6]).is_some()); + } + + #[test] + fn respects_context() { + let cache = ServerSessionStorage::new(5); + cache.set_context(b"hello"); + + assert!(cache + .insert(SslSession::new(vec![1], vec![], b"hello".to_vec(), ExpiryTime(10)).into())); + assert!(cache + .insert(SslSession::new(vec![2], vec![], b"goodbye".to_vec(), ExpiryTime(10)).into())); + + assert!(cache + .find_by_id(&[1]) + .and_then(|sess| cache.filter_session_context(sess)) + .is_some()); + assert!(cache + .find_by_id(&[2]) + .and_then(|sess| cache.filter_session_context(sess)) + .is_none()); + } +} diff --git a/rustls-libssl/src/callbacks.rs b/rustls-libssl/src/callbacks.rs index 05f0b62..b7a727d 100644 --- a/rustls-libssl/src/callbacks.rs +++ b/rustls-libssl/src/callbacks.rs @@ -1,14 +1,18 @@ use core::cell::RefCell; use core::ffi::{c_int, c_uchar, c_void}; use core::{ptr, slice}; +use std::sync::Arc; use openssl_sys::{SSL_TLSEXT_ERR_NOACK, SSL_TLSEXT_ERR_OK}; use rustls::AlertDescription; use crate::entry::{ - SSL_CTX_alpn_select_cb_func, SSL_CTX_cert_cb_func, SSL_CTX_servername_callback_func, SSL, + SSL_CTX_alpn_select_cb_func, SSL_CTX_cert_cb_func, SSL_CTX_new_session_cb, + SSL_CTX_servername_callback_func, SSL_CTX_sess_get_cb, SSL_CTX_sess_remove_cb, + _SSL_SESSION_free, SSL, SSL_CTX, SSL_SESSION, }; use crate::error::Error; +use crate::ffi; /// Smuggling SSL* pointers from the outer entrypoint into the /// callback call site. @@ -179,3 +183,76 @@ impl Default for ServerNameCallbackConfig { } } } + +/// Returns true if a callback was actually called. +/// +/// It is unknowable if this means something was stored externally. +pub fn invoke_session_new_callback( + callback: SSL_CTX_new_session_cb, + sess: Arc, +) -> bool { + let callback = match callback { + Some(callback) => callback, + None => { + return false; + } + }; + + let ssl = SslCallbackContext::ssl_ptr(); + let sess_ptr = Arc::into_raw(sess) as *mut SSL_SESSION; + + let result = unsafe { callback(ssl, sess_ptr) }; + + // "If the callback returns 1, the application retains the reference" + if result == 0 { + _SSL_SESSION_free(sess_ptr); + } + true +} + +pub fn invoke_session_get_callback( + callback: SSL_CTX_sess_get_cb, + id: &[u8], +) -> Option> { + let callback = match callback { + Some(callback) => callback, + None => { + return None; + } + }; + + let ssl_ptr = SslCallbackContext::ssl_ptr(); + let mut copy = 1; + let sess_ptr = unsafe { callback(ssl_ptr, id.as_ptr(), id.len() as c_int, &mut copy) }; + + if sess_ptr.is_null() { + return None; + } + + let maybe_sess = ffi::clone_arc(sess_ptr); + + if copy > 0 { + _SSL_SESSION_free(sess_ptr); + } + + maybe_sess +} + +pub fn invoke_session_remove_callback( + callback: SSL_CTX_sess_remove_cb, + ssl_ctx: *mut SSL_CTX, + sess: Arc, +) { + let callback = match callback { + Some(callback) => callback, + None => { + return; + } + }; + + let sess_ptr = Arc::into_raw(sess) as *mut SSL_SESSION; + + unsafe { callback(ssl_ctx, sess_ptr) }; + + _SSL_SESSION_free(sess_ptr); +} diff --git a/rustls-libssl/src/entry.rs b/rustls-libssl/src/entry.rs index c1105d7..cc8dc15 100644 --- a/rustls-libssl/src/entry.rs +++ b/rustls-libssl/src/entry.rs @@ -6,6 +6,7 @@ use core::{mem, ptr}; use std::io::{self, Read}; use std::os::raw::{c_char, c_int, c_long, c_uchar, c_uint, c_void}; +use std::sync::Arc; use std::{fs, path::PathBuf}; use openssl_sys::{ @@ -20,9 +21,9 @@ use crate::error::{ffi_panic_boundary, Error, MysteriouslyOppositeReturnValue}; use crate::evp_pkey::EvpPkey; use crate::ex_data::ExData; use crate::ffi::{ - clone_arc, free_arc, str_from_cstring, to_arc_mut_ptr, try_clone_arc, try_from, - try_mut_slice_int, try_ref_from_ptr, try_slice, try_slice_int, try_str, Castable, OwnershipArc, - OwnershipRef, + clone_arc, free_arc, free_arc_into_inner, str_from_cstring, to_arc_mut_ptr, try_clone_arc, + try_from, try_mut_slice_int, try_ref_from_ptr, try_slice, try_slice_int, try_str, Castable, + OwnershipArc, OwnershipRef, }; use crate::not_thread_safe::NotThreadSafe; use crate::x509::{load_certs, OwnedX509, OwnedX509Stack}; @@ -113,18 +114,15 @@ pub type SSL_CTX = crate::SslContext; entry! { pub fn _SSL_CTX_new(meth: *const SSL_METHOD) -> *mut SSL_CTX { let method = try_ref_from_ptr!(meth); - let out = to_arc_mut_ptr(NotThreadSafe::new(crate::SslContext::new(method))); - let ex_data = match ExData::new_ssl_ctx(out) { - None => { + let out: *mut SSL_CTX = to_arc_mut_ptr(NotThreadSafe::new(crate::SslContext::new(method))); + // safety: we just made this object, the pointer must be valid + match clone_arc(out).unwrap().get_mut().complete_construction(out) { + Err(err) => { _SSL_CTX_free(out); - return ptr::null_mut(); + err.raise().into() } - Some(ex_data) => ex_data, - }; - - // safety: we just made this object, the pointer must be valid - clone_arc(out).unwrap().get_mut().install_ex_data(ex_data); - out + Ok(()) => out, + } } } @@ -138,7 +136,9 @@ entry! { entry! { pub fn _SSL_CTX_free(ctx: *mut SSL_CTX) { - free_arc(ctx); + if let Some(inner) = free_arc_into_inner(ctx) { + inner.get_mut().flush_all_sessions(); + } } } @@ -228,6 +228,19 @@ entry! { ctx.get_mut().set_servername_callback_context(parg); C_INT_SUCCESS as c_long } + Ok(SslCtrl::SetSessCacheSize) => { + if larg < 0 { + return 0; + } + ctx.get_mut().set_session_cache_size(larg as usize) as c_long + } + Ok(SslCtrl::GetSessCacheSize) => ctx.get().get_session_cache_size() as c_long, + Ok(SslCtrl::SetSessCacheMode) => { + if larg < 0 { + return 0; + } + ctx.get_mut().set_session_cache_mode(larg as u32) as c_long + } Err(()) => { log::warn!("unimplemented _SSL_CTX_ctrl(..., {cmd}, {larg}, ...)"); 0 @@ -629,15 +642,73 @@ entry! { entry! { pub fn _SSL_CTX_set_session_id_context( - _ctx: *mut SSL_CTX, - _sid_ctx: *const c_uchar, - _sid_ctx_len: c_uint, + ctx: *mut SSL_CTX, + sid_ctx: *const c_uchar, + sid_ctx_len: c_uint, ) -> c_int { - log::warn!("SSL_CTX_set_session_id_context not yet implemented"); + let sid_ctx = try_slice!(sid_ctx, sid_ctx_len); + try_clone_arc!(ctx) + .get_mut() + .set_session_id_context(sid_ctx); C_INT_SUCCESS } } +entry! { + pub fn _SSL_CTX_sess_set_new_cb(ctx: *mut SSL_CTX, new_session_cb: SSL_CTX_new_session_cb) { + try_clone_arc!(ctx) + .get_mut() + .set_session_new_cb(new_session_cb) + } +} + +pub type SSL_CTX_new_session_cb = + Option c_int>; + +entry! { + pub fn _SSL_CTX_sess_set_get_cb(ctx: *mut SSL_CTX, get_session_cb: SSL_CTX_sess_get_cb) { + try_clone_arc!(ctx) + .get_mut() + .set_session_get_cb(get_session_cb) + } +} + +pub type SSL_CTX_sess_get_cb = Option< + unsafe extern "C" fn( + ssl: *mut SSL, + data: *const c_uchar, + len: c_int, + copy: *mut c_int, + ) -> *mut SSL_SESSION, +>; + +entry! { + pub fn _SSL_CTX_sess_set_remove_cb( + ctx: *mut SSL_CTX, + remove_session_cb: SSL_CTX_sess_remove_cb, + ) { + try_clone_arc!(ctx) + .get_mut() + .set_session_remove_cb(remove_session_cb) + } +} + +pub type SSL_CTX_sess_remove_cb = + Option; + +entry! { + pub fn _SSL_CTX_get_timeout(ctx: *const SSL_CTX) -> c_long { + try_clone_arc!(ctx).get().get_session_timeout() as c_long + } +} + +entry! { + pub fn _SSL_CTX_set_timeout(ctx: *mut SSL_CTX, t: c_long) -> c_long { + let t = if t < 0 { 0 } else { t as u64 }; + try_clone_arc!(ctx).get_mut().set_session_timeout(t) as c_long + } +} + impl Castable for SSL_CTX { type Ownership = OwnershipArc; type RustType = NotThreadSafe; @@ -748,7 +819,11 @@ entry! { C_INT_SUCCESS as i64 } // not a defined operation in the OpenSSL API - Ok(SslCtrl::SetTlsExtServerNameCallback) | Ok(SslCtrl::SetTlsExtServerNameArg) => 0, + Ok(SslCtrl::SetTlsExtServerNameCallback) + | Ok(SslCtrl::SetTlsExtServerNameArg) + | Ok(SslCtrl::SetSessCacheSize) + | Ok(SslCtrl::GetSessCacheSize) + | Ok(SslCtrl::SetSessCacheMode) => 0, Err(()) => { log::warn!("unimplemented _SSL_ctrl(..., {cmd}, {larg}, ...)"); 0 @@ -1261,6 +1336,22 @@ entry! { } } +entry! { + pub fn _SSL_get1_session(ssl: *mut SSL) -> *mut SSL_SESSION { + try_clone_arc!(ssl) + .get() + .get_current_session() + .map(|sess| Arc::into_raw(sess) as *mut SSL_SESSION) + .unwrap_or_else(ptr::null_mut) + } +} + +entry! { + pub fn _SSL_get_session(ssl: *const SSL) -> *mut SSL_SESSION { + try_clone_arc!(ssl).get().borrow_current_session() + } +} + impl Castable for SSL { type Ownership = OwnershipArc; type RustType = NotThreadSafe; @@ -1431,6 +1522,91 @@ entry! { } } +pub type SSL_SESSION = crate::SslSession; + +entry! { + pub fn _SSL_SESSION_get_id(sess: *const SSL_SESSION, len: *mut c_uint) -> *const c_uchar { + if len.is_null() { + return ptr::null(); + } + + let sess = try_clone_arc!(sess); + let id = sess.get_id(); + unsafe { *len = id.len() as c_uint }; + id.as_ptr() + } +} + +entry! { + pub fn _SSL_SESSION_up_ref(sess: *mut SSL_SESSION) -> c_int { + let sess = try_clone_arc!(sess); + mem::forget(sess.clone()); + C_INT_SUCCESS + } +} + +entry! { + pub fn _d2i_SSL_SESSION( + a: *mut *mut SSL_SESSION, + pp: *mut *const c_uchar, + length: c_long, + ) -> *mut SSL_SESSION { + if !a.is_null() { + return Error::not_supported("d2i_SSL_SESSION with a != NULL") + .raise() + .into(); + } + + if pp.is_null() { + return Error::bad_data("d2i_SSL_SESSION with pp == NULL") + .raise() + .into(); + } + + let ptr = unsafe { ptr::read(pp) }; + let slice = try_slice!(ptr, length); + + let (sess, rest) = match SSL_SESSION::decode(slice) { + Some(r) => r, + None => { + return Error::bad_data("cannot decode SSL_SESSION").raise().into(); + } + }; + let consumed_bytes = slice.len() - rest.len(); + + // move along *pp + unsafe { ptr::write(pp, ptr.add(consumed_bytes)) }; + to_arc_mut_ptr(sess) + } +} + +entry! { + pub fn _i2d_SSL_SESSION(sess: *const SSL_SESSION, pp: *mut *mut c_uchar) -> c_int { + let sess = try_clone_arc!(sess); + let encoded = sess.encode(); + + if !pp.is_null() { + let ptr = unsafe { ptr::read(pp) }; + unsafe { + ptr::copy_nonoverlapping(encoded.as_ptr(), ptr, encoded.len()); + ptr::write(pp, ptr.add(encoded.len())); + } + } + encoded.len() as c_int + } +} + +entry! { + pub fn _SSL_SESSION_free(sess: *mut SSL_SESSION) { + free_arc(sess); + } +} + +impl Castable for SSL_SESSION { + type Ownership = OwnershipArc; + type RustType = SSL_SESSION; +} + /// Normal OpenSSL return value convention success indicator. /// /// Compare [`crate::ffi::MysteriouslyOppositeReturnValue`]. @@ -1475,6 +1651,9 @@ num_enum! { enum SslCtrl { Mode = 33, SetMsgCallbackArg = 16, + SetSessCacheSize = 42, + GetSessCacheSize = 43, + SetSessCacheMode = 44, SetTlsExtServerNameCallback = 53, SetTlsExtServerNameArg = 54, SetTlsExtHostname = 55, @@ -1518,49 +1697,10 @@ entry_stub! { pub fn _SSL_set_session(_ssl: *mut SSL, _session: *mut SSL_SESSION) -> c_int; } -entry_stub! { - pub fn _SSL_get1_session(_ssl: *mut SSL) -> *mut SSL_SESSION; -} - -entry_stub! { - pub fn _SSL_get_session(_ssl: *const SSL) -> *mut SSL_SESSION; -} - entry_stub! { pub fn _SSL_CTX_remove_session(_ssl: *const SSL, _session: *mut SSL_SESSION) -> c_int; } -entry_stub! { - pub fn _SSL_CTX_sess_set_get_cb(_ctx: *mut SSL_CTX, _get_session_cb: SSL_CTX_sess_get_cb); -} - -pub type SSL_CTX_sess_get_cb = Option< - unsafe extern "C" fn( - ssl: *mut SSL, - data: *const c_uchar, - len: c_int, - copy: *mut c_int, - ) -> *mut SSL_SESSION, ->; - -entry_stub! { - pub fn _SSL_CTX_sess_set_remove_cb( - _ctx: *mut SSL_CTX, - _remove_session_cb: SSL_CTX_sess_remove_cb, - ); -} - -pub type SSL_CTX_sess_remove_cb = - Option; - -entry_stub! { - pub fn _SSL_set_session_id_context( - _ssl: *mut SSL, - _sid_ctx: *const c_uchar, - _sid_ctx_len: c_uint, - ) -> c_int; -} - entry_stub! { pub fn _SSL_CTX_set_keylog_callback(_ctx: *mut SSL_CTX, _cb: SSL_CTX_keylog_cb_func); } @@ -1572,33 +1712,6 @@ entry_stub! { pub fn _SSL_CTX_add_client_CA(_ctx: *mut SSL_CTX, _x: *mut X509) -> c_int; } -entry_stub! { - pub fn _SSL_CTX_sess_set_new_cb(_ctx: *mut SSL_CTX, _new_session_cb: SSL_CTX_new_session_cb); -} - -pub type SSL_CTX_new_session_cb = - Option c_int>; - -entry_stub! { - pub fn _SSL_SESSION_get_id(_s: *const SSL_SESSION, _len: *mut c_uint) -> *const c_uchar; -} - -entry_stub! { - pub fn _SSL_SESSION_up_ref(_ses: *mut SSL_SESSION) -> c_int; -} - -entry_stub! { - pub fn _d2i_SSL_SESSION( - _a: *mut *mut SSL_SESSION, - _pp: *mut *const c_uchar, - _length: c_long, - ) -> *mut SSL_SESSION; -} - -entry_stub! { - pub fn _i2d_SSL_SESSION(_in: *const SSL_SESSION, _pp: *mut *mut c_uchar) -> c_int; -} - entry_stub! { pub fn _SSL_CTX_set_ciphersuites(_ctx: *mut SSL_CTX, _s: *const c_char) -> c_int; } @@ -1616,12 +1729,6 @@ entry_stub! { pub fn _SSL_CTX_set_default_verify_store(_ctx: *mut SSL_CTX) -> c_int; } -pub struct SSL_SESSION; - -entry_stub! { - pub fn _SSL_SESSION_free(_sess: *mut SSL_SESSION); -} - entry_stub! { pub fn _SSL_write_early_data( _ssl: *mut SSL, @@ -1640,14 +1747,6 @@ entry_stub! { ) -> c_int; } -entry_stub! { - pub fn _SSL_CTX_get_timeout(_ctx: *const SSL_CTX) -> c_long; -} - -entry_stub! { - pub fn _SSL_CTX_set_timeout(_ctx: *mut SSL_CTX, _t: c_long) -> c_long; -} - entry_stub! { pub fn _SSL_CTX_get_client_CA_list(_ctx: *const SSL_CTX) -> *mut stack_st_X509_NAME; } @@ -2010,4 +2109,29 @@ mod tests { 0 ); } + + #[test] + fn test_SSL_SESSION_roundtrip() { + let sess = crate::SslSession::new( + vec![1; 32], + vec![2; 32], + vec![3; 128], + crate::cache::ExpiryTime(123), + ); + let sess_ptr = to_arc_mut_ptr(sess); + + let mut buffer = [0u8; 1024]; + let mut ptr = buffer.as_mut_ptr(); + let len = _i2d_SSL_SESSION(sess_ptr, &mut ptr); + + println!("encoding: {:?}", &buffer[..len as usize]); + + let mut ptr = buffer.as_ptr(); + let new_sess = _d2i_SSL_SESSION(ptr::null_mut(), &mut ptr, buffer.len() as c_long); + assert!(!new_sess.is_null()); + assert_eq!(len as usize, (ptr as usize) - (buffer.as_ptr() as usize)); + + _SSL_SESSION_free(new_sess); + _SSL_SESSION_free(sess_ptr); + } } diff --git a/rustls-libssl/src/ffi.rs b/rustls-libssl/src/ffi.rs index d19ad08..9ddc8bd 100644 --- a/rustls-libssl/src/ffi.rs +++ b/rustls-libssl/src/ffi.rs @@ -165,6 +165,21 @@ where drop(unsafe { Arc::from_raw(rs_typed) }); } +/// Similar to `free_arc`, but call `into_inner` on the Arc instead of just +/// dropping it. +/// +/// This returns `Some` if this was the last reference. +pub(crate) fn free_arc_into_inner(ptr: *const C) -> Option +where + C: Castable, +{ + if ptr.is_null() { + return None; + } + let rs_typed = cast_const_ptr(ptr); + Arc::into_inner(unsafe { Arc::from_raw(rs_typed) }) +} + /// Convert a mutable pointer to a [`Castable`] to an optional `Box` over the underlying /// [`Castable::RustType`], and immediately let it fall out of scope to be freed. /// diff --git a/rustls-libssl/src/lib.rs b/rustls-libssl/src/lib.rs index eb62060..9307666 100644 --- a/rustls-libssl/src/lib.rs +++ b/rustls-libssl/src/lib.rs @@ -1,5 +1,5 @@ use core::ffi::{c_char, c_int, c_uint, c_void, CStr}; -use core::{mem, ptr}; +use core::{borrow, cmp, fmt, mem, ptr}; use std::ffi::CString; use std::fs; use std::io::{ErrorKind, Read, Write}; @@ -11,6 +11,7 @@ use openssl_sys::{ EVP_PKEY, SSL_ERROR_NONE, SSL_ERROR_SSL, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, X509, X509_STORE, X509_V_ERR_UNSPECIFIED, }; +use rustls::client::Resumption; use rustls::crypto::aws_lc_rs as provider; use rustls::pki_types::{CertificateDer, ServerName}; use rustls::server::{Accepted, Acceptor}; @@ -22,6 +23,7 @@ use rustls::{ use not_thread_safe::NotThreadSafe; mod bio; +mod cache; mod callbacks; #[macro_use] mod constants; @@ -214,10 +216,174 @@ static TLS13_CHACHA20_POLY1305_SHA256: SslCipher = SslCipher { description: c"TLS_CHACHA20_POLY1305_SHA256 TLSv1.3 Kx=any Au=any Enc=CHACHA20/POLY1305(256) Mac=AEAD\n", }; +/// Backs a server-side SSL_SESSION object +/// +/// Note that this has equality and ordering entirely based on the `id` field. +pub struct SslSession { + id: SslSessionLookup, + value: Vec, + context: Vec, + expiry_time: cache::ExpiryTime, +} + +impl SslSession { + /// A magic number for the start of SslSession encodings. + /// + /// Aims to avoid confusion with other SSL_SESSION encodings (eg, from openssl). + /// We are not compatible with these. + const MAGIC: &'static [u8] = b"rustlsv1"; + + pub fn new( + id: Vec, + value: Vec, + context: Vec, + expiry_time: cache::ExpiryTime, + ) -> Self { + Self { + id: SslSessionLookup(id), + value, + context, + expiry_time, + } + } + + /// Encode this session to an opaque binary format. + /// + /// This could be DER (OpenSSL does) but currently is ad-hoc. + pub fn encode(&self) -> Vec { + let id_len = self.id.0.len().to_le_bytes(); + let value_len = self.value.len().to_le_bytes(); + let context_len = self.context.len().to_le_bytes(); + let expiry = self.expiry_time.0.to_le_bytes(); + + let mut ret = Vec::with_capacity( + SslSession::MAGIC.len() + + id_len.len() + + self.id.0.len() + + value_len.len() + + self.value.len() + + context_len.len() + + self.context.len() + + expiry.len(), + ); + ret.extend_from_slice(SslSession::MAGIC); + ret.extend_from_slice(&id_len); + ret.extend_from_slice(&self.id.0); + ret.extend_from_slice(&value_len); + ret.extend_from_slice(&self.value); + ret.extend_from_slice(&context_len); + ret.extend_from_slice(&self.context); + ret.extend_from_slice(&expiry); + ret + } + + /// Decodes from the front of `slice`. Returns the remainder. + pub fn decode(slice: &[u8]) -> Option<(Self, &[u8])> { + fn split_at(slice: &[u8], mid: usize) -> Option<(&[u8], &[u8])> { + if mid <= slice.len() { + Some(slice.split_at(mid)) + } else { + None + } + } + + fn slice_to_usize(slice: &[u8]) -> usize { + // unwrap: `slice` must be `usize_len` in length + usize::from_le_bytes(slice.try_into().unwrap()) + } + + fn slice_to_u64(slice: &[u8]) -> u64 { + // unwrap: `slice` must be `u64_len` in length + u64::from_le_bytes(slice.try_into().unwrap()) + } + + let usize_len = mem::size_of::(); + let u64_len = mem::size_of::(); + + let (magic, slice) = split_at(slice, SslSession::MAGIC.len())?; + if magic != SslSession::MAGIC { + return None; + } + let (id_len, slice) = split_at(slice, usize_len)?; + let (id, slice) = split_at(slice, slice_to_usize(id_len))?; + let (value_len, slice) = split_at(slice, usize_len)?; + let (value, slice) = split_at(slice, slice_to_usize(value_len))?; + let (context_len, slice) = split_at(slice, usize_len)?; + let (context, slice) = split_at(slice, slice_to_usize(context_len))?; + let (expiry, slice) = split_at(slice, u64_len)?; + Some(( + Self { + id: SslSessionLookup(id.to_vec()), + value: value.to_vec(), + context: context.to_vec(), + expiry_time: cache::ExpiryTime(slice_to_u64(expiry)), + }, + slice, + )) + } + + pub fn get_id(&self) -> &[u8] { + &self.id.0 + } + + pub fn expired(&self, at_time: cache::TimeBase) -> bool { + self.expiry_time.in_past(at_time) + } + + pub fn older_than(&self, other: &Self) -> bool { + self.expiry_time.0 < other.expiry_time.0 + } +} + +impl PartialOrd for SslSession { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.id.cmp(&other.id)) + } +} + +impl Ord for SslSession { + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.id.cmp(&other.id) + } +} + +impl PartialEq for SslSession { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl Eq for SslSession {} + +impl borrow::Borrow for Arc { + fn borrow(&self) -> &SslSessionLookup { + &self.id + } +} + +impl fmt::Debug for SslSession { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.debug_struct("SslSession") + .field("id", &self.id) + .field("expiry", &self.expiry_time) + .finish_non_exhaustive() + } +} + +#[derive(Debug, PartialOrd, Ord, Eq, PartialEq)] +pub struct SslSessionLookup(Vec); + +impl SslSessionLookup { + pub fn for_id(id: &[u8]) -> Self { + Self(id.to_vec()) + } +} + pub struct SslContext { method: &'static SslMethod, ex_data: ex_data::ExData, versions: EnabledVersions, + caches: cache::SessionCaches, raw_options: u64, verify_mode: VerifyMode, verify_depth: c_int, @@ -239,6 +405,7 @@ impl SslContext { method, ex_data: ex_data::ExData::default(), versions: EnabledVersions::default(), + caches: cache::SessionCaches::default(), raw_options: 0, verify_mode: VerifyMode::default(), verify_depth: -1, @@ -255,8 +422,16 @@ impl SslContext { } } - fn install_ex_data(&mut self, ex_data: ex_data::ExData) { - self.ex_data = ex_data; + fn complete_construction( + &mut self, + pointer_to_self: *mut entry::SSL_CTX, + ) -> Result<(), error::Error> { + self.caches.set_pointer_to_owning_ssl_ctx(pointer_to_self); + + self.ex_data = ex_data::ExData::new_ssl_ctx(pointer_to_self) + .ok_or_else(|| error::Error::bad_data("ex_data construction failed"))?; + + Ok(()) } fn set_ex_data(&mut self, idx: c_int, data: *mut c_void) -> Result<(), error::Error> { @@ -311,6 +486,46 @@ impl SslContext { .unwrap_or_default() } + fn get_session_cache_size(&self) -> usize { + self.caches.size() + } + + fn set_session_cache_size(&mut self, size: usize) -> usize { + self.caches.set_size(size) + } + + fn set_session_cache_mode(&mut self, mode: u32) -> u32 { + self.caches.set_mode(mode) + } + + fn set_session_id_context(&mut self, context: &[u8]) { + self.caches.set_context(context); + } + + fn set_session_new_cb(&mut self, callback: entry::SSL_CTX_new_session_cb) { + self.caches.set_new_callback(callback); + } + + fn set_session_get_cb(&mut self, callback: entry::SSL_CTX_sess_get_cb) { + self.caches.set_get_callback(callback); + } + + fn set_session_remove_cb(&mut self, callback: entry::SSL_CTX_sess_remove_cb) { + self.caches.set_remove_callback(callback); + } + + fn get_session_timeout(&self) -> u64 { + self.caches.get_timeout() + } + + fn set_session_timeout(&mut self, timeout: u64) -> u64 { + self.caches.set_timeout(timeout) + } + + fn flush_all_sessions(&mut self) { + self.caches.flush_all(); + } + fn set_max_early_data(&mut self, max: u32) { self.max_early_data = max; } @@ -489,7 +704,11 @@ enum ConnState { Client(Connection, Arc), Accepting(Acceptor), Accepted(Accepted), - Server(Connection, Arc), + Server( + Connection, + Arc, + Arc, + ), } impl Ssl { @@ -735,8 +954,6 @@ impl Ssl { None => ServerName::try_from("0.0.0.0").unwrap(), }; - let method = self.ctx.get().method; - let provider = Arc::new(provider::default_provider()); let verifier = Arc::new(verifier::ServerVerifier::new( self.verify_roots.clone().into(), @@ -745,7 +962,9 @@ impl Ssl { &self.verify_server_name, )); - let versions = self.versions.reduce_versions(method.client_versions)?; + let versions = self + .versions + .reduce_versions(self.ctx.get().method.client_versions)?; let wants_resolver = ClientConfig::builder_with_provider(provider) .with_protocol_versions(&versions) @@ -760,6 +979,7 @@ impl Ssl { }; config.alpn_protocols.clone_from(&self.alpn); + config.resumption = Resumption::store(self.ctx.get_mut().caches.get_client()); let client_conn = ClientConnection::new(Arc::new(config), sni_server_name.clone()) .map_err(error::Error::from_rustls)?; @@ -818,8 +1038,6 @@ impl Ssl { } fn init_server_conn(&mut self) -> Result<(), error::Error> { - let method = self.ctx.get().method; - let provider = Arc::new(provider::default_provider()); let verifier = Arc::new( verifier::ClientVerifier::new( @@ -835,7 +1053,9 @@ impl Ssl { .server_resolver() .ok_or_else(|| error::Error::bad_data("missing server keys"))?; - let versions = self.versions.reduce_versions(method.server_versions)?; + let versions = self + .versions + .reduce_versions(self.ctx.get().method.server_versions)?; let mut config = ServerConfig::builder_with_provider(provider) .with_protocol_versions(&versions) @@ -846,6 +1066,8 @@ impl Ssl { config.alpn_protocols = mem::take(&mut self.alpn); config.max_early_data_size = self.max_early_data; config.send_tls13_tickets = 2; // match OpenSSL default: see `man SSL_CTX_set_num_tickets` + let cache = self.ctx.get_mut().caches.get_server(); + config.session_storage = cache.clone(); let accepted = match mem::replace(&mut self.conn, ConnState::Nothing) { ConnState::Accepted(accepted) => accepted, @@ -857,27 +1079,27 @@ impl Ssl { .into_connection(Arc::new(config)) .map_err(|(err, _alert)| error::Error::from_rustls(err))?; - self.conn = ConnState::Server(server_conn.into(), verifier); + self.conn = ConnState::Server(server_conn.into(), verifier, cache); Ok(()) } fn conn(&self) -> Option<&Connection> { match &self.conn { - ConnState::Client(conn, _) | ConnState::Server(conn, _) => Some(conn), + ConnState::Client(conn, _) | ConnState::Server(conn, _, _) => Some(conn), _ => None, } } fn conn_mut(&mut self) -> Option<&mut Connection> { match &mut self.conn { - ConnState::Client(conn, _) | ConnState::Server(conn, _) => Some(conn), + ConnState::Client(conn, _) | ConnState::Server(conn, _, _) => Some(conn), _ => None, } } fn want(&self) -> Want { match &self.conn { - ConnState::Client(conn, _) | ConnState::Server(conn, _) => Want { + ConnState::Client(conn, _) | ConnState::Server(conn, _, _) => Want { read: conn.wants_read(), write: conn.wants_write(), }, @@ -934,7 +1156,7 @@ impl Ssl { }; match &mut self.conn { - ConnState::Client(conn, _) | ConnState::Server(conn, _) => { + ConnState::Client(conn, _) | ConnState::Server(conn, _, _) => { match conn.complete_io(bio) { Ok(_) => {} Err(e) => { @@ -1085,7 +1307,7 @@ impl Ssl { fn get_last_verification_result(&self) -> i64 { match &self.conn { ConnState::Client(_, verifier) => verifier.last_result(), - ConnState::Server(_, verifier) => verifier.last_result(), + ConnState::Server(_, verifier, _) => verifier.last_result(), _ => X509_V_ERR_UNSPECIFIED as i64, } } @@ -1175,6 +1397,22 @@ impl Ssl { None => false, } } + + fn get_current_session(&self) -> Option> { + match &self.conn { + ConnState::Server(_, _, cache) => cache.get_most_recent_session(), + // divergence: `SSL_get1_session` etc only work for server SSLs + _ => None, + } + } + + fn borrow_current_session(&self) -> *mut entry::SSL_SESSION { + match &self.conn { + ConnState::Server(_, _, cache) => cache.borrow_most_recent_session(), + // divergence: `SSL_get_session` etc only work for server SSLs + _ => ptr::null_mut(), + } + } } /// Encode rustls's internal representation in the wire format. diff --git a/rustls-libssl/tests/nginx.conf b/rustls-libssl/tests/nginx.conf index 4acc9ed..bbdd78e 100644 --- a/rustls-libssl/tests/nginx.conf +++ b/rustls-libssl/tests/nginx.conf @@ -10,10 +10,11 @@ http { access_log access.log; server { + # no resumption (default) listen 8443 ssl; - server_name localhost; ssl_certificate ../../../test-ca/rsa/server.cert; ssl_certificate_key ../../../test-ca/rsa/server.key; + server_name localhost; location = / { return 200 "hello world\n"; @@ -44,4 +45,81 @@ http { return 200 "s-dn:$ssl_client_s_dn\ni-dn:$ssl_client_i_dn\nserial:$ssl_client_serial\nfp:$ssl_client_fingerprint\nverify:$ssl_client_verify\nv-start:$ssl_client_v_start\nv-end:$ssl_client_v_end\nv-remain:$ssl_client_v_remain\ncert:\n$ssl_client_cert\n"; } } + + server { + # per-worker resumption + listen 8444 ssl; + ssl_session_cache builtin; + ssl_certificate ../../../test-ca/rsa/server.cert; + ssl_certificate_key ../../../test-ca/rsa/server.key; + server_name localhost; + + location = / { + return 200 "hello world\n"; + } + + location /ssl-agreed { + return 200 "protocol:$ssl_protocol,cipher:$ssl_cipher\n"; + } + + location /ssl-server-name { + return 200 "server-name:$ssl_server_name\n"; + } + + location /ssl-was-reused { + return 200 "reused:$ssl_session_reused\n"; + } + } + + server { + # per-worker & per-server resumption + listen 8445 ssl; + ssl_session_cache builtin shared:port8445:1M; + ssl_certificate ../../../test-ca/rsa/server.cert; + ssl_certificate_key ../../../test-ca/rsa/server.key; + server_name localhost; + + + location = / { + return 200 "hello world\n"; + } + + location /ssl-agreed { + return 200 "protocol:$ssl_protocol,cipher:$ssl_cipher\n"; + } + + location /ssl-server-name { + return 200 "server-name:$ssl_server_name\n"; + } + + location /ssl-was-reused { + return 200 "reused:$ssl_session_reused\n"; + } + + } + + server { + # per-server resumption + listen 8446 ssl; + ssl_session_cache shared:port8446:1M; + ssl_certificate ../../../test-ca/rsa/server.cert; + ssl_certificate_key ../../../test-ca/rsa/server.key; + server_name localhost; + + location = / { + return 200 "hello world\n"; + } + + location /ssl-agreed { + return 200 "protocol:$ssl_protocol,cipher:$ssl_cipher\n"; + } + + location /ssl-server-name { + return 200 "server-name:$ssl_server_name\n"; + } + + location /ssl-was-reused { + return 200 "reused:$ssl_session_reused\n"; + } + } } diff --git a/rustls-libssl/tests/runner.rs b/rustls-libssl/tests/runner.rs index 03e6075..b3bd610 100644 --- a/rustls-libssl/tests/runner.rs +++ b/rustls-libssl/tests/runner.rs @@ -371,6 +371,39 @@ fn nginx() { b"hello world\n" ); + for (port, reused) in [(8443, '.'), (8444, 'r'), (8445, 'r'), (8446, 'r')] { + // multiple requests without http connection reuse + // (second should be a TLS resumption if possible) + assert_eq!( + Command::new("curl") + .env("LD_LIBRARY_PATH", "") + .args([ + "--verbose", + "--cacert", + "test-ca/rsa/ca.cert", + "-H", + "connection: close", + &format!("https://localhost:{port}/"), + &format!("https://localhost:{port}/ssl-agreed"), + &format!("https://localhost:{port}/ssl-server-name"), + &format!("https://localhost:{port}/ssl-was-reused") + ]) + .stdout(Stdio::piped()) + .output() + .map(print_output) + .unwrap() + .stdout, + format!( + "hello world\n\ + protocol:TLSv1.3,cipher:TLS_AES_256_GCM_SHA384\n\ + server-name:localhost\n\ + reused:{reused}\n" + ) + .as_bytes(), + ); + println!("PASS: resumption test for port={port} reused={reused}"); + } + // big download (throttled by curl to ensure non-blocking writes work) assert_eq!( Command::new("curl") diff --git a/rustls-libssl/tests/server.c b/rustls-libssl/tests/server.c index d329960..747a18e 100644 --- a/rustls-libssl/tests/server.c +++ b/rustls-libssl/tests/server.c @@ -69,6 +69,32 @@ static int sni_callback(SSL *ssl, int *al, void *arg) { return SSL_TLSEXT_ERR_OK; } +static int sess_new_callback(SSL *ssl, SSL_SESSION *sess) { + printf("in sess_new_callback\n"); + assert(ssl != NULL); + assert(sess != NULL); + unsigned id_len = 0; + SSL_SESSION_get_id(sess, &id_len); + printf(" SSL_SESSION_get_id len=%u\n", id_len); + return 0; +} + +static SSL_SESSION *sess_get_callback(SSL *ssl, const uint8_t *id, int id_len, + int *copy) { + (void)id; + printf("in sess_get_callback\n"); + assert(ssl != NULL); + printf(" id_len=%d\n", id_len); + *copy = 0; + return NULL; +} + +static void sess_remove_callback(SSL_CTX *ctx, SSL_SESSION *sess) { + printf("in sess_remove_callback\n"); + assert(ctx != NULL); + assert(sess != NULL); +} + int main(int argc, char **argv) { if (argc != 5) { printf("%s |unauth\n\n", @@ -123,6 +149,13 @@ int main(int argc, char **argv) { SSL_CTX_set_tlsext_servername_arg(ctx, &sni_cookie); dump_openssl_error_stack(); + SSL_CTX_sess_set_new_cb(ctx, sess_new_callback); + SSL_CTX_sess_set_get_cb(ctx, sess_get_callback); + SSL_CTX_sess_set_remove_cb(ctx, sess_remove_callback); + TRACE(SSL_CTX_sess_set_cache_size(ctx, 10)); + TRACE(SSL_CTX_sess_get_cache_size(ctx)); + TRACE(SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_SERVER)); + X509 *server_cert = NULL; EVP_PKEY *server_key = NULL; TRACE(SSL_CTX_use_certificate_chain_file(ctx, certfile));