@@ -12,6 +12,7 @@ use std::{
12
12
13
13
use futures:: future:: BoxFuture ;
14
14
use parking_lot:: Mutex ;
15
+ use rsasl:: { config:: SASLConfig , mechname:: MechanismNameError , prelude:: Mechname } ;
15
16
use thiserror:: Error ;
16
17
use tokio:: {
17
18
io:: { AsyncRead , AsyncWrite , AsyncWriteExt , WriteHalf } ,
@@ -23,8 +24,6 @@ use tokio::{
23
24
} ;
24
25
use tracing:: { debug, info, warn} ;
25
26
26
- use crate :: protocol:: { api_version:: ApiVersionRange , primitives:: CompactString } ;
27
- use crate :: protocol:: { messages:: ApiVersionsRequest , traits:: ReadType } ;
28
27
use crate :: {
29
28
backoff:: ErrorOrThrottle ,
30
29
protocol:: {
@@ -34,12 +33,21 @@ use crate::{
34
33
frame:: { AsyncMessageRead , AsyncMessageWrite } ,
35
34
messages:: {
36
35
ReadVersionedError , ReadVersionedType , RequestBody , RequestHeader , ResponseHeader ,
37
- SaslAuthenticateRequest , SaslHandshakeRequest , WriteVersionedError , WriteVersionedType ,
36
+ SaslAuthenticateRequest , SaslAuthenticateResponse , SaslHandshakeRequest ,
37
+ SaslHandshakeResponse , WriteVersionedError , WriteVersionedType ,
38
38
} ,
39
39
primitives:: { Int16 , Int32 , NullableString , TaggedFields } ,
40
40
} ,
41
41
throttle:: maybe_throttle,
42
42
} ;
43
+ use crate :: {
44
+ client:: SaslConfig ,
45
+ protocol:: { api_version:: ApiVersionRange , primitives:: CompactString } ,
46
+ } ;
47
+ use crate :: {
48
+ connection:: Credentials ,
49
+ protocol:: { messages:: ApiVersionsRequest , traits:: ReadType } ,
50
+ } ;
43
51
44
52
#[ derive( Debug ) ]
45
53
struct Response {
@@ -186,6 +194,12 @@ pub enum SaslError {
186
194
187
195
#[ error( "API error: {0}" ) ]
188
196
ApiError ( #[ from] ApiError ) ,
197
+
198
+ #[ error( "Invalid sasl mechanism: {0}" ) ]
199
+ InvalidSaslMechanism ( #[ from] MechanismNameError ) ,
200
+
201
+ #[ error( "unsupported sasl mechanism" ) ]
202
+ UnsupportedSaslMechanism ,
189
203
}
190
204
191
205
impl < RW > Messenger < RW >
@@ -531,16 +545,10 @@ where
531
545
Err ( SyncVersionsError :: NoWorkingVersion )
532
546
}
533
547
534
- pub async fn sasl_handshake (
548
+ async fn sasl_authentication (
535
549
& self ,
536
- mechanism : & str ,
537
550
auth_bytes : Vec < u8 > ,
538
- ) -> Result < ( ) , SaslError > {
539
- let req = SaslHandshakeRequest :: new ( mechanism) ;
540
- let resp = self . request ( req) . await ?;
541
- if let Some ( err) = resp. error_code {
542
- return Err ( SaslError :: ApiError ( err) ) ;
543
- }
551
+ ) -> Result < SaslAuthenticateResponse , SaslError > {
544
552
let req = SaslAuthenticateRequest :: new ( auth_bytes) ;
545
553
let resp = self . request ( req) . await ?;
546
554
if let Some ( err) = resp. error_code {
@@ -549,6 +557,63 @@ where
549
557
}
550
558
return Err ( SaslError :: ApiError ( err) ) ;
551
559
}
560
+
561
+ Ok ( resp)
562
+ }
563
+
564
+ async fn sasl_handshake ( & self , mechanism : & str ) -> Result < SaslHandshakeResponse , SaslError > {
565
+ let req = SaslHandshakeRequest :: new ( mechanism) ;
566
+ let resp = self . request ( req) . await ?;
567
+ if let Some ( err) = resp. error_code {
568
+ return Err ( SaslError :: ApiError ( err) ) ;
569
+ }
570
+ Ok ( resp)
571
+ }
572
+
573
+ pub async fn do_sasl ( & self , config : SaslConfig ) -> Result < ( ) , SaslError > {
574
+ let mechanism = config. mechanism ( ) ;
575
+ let resp = self . sasl_handshake ( mechanism) . await ?;
576
+
577
+ let Credentials { username, password } = config. credentials ( ) ;
578
+ let config = SASLConfig :: with_credentials ( None , username, password) . unwrap ( ) ;
579
+ let sasl = rsasl:: prelude:: SASLClient :: new ( config) ;
580
+ let raw_mechanisms = resp. mechanisms . 0 . unwrap_or_default ( ) ;
581
+ let mechanisms = raw_mechanisms
582
+ . iter ( )
583
+ . map ( |mech| Mechname :: parse ( mech. 0 . as_bytes ( ) ) . map_err ( SaslError :: InvalidSaslMechanism ) )
584
+ . collect :: < Result < Vec < _ > , SaslError > > ( ) ?;
585
+ debug ! ( "Supported mechanisms {:?}" , mechanisms) ;
586
+ let prefer_mechanism =
587
+ Mechname :: parse ( mechanism. as_bytes ( ) ) . map_err ( SaslError :: InvalidSaslMechanism ) ?;
588
+ if !mechanisms. contains ( & prefer_mechanism) {
589
+ return Err ( SaslError :: UnsupportedSaslMechanism ) ;
590
+ }
591
+ let mut session = sasl
592
+ . start_suggested ( & [ prefer_mechanism] )
593
+ . map_err ( |_| SaslError :: UnsupportedSaslMechanism ) ?;
594
+ // let selected_mechanism = session.get_mechname();
595
+ debug ! ( "Using {:?} for the SASL Mechanism" , mechanism) ;
596
+ let mut data: Option < Vec < u8 > > = None ;
597
+
598
+ // Stepping the authentication exchange to completion.
599
+ while {
600
+ let mut out = Cursor :: new ( Vec :: new ( ) ) ;
601
+ // The each call to step writes the generated auth data into the provided writer.
602
+ // Normally this data would then have to be sent to the other party, but this goes
603
+ // beyond the scope of this example.
604
+ let state = session
605
+ . step ( data. as_deref ( ) , & mut out)
606
+ . expect ( "step errored!" ) ;
607
+
608
+ data = Some ( out. into_inner ( ) ) ;
609
+
610
+ // Returns `true` if step needs to be called again with another batch of data.
611
+ state. is_running ( )
612
+ } {
613
+ let authentication_response = self . sasl_authentication ( data. take ( ) . unwrap ( ) ) . await ?;
614
+ data = Some ( authentication_response. auth_bytes . 0 ) ;
615
+ }
616
+
552
617
Ok ( ( ) )
553
618
}
554
619
}
0 commit comments