1
1
use std:: {
2
2
collections:: HashMap ,
3
+ future:: Future ,
3
4
io:: Cursor ,
4
5
ops:: DerefMut ,
5
6
sync:: {
6
7
atomic:: { AtomicI32 , Ordering } ,
7
8
Arc , RwLock ,
8
9
} ,
10
+ task:: Poll ,
9
11
} ;
10
12
13
+ use futures:: future:: BoxFuture ;
14
+ use parking_lot:: Mutex ;
11
15
use thiserror:: Error ;
12
16
use tokio:: {
13
17
io:: { AsyncRead , AsyncWrite , AsyncWriteExt , WriteHalf } ,
14
18
sync:: {
15
19
oneshot:: { channel, Sender } ,
16
- Mutex ,
20
+ Mutex as AsyncMutex ,
17
21
} ,
18
22
task:: JoinHandle ,
19
23
} ;
@@ -57,7 +61,7 @@ enum MessengerState {
57
61
}
58
62
59
63
impl MessengerState {
60
- async fn poison ( & mut self , err : RequestError ) -> Arc < RequestError > {
64
+ fn poison ( & mut self , err : RequestError ) -> Arc < RequestError > {
61
65
match self {
62
66
Self :: RequestMap ( map) => {
63
67
let err = Arc :: new ( err) ;
@@ -91,7 +95,7 @@ pub struct Messenger<RW> {
91
95
/// The half of the stream that we use to send data TO the broker.
92
96
///
93
97
/// This will be used by [`request`](Self::request) to queue up messages.
94
- stream_write : Arc < Mutex < WriteHalf < RW > > > ,
98
+ stream_write : Arc < AsyncMutex < WriteHalf < RW > > > ,
95
99
96
100
/// The next correlation ID.
97
101
///
@@ -195,7 +199,7 @@ where
195
199
}
196
200
} ;
197
201
198
- let active_request = match state_captured. lock ( ) . await . deref_mut ( ) {
202
+ let active_request = match state_captured. lock ( ) . deref_mut ( ) {
199
203
MessengerState :: RequestMap ( map) => {
200
204
if let Some ( active_request) = map. remove ( & header. correlation_id . 0 ) {
201
205
active_request
@@ -240,17 +244,15 @@ where
240
244
Err ( e) => {
241
245
state_captured
242
246
. lock ( )
243
- . await
244
- . poison ( RequestError :: ReadFramedMessageError ( e) )
245
- . await ;
247
+ . poison ( RequestError :: ReadFramedMessageError ( e) ) ;
246
248
return ;
247
249
}
248
250
}
249
251
}
250
252
} ) ;
251
253
252
254
Self {
253
- stream_write : Arc :: new ( Mutex :: new ( stream_write) ) ,
255
+ stream_write : Arc :: new ( AsyncMutex :: new ( stream_write) ) ,
254
256
correlation_id : AtomicI32 :: new ( 0 ) ,
255
257
version_ranges : RwLock :: new ( HashMap :: new ( ) ) ,
256
258
state,
@@ -315,7 +317,12 @@ where
315
317
316
318
let ( tx, rx) = channel ( ) ;
317
319
318
- match self . state . lock ( ) . await . deref_mut ( ) {
320
+ // to prevent stale data in inner state, ensure that we would remove the request again if we are cancelled while
321
+ // sending the request
322
+ let cleanup_on_cancel =
323
+ CleanupRequestStateOnCancel :: new ( Arc :: clone ( & self . state ) , correlation_id) ;
324
+
325
+ match self . state . lock ( ) . deref_mut ( ) {
319
326
MessengerState :: RequestMap ( map) => {
320
327
map. insert (
321
328
correlation_id,
@@ -331,6 +338,7 @@ where
331
338
}
332
339
333
340
self . send_message ( buf) . await ?;
341
+ cleanup_on_cancel. message_sent ( ) ;
334
342
335
343
let mut response = rx. await . expect ( "Who closed this channel?!" ) ?;
336
344
let body = R :: ResponseBody :: read_versioned ( & mut response. data , body_api_version) ?;
@@ -355,23 +363,23 @@ where
355
363
Ok ( ( ) ) => Ok ( ( ) ) ,
356
364
Err ( e) => {
357
365
// need to poison the stream because message framing might be out-of-sync
358
- let mut state = self . state . lock ( ) . await ;
359
- Err ( RequestError :: Poisoned ( state. poison ( e) . await ) )
366
+ let mut state = self . state . lock ( ) ;
367
+ Err ( RequestError :: Poisoned ( state. poison ( e) ) )
360
368
}
361
369
}
362
370
}
363
371
364
372
async fn send_message_inner ( & self , msg : Vec < u8 > ) -> Result < ( ) , RequestError > {
365
373
let mut stream_write = Arc :: clone ( & self . stream_write ) . lock_owned ( ) . await ;
366
374
367
- // use a task so that cancelation doesn't cancel the send operation and leaves half-send messages on the wire
368
- let handle = tokio :: spawn ( async move {
375
+ // use a wrapper so that cancelation doesn't cancel the send operation and leaves half-send messages on the wire
376
+ let fut = CancellationSafeFuture :: new ( async move {
369
377
stream_write. write_message ( & msg) . await ?;
370
378
stream_write. flush ( ) . await ?;
371
379
Ok ( ( ) )
372
380
} ) ;
373
381
374
- handle . await . expect ( "background task died" )
382
+ fut . await
375
383
}
376
384
377
385
pub async fn sync_versions ( & self ) -> Result < ( ) , SyncVersionsError > {
@@ -495,6 +503,106 @@ fn match_versions(range_a: ApiVersionRange, range_b: ApiVersionRange) -> Option<
495
503
}
496
504
}
497
505
506
+ /// Helper that ensures that a request is removed when a request is cancelled before it was actually sent out.
507
+ struct CleanupRequestStateOnCancel {
508
+ state : Arc < Mutex < MessengerState > > ,
509
+ correlation_id : i32 ,
510
+ message_sent : bool ,
511
+ }
512
+
513
+ impl CleanupRequestStateOnCancel {
514
+ /// Create new helper.
515
+ ///
516
+ /// You must call [`message_sent`](Self::message_sent) when the request was sent.
517
+ fn new ( state : Arc < Mutex < MessengerState > > , correlation_id : i32 ) -> Self {
518
+ Self {
519
+ state,
520
+ correlation_id,
521
+ message_sent : false ,
522
+ }
523
+ }
524
+
525
+ /// Request was sent. Do NOT clean the state any longer.
526
+ fn message_sent ( mut self ) {
527
+ self . message_sent = true ;
528
+ }
529
+ }
530
+
531
+ impl Drop for CleanupRequestStateOnCancel {
532
+ fn drop ( & mut self ) {
533
+ if !self . message_sent {
534
+ if let MessengerState :: RequestMap ( map) = self . state . lock ( ) . deref_mut ( ) {
535
+ map. remove ( & self . correlation_id ) ;
536
+ }
537
+ }
538
+ }
539
+ }
540
+
541
+ /// Wrapper around a future that cannot be cancelled.
542
+ ///
543
+ /// When the future is dropped/cancelled, we'll spawn a tokio task to _rescue_ it.
544
+ struct CancellationSafeFuture < F >
545
+ where
546
+ F : Future + Send + ' static ,
547
+ {
548
+ /// Mark if the inner future finished. If not, we must spawn a helper task on drop.
549
+ done : bool ,
550
+
551
+ /// Inner future.
552
+ ///
553
+ /// Wrapped in an `Option` so we can extract it during drop. Inside that option however we also need a pinned
554
+ /// box because once this wrapper is polled, it will be pinned in memory -- even during drop. Now the inner
555
+ /// future does not necessarily implement `Unpin`, so we need a heap allocation to pin it in memory even when we
556
+ /// move it out of this option.
557
+ inner : Option < BoxFuture < ' static , F :: Output > > ,
558
+ }
559
+
560
+ impl < F > Drop for CancellationSafeFuture < F >
561
+ where
562
+ F : Future + Send + ' static ,
563
+ {
564
+ fn drop ( & mut self ) {
565
+ if !self . done {
566
+ let inner = self . inner . take ( ) . expect ( "Double-drop?" ) ;
567
+ tokio:: task:: spawn ( async move {
568
+ inner. await ;
569
+ } ) ;
570
+ }
571
+ }
572
+ }
573
+
574
+ impl < F > CancellationSafeFuture < F >
575
+ where
576
+ F : Future + Send ,
577
+ {
578
+ fn new ( fut : F ) -> Self {
579
+ Self {
580
+ done : false ,
581
+ inner : Some ( Box :: pin ( fut) ) ,
582
+ }
583
+ }
584
+ }
585
+
586
+ impl < F > Future for CancellationSafeFuture < F >
587
+ where
588
+ F : Future + Send ,
589
+ {
590
+ type Output = F :: Output ;
591
+
592
+ fn poll (
593
+ mut self : std:: pin:: Pin < & mut Self > ,
594
+ cx : & mut std:: task:: Context < ' _ > ,
595
+ ) -> Poll < Self :: Output > {
596
+ match self . inner . as_mut ( ) . expect ( "no dropped" ) . as_mut ( ) . poll ( cx) {
597
+ Poll :: Ready ( res) => {
598
+ self . done = true ;
599
+ Poll :: Ready ( res)
600
+ }
601
+ Poll :: Pending => Poll :: Pending ,
602
+ }
603
+ }
604
+ }
605
+
498
606
#[ cfg( test) ]
499
607
mod tests {
500
608
use std:: { ops:: Deref , time:: Duration } ;
0 commit comments