1
+ use std:: { fmt:: Debug , sync:: Arc } ;
2
+
3
+ use futures:: future:: BoxFuture ;
4
+ use rsasl:: {
5
+ callback:: SessionCallback ,
6
+ config:: SASLConfig ,
7
+ property:: { AuthzId , OAuthBearerKV , OAuthBearerToken } ,
8
+ } ;
9
+
10
+ use crate :: messenger:: SaslError ;
11
+
1
12
#[ derive( Debug , Clone ) ]
2
13
pub enum SaslConfig {
3
14
/// SASL - PLAIN
@@ -15,6 +26,11 @@ pub enum SaslConfig {
15
26
/// # References
16
27
/// - <https://datatracker.ietf.org/doc/html/draft-melnikov-scram-sha-512-04>
17
28
ScramSha512 ( Credentials ) ,
29
+ /// SASL - OAUTHBEARER
30
+ ///
31
+ /// # References
32
+ /// - <https://datatracker.ietf.org/doc/html/rfc7628>
33
+ Oauthbearer ( OauthBearerCredentials ) ,
18
34
}
19
35
20
36
#[ derive( Debug , Clone ) ]
@@ -30,19 +46,104 @@ impl Credentials {
30
46
}
31
47
32
48
impl SaslConfig {
33
- pub ( crate ) fn credentials ( & self ) -> Credentials {
49
+ pub ( crate ) async fn get_sasl_config ( & self ) -> Result < Arc < SASLConfig > , SaslError > {
34
50
match self {
35
- Self :: Plain ( credentials) => credentials. clone ( ) ,
36
- Self :: ScramSha256 ( credentials) => credentials. clone ( ) ,
37
- Self :: ScramSha512 ( credentials) => credentials. clone ( ) ,
51
+ Self :: Plain ( credentials)
52
+ | Self :: ScramSha256 ( credentials)
53
+ | Self :: ScramSha512 ( credentials) => Ok ( SASLConfig :: with_credentials (
54
+ None ,
55
+ credentials. username . clone ( ) ,
56
+ credentials. password . clone ( ) ,
57
+ ) ?) ,
58
+ Self :: Oauthbearer ( credentials) => {
59
+ // Fetch the token first, since that's an async call.
60
+ let token = ( * credentials. callback ) ( )
61
+ . await
62
+ . map_err ( SaslError :: Callback ) ?;
63
+
64
+ struct OauthProvider {
65
+ authz_id : Option < String > ,
66
+ bearer_kvs : Vec < ( String , String ) > ,
67
+ token : String ,
68
+ }
69
+
70
+ // Define a callback that is called while stepping through the SASL client
71
+ // to provide necessary data for oauth.
72
+ // Since this callback is synchronous, we fetch the token first. Generally
73
+ // speaking the SASL process should not take long enough for the token to
74
+ // expire, but we do need to check for token expiry each time we authenticate.
75
+ impl SessionCallback for OauthProvider {
76
+ fn callback (
77
+ & self ,
78
+ _session_data : & rsasl:: callback:: SessionData ,
79
+ _context : & rsasl:: callback:: Context < ' _ > ,
80
+ request : & mut rsasl:: callback:: Request < ' _ > ,
81
+ ) -> Result < ( ) , rsasl:: prelude:: SessionError > {
82
+ request
83
+ . satisfy :: < OAuthBearerKV > (
84
+ & self
85
+ . bearer_kvs
86
+ . iter ( )
87
+ . map ( |( k, v) | ( k. as_str ( ) , v. as_str ( ) ) )
88
+ . collect :: < Vec < _ > > ( ) ,
89
+ ) ?
90
+ . satisfy :: < OAuthBearerToken > ( & self . token ) ?;
91
+ if let Some ( authz_id) = & self . authz_id {
92
+ request. satisfy :: < AuthzId > ( authz_id) ?;
93
+ }
94
+ Ok ( ( ) )
95
+ }
96
+ }
97
+
98
+ Ok ( SASLConfig :: builder ( )
99
+ . with_default_mechanisms ( )
100
+ . with_callback ( OauthProvider {
101
+ authz_id : credentials. authz_id . clone ( ) ,
102
+ bearer_kvs : credentials. bearer_kvs . clone ( ) ,
103
+ token,
104
+ } ) ?)
105
+ }
38
106
}
39
107
}
40
108
41
109
pub ( crate ) fn mechanism ( & self ) -> & str {
110
+ use rsasl:: mechanisms:: * ;
42
111
match self {
43
- Self :: Plain { .. } => "PLAIN" ,
44
- Self :: ScramSha256 { .. } => "SCRAM-SHA-256" ,
45
- Self :: ScramSha512 { .. } => "SCRAM-SHA-512" ,
112
+ Self :: Plain { .. } => plain:: PLAIN . mechanism . as_str ( ) ,
113
+ Self :: ScramSha256 { .. } => scram:: SCRAM_SHA256 . mechanism . as_str ( ) ,
114
+ Self :: ScramSha512 { .. } => scram:: SCRAM_SHA512 . mechanism . as_str ( ) ,
115
+ Self :: Oauthbearer { .. } => oauthbearer:: OAUTHBEARER . mechanism . as_str ( ) ,
46
116
}
47
117
}
48
118
}
119
+
120
+ type DynError = Box < dyn std:: error:: Error + Send + Sync > ;
121
+
122
+ /// Callback for fetching an OAUTH token. This should cache tokens and only request a new token
123
+ /// when the old is close to expiring.
124
+ pub type OauthCallback =
125
+ Arc < dyn Fn ( ) -> BoxFuture < ' static , Result < String , DynError > > + Send + Sync > ;
126
+
127
+ #[ derive( Clone ) ]
128
+ pub struct OauthBearerCredentials {
129
+ /// Callback that should return a token that is valid and will remain valid for
130
+ /// long enough to complete authentication. This should cache the token and only request
131
+ /// a new one when the old is close to expiring.
132
+ /// The token must be on [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) format.
133
+ pub callback : OauthCallback ,
134
+ /// ID of a user to impersonate. Can be left as `None` to authenticate using
135
+ /// the user for the token returned by `callback`.
136
+ pub authz_id : Option < String > ,
137
+ /// Custom key-value pairs sent as part of the SASL request. Most normal usage
138
+ /// can let this be an empty list.
139
+ pub bearer_kvs : Vec < ( String , String ) > ,
140
+ }
141
+
142
+ impl Debug for OauthBearerCredentials {
143
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
144
+ f. debug_struct ( "OauthBearerCredentials" )
145
+ . field ( "authz_id" , & self . authz_id )
146
+ . field ( "bearer_kvs" , & self . bearer_kvs )
147
+ . finish_non_exhaustive ( )
148
+ }
149
+ }
0 commit comments