Skip to content

Commit ef7b953

Browse files
committed
feat: support SASL SCRAM-SHA-256 and SCRAM-SHA-512
1 parent a331d09 commit ef7b953

File tree

8 files changed

+109
-32
lines changed

8 files changed

+109
-32
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ tokio = { version = "1.19", default-features = false, features = ["io-util", "ne
3636
tokio-rustls = { version = "0.26", optional = true, default-features = false, features = ["logging", "ring", "tls12"] }
3737
tracing = "0.1"
3838
zstd = { version = "0.13", optional = true }
39+
rsasl = { version = "2.0", default-features = false, features = ["config_builder", "provider", "login", "plain", "scram-sha-1", "scram-sha-2"]}
3940

4041
[dev-dependencies]
4142
assert_matches = "1.5"

src/client/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use error::{Error, Result};
2222

2323
use self::{controller::ControllerClient, partition::UnknownTopicHandling};
2424

25-
pub use crate::connection::SaslConfig;
25+
pub use crate::connection::{Credentials, SaslConfig};
2626

2727
#[derive(Debug, Error)]
2828
pub enum ProduceError {

src/connection.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::{
2020
client::metadata_cache::MetadataCache,
2121
};
2222

23+
pub use self::transport::Credentials;
2324
pub use self::transport::SaslConfig;
2425
pub use self::transport::TlsConfig;
2526

@@ -164,9 +165,7 @@ impl ConnectionHandler for BrokerRepresentation {
164165
let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id);
165166
messenger.sync_versions().await?;
166167
if let Some(sasl_config) = sasl_config {
167-
messenger
168-
.sasl_handshake(sasl_config.mechanism(), sasl_config.auth_bytes())
169-
.await?;
168+
messenger.do_sasl(sasl_config).await?;
170169
}
171170
Ok(Arc::new(messenger))
172171
}

src/connection/transport.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use tokio::net::TcpStream;
1111
use tokio_rustls::{client::TlsStream, TlsConnector};
1212

1313
mod sasl;
14-
pub use sasl::SaslConfig;
14+
pub use sasl::{Credentials, SaslConfig};
1515

1616
#[cfg(feature = "transport-tls")]
1717
pub type TlsConfig = Option<Arc<rustls::ClientConfig>>;

src/connection/transport/sasl.rs

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,45 @@ pub enum SaslConfig {
44
///
55
/// # References
66
/// - <https://datatracker.ietf.org/doc/html/rfc4616>
7-
Plain { username: String, password: String },
7+
Plain(Credentials),
8+
/// SASL - SCRAM-SHA-256
9+
///
10+
/// # References
11+
/// - <https://datatracker.ietf.org/doc/html/rfc7677>
12+
ScramSha256(Credentials),
13+
/// SASL - SCRAM-SHA-512
14+
///
15+
/// # References
16+
/// - <https://datatracker.ietf.org/doc/html/rfc5802>
17+
ScramSha512(Credentials),
18+
}
19+
20+
#[derive(Debug, Clone)]
21+
pub struct Credentials {
22+
pub username: String,
23+
pub password: String,
24+
}
25+
26+
impl Credentials {
27+
pub fn new(username: String, password: String) -> Self {
28+
Self { username, password }
29+
}
830
}
931

1032
impl SaslConfig {
11-
pub(crate) fn auth_bytes(&self) -> Vec<u8> {
33+
pub(crate) fn credentials(&self) -> Credentials {
1234
match self {
13-
Self::Plain { username, password } => {
14-
let mut auth: Vec<u8> = vec![0];
15-
auth.extend(username.bytes());
16-
auth.push(0);
17-
auth.extend(password.bytes());
18-
auth
19-
}
35+
Self::Plain(credentials) => credentials.clone(),
36+
Self::ScramSha256(credentials) => credentials.clone(),
37+
Self::ScramSha512(credentials) => credentials.clone(),
2038
}
2139
}
2240

2341
pub(crate) fn mechanism(&self) -> &str {
2442
match self {
2543
Self::Plain { .. } => "PLAIN",
44+
Self::ScramSha256 { .. } => "SCRAM-SHA-256",
45+
Self::ScramSha512 { .. } => "SCRAM-SHA-512",
2646
}
2747
}
2848
}

src/messenger.rs

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::{
1212

1313
use futures::future::BoxFuture;
1414
use parking_lot::Mutex;
15+
use rsasl::{config::SASLConfig, mechname::MechanismNameError, prelude::Mechname};
1516
use thiserror::Error;
1617
use tokio::{
1718
io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
@@ -23,8 +24,6 @@ use tokio::{
2324
};
2425
use tracing::{debug, info, warn};
2526

26-
use crate::protocol::{api_version::ApiVersionRange, primitives::CompactString};
27-
use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType};
2827
use crate::{
2928
backoff::ErrorOrThrottle,
3029
protocol::{
@@ -34,12 +33,21 @@ use crate::{
3433
frame::{AsyncMessageRead, AsyncMessageWrite},
3534
messages::{
3635
ReadVersionedError, ReadVersionedType, RequestBody, RequestHeader, ResponseHeader,
37-
SaslAuthenticateRequest, SaslHandshakeRequest, WriteVersionedError, WriteVersionedType,
36+
SaslAuthenticateRequest, SaslAuthenticateResponse, SaslHandshakeRequest,
37+
SaslHandshakeResponse, WriteVersionedError, WriteVersionedType,
3838
},
3939
primitives::{Int16, Int32, NullableString, TaggedFields},
4040
},
4141
throttle::maybe_throttle,
4242
};
43+
use crate::{
44+
client::SaslConfig,
45+
protocol::{api_version::ApiVersionRange, primitives::CompactString},
46+
};
47+
use crate::{
48+
connection::Credentials,
49+
protocol::{messages::ApiVersionsRequest, traits::ReadType},
50+
};
4351

4452
#[derive(Debug)]
4553
struct Response {
@@ -186,6 +194,9 @@ pub enum SaslError {
186194

187195
#[error("API error: {0}")]
188196
ApiError(#[from] ApiError),
197+
198+
#[error("Invalid sasl mechanism: {0}")]
199+
InvalidSaslMechanism(#[from] MechanismNameError),
189200
}
190201

191202
impl<RW> Messenger<RW>
@@ -531,16 +542,10 @@ where
531542
Err(SyncVersionsError::NoWorkingVersion)
532543
}
533544

534-
pub async fn sasl_handshake(
545+
async fn sasl_authentication(
535546
&self,
536-
mechanism: &str,
537547
auth_bytes: Vec<u8>,
538-
) -> Result<(), SaslError> {
539-
let req = SaslHandshakeRequest::new(mechanism);
540-
let resp = self.request(req).await?;
541-
if let Some(err) = resp.error_code {
542-
return Err(SaslError::ApiError(err));
543-
}
548+
) -> Result<SaslAuthenticateResponse, SaslError> {
544549
let req = SaslAuthenticateRequest::new(auth_bytes);
545550
let resp = self.request(req).await?;
546551
if let Some(err) = resp.error_code {
@@ -549,6 +554,59 @@ where
549554
}
550555
return Err(SaslError::ApiError(err));
551556
}
557+
558+
Ok(resp)
559+
}
560+
561+
async fn sasl_handshake(&self, mechanism: &str) -> Result<SaslHandshakeResponse, SaslError> {
562+
let req = SaslHandshakeRequest::new(mechanism);
563+
let resp = self.request(req).await?;
564+
if let Some(err) = resp.error_code {
565+
return Err(SaslError::ApiError(err));
566+
}
567+
Ok(resp)
568+
}
569+
570+
pub async fn do_sasl(&self, config: SaslConfig) -> Result<(), SaslError> {
571+
let mechanism = config.mechanism();
572+
let resp = self.sasl_handshake(mechanism).await?;
573+
574+
let Credentials { username, password } = config.credentials();
575+
let config = SASLConfig::with_credentials(None, username, password).unwrap();
576+
let sasl = rsasl::prelude::SASLClient::new(config);
577+
let raw_mechanisms = resp.mechanisms.0.unwrap_or_default();
578+
let mechanisms = raw_mechanisms
579+
.iter()
580+
.map(|mech| {
581+
debug!("{:?}", mech);
582+
Mechname::parse(mech.0.as_bytes()).map_err(SaslError::InvalidSaslMechanism)
583+
})
584+
.collect::<Result<Vec<_>, SaslError>>()?;
585+
debug!("Supported mechanisms {:?}", mechanisms);
586+
let mut session = sasl.start_suggested(&mechanisms).unwrap();
587+
let selected_mechanism = session.get_mechname();
588+
debug!("Using {:?} for the SASL Mechanism", selected_mechanism);
589+
let mut data: Option<Vec<u8>> = None;
590+
591+
// Stepping the authentication exchange to completion.
592+
while {
593+
let mut out = Cursor::new(Vec::new());
594+
// The each call to step writes the generated auth data into the provided writer.
595+
// Normally this data would then have to be sent to the other party, but this goes
596+
// beyond the scope of this example.
597+
let state = session
598+
.step(data.as_deref(), &mut out)
599+
.expect("step errored!");
600+
601+
data = Some(out.into_inner());
602+
603+
// Returns `true` if step needs to be called again with another batch of data.
604+
state.is_running()
605+
} {
606+
let authentication_response = self.sasl_authentication(data.unwrap().to_vec()).await?;
607+
data = Some(authentication_response.auth_bytes.0);
608+
}
609+
552610
Ok(())
553611
}
554612
}

tests/client.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,9 @@ async fn test_sasl() {
4141
return;
4242
}
4343
ClientBuilder::new(vec![env::var("KAFKA_SASL_CONNECT").unwrap()])
44-
.sasl_config(rskafka::client::SaslConfig::Plain {
45-
username: "admin".to_string(),
46-
password: "admin-secret".to_string(),
47-
})
44+
.sasl_config(rskafka::client::SaslConfig::Plain(
45+
rskafka::client::Credentials::new("admin".to_string(), "admin-secret".to_string()),
46+
))
4847
.build()
4948
.await
5049
.unwrap();
@@ -425,7 +424,7 @@ async fn test_get_offset() {
425424
// use out-of order timestamps to ensure our "lastest offset" logic works
426425
let record_early = record(b"");
427426
let record_late = Record {
428-
timestamp: record_early.timestamp + chrono::Duration::seconds(1),
427+
timestamp: record_early.timestamp + chrono::Duration::try_seconds(1).unwrap(),
429428
..record_early.clone()
430429
};
431430
let offsets = partition_client

tests/produce_consume.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ async fn assert_produce_consume<F1, G1, F2, G2>(
270270

271271
// timestamps for records. We'll reorder the messages though to ts2, ts1, ts3
272272
let ts1 = Utc.timestamp_millis_opt(1337).unwrap();
273-
let ts2 = ts1 + Duration::milliseconds(1);
274-
let ts3 = ts2 + Duration::milliseconds(1);
273+
let ts2 = ts1 + Duration::try_milliseconds(1).unwrap();
274+
let ts3 = ts2 + Duration::try_milliseconds(1).unwrap();
275275

276276
let record_1 = {
277277
let record = Record {

0 commit comments

Comments
 (0)