diff --git a/Cargo.toml b/Cargo.toml index 926ea76..1b5a150 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ tokio = { version = "1.19", default-features = false, features = ["io-util", "ne tokio-rustls = { version = "0.26", optional = true, default-features = false, features = ["logging", "ring", "tls12"] } tracing = "0.1" zstd = { version = "0.13", optional = true } +rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2"]} [dev-dependencies] assert_matches = "1.5" diff --git a/src/client/mod.rs b/src/client/mod.rs index fa69c38..214f3dd 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -22,7 +22,7 @@ use error::{Error, Result}; use self::{controller::ControllerClient, partition::UnknownTopicHandling}; -pub use crate::connection::SaslConfig; +pub use crate::connection::{Credentials, SaslConfig}; #[derive(Debug, Error)] pub enum ProduceError { diff --git a/src/connection.rs b/src/connection.rs index 0fef385..b90a6ba 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -20,6 +20,7 @@ use crate::{ client::metadata_cache::MetadataCache, }; +pub use self::transport::Credentials; pub use self::transport::SaslConfig; pub use self::transport::TlsConfig; @@ -164,9 +165,7 @@ impl ConnectionHandler for BrokerRepresentation { let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id); messenger.sync_versions().await?; if let Some(sasl_config) = sasl_config { - messenger - .sasl_handshake(sasl_config.mechanism(), sasl_config.auth_bytes()) - .await?; + messenger.do_sasl(sasl_config).await?; } Ok(Arc::new(messenger)) } diff --git a/src/connection/transport.rs b/src/connection/transport.rs index 4fa4c4a..7119a19 100644 --- a/src/connection/transport.rs +++ b/src/connection/transport.rs @@ -11,7 +11,7 @@ use tokio::net::TcpStream; use tokio_rustls::{client::TlsStream, TlsConnector}; mod sasl; -pub use sasl::SaslConfig; +pub use sasl::{Credentials, SaslConfig}; #[cfg(feature = "transport-tls")] pub type TlsConfig = Option>; diff --git a/src/connection/transport/sasl.rs b/src/connection/transport/sasl.rs index 6d00c6d..c266b58 100644 --- a/src/connection/transport/sasl.rs +++ b/src/connection/transport/sasl.rs @@ -4,25 +4,45 @@ pub enum SaslConfig { /// /// # References /// - - Plain { username: String, password: String }, + Plain(Credentials), + /// SASL - SCRAM-SHA-256 + /// + /// # References + /// - + ScramSha256(Credentials), + /// SASL - SCRAM-SHA-512 + /// + /// # References + /// - + ScramSha512(Credentials), +} + +#[derive(Debug, Clone)] +pub struct Credentials { + pub username: String, + pub password: String, +} + +impl Credentials { + pub fn new(username: String, password: String) -> Self { + Self { username, password } + } } impl SaslConfig { - pub(crate) fn auth_bytes(&self) -> Vec { + pub(crate) fn credentials(&self) -> Credentials { match self { - Self::Plain { username, password } => { - let mut auth: Vec = vec![0]; - auth.extend(username.bytes()); - auth.push(0); - auth.extend(password.bytes()); - auth - } + Self::Plain(credentials) => credentials.clone(), + Self::ScramSha256(credentials) => credentials.clone(), + Self::ScramSha512(credentials) => credentials.clone(), } } pub(crate) fn mechanism(&self) -> &str { match self { Self::Plain { .. } => "PLAIN", + Self::ScramSha256 { .. } => "SCRAM-SHA-256", + Self::ScramSha512 { .. } => "SCRAM-SHA-512", } } } diff --git a/src/messenger.rs b/src/messenger.rs index 3d97bc4..efad01a 100644 --- a/src/messenger.rs +++ b/src/messenger.rs @@ -12,6 +12,11 @@ use std::{ use futures::future::BoxFuture; use parking_lot::Mutex; +use rsasl::{ + config::SASLConfig, + mechname::MechanismNameError, + prelude::{Mechname, SessionError}, +}; use thiserror::Error; use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf}, @@ -23,8 +28,6 @@ use tokio::{ }; use tracing::{debug, info, warn}; -use crate::protocol::{api_version::ApiVersionRange, primitives::CompactString}; -use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType}; use crate::{ backoff::ErrorOrThrottle, protocol::{ @@ -34,12 +37,21 @@ use crate::{ frame::{AsyncMessageRead, AsyncMessageWrite}, messages::{ ReadVersionedError, ReadVersionedType, RequestBody, RequestHeader, ResponseHeader, - SaslAuthenticateRequest, SaslHandshakeRequest, WriteVersionedError, WriteVersionedType, + SaslAuthenticateRequest, SaslAuthenticateResponse, SaslHandshakeRequest, + SaslHandshakeResponse, WriteVersionedError, WriteVersionedType, }, primitives::{Int16, Int32, NullableString, TaggedFields}, }, throttle::maybe_throttle, }; +use crate::{ + client::SaslConfig, + protocol::{api_version::ApiVersionRange, primitives::CompactString}, +}; +use crate::{ + connection::Credentials, + protocol::{messages::ApiVersionsRequest, traits::ReadType}, +}; #[derive(Debug)] struct Response { @@ -186,6 +198,15 @@ pub enum SaslError { #[error("API error: {0}")] ApiError(#[from] ApiError), + + #[error("Invalid sasl mechanism: {0}")] + InvalidSaslMechanism(#[from] MechanismNameError), + + #[error("Sasl session error: {0}")] + SaslSessionError(#[from] SessionError), + + #[error("unsupported sasl mechanism")] + UnsupportedSaslMechanism, } impl Messenger @@ -531,16 +552,10 @@ where Err(SyncVersionsError::NoWorkingVersion) } - pub async fn sasl_handshake( + async fn sasl_authentication( &self, - mechanism: &str, auth_bytes: Vec, - ) -> Result<(), SaslError> { - let req = SaslHandshakeRequest::new(mechanism); - let resp = self.request(req).await?; - if let Some(err) = resp.error_code { - return Err(SaslError::ApiError(err)); - } + ) -> Result { let req = SaslAuthenticateRequest::new(auth_bytes); let resp = self.request(req).await?; if let Some(err) = resp.error_code { @@ -549,6 +564,54 @@ where } return Err(SaslError::ApiError(err)); } + + Ok(resp) + } + + async fn sasl_handshake(&self, mechanism: &str) -> Result { + let req = SaslHandshakeRequest::new(mechanism); + let resp = self.request(req).await?; + if let Some(err) = resp.error_code { + return Err(SaslError::ApiError(err)); + } + Ok(resp) + } + + pub async fn do_sasl(&self, config: SaslConfig) -> Result<(), SaslError> { + let mechanism = config.mechanism(); + let resp = self.sasl_handshake(mechanism).await?; + + let Credentials { username, password } = config.credentials(); + let config = SASLConfig::with_credentials(None, username, password).unwrap(); + let sasl = rsasl::prelude::SASLClient::new(config); + let raw_mechanisms = resp.mechanisms.0.unwrap_or_default(); + let mechanisms = raw_mechanisms + .iter() + .map(|mech| Mechname::parse(mech.0.as_bytes()).map_err(SaslError::InvalidSaslMechanism)) + .collect::, SaslError>>()?; + debug!(?mechanisms, "Supported SASL mechanisms"); + let prefer_mechanism = + Mechname::parse(mechanism.as_bytes()).map_err(SaslError::InvalidSaslMechanism)?; + if !mechanisms.contains(&prefer_mechanism) { + return Err(SaslError::UnsupportedSaslMechanism); + } + let mut session = sasl + .start_suggested(&[prefer_mechanism]) + .map_err(|_| SaslError::UnsupportedSaslMechanism)?; + debug!(?mechanism, "Using SASL Mechanism"); + // we step through the auth process, starting on our side with NO data received so far + let mut data_received: Option> = None; + loop { + let mut to_sent = Cursor::new(Vec::new()); + let state = session.step(data_received.as_deref(), &mut to_sent)?; + if !state.is_running() { + break; + } + + let authentication_response = self.sasl_authentication(to_sent.into_inner()).await?; + data_received = Some(authentication_response.auth_bytes.0); + } + Ok(()) } } diff --git a/tests/client.rs b/tests/client.rs index 7ec8486..024dc7b 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -41,10 +41,9 @@ async fn test_sasl() { return; } ClientBuilder::new(vec![env::var("KAFKA_SASL_CONNECT").unwrap()]) - .sasl_config(rskafka::client::SaslConfig::Plain { - username: "admin".to_string(), - password: "admin-secret".to_string(), - }) + .sasl_config(rskafka::client::SaslConfig::Plain( + rskafka::client::Credentials::new("admin".to_string(), "admin-secret".to_string()), + )) .build() .await .unwrap(); @@ -425,7 +424,7 @@ async fn test_get_offset() { // use out-of order timestamps to ensure our "lastest offset" logic works let record_early = record(b""); let record_late = Record { - timestamp: record_early.timestamp + chrono::Duration::seconds(1), + timestamp: record_early.timestamp + chrono::Duration::try_seconds(1).unwrap(), ..record_early.clone() }; let offsets = partition_client diff --git a/tests/produce_consume.rs b/tests/produce_consume.rs index 7efc470..47c91b7 100644 --- a/tests/produce_consume.rs +++ b/tests/produce_consume.rs @@ -270,8 +270,8 @@ async fn assert_produce_consume( // timestamps for records. We'll reorder the messages though to ts2, ts1, ts3 let ts1 = Utc.timestamp_millis_opt(1337).unwrap(); - let ts2 = ts1 + Duration::milliseconds(1); - let ts3 = ts2 + Duration::milliseconds(1); + let ts2 = ts1 + Duration::try_milliseconds(1).unwrap(); + let ts3 = ts2 + Duration::try_milliseconds(1).unwrap(); let record_1 = { let record = Record {