Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions fuzz/fuzz_targets/protocol_reader.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#![no_main]
use std::{collections::HashMap, io::Cursor, time::Duration};
use std::{collections::HashMap, io::Cursor, sync::Arc, time::Duration};

use libfuzzer_sys::fuzz_target;
use pin_project_lite::pin_project;
use rskafka::{
build_info::DEFAULT_CLIENT_ID,
messenger::Messenger,
protocol::{
api_key::ApiKey,
Expand Down Expand Up @@ -135,7 +136,7 @@ where
let transport = MockTransport::new(transport_data);

// setup messenger
let messenger = Messenger::new(transport, message_size);
let messenger = Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID));
messenger.override_version_ranges(HashMap::from([(
api_key,
ApiVersionRange::new(api_version, api_version),
Expand Down
7 changes: 7 additions & 0 deletions src/build_info.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//! Static information that is determined at build time.

/// Default client ID that is used when the user does not specify one.
///
/// Technically we don't need to send a client_id, but newer redpanda version fail to parse the message
/// without it. See <https://github.com/influxdata/rskafka/issues/169>.
pub const DEFAULT_CLIENT_ID: &str = env!("CARGO_PKG_NAME");
11 changes: 11 additions & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use thiserror::Error;

use crate::{
build_info::DEFAULT_CLIENT_ID,
client::partition::PartitionClient,
connection::{BrokerConnector, MetadataLookupMode, TlsConfig},
protocol::primitives::Boolean,
Expand Down Expand Up @@ -38,6 +39,7 @@ pub enum ProduceError {
/// Builder for [`Client`].
pub struct ClientBuilder {
bootstrap_brokers: Vec<String>,
client_id: Option<Arc<str>>,
max_message_size: usize,
socks5_proxy: Option<String>,
tls_config: TlsConfig,
Expand All @@ -48,12 +50,19 @@ impl ClientBuilder {
pub fn new(bootstrap_brokers: Vec<String>) -> Self {
Self {
bootstrap_brokers,
client_id: None,
max_message_size: 100 * 1024 * 1024, // 100MB
socks5_proxy: None,
tls_config: TlsConfig::default(),
}
}

/// Sets client ID.
pub fn client_id(mut self, client_id: impl Into<Arc<str>>) -> Self {
self.client_id = Some(client_id.into());
self
}

/// Set maximum size (in bytes) of message frames that can be received from a broker.
///
/// Setting this to larger sizes allows you to specify larger size limits in [`PartitionClient::fetch_records`],
Expand Down Expand Up @@ -82,6 +91,8 @@ impl ClientBuilder {
pub async fn build(self) -> Result<Client> {
let brokers = Arc::new(BrokerConnector::new(
self.bootstrap_brokers,
self.client_id
.unwrap_or_else(|| Arc::from(DEFAULT_CLIENT_ID)),
self.tls_config,
self.socks5_proxy,
self.max_message_size,
Expand Down
27 changes: 24 additions & 3 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ trait ConnectionHandler {

async fn connect(
&self,
client_id: Arc<str>,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
max_message_size: usize,
Expand Down Expand Up @@ -103,6 +104,7 @@ impl ConnectionHandler for BrokerRepresentation {

async fn connect(
&self,
client_id: Arc<str>,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
max_message_size: usize,
Expand All @@ -120,7 +122,11 @@ impl ConnectionHandler for BrokerRepresentation {
error,
})?;

let messenger = Arc::new(Messenger::new(BufStream::new(transport), max_message_size));
let messenger = Arc::new(Messenger::new(
BufStream::new(transport),
max_message_size,
client_id,
));
messenger.sync_versions().await?;
Ok(messenger)
}
Expand All @@ -136,6 +142,9 @@ pub struct BrokerConnector {
/// Broker URLs used to boostrap this pool
bootstrap_brokers: Vec<String>,

/// Client ID.
client_id: Arc<str>,

/// Discovered brokers in the cluster, including bootstrap brokers
topology: BrokerTopology,

Expand Down Expand Up @@ -165,12 +174,14 @@ pub struct BrokerConnector {
impl BrokerConnector {
pub fn new(
bootstrap_brokers: Vec<String>,
client_id: Arc<str>,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
max_message_size: usize,
) -> Self {
Self {
bootstrap_brokers,
client_id,
topology: Default::default(),
cached_arbitrary_broker: Mutex::new(None),
cached_metadata: Default::default(),
Expand Down Expand Up @@ -265,6 +276,7 @@ impl BrokerConnector {
Some(broker) => {
let connection = BrokerRepresentation::Topology(broker)
.connect(
Arc::clone(&self.client_id),
self.tls_config.clone(),
self.socks5_proxy.clone(),
self.max_message_size,
Expand Down Expand Up @@ -350,6 +362,7 @@ impl BrokerCache for &BrokerConnector {

let connection = connect_to_a_broker_with_retry(
self.brokers(),
Arc::clone(&self.client_id),
&self.backoff_config,
self.tls_config.clone(),
self.socks5_proxy.clone(),
Expand All @@ -369,6 +382,7 @@ impl BrokerCache for &BrokerConnector {

async fn connect_to_a_broker_with_retry<B>(
mut brokers: Vec<B>,
client_id: Arc<str>,
backoff_config: &BackoffConfig,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
Expand All @@ -385,7 +399,12 @@ where
.retry_with_backoff("broker_connect", || async {
for broker in &brokers {
let conn = broker
.connect(tls_config.clone(), socks5_proxy.clone(), max_message_size)
.connect(
Arc::clone(&client_id),
tls_config.clone(),
socks5_proxy.clone(),
max_message_size,
)
.await;

let connection = match conn {
Expand Down Expand Up @@ -457,7 +476,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::api_key::ApiKey;
use crate::{build_info::DEFAULT_CLIENT_ID, protocol::api_key::ApiKey};
use std::sync::atomic::{AtomicBool, Ordering};

struct FakeBroker(Box<dyn Fn() -> Result<MetadataResponse, RequestError> + Send + Sync>);
Expand Down Expand Up @@ -684,6 +703,7 @@ mod tests {

async fn connect(
&self,
_client_id: Arc<str>,
_tls_config: TlsConfig,
_socks5_proxy: Option<String>,
_max_message_size: usize,
Expand All @@ -709,6 +729,7 @@ mod tests {
// connects successfully.
let conn = connect_to_a_broker_with_retry(
brokers,
Arc::from(DEFAULT_CLIENT_ID),
&Default::default(),
Default::default(),
Default::default(),
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

mod backoff;

pub mod build_info;

pub mod client;

mod connection;
Expand Down
41 changes: 24 additions & 17 deletions src/messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ pub struct Messenger<RW> {
/// This will be used by [`request`](Self::request) to queue up messages.
stream_write: Arc<AsyncMutex<WriteHalf<RW>>>,

/// Client ID.
client_id: Arc<str>,

/// The next correlation ID.
///
/// This is used to map responses to active requests.
Expand Down Expand Up @@ -175,7 +178,7 @@ impl<RW> Messenger<RW>
where
RW: AsyncRead + AsyncWrite + Send + 'static,
{
pub fn new(stream: RW, max_message_size: usize) -> Self {
pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
let (stream_read, stream_write) = tokio::io::split(stream);
let state = Arc::new(Mutex::new(MessengerState::RequestMap(HashMap::default())));
let state_captured = Arc::clone(&state);
Expand Down Expand Up @@ -255,6 +258,7 @@ where

Self {
stream_write: Arc::new(AsyncMutex::new(stream_write)),
client_id,
correlation_id: AtomicI32::new(0),
version_ranges: RwLock::new(HashMap::new()),
state,
Expand Down Expand Up @@ -304,7 +308,7 @@ where
correlation_id: Int32(correlation_id),
// Technically we don't need to send a client_id, but newer redpanda version fail to parse the message
// without it. See https://github.com/influxdata/rskafka/issues/169 .
client_id: Some(NullableString(Some(String::from(env!("CARGO_PKG_NAME"))))),
client_id: Some(NullableString(Some(String::from(self.client_id.as_ref())))),
tagged_fields: Some(TaggedFields::default()),
};
let header_version = if use_tagged_fields_in_request {
Expand Down Expand Up @@ -620,12 +624,15 @@ mod tests {

use super::*;

use crate::protocol::{
error::Error as ApiError,
messages::{
ApiVersionsResponse, ApiVersionsResponseApiKey, ListOffsetsRequest, NORMAL_CONSUMER,
use crate::{
build_info::DEFAULT_CLIENT_ID,
protocol::{
error::Error as ApiError,
messages::{
ApiVersionsResponse, ApiVersionsResponseApiKey, ListOffsetsRequest, NORMAL_CONSUMER,
},
traits::WriteType,
},
traits::WriteType,
};

#[test]
Expand Down Expand Up @@ -666,7 +673,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ok() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000);
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct response
let mut msg = vec![];
Expand Down Expand Up @@ -703,7 +710,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ignores_error_code() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000);
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct error response
let mut msg = vec![];
Expand Down Expand Up @@ -766,7 +773,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ignores_read_code() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000);
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct error response
let mut msg = vec![];
Expand Down Expand Up @@ -817,7 +824,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_err_flipped_range() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000);
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct response
let mut msg = vec![];
Expand Down Expand Up @@ -850,7 +857,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ignores_garbage() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000);
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct response
let mut msg = vec![];
Expand Down Expand Up @@ -914,7 +921,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_err_no_working_version() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000);
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct error response
for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0 .0)
Expand Down Expand Up @@ -953,7 +960,7 @@ mod tests {
#[tokio::test]
async fn test_poison_hangup() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000);
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
messenger.set_version_ranges(HashMap::from([(
ApiKey::ListOffsets,
ListOffsetsRequest::API_VERSION_RANGE,
Expand All @@ -975,7 +982,7 @@ mod tests {
#[tokio::test]
async fn test_poison_negative_message_size() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000);
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
messenger.set_version_ranges(HashMap::from([(
ApiKey::ListOffsets,
ListOffsetsRequest::API_VERSION_RANGE,
Expand Down Expand Up @@ -1008,7 +1015,7 @@ mod tests {
#[tokio::test]
async fn test_broken_msg_header_does_not_poison() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000);
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
messenger.set_version_ranges(HashMap::from([(
ApiKey::ApiVersions,
ApiVersionsRequest::API_VERSION_RANGE,
Expand Down Expand Up @@ -1053,7 +1060,7 @@ mod tests {
let (tx_front, rx_middle) = tokio::io::duplex(1);
let (tx_middle, mut rx_back) = tokio::io::duplex(1);

let messenger = Messenger::new(tx_front, 1_000);
let messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// create two barriers:
// - pause: will be passed after 3 bytes were sent by the client
Expand Down