Skip to content

Commit b0852fd

Browse files
committed
fix: do not OOM in carefully_decompress_snappy
1 parent 948aee4 commit b0852fd

File tree

1 file changed

+215
-117
lines changed

1 file changed

+215
-117
lines changed

src/protocol/record.rs

Lines changed: 215 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,8 @@ where
619619
}
620620
#[cfg(feature = "compression-snappy")]
621621
RecordBatchCompression::Snappy => {
622+
use crate::protocol::vec_builder::DEFAULT_BLOCK_SIZE;
623+
622624
// Construct the input for the raw decoder.
623625
let mut input = vec![];
624626
reader.read_to_end(&mut input)?;
@@ -655,13 +657,13 @@ where
655657
let mut chunk_data = vec![0u8; chunk_length];
656658
cursor.read_exact(&mut chunk_data)?;
657659

658-
let mut buf = carefully_decompress_snappy(&chunk_data)?;
660+
let mut buf = carefully_decompress_snappy(&chunk_data, DEFAULT_BLOCK_SIZE)?;
659661
output.append(&mut buf);
660662
}
661663

662664
output
663665
} else {
664-
carefully_decompress_snappy(&input)?
666+
carefully_decompress_snappy(&input, DEFAULT_BLOCK_SIZE)?
665667
};
666668

667669
// Read uncompressed records.
@@ -889,39 +891,99 @@ where
889891
}
890892
}
891893

894+
/// Try to decompress a snappy message without blindly believing the uncompressed size encoded at the start of the
895+
/// message (and therefore potentially OOMing).
892896
#[cfg(feature = "compression-snappy")]
893-
fn carefully_decompress_snappy(input: &[u8]) -> Result<Vec<u8>, ReadError> {
894-
use crate::protocol::vec_builder::DEFAULT_BLOCK_SIZE;
897+
fn carefully_decompress_snappy(
898+
input: &[u8],
899+
start_block_size: usize,
900+
) -> Result<Vec<u8>, ReadError> {
901+
use crate::protocol::primitives::UnsignedVarint;
895902
use snap::raw::{decompress_len, Decoder};
896903

904+
// early exit, otherwise `uncompressed_size_encoded_length` will be 1 even though there was no input
905+
if input.is_empty() {
906+
return Err(ReadError::Malformed(Box::new(snap::Error::Empty)));
907+
}
908+
897909
// The snappy compression used here is unframed aka "raw". So we first need to figure out the
898910
// uncompressed length. See
899911
//
900912
// - https://github.com/edenhill/librdkafka/blob/2b76b65212e5efda213961d5f84e565038036270/src/rdkafka_msgset_reader.c#L345-L348
901913
// - https://github.com/edenhill/librdkafka/blob/747f77c98fbddf7dc6508f76398e0fc9ee91450f/src/snappy.c#L779
902914
let uncompressed_size = decompress_len(input).map_err(|e| ReadError::Malformed(Box::new(e)))?;
903915

916+
// figure out how long the encoded size was
917+
let uncompressed_size_encoded_length = {
918+
let mut buf = Vec::with_capacity(100);
919+
UnsignedVarint(uncompressed_size as u64)
920+
.write(&mut buf)
921+
.expect("this write should never fail");
922+
buf.len()
923+
};
924+
904925
// Decode snappy payload.
905926
// The uncompressed length is unchecked and can be up to 2^32-1 bytes. To avoid a DDoS vector we try to
906927
// limit it to a small size and if that fails we double that size;
907-
let mut max_uncompressed_size = DEFAULT_BLOCK_SIZE;
928+
let mut max_uncompressed_size = start_block_size;
908929

930+
// Try to decode the message with growing output buffers.
909931
loop {
910932
let try_uncompressed_size = uncompressed_size.min(max_uncompressed_size);
911933

934+
// We need to lie to the snap decoder about the target length, otherwise it will reject our shortened test
935+
// straight away. Luckily that's rather easy and we just need fake the length stored right at the beginning of
936+
// the message.
937+
let try_input = {
938+
let mut buf = Cursor::new(Vec::with_capacity(input.len()));
939+
UnsignedVarint(try_uncompressed_size as u64)
940+
.write(&mut buf)
941+
.expect("this write should never fail");
942+
buf.write_all(&input[uncompressed_size_encoded_length..])
943+
.expect("this write should never fail");
944+
buf.into_inner()
945+
};
946+
912947
let mut decoder = Decoder::new();
913948
let mut output = vec![0; try_uncompressed_size];
914-
let actual_uncompressed_size = match decoder.decompress(input, &mut output) {
949+
let actual_uncompressed_size = match decoder.decompress(&try_input, &mut output) {
915950
Ok(size) => size,
916-
Err(snap::Error::BufferTooSmall { .. })
917-
if max_uncompressed_size < uncompressed_size =>
918-
{
919-
// try larger buffer
920-
max_uncompressed_size *= 2;
921-
continue;
922-
}
923951
Err(e) => {
924-
return Err(ReadError::Malformed(Box::new(e)));
952+
let looks_like_dst_too_small = match e {
953+
// `CopyWrite` only occurs when the dst buffer is too small.
954+
snap::Error::CopyWrite { .. } => true,
955+
956+
// `Literal` may occur due to src or dst errors, so need to check
957+
snap::Error::Literal { len, dst_len, .. } => dst_len < len,
958+
959+
// `HeaderMismatch` may also occur when the output was smaller than we predicted, in which case the
960+
// header would actually be broken
961+
snap::Error::HeaderMismatch {
962+
expected_len,
963+
got_len,
964+
} => expected_len < got_len,
965+
966+
// `BufferTooSmall` cannot happed by construction, because we just allocated the right buffer
967+
snap::Error::BufferTooSmall { .. } => {
968+
unreachable!("Just allocated a correctly-sized output buffer.")
969+
}
970+
971+
// `Offset` does NOT occur due undersized dst but due to invalid offset calculations. Instead
972+
// `CopyWrite` would be used.
973+
snap::Error::Offset { .. } => false,
974+
975+
// All other errors are real errors
976+
_ => false,
977+
};
978+
let used_smaller_dst = max_uncompressed_size < uncompressed_size;
979+
980+
if looks_like_dst_too_small && used_smaller_dst {
981+
// try larger buffer
982+
max_uncompressed_size *= 2;
983+
continue;
984+
} else {
985+
return Err(ReadError::Malformed(Box::new(e)));
986+
}
925987
}
926988
};
927989
if actual_uncompressed_size != uncompressed_size {
@@ -1129,86 +1191,34 @@ mod tests {
11291191
}
11301192

11311193
#[cfg(feature = "compression-snappy")]
1132-
#[test]
1133-
fn test_decode_fixture_snappy() {
1134-
// This data was obtained by watching rdkafka.
1135-
let data = [
1136-
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x58\x00\x00\x00\x00".to_vec(),
1137-
b"\x02\xad\x86\xf4\xf4\x00\x02\x00\x00\x00\x00\x00\x00\x01\x7e\xb6".to_vec(),
1138-
b"\x45\x0e\x52\x00\x00\x01\x7e\xb6\x45\x0e\x52\xff\xff\xff\xff\xff".to_vec(),
1139-
b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x01\x80\x01\x1c".to_vec(),
1140-
b"\xfc\x01\x00\x00\x00\xc8\x01\x78\xfe\x01\x00\x8a\x01\x00\x50\x16".to_vec(),
1141-
b"\x68\x65\x6c\x6c\x6f\x20\x6b\x61\x66\x6b\x61\x02\x06\x66\x6f\x6f".to_vec(),
1142-
b"\x06\x62\x61\x72".to_vec(),
1143-
]
1144-
.concat();
1145-
1146-
let actual = RecordBatch::read(&mut Cursor::new(data)).unwrap();
1147-
let expected = RecordBatch {
1148-
base_offset: 0,
1149-
partition_leader_epoch: 0,
1150-
last_offset_delta: 0,
1151-
first_timestamp: 1643735486034,
1152-
max_timestamp: 1643735486034,
1153-
producer_id: -1,
1154-
producer_epoch: -1,
1155-
base_sequence: -1,
1156-
records: ControlBatchOrRecords::Records(vec![Record {
1157-
timestamp_delta: 0,
1158-
offset_delta: 0,
1159-
key: Some(vec![b'x'; 100]),
1160-
value: Some(b"hello kafka".to_vec()),
1161-
headers: vec![RecordHeader {
1162-
key: "foo".to_owned(),
1163-
value: b"bar".to_vec(),
1164-
}],
1165-
}]),
1166-
compression: RecordBatchCompression::Snappy,
1167-
is_transactional: false,
1168-
timestamp_type: RecordBatchTimestampType::CreateTime,
1169-
};
1170-
assert_eq!(actual, expected);
1171-
1172-
let mut data2 = vec![];
1173-
actual.write(&mut data2).unwrap();
1174-
1175-
// don't compare if the data is equal because compression encoder might work slightly differently, use another
1176-
// roundtrip instead
1177-
let actual2 = RecordBatch::read(&mut Cursor::new(data2)).unwrap();
1178-
assert_eq!(actual2, expected);
1179-
}
1180-
1181-
#[cfg(feature = "compression-snappy")]
1182-
#[test]
1183-
fn test_decode_fixture_snappy_java() {
1184-
// This data was obtained by watching Kafka returning a recording to rskafka that was produced by the official
1185-
// Java client.
1186-
let data = [
1187-
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x8c\x00\x00\x00\x00".to_vec(),
1188-
b"\x02\x79\x1e\x2d\xce\x00\x02\x00\x00\x00\x01\x00\x00\x01\x7f\x07".to_vec(),
1189-
b"\x25\x7a\xb1\x00\x00\x01\x7f\x07\x25\x7a\xb1\xff\xff\xff\xff\xff".to_vec(),
1190-
b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x02\x82\x53\x4e".to_vec(),
1191-
b"\x41\x50\x50\x59\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00".to_vec(),
1192-
b"\x47\xff\x01\x1c\xfc\x01\x00\x00\x00\xc8\x01\x78\xfe\x01\x00\x8a".to_vec(),
1193-
b"\x01\x00\x64\x16\x68\x65\x6c\x6c\x6f\x20\x6b\x61\x66\x6b\x61\x02".to_vec(),
1194-
b"\x06\x66\x6f\x6f\x06\x62\x61\x72\xfa\x01\x00\x00\x02\xfe\x80\x00".to_vec(),
1195-
b"\x96\x80\x00\x4c\x14\x73\x6f\x6d\x65\x20\x76\x61\x6c\x75\x65\x02".to_vec(),
1196-
b"\x06\x66\x6f\x6f\x06\x62\x61\x72".to_vec(),
1197-
]
1198-
.concat();
1199-
1200-
let actual = RecordBatch::read(&mut Cursor::new(data)).unwrap();
1201-
let expected = RecordBatch {
1202-
base_offset: 0,
1203-
partition_leader_epoch: 0,
1204-
last_offset_delta: 1,
1205-
first_timestamp: 1645092371121,
1206-
max_timestamp: 1645092371121,
1207-
producer_id: -1,
1208-
producer_epoch: -1,
1209-
base_sequence: -1,
1210-
records: ControlBatchOrRecords::Records(vec![
1211-
Record {
1194+
mod snappy {
1195+
use super::*;
1196+
1197+
#[test]
1198+
fn test_decode_fixture_snappy() {
1199+
// This data was obtained by watching rdkafka.
1200+
let data = [
1201+
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x58\x00\x00\x00\x00".to_vec(),
1202+
b"\x02\xad\x86\xf4\xf4\x00\x02\x00\x00\x00\x00\x00\x00\x01\x7e\xb6".to_vec(),
1203+
b"\x45\x0e\x52\x00\x00\x01\x7e\xb6\x45\x0e\x52\xff\xff\xff\xff\xff".to_vec(),
1204+
b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x01\x80\x01\x1c".to_vec(),
1205+
b"\xfc\x01\x00\x00\x00\xc8\x01\x78\xfe\x01\x00\x8a\x01\x00\x50\x16".to_vec(),
1206+
b"\x68\x65\x6c\x6c\x6f\x20\x6b\x61\x66\x6b\x61\x02\x06\x66\x6f\x6f".to_vec(),
1207+
b"\x06\x62\x61\x72".to_vec(),
1208+
]
1209+
.concat();
1210+
1211+
let actual = RecordBatch::read(&mut Cursor::new(data)).unwrap();
1212+
let expected = RecordBatch {
1213+
base_offset: 0,
1214+
partition_leader_epoch: 0,
1215+
last_offset_delta: 0,
1216+
first_timestamp: 1643735486034,
1217+
max_timestamp: 1643735486034,
1218+
producer_id: -1,
1219+
producer_epoch: -1,
1220+
base_sequence: -1,
1221+
records: ControlBatchOrRecords::Records(vec![Record {
12121222
timestamp_delta: 0,
12131223
offset_delta: 0,
12141224
key: Some(vec![b'x'; 100]),
@@ -1217,31 +1227,119 @@ mod tests {
12171227
key: "foo".to_owned(),
12181228
value: b"bar".to_vec(),
12191229
}],
1220-
},
1221-
Record {
1222-
timestamp_delta: 0,
1223-
offset_delta: 1,
1224-
key: Some(vec![b'x'; 100]),
1225-
value: Some(b"some value".to_vec()),
1226-
headers: vec![RecordHeader {
1227-
key: "foo".to_owned(),
1228-
value: b"bar".to_vec(),
1229-
}],
1230-
},
1231-
]),
1232-
compression: RecordBatchCompression::Snappy,
1233-
is_transactional: false,
1234-
timestamp_type: RecordBatchTimestampType::CreateTime,
1235-
};
1236-
assert_eq!(actual, expected);
1230+
}]),
1231+
compression: RecordBatchCompression::Snappy,
1232+
is_transactional: false,
1233+
timestamp_type: RecordBatchTimestampType::CreateTime,
1234+
};
1235+
assert_eq!(actual, expected);
1236+
1237+
let mut data2 = vec![];
1238+
actual.write(&mut data2).unwrap();
1239+
1240+
// don't compare if the data is equal because compression encoder might work slightly differently, use another
1241+
// roundtrip instead
1242+
let actual2 = RecordBatch::read(&mut Cursor::new(data2)).unwrap();
1243+
assert_eq!(actual2, expected);
1244+
}
12371245

1238-
let mut data2 = vec![];
1239-
actual.write(&mut data2).unwrap();
1246+
#[test]
1247+
fn test_decode_fixture_snappy_java() {
1248+
// This data was obtained by watching Kafka returning a recording to rskafka that was produced by the official
1249+
// Java client.
1250+
let data = [
1251+
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x8c\x00\x00\x00\x00".to_vec(),
1252+
b"\x02\x79\x1e\x2d\xce\x00\x02\x00\x00\x00\x01\x00\x00\x01\x7f\x07".to_vec(),
1253+
b"\x25\x7a\xb1\x00\x00\x01\x7f\x07\x25\x7a\xb1\xff\xff\xff\xff\xff".to_vec(),
1254+
b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x02\x82\x53\x4e".to_vec(),
1255+
b"\x41\x50\x50\x59\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00".to_vec(),
1256+
b"\x47\xff\x01\x1c\xfc\x01\x00\x00\x00\xc8\x01\x78\xfe\x01\x00\x8a".to_vec(),
1257+
b"\x01\x00\x64\x16\x68\x65\x6c\x6c\x6f\x20\x6b\x61\x66\x6b\x61\x02".to_vec(),
1258+
b"\x06\x66\x6f\x6f\x06\x62\x61\x72\xfa\x01\x00\x00\x02\xfe\x80\x00".to_vec(),
1259+
b"\x96\x80\x00\x4c\x14\x73\x6f\x6d\x65\x20\x76\x61\x6c\x75\x65\x02".to_vec(),
1260+
b"\x06\x66\x6f\x6f\x06\x62\x61\x72".to_vec(),
1261+
]
1262+
.concat();
1263+
1264+
let actual = RecordBatch::read(&mut Cursor::new(data)).unwrap();
1265+
let expected = RecordBatch {
1266+
base_offset: 0,
1267+
partition_leader_epoch: 0,
1268+
last_offset_delta: 1,
1269+
first_timestamp: 1645092371121,
1270+
max_timestamp: 1645092371121,
1271+
producer_id: -1,
1272+
producer_epoch: -1,
1273+
base_sequence: -1,
1274+
records: ControlBatchOrRecords::Records(vec![
1275+
Record {
1276+
timestamp_delta: 0,
1277+
offset_delta: 0,
1278+
key: Some(vec![b'x'; 100]),
1279+
value: Some(b"hello kafka".to_vec()),
1280+
headers: vec![RecordHeader {
1281+
key: "foo".to_owned(),
1282+
value: b"bar".to_vec(),
1283+
}],
1284+
},
1285+
Record {
1286+
timestamp_delta: 0,
1287+
offset_delta: 1,
1288+
key: Some(vec![b'x'; 100]),
1289+
value: Some(b"some value".to_vec()),
1290+
headers: vec![RecordHeader {
1291+
key: "foo".to_owned(),
1292+
value: b"bar".to_vec(),
1293+
}],
1294+
},
1295+
]),
1296+
compression: RecordBatchCompression::Snappy,
1297+
is_transactional: false,
1298+
timestamp_type: RecordBatchTimestampType::CreateTime,
1299+
};
1300+
assert_eq!(actual, expected);
1301+
1302+
let mut data2 = vec![];
1303+
actual.write(&mut data2).unwrap();
1304+
1305+
// don't compare if the data is equal because compression encoder might work slightly differently, use another
1306+
// roundtrip instead
1307+
let actual2 = RecordBatch::read(&mut Cursor::new(data2)).unwrap();
1308+
assert_eq!(actual2, expected);
1309+
}
12401310

1241-
// don't compare if the data is equal because compression encoder might work slightly differently, use another
1242-
// roundtrip instead
1243-
let actual2 = RecordBatch::read(&mut Cursor::new(data2)).unwrap();
1244-
assert_eq!(actual2, expected);
1311+
#[test]
1312+
fn test_carefully_decompress_snappy_empty_input() {
1313+
let err = carefully_decompress_snappy(&[], 1).unwrap_err();
1314+
assert_matches!(err, ReadError::Malformed(_));
1315+
}
1316+
1317+
#[test]
1318+
fn test_carefully_decompress_snappy_empty_payload() {
1319+
let compressed = compress(&[]);
1320+
let data = carefully_decompress_snappy(&compressed, 1).unwrap();
1321+
assert!(data.is_empty());
1322+
}
1323+
1324+
proptest! {
1325+
#![proptest_config(ProptestConfig{cases: 200, ..Default::default()})]
1326+
#[test]
1327+
fn test_carefully_decompress_snappy(input in prop::collection::vec(any::<u8>(), 0..10_000)) {
1328+
let compressed = compress(&input);
1329+
let input2 = carefully_decompress_snappy(&compressed, 1).unwrap();
1330+
assert_eq!(input, input2);
1331+
}
1332+
}
1333+
1334+
fn compress(data: &[u8]) -> Vec<u8> {
1335+
use snap::raw::{max_compress_len, Encoder};
1336+
1337+
let mut encoder = Encoder::new();
1338+
let mut output = vec![0; max_compress_len(data.len())];
1339+
let l = encoder.compress(data, &mut output).unwrap();
1340+
1341+
output[..l].to_vec()
1342+
}
12451343
}
12461344

12471345
#[cfg(feature = "compression-zstd")]

0 commit comments

Comments
 (0)