Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/protocol/messages/create_topics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ where
assert!(v <= 5);

self.partition_index.write(writer)?;
self.broker_ids.write(writer)?;

if v >= 5 {
CompactArrayRef(self.broker_ids.0.as_deref()).write(writer)?;
} else {
self.broker_ids.write(writer)?;
}

if v >= 5 {
match self.tagged_fields.as_ref() {
Expand Down
110 changes: 95 additions & 15 deletions src/protocol/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,74 @@ where
}
}

/// Represents a sequence of objects of a given type T.
///
/// Type T can be either a primitive type (e.g. STRING) or a structure. First, the length N + 1 is given as an
/// UNSIGNED_VARINT. Then N instances of type T follow. A null array is represented with a length of 0. In protocol
/// documentation an array of T instances is referred to as `[T]`.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct CompactArray<T>(pub Option<Vec<T>>);

impl<R, T> ReadType<R> for CompactArray<T>
where
R: Read,
T: ReadType<R>,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
let len = UnsignedVarint::read(reader)?.0;
match len {
0 => Ok(Self(None)),
n => {
let len = usize::try_from(n - 1).map_err(ReadError::Overflow)?;
let mut builder = VecBuilder::new(len);
for _ in 0..len {
builder.push(T::read(reader)?);
}
Ok(Self(Some(builder.into())))
}
}
}
}

impl<W, T> WriteType<W> for CompactArray<T>
where
W: Write,
T: WriteType<W>,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
CompactArrayRef(self.0.as_deref()).write(writer)
}
}

/// Same as [`CompactArray`] but contains referenced data.
///
/// This only supports writing.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CompactArrayRef<'a, T>(pub Option<&'a [T]>);

impl<'a, W, T> WriteType<W> for CompactArrayRef<'a, T>
where
W: Write,
T: WriteType<W>,
{
fn write(&self, writer: &mut W) -> Result<(), WriteError> {
match self.0 {
None => UnsignedVarint(0).write(writer),
Some(inner) => {
let len = u64::try_from(inner.len() + 1).map_err(WriteError::from)?;
UnsignedVarint(len).write(writer)?;

for element in inner {
element.write(writer)?;
}

Ok(())
}
}
}
}

/// Represents a sequence of Kafka records as NULLABLE_BYTES.
///
/// This primitive actually depends on the message version and evolved twice in [KIP-32] and [KIP-98]. We only support
Expand Down Expand Up @@ -933,23 +1001,19 @@ mod tests {
Int32(i32::MAX).write(&mut buf).unwrap();
buf.set_position(0);

// Use a rather large struct here to trigger OOM
#[derive(Debug)]
struct Large {
_inner: [u8; 1024],
}
let err = Array::<Large>::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}

impl<R> ReadType<R> for Large
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
Int32::read(reader)?;
unreachable!()
}
}
test_roundtrip!(CompactArray<Int32>, test_compact_array_roundtrip);

let err = Array::<Large>::read(&mut buf).unwrap_err();
#[test]
fn test_compact_array_blowup_memory() {
let mut buf = Cursor::new(Vec::<u8>::new());
UnsignedVarint(u64::MAX).write(&mut buf).unwrap();
buf.set_position(0);

let err = CompactArray::<Large>::read(&mut buf).unwrap_err();
assert_matches!(err, ReadError::IO(_));
}

Expand Down Expand Up @@ -989,4 +1053,20 @@ mod tests {
timestamp_type: RecordBatchTimestampType::CreateTime,
}
}

/// A rather large struct here to trigger OOM.
#[derive(Debug)]
struct Large {
_inner: [u8; 1024],
}

impl<R> ReadType<R> for Large
where
R: Read,
{
fn read(reader: &mut R) -> Result<Self, ReadError> {
Int32::read(reader)?;
unreachable!()
}
}
}