Skip to content

Commit 1ca7c47

Browse files
Merge pull request #156 from influxdata/crepererum/speed_up_io
refactor: speed up IO
2 parents eb79c45 + f61059b commit 1ca7c47

File tree

1 file changed

+122
-14
lines changed

1 file changed

+122
-14
lines changed

src/messenger.rs

Lines changed: 122 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
use std::{
22
collections::HashMap,
3+
future::Future,
34
io::Cursor,
45
ops::DerefMut,
56
sync::{
67
atomic::{AtomicI32, Ordering},
78
Arc, RwLock,
89
},
10+
task::Poll,
911
};
1012

13+
use futures::future::BoxFuture;
14+
use parking_lot::Mutex;
1115
use thiserror::Error;
1216
use tokio::{
1317
io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
1418
sync::{
1519
oneshot::{channel, Sender},
16-
Mutex,
20+
Mutex as AsyncMutex,
1721
},
1822
task::JoinHandle,
1923
};
@@ -57,7 +61,7 @@ enum MessengerState {
5761
}
5862

5963
impl MessengerState {
60-
async fn poison(&mut self, err: RequestError) -> Arc<RequestError> {
64+
fn poison(&mut self, err: RequestError) -> Arc<RequestError> {
6165
match self {
6266
Self::RequestMap(map) => {
6367
let err = Arc::new(err);
@@ -91,7 +95,7 @@ pub struct Messenger<RW> {
9195
/// The half of the stream that we use to send data TO the broker.
9296
///
9397
/// This will be used by [`request`](Self::request) to queue up messages.
94-
stream_write: Arc<Mutex<WriteHalf<RW>>>,
98+
stream_write: Arc<AsyncMutex<WriteHalf<RW>>>,
9599

96100
/// The next correlation ID.
97101
///
@@ -195,7 +199,7 @@ where
195199
}
196200
};
197201

198-
let active_request = match state_captured.lock().await.deref_mut() {
202+
let active_request = match state_captured.lock().deref_mut() {
199203
MessengerState::RequestMap(map) => {
200204
if let Some(active_request) = map.remove(&header.correlation_id.0) {
201205
active_request
@@ -240,17 +244,15 @@ where
240244
Err(e) => {
241245
state_captured
242246
.lock()
243-
.await
244-
.poison(RequestError::ReadFramedMessageError(e))
245-
.await;
247+
.poison(RequestError::ReadFramedMessageError(e));
246248
return;
247249
}
248250
}
249251
}
250252
});
251253

252254
Self {
253-
stream_write: Arc::new(Mutex::new(stream_write)),
255+
stream_write: Arc::new(AsyncMutex::new(stream_write)),
254256
correlation_id: AtomicI32::new(0),
255257
version_ranges: RwLock::new(HashMap::new()),
256258
state,
@@ -315,7 +317,12 @@ where
315317

316318
let (tx, rx) = channel();
317319

318-
match self.state.lock().await.deref_mut() {
320+
// to prevent stale data in inner state, ensure that we would remove the request again if we are cancelled while
321+
// sending the request
322+
let cleanup_on_cancel =
323+
CleanupRequestStateOnCancel::new(Arc::clone(&self.state), correlation_id);
324+
325+
match self.state.lock().deref_mut() {
319326
MessengerState::RequestMap(map) => {
320327
map.insert(
321328
correlation_id,
@@ -331,6 +338,7 @@ where
331338
}
332339

333340
self.send_message(buf).await?;
341+
cleanup_on_cancel.message_sent();
334342

335343
let mut response = rx.await.expect("Who closed this channel?!")?;
336344
let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)?;
@@ -355,23 +363,23 @@ where
355363
Ok(()) => Ok(()),
356364
Err(e) => {
357365
// need to poison the stream because message framing might be out-of-sync
358-
let mut state = self.state.lock().await;
359-
Err(RequestError::Poisoned(state.poison(e).await))
366+
let mut state = self.state.lock();
367+
Err(RequestError::Poisoned(state.poison(e)))
360368
}
361369
}
362370
}
363371

364372
async fn send_message_inner(&self, msg: Vec<u8>) -> Result<(), RequestError> {
365373
let mut stream_write = Arc::clone(&self.stream_write).lock_owned().await;
366374

367-
// use a task so that cancelation doesn't cancel the send operation and leaves half-send messages on the wire
368-
let handle = tokio::spawn(async move {
375+
// use a wrapper so that cancelation doesn't cancel the send operation and leaves half-send messages on the wire
376+
let fut = CancellationSafeFuture::new(async move {
369377
stream_write.write_message(&msg).await?;
370378
stream_write.flush().await?;
371379
Ok(())
372380
});
373381

374-
handle.await.expect("background task died")
382+
fut.await
375383
}
376384

377385
pub async fn sync_versions(&self) -> Result<(), SyncVersionsError> {
@@ -495,6 +503,106 @@ fn match_versions(range_a: ApiVersionRange, range_b: ApiVersionRange) -> Option<
495503
}
496504
}
497505

506+
/// Helper that ensures that a request is removed when a request is cancelled before it was actually sent out.
507+
struct CleanupRequestStateOnCancel {
508+
state: Arc<Mutex<MessengerState>>,
509+
correlation_id: i32,
510+
message_sent: bool,
511+
}
512+
513+
impl CleanupRequestStateOnCancel {
514+
/// Create new helper.
515+
///
516+
/// You must call [`message_sent`](Self::message_sent) when the request was sent.
517+
fn new(state: Arc<Mutex<MessengerState>>, correlation_id: i32) -> Self {
518+
Self {
519+
state,
520+
correlation_id,
521+
message_sent: false,
522+
}
523+
}
524+
525+
/// Request was sent. Do NOT clean the state any longer.
526+
fn message_sent(mut self) {
527+
self.message_sent = true;
528+
}
529+
}
530+
531+
impl Drop for CleanupRequestStateOnCancel {
532+
fn drop(&mut self) {
533+
if !self.message_sent {
534+
if let MessengerState::RequestMap(map) = self.state.lock().deref_mut() {
535+
map.remove(&self.correlation_id);
536+
}
537+
}
538+
}
539+
}
540+
541+
/// Wrapper around a future that cannot be cancelled.
542+
///
543+
/// When the future is dropped/cancelled, we'll spawn a tokio task to _rescue_ it.
544+
struct CancellationSafeFuture<F>
545+
where
546+
F: Future + Send + 'static,
547+
{
548+
/// Mark if the inner future finished. If not, we must spawn a helper task on drop.
549+
done: bool,
550+
551+
/// Inner future.
552+
///
553+
/// Wrapped in an `Option` so we can extract it during drop. Inside that option however we also need a pinned
554+
/// box because once this wrapper is polled, it will be pinned in memory -- even during drop. Now the inner
555+
/// future does not necessarily implement `Unpin`, so we need a heap allocation to pin it in memory even when we
556+
/// move it out of this option.
557+
inner: Option<BoxFuture<'static, F::Output>>,
558+
}
559+
560+
impl<F> Drop for CancellationSafeFuture<F>
561+
where
562+
F: Future + Send + 'static,
563+
{
564+
fn drop(&mut self) {
565+
if !self.done {
566+
let inner = self.inner.take().expect("Double-drop?");
567+
tokio::task::spawn(async move {
568+
inner.await;
569+
});
570+
}
571+
}
572+
}
573+
574+
impl<F> CancellationSafeFuture<F>
575+
where
576+
F: Future + Send,
577+
{
578+
fn new(fut: F) -> Self {
579+
Self {
580+
done: false,
581+
inner: Some(Box::pin(fut)),
582+
}
583+
}
584+
}
585+
586+
impl<F> Future for CancellationSafeFuture<F>
587+
where
588+
F: Future + Send,
589+
{
590+
type Output = F::Output;
591+
592+
fn poll(
593+
mut self: std::pin::Pin<&mut Self>,
594+
cx: &mut std::task::Context<'_>,
595+
) -> Poll<Self::Output> {
596+
match self.inner.as_mut().expect("no dropped").as_mut().poll(cx) {
597+
Poll::Ready(res) => {
598+
self.done = true;
599+
Poll::Ready(res)
600+
}
601+
Poll::Pending => Poll::Pending,
602+
}
603+
}
604+
}
605+
498606
#[cfg(test)]
499607
mod tests {
500608
use std::{ops::Deref, time::Duration};

0 commit comments

Comments
 (0)