Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fuzz/fuzz_targets/protocol_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ where
let transport = MockTransport::new(transport_data);

// setup messenger
let messenger = Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID));
messenger.override_version_ranges(HashMap::from([(
api_key,
ApiVersionRange::new(api_version, api_version),
Expand Down
8 changes: 2 additions & 6 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,9 @@ impl ConnectionHandler for BrokerRepresentation {
error,
})?;

let messenger = Arc::new(Messenger::new(
BufStream::new(transport),
max_message_size,
client_id,
));
let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id);
messenger.sync_versions().await?;
Ok(messenger)
Ok(Arc::new(messenger))
}
}

Expand Down
75 changes: 46 additions & 29 deletions src/messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
ops::DerefMut,
sync::{
atomic::{AtomicI32, Ordering},
Arc, RwLock,
Arc,
},
task::Poll,
};
Expand Down Expand Up @@ -108,7 +108,7 @@ pub struct Messenger<RW> {
/// Version ranges that we think are supported by the broker.
///
/// This needs to be bootstrapped by [`sync_versions`](Self::sync_versions).
version_ranges: RwLock<HashMap<ApiKey, ApiVersionRange>>,
version_ranges: HashMap<ApiKey, ApiVersionRange>,

/// Current stream state.
///
Expand Down Expand Up @@ -260,30 +260,41 @@ where
stream_write: Arc::new(AsyncMutex::new(stream_write)),
client_id,
correlation_id: AtomicI32::new(0),
version_ranges: RwLock::new(HashMap::new()),
version_ranges: HashMap::new(),
state,
join_handle,
}
}

#[cfg(feature = "unstable-fuzzing")]
pub fn override_version_ranges(&self, ranges: HashMap<ApiKey, ApiVersionRange>) {
pub fn override_version_ranges(&mut self, ranges: HashMap<ApiKey, ApiVersionRange>) {
self.set_version_ranges(ranges);
}

fn set_version_ranges(&self, ranges: HashMap<ApiKey, ApiVersionRange>) {
*self.version_ranges.write().expect("lock poisoned") = ranges;
/// Set supported version range.
fn set_version_ranges(&mut self, ranges: HashMap<ApiKey, ApiVersionRange>) {
self.version_ranges = ranges;
}

pub async fn request<R>(&self, msg: R) -> Result<R::ResponseBody, RequestError>
where
R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
{
let body_api_version = self
.version_ranges
.read()
.expect("lock poisoned")
self.request_with_version_ranges(msg, &self.version_ranges)
.await
}

async fn request_with_version_ranges<R>(
&self,
msg: R,
version_ranges: &HashMap<ApiKey, ApiVersionRange>,
) -> Result<R::ResponseBody, RequestError>
where
R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
{
let body_api_version = version_ranges
.get(&R::API_KEY)
.and_then(|range_server| match_versions(*range_server, R::API_VERSION_RANGE))
.ok_or(RequestError::NoVersionMatch {
Expand Down Expand Up @@ -390,18 +401,21 @@ where
fut.await
}

pub async fn sync_versions(&self) -> Result<(), SyncVersionsError> {
/// Sync supported version range.
///
/// Takes `&self mut` to ensure exclusive access.
pub async fn sync_versions(&mut self) -> Result<(), SyncVersionsError> {
for upper_bound in (ApiVersionsRequest::API_VERSION_RANGE.min().0 .0
..=ApiVersionsRequest::API_VERSION_RANGE.max().0 .0)
.rev()
{
self.set_version_ranges(HashMap::from([(
let version_ranges = HashMap::from([(
ApiKey::ApiVersions,
ApiVersionRange::new(
ApiVersionsRequest::API_VERSION_RANGE.min(),
ApiVersion(Int16(upper_bound)),
),
)]));
)]);

let body = ApiVersionsRequest {
client_software_name: Some(CompactString(String::from(env!("CARGO_PKG_NAME")))),
Expand All @@ -411,7 +425,10 @@ where
tagged_fields: Some(TaggedFields::default()),
};

match self.request(body).await {
match self
.request_with_version_ranges(body, &version_ranges)
.await
{
Ok(response) => {
if let Some(e) = response.error_code {
debug!(
Expand Down Expand Up @@ -613,7 +630,7 @@ where

#[cfg(test)]
mod tests {
use std::{ops::Deref, time::Duration};
use std::time::Duration;

use assert_matches::assert_matches;
use futures::{pin_mut, FutureExt};
Expand Down Expand Up @@ -673,7 +690,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ok() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct response
let mut msg = vec![];
Expand Down Expand Up @@ -704,13 +721,13 @@ mod tests {
(ApiKey::Produce),
ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
)]);
assert_eq!(messenger.version_ranges.read().unwrap().deref(), &expected);
assert_eq!(messenger.version_ranges, expected);
}

#[tokio::test]
async fn test_sync_versions_ignores_error_code() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct error response
let mut msg = vec![];
Expand Down Expand Up @@ -767,13 +784,13 @@ mod tests {
(ApiKey::Produce),
ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
)]);
assert_eq!(messenger.version_ranges.read().unwrap().deref(), &expected);
assert_eq!(messenger.version_ranges, expected);
}

#[tokio::test]
async fn test_sync_versions_ignores_read_code() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct error response
let mut msg = vec![];
Expand Down Expand Up @@ -818,13 +835,13 @@ mod tests {
(ApiKey::Produce),
ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
)]);
assert_eq!(messenger.version_ranges.read().unwrap().deref(), &expected);
assert_eq!(messenger.version_ranges, expected);
}

#[tokio::test]
async fn test_sync_versions_err_flipped_range() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct response
let mut msg = vec![];
Expand Down Expand Up @@ -857,7 +874,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ignores_garbage() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct response
let mut msg = vec![];
Expand Down Expand Up @@ -915,13 +932,13 @@ mod tests {
(ApiKey::Produce),
ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
)]);
assert_eq!(messenger.version_ranges.read().unwrap().deref(), &expected);
assert_eq!(messenger.version_ranges, expected);
}

#[tokio::test]
async fn test_sync_versions_err_no_working_version() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// construct error response
for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0 .0)
Expand Down Expand Up @@ -960,7 +977,7 @@ mod tests {
#[tokio::test]
async fn test_poison_hangup() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
messenger.set_version_ranges(HashMap::from([(
ApiKey::ListOffsets,
ListOffsetsRequest::API_VERSION_RANGE,
Expand All @@ -982,7 +999,7 @@ mod tests {
#[tokio::test]
async fn test_poison_negative_message_size() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
messenger.set_version_ranges(HashMap::from([(
ApiKey::ListOffsets,
ListOffsetsRequest::API_VERSION_RANGE,
Expand Down Expand Up @@ -1015,7 +1032,7 @@ mod tests {
#[tokio::test]
async fn test_broken_msg_header_does_not_poison() {
let (sim, rx) = MessageSimulator::new();
let messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
messenger.set_version_ranges(HashMap::from([(
ApiKey::ApiVersions,
ApiVersionsRequest::API_VERSION_RANGE,
Expand Down Expand Up @@ -1060,7 +1077,7 @@ mod tests {
let (tx_front, rx_middle) = tokio::io::duplex(1);
let (tx_middle, mut rx_back) = tokio::io::duplex(1);

let messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));

// create two barriers:
// - pause: will be passed after 3 bytes were sent by the client
Expand Down