Skip to content

Commit df3e5c9

Browse files
committed
fix: do not OOM in carefully_decompress_snappy
1 parent 948aee4 commit df3e5c9

File tree

1 file changed

+219
-117
lines changed

1 file changed

+219
-117
lines changed

src/protocol/record.rs

Lines changed: 219 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,8 @@ where
619619
}
620620
#[cfg(feature = "compression-snappy")]
621621
RecordBatchCompression::Snappy => {
622+
use crate::protocol::vec_builder::DEFAULT_BLOCK_SIZE;
623+
622624
// Construct the input for the raw decoder.
623625
let mut input = vec![];
624626
reader.read_to_end(&mut input)?;
@@ -655,13 +657,13 @@ where
655657
let mut chunk_data = vec![0u8; chunk_length];
656658
cursor.read_exact(&mut chunk_data)?;
657659

658-
let mut buf = carefully_decompress_snappy(&chunk_data)?;
660+
let mut buf = carefully_decompress_snappy(&chunk_data, DEFAULT_BLOCK_SIZE)?;
659661
output.append(&mut buf);
660662
}
661663

662664
output
663665
} else {
664-
carefully_decompress_snappy(&input)?
666+
carefully_decompress_snappy(&input, DEFAULT_BLOCK_SIZE)?
665667
};
666668

667669
// Read uncompressed records.
@@ -889,39 +891,103 @@ where
889891
}
890892
}
891893

894+
/// Try to decompress a snappy message without blindly believing the uncompressed size encoded at the start of the
895+
/// message (and therefore potentially OOMing).
892896
#[cfg(feature = "compression-snappy")]
893-
fn carefully_decompress_snappy(input: &[u8]) -> Result<Vec<u8>, ReadError> {
894-
use crate::protocol::vec_builder::DEFAULT_BLOCK_SIZE;
897+
fn carefully_decompress_snappy(
898+
input: &[u8],
899+
start_block_size: usize,
900+
) -> Result<Vec<u8>, ReadError> {
901+
use crate::protocol::primitives::UnsignedVarint;
895902
use snap::raw::{decompress_len, Decoder};
896903

904+
// early exit, otherwise `uncompressed_size_encoded_length` will be 1 even though there was no input
905+
if input.is_empty() {
906+
return Err(ReadError::Malformed(Box::new(snap::Error::Empty)));
907+
}
908+
897909
// The snappy compression used here is unframed aka "raw". So we first need to figure out the
898910
// uncompressed length. See
899911
//
900912
// - https://github.com/edenhill/librdkafka/blob/2b76b65212e5efda213961d5f84e565038036270/src/rdkafka_msgset_reader.c#L345-L348
901913
// - https://github.com/edenhill/librdkafka/blob/747f77c98fbddf7dc6508f76398e0fc9ee91450f/src/snappy.c#L779
902914
let uncompressed_size = decompress_len(input).map_err(|e| ReadError::Malformed(Box::new(e)))?;
903915

916+
// figure out how long the encoded size was
917+
let uncompressed_size_encoded_length = {
918+
let mut buf = Vec::with_capacity(100);
919+
UnsignedVarint(uncompressed_size as u64)
920+
.write(&mut buf)
921+
.expect("this write should never fail");
922+
buf.len()
923+
};
924+
904925
// Decode snappy payload.
905926
// The uncompressed length is unchecked and can be up to 2^32-1 bytes. To avoid a DDoS vector we try to
906927
// limit it to a small size and if that fails we double that size;
907-
let mut max_uncompressed_size = DEFAULT_BLOCK_SIZE;
928+
let mut max_uncompressed_size = start_block_size;
908929

930+
// Try to decode the message with growing output buffers.
909931
loop {
910932
let try_uncompressed_size = uncompressed_size.min(max_uncompressed_size);
911933

934+
// We need to lie to the snap decoder about the target length, otherwise it will reject our shortened test
935+
// straight away. Luckily that's rather easy and we just need fake the length stored right at the beginning of
936+
// the message.
937+
let try_input = {
938+
let mut buf = Cursor::new(Vec::with_capacity(input.len()));
939+
UnsignedVarint(try_uncompressed_size as u64)
940+
.write(&mut buf)
941+
.expect("this write should never fail");
942+
buf.write_all(&input[uncompressed_size_encoded_length..])
943+
.expect("this write should never fail");
944+
buf.into_inner()
945+
};
946+
912947
let mut decoder = Decoder::new();
913948
let mut output = vec![0; try_uncompressed_size];
914-
let actual_uncompressed_size = match decoder.decompress(input, &mut output) {
949+
let actual_uncompressed_size = match decoder.decompress(&try_input, &mut output) {
915950
Ok(size) => size,
916-
Err(snap::Error::BufferTooSmall { .. })
917-
if max_uncompressed_size < uncompressed_size =>
918-
{
919-
// try larger buffer
920-
max_uncompressed_size *= 2;
921-
continue;
922-
}
923951
Err(e) => {
924-
return Err(ReadError::Malformed(Box::new(e)));
952+
let looks_like_dst_too_small = match e {
953+
// `CopyWrite` only occurs when the dst buffer is too small.
954+
snap::Error::CopyWrite { .. } => true,
955+
956+
// `Literal` may occur due to src or dst errors, so need to check
957+
snap::Error::Literal {
958+
len,
959+
dst_len,
960+
src_len,
961+
} => (dst_len < len) && (src_len >= len),
962+
963+
// `HeaderMismatch` may also occur when the output was smaller than we predicted, in which case the
964+
// header would actually be broken
965+
snap::Error::HeaderMismatch {
966+
expected_len,
967+
got_len,
968+
} => expected_len < got_len,
969+
970+
// `BufferTooSmall` cannot happed by construction, because we just allocated the right buffer
971+
snap::Error::BufferTooSmall { .. } => {
972+
unreachable!("Just allocated a correctly-sized output buffer.")
973+
}
974+
975+
// `Offset` does NOT occur due undersized dst but due to invalid offset calculations. Instead
976+
// `CopyWrite` would be used.
977+
snap::Error::Offset { .. } => false,
978+
979+
// All other errors are real errors
980+
_ => false,
981+
};
982+
let used_smaller_dst = max_uncompressed_size < uncompressed_size;
983+
984+
if looks_like_dst_too_small && used_smaller_dst {
985+
// try larger buffer
986+
max_uncompressed_size *= 2;
987+
continue;
988+
} else {
989+
return Err(ReadError::Malformed(Box::new(e)));
990+
}
925991
}
926992
};
927993
if actual_uncompressed_size != uncompressed_size {
@@ -1129,86 +1195,34 @@ mod tests {
11291195
}
11301196

11311197
#[cfg(feature = "compression-snappy")]
1132-
#[test]
1133-
fn test_decode_fixture_snappy() {
1134-
// This data was obtained by watching rdkafka.
1135-
let data = [
1136-
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x58\x00\x00\x00\x00".to_vec(),
1137-
b"\x02\xad\x86\xf4\xf4\x00\x02\x00\x00\x00\x00\x00\x00\x01\x7e\xb6".to_vec(),
1138-
b"\x45\x0e\x52\x00\x00\x01\x7e\xb6\x45\x0e\x52\xff\xff\xff\xff\xff".to_vec(),
1139-
b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x01\x80\x01\x1c".to_vec(),
1140-
b"\xfc\x01\x00\x00\x00\xc8\x01\x78\xfe\x01\x00\x8a\x01\x00\x50\x16".to_vec(),
1141-
b"\x68\x65\x6c\x6c\x6f\x20\x6b\x61\x66\x6b\x61\x02\x06\x66\x6f\x6f".to_vec(),
1142-
b"\x06\x62\x61\x72".to_vec(),
1143-
]
1144-
.concat();
1145-
1146-
let actual = RecordBatch::read(&mut Cursor::new(data)).unwrap();
1147-
let expected = RecordBatch {
1148-
base_offset: 0,
1149-
partition_leader_epoch: 0,
1150-
last_offset_delta: 0,
1151-
first_timestamp: 1643735486034,
1152-
max_timestamp: 1643735486034,
1153-
producer_id: -1,
1154-
producer_epoch: -1,
1155-
base_sequence: -1,
1156-
records: ControlBatchOrRecords::Records(vec![Record {
1157-
timestamp_delta: 0,
1158-
offset_delta: 0,
1159-
key: Some(vec![b'x'; 100]),
1160-
value: Some(b"hello kafka".to_vec()),
1161-
headers: vec![RecordHeader {
1162-
key: "foo".to_owned(),
1163-
value: b"bar".to_vec(),
1164-
}],
1165-
}]),
1166-
compression: RecordBatchCompression::Snappy,
1167-
is_transactional: false,
1168-
timestamp_type: RecordBatchTimestampType::CreateTime,
1169-
};
1170-
assert_eq!(actual, expected);
1171-
1172-
let mut data2 = vec![];
1173-
actual.write(&mut data2).unwrap();
1174-
1175-
// don't compare if the data is equal because compression encoder might work slightly differently, use another
1176-
// roundtrip instead
1177-
let actual2 = RecordBatch::read(&mut Cursor::new(data2)).unwrap();
1178-
assert_eq!(actual2, expected);
1179-
}
1180-
1181-
#[cfg(feature = "compression-snappy")]
1182-
#[test]
1183-
fn test_decode_fixture_snappy_java() {
1184-
// This data was obtained by watching Kafka returning a recording to rskafka that was produced by the official
1185-
// Java client.
1186-
let data = [
1187-
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x8c\x00\x00\x00\x00".to_vec(),
1188-
b"\x02\x79\x1e\x2d\xce\x00\x02\x00\x00\x00\x01\x00\x00\x01\x7f\x07".to_vec(),
1189-
b"\x25\x7a\xb1\x00\x00\x01\x7f\x07\x25\x7a\xb1\xff\xff\xff\xff\xff".to_vec(),
1190-
b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x02\x82\x53\x4e".to_vec(),
1191-
b"\x41\x50\x50\x59\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00".to_vec(),
1192-
b"\x47\xff\x01\x1c\xfc\x01\x00\x00\x00\xc8\x01\x78\xfe\x01\x00\x8a".to_vec(),
1193-
b"\x01\x00\x64\x16\x68\x65\x6c\x6c\x6f\x20\x6b\x61\x66\x6b\x61\x02".to_vec(),
1194-
b"\x06\x66\x6f\x6f\x06\x62\x61\x72\xfa\x01\x00\x00\x02\xfe\x80\x00".to_vec(),
1195-
b"\x96\x80\x00\x4c\x14\x73\x6f\x6d\x65\x20\x76\x61\x6c\x75\x65\x02".to_vec(),
1196-
b"\x06\x66\x6f\x6f\x06\x62\x61\x72".to_vec(),
1197-
]
1198-
.concat();
1199-
1200-
let actual = RecordBatch::read(&mut Cursor::new(data)).unwrap();
1201-
let expected = RecordBatch {
1202-
base_offset: 0,
1203-
partition_leader_epoch: 0,
1204-
last_offset_delta: 1,
1205-
first_timestamp: 1645092371121,
1206-
max_timestamp: 1645092371121,
1207-
producer_id: -1,
1208-
producer_epoch: -1,
1209-
base_sequence: -1,
1210-
records: ControlBatchOrRecords::Records(vec![
1211-
Record {
1198+
mod snappy {
1199+
use super::*;
1200+
1201+
#[test]
1202+
fn test_decode_fixture_snappy() {
1203+
// This data was obtained by watching rdkafka.
1204+
let data = [
1205+
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x58\x00\x00\x00\x00".to_vec(),
1206+
b"\x02\xad\x86\xf4\xf4\x00\x02\x00\x00\x00\x00\x00\x00\x01\x7e\xb6".to_vec(),
1207+
b"\x45\x0e\x52\x00\x00\x01\x7e\xb6\x45\x0e\x52\xff\xff\xff\xff\xff".to_vec(),
1208+
b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x01\x80\x01\x1c".to_vec(),
1209+
b"\xfc\x01\x00\x00\x00\xc8\x01\x78\xfe\x01\x00\x8a\x01\x00\x50\x16".to_vec(),
1210+
b"\x68\x65\x6c\x6c\x6f\x20\x6b\x61\x66\x6b\x61\x02\x06\x66\x6f\x6f".to_vec(),
1211+
b"\x06\x62\x61\x72".to_vec(),
1212+
]
1213+
.concat();
1214+
1215+
let actual = RecordBatch::read(&mut Cursor::new(data)).unwrap();
1216+
let expected = RecordBatch {
1217+
base_offset: 0,
1218+
partition_leader_epoch: 0,
1219+
last_offset_delta: 0,
1220+
first_timestamp: 1643735486034,
1221+
max_timestamp: 1643735486034,
1222+
producer_id: -1,
1223+
producer_epoch: -1,
1224+
base_sequence: -1,
1225+
records: ControlBatchOrRecords::Records(vec![Record {
12121226
timestamp_delta: 0,
12131227
offset_delta: 0,
12141228
key: Some(vec![b'x'; 100]),
@@ -1217,31 +1231,119 @@ mod tests {
12171231
key: "foo".to_owned(),
12181232
value: b"bar".to_vec(),
12191233
}],
1220-
},
1221-
Record {
1222-
timestamp_delta: 0,
1223-
offset_delta: 1,
1224-
key: Some(vec![b'x'; 100]),
1225-
value: Some(b"some value".to_vec()),
1226-
headers: vec![RecordHeader {
1227-
key: "foo".to_owned(),
1228-
value: b"bar".to_vec(),
1229-
}],
1230-
},
1231-
]),
1232-
compression: RecordBatchCompression::Snappy,
1233-
is_transactional: false,
1234-
timestamp_type: RecordBatchTimestampType::CreateTime,
1235-
};
1236-
assert_eq!(actual, expected);
1234+
}]),
1235+
compression: RecordBatchCompression::Snappy,
1236+
is_transactional: false,
1237+
timestamp_type: RecordBatchTimestampType::CreateTime,
1238+
};
1239+
assert_eq!(actual, expected);
1240+
1241+
let mut data2 = vec![];
1242+
actual.write(&mut data2).unwrap();
1243+
1244+
// don't compare if the data is equal because compression encoder might work slightly differently, use another
1245+
// roundtrip instead
1246+
let actual2 = RecordBatch::read(&mut Cursor::new(data2)).unwrap();
1247+
assert_eq!(actual2, expected);
1248+
}
12371249

1238-
let mut data2 = vec![];
1239-
actual.write(&mut data2).unwrap();
1250+
#[test]
1251+
fn test_decode_fixture_snappy_java() {
1252+
// This data was obtained by watching Kafka returning a recording to rskafka that was produced by the official
1253+
// Java client.
1254+
let data = [
1255+
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x8c\x00\x00\x00\x00".to_vec(),
1256+
b"\x02\x79\x1e\x2d\xce\x00\x02\x00\x00\x00\x01\x00\x00\x01\x7f\x07".to_vec(),
1257+
b"\x25\x7a\xb1\x00\x00\x01\x7f\x07\x25\x7a\xb1\xff\xff\xff\xff\xff".to_vec(),
1258+
b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x02\x82\x53\x4e".to_vec(),
1259+
b"\x41\x50\x50\x59\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00".to_vec(),
1260+
b"\x47\xff\x01\x1c\xfc\x01\x00\x00\x00\xc8\x01\x78\xfe\x01\x00\x8a".to_vec(),
1261+
b"\x01\x00\x64\x16\x68\x65\x6c\x6c\x6f\x20\x6b\x61\x66\x6b\x61\x02".to_vec(),
1262+
b"\x06\x66\x6f\x6f\x06\x62\x61\x72\xfa\x01\x00\x00\x02\xfe\x80\x00".to_vec(),
1263+
b"\x96\x80\x00\x4c\x14\x73\x6f\x6d\x65\x20\x76\x61\x6c\x75\x65\x02".to_vec(),
1264+
b"\x06\x66\x6f\x6f\x06\x62\x61\x72".to_vec(),
1265+
]
1266+
.concat();
1267+
1268+
let actual = RecordBatch::read(&mut Cursor::new(data)).unwrap();
1269+
let expected = RecordBatch {
1270+
base_offset: 0,
1271+
partition_leader_epoch: 0,
1272+
last_offset_delta: 1,
1273+
first_timestamp: 1645092371121,
1274+
max_timestamp: 1645092371121,
1275+
producer_id: -1,
1276+
producer_epoch: -1,
1277+
base_sequence: -1,
1278+
records: ControlBatchOrRecords::Records(vec![
1279+
Record {
1280+
timestamp_delta: 0,
1281+
offset_delta: 0,
1282+
key: Some(vec![b'x'; 100]),
1283+
value: Some(b"hello kafka".to_vec()),
1284+
headers: vec![RecordHeader {
1285+
key: "foo".to_owned(),
1286+
value: b"bar".to_vec(),
1287+
}],
1288+
},
1289+
Record {
1290+
timestamp_delta: 0,
1291+
offset_delta: 1,
1292+
key: Some(vec![b'x'; 100]),
1293+
value: Some(b"some value".to_vec()),
1294+
headers: vec![RecordHeader {
1295+
key: "foo".to_owned(),
1296+
value: b"bar".to_vec(),
1297+
}],
1298+
},
1299+
]),
1300+
compression: RecordBatchCompression::Snappy,
1301+
is_transactional: false,
1302+
timestamp_type: RecordBatchTimestampType::CreateTime,
1303+
};
1304+
assert_eq!(actual, expected);
1305+
1306+
let mut data2 = vec![];
1307+
actual.write(&mut data2).unwrap();
1308+
1309+
// don't compare if the data is equal because compression encoder might work slightly differently, use another
1310+
// roundtrip instead
1311+
let actual2 = RecordBatch::read(&mut Cursor::new(data2)).unwrap();
1312+
assert_eq!(actual2, expected);
1313+
}
12401314

1241-
// don't compare if the data is equal because compression encoder might work slightly differently, use another
1242-
// roundtrip instead
1243-
let actual2 = RecordBatch::read(&mut Cursor::new(data2)).unwrap();
1244-
assert_eq!(actual2, expected);
1315+
#[test]
1316+
fn test_carefully_decompress_snappy_empty_input() {
1317+
let err = carefully_decompress_snappy(&[], 1).unwrap_err();
1318+
assert_matches!(err, ReadError::Malformed(_));
1319+
}
1320+
1321+
#[test]
1322+
fn test_carefully_decompress_snappy_empty_payload() {
1323+
let compressed = compress(&[]);
1324+
let data = carefully_decompress_snappy(&compressed, 1).unwrap();
1325+
assert!(data.is_empty());
1326+
}
1327+
1328+
proptest! {
1329+
#![proptest_config(ProptestConfig{cases: 200, ..Default::default()})]
1330+
#[test]
1331+
fn test_carefully_decompress_snappy(input in prop::collection::vec(any::<u8>(), 0..10_000)) {
1332+
let compressed = compress(&input);
1333+
let input2 = carefully_decompress_snappy(&compressed, 1).unwrap();
1334+
assert_eq!(input, input2);
1335+
}
1336+
}
1337+
1338+
fn compress(data: &[u8]) -> Vec<u8> {
1339+
use snap::raw::{max_compress_len, Encoder};
1340+
1341+
let mut encoder = Encoder::new();
1342+
let mut output = vec![0; max_compress_len(data.len())];
1343+
let l = encoder.compress(data, &mut output).unwrap();
1344+
1345+
output[..l].to_vec()
1346+
}
12451347
}
12461348

12471349
#[cfg(feature = "compression-zstd")]

0 commit comments

Comments
 (0)