Skip to content

Commit 55755ba

Browse files
Merge pull request #180 from influxdata/crepererum/custom_client_id
feat: allow setting a custom client ID
2 parents 00988a5 + 97b2bec commit 55755ba

File tree

6 files changed

+71
-22
lines changed

6 files changed

+71
-22
lines changed

fuzz/fuzz_targets/protocol_reader.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#![no_main]
2-
use std::{collections::HashMap, io::Cursor, time::Duration};
2+
use std::{collections::HashMap, io::Cursor, sync::Arc, time::Duration};
33

44
use libfuzzer_sys::fuzz_target;
55
use pin_project_lite::pin_project;
66
use rskafka::{
7+
build_info::DEFAULT_CLIENT_ID,
78
messenger::Messenger,
89
protocol::{
910
api_key::ApiKey,
@@ -135,7 +136,7 @@ where
135136
let transport = MockTransport::new(transport_data);
136137

137138
// setup messenger
138-
let messenger = Messenger::new(transport, message_size);
139+
let messenger = Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID));
139140
messenger.override_version_ranges(HashMap::from([(
140141
api_key,
141142
ApiVersionRange::new(api_version, api_version),

src/build_info.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
//! Static information that is determined at build time.
2+
3+
/// Default client ID that is used when the user does not specify one.
4+
///
5+
/// Technically we don't need to send a client_id, but newer redpanda version fail to parse the message
6+
/// without it. See <https://github.com/influxdata/rskafka/issues/169>.
7+
pub const DEFAULT_CLIENT_ID: &str = env!("CARGO_PKG_NAME");

src/client/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::sync::Arc;
33
use thiserror::Error;
44

55
use crate::{
6+
build_info::DEFAULT_CLIENT_ID,
67
client::partition::PartitionClient,
78
connection::{BrokerConnector, MetadataLookupMode, TlsConfig},
89
protocol::primitives::Boolean,
@@ -38,6 +39,7 @@ pub enum ProduceError {
3839
/// Builder for [`Client`].
3940
pub struct ClientBuilder {
4041
bootstrap_brokers: Vec<String>,
42+
client_id: Option<Arc<str>>,
4143
max_message_size: usize,
4244
socks5_proxy: Option<String>,
4345
tls_config: TlsConfig,
@@ -48,12 +50,19 @@ impl ClientBuilder {
4850
pub fn new(bootstrap_brokers: Vec<String>) -> Self {
4951
Self {
5052
bootstrap_brokers,
53+
client_id: None,
5154
max_message_size: 100 * 1024 * 1024, // 100MB
5255
socks5_proxy: None,
5356
tls_config: TlsConfig::default(),
5457
}
5558
}
5659

60+
/// Sets client ID.
61+
pub fn client_id(mut self, client_id: impl Into<Arc<str>>) -> Self {
62+
self.client_id = Some(client_id.into());
63+
self
64+
}
65+
5766
/// Set maximum size (in bytes) of message frames that can be received from a broker.
5867
///
5968
/// Setting this to larger sizes allows you to specify larger size limits in [`PartitionClient::fetch_records`],
@@ -82,6 +91,8 @@ impl ClientBuilder {
8291
pub async fn build(self) -> Result<Client> {
8392
let brokers = Arc::new(BrokerConnector::new(
8493
self.bootstrap_brokers,
94+
self.client_id
95+
.unwrap_or_else(|| Arc::from(DEFAULT_CLIENT_ID)),
8596
self.tls_config,
8697
self.socks5_proxy,
8798
self.max_message_size,

src/connection.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ trait ConnectionHandler {
5353

5454
async fn connect(
5555
&self,
56+
client_id: Arc<str>,
5657
tls_config: TlsConfig,
5758
socks5_proxy: Option<String>,
5859
max_message_size: usize,
@@ -103,6 +104,7 @@ impl ConnectionHandler for BrokerRepresentation {
103104

104105
async fn connect(
105106
&self,
107+
client_id: Arc<str>,
106108
tls_config: TlsConfig,
107109
socks5_proxy: Option<String>,
108110
max_message_size: usize,
@@ -120,7 +122,11 @@ impl ConnectionHandler for BrokerRepresentation {
120122
error,
121123
})?;
122124

123-
let messenger = Arc::new(Messenger::new(BufStream::new(transport), max_message_size));
125+
let messenger = Arc::new(Messenger::new(
126+
BufStream::new(transport),
127+
max_message_size,
128+
client_id,
129+
));
124130
messenger.sync_versions().await?;
125131
Ok(messenger)
126132
}
@@ -136,6 +142,9 @@ pub struct BrokerConnector {
136142
/// Broker URLs used to boostrap this pool
137143
bootstrap_brokers: Vec<String>,
138144

145+
/// Client ID.
146+
client_id: Arc<str>,
147+
139148
/// Discovered brokers in the cluster, including bootstrap brokers
140149
topology: BrokerTopology,
141150

@@ -165,12 +174,14 @@ pub struct BrokerConnector {
165174
impl BrokerConnector {
166175
pub fn new(
167176
bootstrap_brokers: Vec<String>,
177+
client_id: Arc<str>,
168178
tls_config: TlsConfig,
169179
socks5_proxy: Option<String>,
170180
max_message_size: usize,
171181
) -> Self {
172182
Self {
173183
bootstrap_brokers,
184+
client_id,
174185
topology: Default::default(),
175186
cached_arbitrary_broker: Mutex::new(None),
176187
cached_metadata: Default::default(),
@@ -265,6 +276,7 @@ impl BrokerConnector {
265276
Some(broker) => {
266277
let connection = BrokerRepresentation::Topology(broker)
267278
.connect(
279+
Arc::clone(&self.client_id),
268280
self.tls_config.clone(),
269281
self.socks5_proxy.clone(),
270282
self.max_message_size,
@@ -350,6 +362,7 @@ impl BrokerCache for &BrokerConnector {
350362

351363
let connection = connect_to_a_broker_with_retry(
352364
self.brokers(),
365+
Arc::clone(&self.client_id),
353366
&self.backoff_config,
354367
self.tls_config.clone(),
355368
self.socks5_proxy.clone(),
@@ -369,6 +382,7 @@ impl BrokerCache for &BrokerConnector {
369382

370383
async fn connect_to_a_broker_with_retry<B>(
371384
mut brokers: Vec<B>,
385+
client_id: Arc<str>,
372386
backoff_config: &BackoffConfig,
373387
tls_config: TlsConfig,
374388
socks5_proxy: Option<String>,
@@ -385,7 +399,12 @@ where
385399
.retry_with_backoff("broker_connect", || async {
386400
for broker in &brokers {
387401
let conn = broker
388-
.connect(tls_config.clone(), socks5_proxy.clone(), max_message_size)
402+
.connect(
403+
Arc::clone(&client_id),
404+
tls_config.clone(),
405+
socks5_proxy.clone(),
406+
max_message_size,
407+
)
389408
.await;
390409

391410
let connection = match conn {
@@ -457,7 +476,7 @@ where
457476
#[cfg(test)]
458477
mod tests {
459478
use super::*;
460-
use crate::protocol::api_key::ApiKey;
479+
use crate::{build_info::DEFAULT_CLIENT_ID, protocol::api_key::ApiKey};
461480
use std::sync::atomic::{AtomicBool, Ordering};
462481

463482
struct FakeBroker(Box<dyn Fn() -> Result<MetadataResponse, RequestError> + Send + Sync>);
@@ -684,6 +703,7 @@ mod tests {
684703

685704
async fn connect(
686705
&self,
706+
_client_id: Arc<str>,
687707
_tls_config: TlsConfig,
688708
_socks5_proxy: Option<String>,
689709
_max_message_size: usize,
@@ -709,6 +729,7 @@ mod tests {
709729
// connects successfully.
710730
let conn = connect_to_a_broker_with_retry(
711731
brokers,
732+
Arc::from(DEFAULT_CLIENT_ID),
712733
&Default::default(),
713734
Default::default(),
714735
Default::default(),

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
mod backoff;
2121

22+
pub mod build_info;
23+
2224
pub mod client;
2325

2426
mod connection;

src/messenger.rs

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ pub struct Messenger<RW> {
9797
/// This will be used by [`request`](Self::request) to queue up messages.
9898
stream_write: Arc<AsyncMutex<WriteHalf<RW>>>,
9999

100+
/// Client ID.
101+
client_id: Arc<str>,
102+
100103
/// The next correlation ID.
101104
///
102105
/// This is used to map responses to active requests.
@@ -175,7 +178,7 @@ impl<RW> Messenger<RW>
175178
where
176179
RW: AsyncRead + AsyncWrite + Send + 'static,
177180
{
178-
pub fn new(stream: RW, max_message_size: usize) -> Self {
181+
pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
179182
let (stream_read, stream_write) = tokio::io::split(stream);
180183
let state = Arc::new(Mutex::new(MessengerState::RequestMap(HashMap::default())));
181184
let state_captured = Arc::clone(&state);
@@ -255,6 +258,7 @@ where
255258

256259
Self {
257260
stream_write: Arc::new(AsyncMutex::new(stream_write)),
261+
client_id,
258262
correlation_id: AtomicI32::new(0),
259263
version_ranges: RwLock::new(HashMap::new()),
260264
state,
@@ -304,7 +308,7 @@ where
304308
correlation_id: Int32(correlation_id),
305309
// Technically we don't need to send a client_id, but newer redpanda version fail to parse the message
306310
// without it. See https://github.com/influxdata/rskafka/issues/169 .
307-
client_id: Some(NullableString(Some(String::from(env!("CARGO_PKG_NAME"))))),
311+
client_id: Some(NullableString(Some(String::from(self.client_id.as_ref())))),
308312
tagged_fields: Some(TaggedFields::default()),
309313
};
310314
let header_version = if use_tagged_fields_in_request {
@@ -620,12 +624,15 @@ mod tests {
620624

621625
use super::*;
622626

623-
use crate::protocol::{
624-
error::Error as ApiError,
625-
messages::{
626-
ApiVersionsResponse, ApiVersionsResponseApiKey, ListOffsetsRequest, NORMAL_CONSUMER,
627+
use crate::{
628+
build_info::DEFAULT_CLIENT_ID,
629+
protocol::{
630+
error::Error as ApiError,
631+
messages::{
632+
ApiVersionsResponse, ApiVersionsResponseApiKey, ListOffsetsRequest, NORMAL_CONSUMER,
633+
},
634+
traits::WriteType,
627635
},
628-
traits::WriteType,
629636
};
630637

631638
#[test]
@@ -666,7 +673,7 @@ mod tests {
666673
#[tokio::test]
667674
async fn test_sync_versions_ok() {
668675
let (sim, rx) = MessageSimulator::new();
669-
let messenger = Messenger::new(rx, 1_000);
676+
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
670677

671678
// construct response
672679
let mut msg = vec![];
@@ -703,7 +710,7 @@ mod tests {
703710
#[tokio::test]
704711
async fn test_sync_versions_ignores_error_code() {
705712
let (sim, rx) = MessageSimulator::new();
706-
let messenger = Messenger::new(rx, 1_000);
713+
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
707714

708715
// construct error response
709716
let mut msg = vec![];
@@ -766,7 +773,7 @@ mod tests {
766773
#[tokio::test]
767774
async fn test_sync_versions_ignores_read_code() {
768775
let (sim, rx) = MessageSimulator::new();
769-
let messenger = Messenger::new(rx, 1_000);
776+
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
770777

771778
// construct error response
772779
let mut msg = vec![];
@@ -817,7 +824,7 @@ mod tests {
817824
#[tokio::test]
818825
async fn test_sync_versions_err_flipped_range() {
819826
let (sim, rx) = MessageSimulator::new();
820-
let messenger = Messenger::new(rx, 1_000);
827+
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
821828

822829
// construct response
823830
let mut msg = vec![];
@@ -850,7 +857,7 @@ mod tests {
850857
#[tokio::test]
851858
async fn test_sync_versions_ignores_garbage() {
852859
let (sim, rx) = MessageSimulator::new();
853-
let messenger = Messenger::new(rx, 1_000);
860+
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
854861

855862
// construct response
856863
let mut msg = vec![];
@@ -914,7 +921,7 @@ mod tests {
914921
#[tokio::test]
915922
async fn test_sync_versions_err_no_working_version() {
916923
let (sim, rx) = MessageSimulator::new();
917-
let messenger = Messenger::new(rx, 1_000);
924+
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
918925

919926
// construct error response
920927
for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0 .0)
@@ -953,7 +960,7 @@ mod tests {
953960
#[tokio::test]
954961
async fn test_poison_hangup() {
955962
let (sim, rx) = MessageSimulator::new();
956-
let messenger = Messenger::new(rx, 1_000);
963+
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
957964
messenger.set_version_ranges(HashMap::from([(
958965
ApiKey::ListOffsets,
959966
ListOffsetsRequest::API_VERSION_RANGE,
@@ -975,7 +982,7 @@ mod tests {
975982
#[tokio::test]
976983
async fn test_poison_negative_message_size() {
977984
let (sim, rx) = MessageSimulator::new();
978-
let messenger = Messenger::new(rx, 1_000);
985+
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
979986
messenger.set_version_ranges(HashMap::from([(
980987
ApiKey::ListOffsets,
981988
ListOffsetsRequest::API_VERSION_RANGE,
@@ -1008,7 +1015,7 @@ mod tests {
10081015
#[tokio::test]
10091016
async fn test_broken_msg_header_does_not_poison() {
10101017
let (sim, rx) = MessageSimulator::new();
1011-
let messenger = Messenger::new(rx, 1_000);
1018+
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
10121019
messenger.set_version_ranges(HashMap::from([(
10131020
ApiKey::ApiVersions,
10141021
ApiVersionsRequest::API_VERSION_RANGE,
@@ -1053,7 +1060,7 @@ mod tests {
10531060
let (tx_front, rx_middle) = tokio::io::duplex(1);
10541061
let (tx_middle, mut rx_back) = tokio::io::duplex(1);
10551062

1056-
let messenger = Messenger::new(tx_front, 1_000);
1063+
let messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));
10571064

10581065
// create two barriers:
10591066
// - pause: will be passed after 3 bytes were sent by the client

0 commit comments

Comments
 (0)