@@ -12,6 +12,7 @@ use std::{
12
12
13
13
use futures:: future:: BoxFuture ;
14
14
use parking_lot:: Mutex ;
15
+ use rsasl:: { config:: SASLConfig , mechname:: MechanismNameError , prelude:: Mechname } ;
15
16
use thiserror:: Error ;
16
17
use tokio:: {
17
18
io:: { AsyncRead , AsyncWrite , AsyncWriteExt , WriteHalf } ,
@@ -23,8 +24,6 @@ use tokio::{
23
24
} ;
24
25
use tracing:: { debug, info, warn} ;
25
26
26
- use crate :: protocol:: { api_version:: ApiVersionRange , primitives:: CompactString } ;
27
- use crate :: protocol:: { messages:: ApiVersionsRequest , traits:: ReadType } ;
28
27
use crate :: {
29
28
backoff:: ErrorOrThrottle ,
30
29
protocol:: {
@@ -34,12 +33,21 @@ use crate::{
34
33
frame:: { AsyncMessageRead , AsyncMessageWrite } ,
35
34
messages:: {
36
35
ReadVersionedError , ReadVersionedType , RequestBody , RequestHeader , ResponseHeader ,
37
- SaslAuthenticateRequest , SaslHandshakeRequest , WriteVersionedError , WriteVersionedType ,
36
+ SaslAuthenticateRequest , SaslAuthenticateResponse , SaslHandshakeRequest ,
37
+ SaslHandshakeResponse , WriteVersionedError , WriteVersionedType ,
38
38
} ,
39
39
primitives:: { Int16 , Int32 , NullableString , TaggedFields } ,
40
40
} ,
41
41
throttle:: maybe_throttle,
42
42
} ;
43
+ use crate :: {
44
+ client:: SaslConfig ,
45
+ protocol:: { api_version:: ApiVersionRange , primitives:: CompactString } ,
46
+ } ;
47
+ use crate :: {
48
+ connection:: Credentials ,
49
+ protocol:: { messages:: ApiVersionsRequest , traits:: ReadType } ,
50
+ } ;
43
51
44
52
#[ derive( Debug ) ]
45
53
struct Response {
@@ -186,6 +194,9 @@ pub enum SaslError {
186
194
187
195
#[ error( "API error: {0}" ) ]
188
196
ApiError ( #[ from] ApiError ) ,
197
+
198
+ #[ error( "Invalid sasl mechanism: {0}" ) ]
199
+ InvalidSaslMechanism ( #[ from] MechanismNameError ) ,
189
200
}
190
201
191
202
impl < RW > Messenger < RW >
@@ -531,16 +542,10 @@ where
531
542
Err ( SyncVersionsError :: NoWorkingVersion )
532
543
}
533
544
534
- pub async fn sasl_handshake (
545
+ async fn sasl_authentication (
535
546
& self ,
536
- mechanism : & str ,
537
547
auth_bytes : Vec < u8 > ,
538
- ) -> Result < ( ) , SaslError > {
539
- let req = SaslHandshakeRequest :: new ( mechanism) ;
540
- let resp = self . request ( req) . await ?;
541
- if let Some ( err) = resp. error_code {
542
- return Err ( SaslError :: ApiError ( err) ) ;
543
- }
548
+ ) -> Result < SaslAuthenticateResponse , SaslError > {
544
549
let req = SaslAuthenticateRequest :: new ( auth_bytes) ;
545
550
let resp = self . request ( req) . await ?;
546
551
if let Some ( err) = resp. error_code {
@@ -549,6 +554,59 @@ where
549
554
}
550
555
return Err ( SaslError :: ApiError ( err) ) ;
551
556
}
557
+
558
+ Ok ( resp)
559
+ }
560
+
561
+ async fn sasl_handshake ( & self , mechanism : & str ) -> Result < SaslHandshakeResponse , SaslError > {
562
+ let req = SaslHandshakeRequest :: new ( mechanism) ;
563
+ let resp = self . request ( req) . await ?;
564
+ if let Some ( err) = resp. error_code {
565
+ return Err ( SaslError :: ApiError ( err) ) ;
566
+ }
567
+ Ok ( resp)
568
+ }
569
+
570
+ pub async fn do_sasl ( & self , config : SaslConfig ) -> Result < ( ) , SaslError > {
571
+ let mechanism = config. mechanism ( ) ;
572
+ let resp = self . sasl_handshake ( mechanism) . await ?;
573
+
574
+ let Credentials { username, password } = config. credentials ( ) ;
575
+ let config = SASLConfig :: with_credentials ( None , username, password) . unwrap ( ) ;
576
+ let sasl = rsasl:: prelude:: SASLClient :: new ( config) ;
577
+ let raw_mechanisms = resp. mechanisms . 0 . unwrap_or_default ( ) ;
578
+ let mechanisms = raw_mechanisms
579
+ . iter ( )
580
+ . map ( |mech| {
581
+ debug ! ( "{:?}" , mech) ;
582
+ Mechname :: parse ( mech. 0 . as_bytes ( ) ) . map_err ( SaslError :: InvalidSaslMechanism )
583
+ } )
584
+ . collect :: < Result < Vec < _ > , SaslError > > ( ) ?;
585
+ debug ! ( "Supported mechanisms {:?}" , mechanisms) ;
586
+ let mut session = sasl. start_suggested ( & mechanisms) . unwrap ( ) ;
587
+ let selected_mechanism = session. get_mechname ( ) ;
588
+ debug ! ( "Using {:?} for the SASL Mechanism" , selected_mechanism) ;
589
+ let mut data: Option < Vec < u8 > > = None ;
590
+
591
+ // Stepping the authentication exchange to completion.
592
+ while {
593
+ let mut out = Cursor :: new ( Vec :: new ( ) ) ;
594
+ // The each call to step writes the generated auth data into the provided writer.
595
+ // Normally this data would then have to be sent to the other party, but this goes
596
+ // beyond the scope of this example.
597
+ let state = session
598
+ . step ( data. as_deref ( ) , & mut out)
599
+ . expect ( "step errored!" ) ;
600
+
601
+ data = Some ( out. into_inner ( ) ) ;
602
+
603
+ // Returns `true` if step needs to be called again with another batch of data.
604
+ state. is_running ( )
605
+ } {
606
+ let authentication_response = self . sasl_authentication ( data. unwrap ( ) . to_vec ( ) ) . await ?;
607
+ data = Some ( authentication_response. auth_bytes . 0 ) ;
608
+ }
609
+
552
610
Ok ( ( ) )
553
611
}
554
612
}
0 commit comments