Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 122 additions & 14 deletions src/messenger.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
use std::{
collections::HashMap,
future::Future,
io::Cursor,
ops::DerefMut,
sync::{
atomic::{AtomicI32, Ordering},
Arc, RwLock,
},
task::Poll,
};

use futures::future::BoxFuture;
use parking_lot::Mutex;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
sync::{
oneshot::{channel, Sender},
Mutex,
Mutex as AsyncMutex,
},
task::JoinHandle,
};
Expand Down Expand Up @@ -57,7 +61,7 @@ enum MessengerState {
}

impl MessengerState {
async fn poison(&mut self, err: RequestError) -> Arc<RequestError> {
fn poison(&mut self, err: RequestError) -> Arc<RequestError> {
match self {
Self::RequestMap(map) => {
let err = Arc::new(err);
Expand Down Expand Up @@ -91,7 +95,7 @@ pub struct Messenger<RW> {
/// The half of the stream that we use to send data TO the broker.
///
/// This will be used by [`request`](Self::request) to queue up messages.
stream_write: Arc<Mutex<WriteHalf<RW>>>,
stream_write: Arc<AsyncMutex<WriteHalf<RW>>>,

/// The next correlation ID.
///
Expand Down Expand Up @@ -195,7 +199,7 @@ where
}
};

let active_request = match state_captured.lock().await.deref_mut() {
let active_request = match state_captured.lock().deref_mut() {
MessengerState::RequestMap(map) => {
if let Some(active_request) = map.remove(&header.correlation_id.0) {
active_request
Expand Down Expand Up @@ -240,17 +244,15 @@ where
Err(e) => {
state_captured
.lock()
.await
.poison(RequestError::ReadFramedMessageError(e))
.await;
.poison(RequestError::ReadFramedMessageError(e));
return;
}
}
}
});

Self {
stream_write: Arc::new(Mutex::new(stream_write)),
stream_write: Arc::new(AsyncMutex::new(stream_write)),
correlation_id: AtomicI32::new(0),
version_ranges: RwLock::new(HashMap::new()),
state,
Expand Down Expand Up @@ -315,7 +317,12 @@ where

let (tx, rx) = channel();

match self.state.lock().await.deref_mut() {
// to prevent stale data in inner state, ensure that we would remove the request again if we are cancelled while
// sending the request
let cleanup_on_cancel =
CleanupRequestStateOnCancel::new(Arc::clone(&self.state), correlation_id);

match self.state.lock().deref_mut() {
MessengerState::RequestMap(map) => {
map.insert(
correlation_id,
Expand All @@ -331,6 +338,7 @@ where
}

self.send_message(buf).await?;
cleanup_on_cancel.message_sent();

let mut response = rx.await.expect("Who closed this channel?!")?;
let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)?;
Expand All @@ -355,23 +363,23 @@ where
Ok(()) => Ok(()),
Err(e) => {
// need to poison the stream because message framing might be out-of-sync
let mut state = self.state.lock().await;
Err(RequestError::Poisoned(state.poison(e).await))
let mut state = self.state.lock();
Err(RequestError::Poisoned(state.poison(e)))
}
}
}

async fn send_message_inner(&self, msg: Vec<u8>) -> Result<(), RequestError> {
let mut stream_write = Arc::clone(&self.stream_write).lock_owned().await;

// use a task so that cancelation doesn't cancel the send operation and leaves half-send messages on the wire
let handle = tokio::spawn(async move {
// use a wrapper so that cancelation doesn't cancel the send operation and leaves half-send messages on the wire
let fut = CancellationSafeFuture::new(async move {
stream_write.write_message(&msg).await?;
stream_write.flush().await?;
Ok(())
});

handle.await.expect("background task died")
fut.await
}

pub async fn sync_versions(&self) -> Result<(), SyncVersionsError> {
Expand Down Expand Up @@ -495,6 +503,106 @@ fn match_versions(range_a: ApiVersionRange, range_b: ApiVersionRange) -> Option<
}
}

/// Helper that ensures that a request is removed when a request is cancelled before it was actually sent out.
struct CleanupRequestStateOnCancel {
state: Arc<Mutex<MessengerState>>,
correlation_id: i32,
message_sent: bool,
}

impl CleanupRequestStateOnCancel {
/// Create new helper.
///
/// You must call [`message_sent`](Self::message_sent) when the request was sent.
fn new(state: Arc<Mutex<MessengerState>>, correlation_id: i32) -> Self {
Self {
state,
correlation_id,
message_sent: false,
}
}

/// Request was sent. Do NOT clean the state any longer.
fn message_sent(mut self) {
self.message_sent = true;
}
}

impl Drop for CleanupRequestStateOnCancel {
fn drop(&mut self) {
if !self.message_sent {
if let MessengerState::RequestMap(map) = self.state.lock().deref_mut() {
map.remove(&self.correlation_id);
}
}
}
}

/// Wrapper around a future that cannot be cancelled.
///
/// When the future is dropped/cancelled, we'll spawn a tokio task to _rescue_ it.
struct CancellationSafeFuture<F>
where
F: Future + Send + 'static,
{
/// Mark if the inner future finished. If not, we must spawn a helper task on drop.
done: bool,

/// Inner future.
///
/// Wrapped in an `Option` so we can extract it during drop. Inside that option however we also need a pinned
/// box because once this wrapper is polled, it will be pinned in memory -- even during drop. Now the inner
/// future does not necessarily implement `Unpin`, so we need a heap allocation to pin it in memory even when we
/// move it out of this option.
inner: Option<BoxFuture<'static, F::Output>>,
}

impl<F> Drop for CancellationSafeFuture<F>
where
F: Future + Send + 'static,
{
fn drop(&mut self) {
if !self.done {
let inner = self.inner.take().expect("Double-drop?");
tokio::task::spawn(async move {
inner.await;
});
}
}
}

impl<F> CancellationSafeFuture<F>
where
F: Future + Send,
{
fn new(fut: F) -> Self {
Self {
done: false,
inner: Some(Box::pin(fut)),
}
}
}

impl<F> Future for CancellationSafeFuture<F>
where
F: Future + Send,
{
type Output = F::Output;

fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
match self.inner.as_mut().expect("no dropped").as_mut().poll(cx) {
Poll::Ready(res) => {
self.done = true;
Poll::Ready(res)
}
Poll::Pending => Poll::Pending,
}
}
}

#[cfg(test)]
mod tests {
use std::{ops::Deref, time::Duration};
Expand Down