Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions wincode/benches/benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,81 @@ fn bench_pod_struct_comparison(c: &mut Criterion) {
group.finish();
}

#[cfg(feature = "solana-short-vec")]
fn bench_short_u16_comparison(c: &mut Criterion) {
use {
solana_short_vec::ShortU16,
wincode::{len::short_vec::decode_short_u16, serialize_into},
};
let mut group = c.benchmark_group("ShortU16");

let cases = [
(0x7f_u16, &[0x7f][..]),
(0x3fff_u16, &[0xff, 0x7f][..]),
(0xffff_u16, &[0xff, 0xff, 0x03][..]),
];

let mut ser_buffer = [0u8; 3];
for (val, bytes) in cases {
group.throughput(Throughput::Bytes(bytes.len() as u64));

group.bench_with_input(
BenchmarkId::new("wincode:decode_short_u16", val),
&bytes,
|b, bytes| b.iter(|| decode_short_u16(black_box(bytes)).unwrap()),
);

group.bench_with_input(
BenchmarkId::new("solana_short_vec:decode_shortu16_len", val),
&bytes,
|b, bytes| b.iter(|| solana_short_vec::decode_shortu16_len(black_box(bytes)).unwrap()),
);

let short_u16 = ShortU16(val);
let serialized = bincode::serialize(&short_u16).unwrap();
assert_eq!(serialize(&short_u16).unwrap(), serialized);
assert_eq!(
deserialize::<ShortU16>(&serialized).unwrap().0,
bincode::deserialize::<ShortU16>(&serialized).unwrap().0
);

group.bench_with_input(
BenchmarkId::new("wincode:serialize", val),
&short_u16,
|b, s| {
b.iter(|| {
serialize_into(black_box(&mut ser_buffer.as_mut_slice()), black_box(s)).unwrap()
})
},
);

group.bench_with_input(
BenchmarkId::new("bincode:serialize", val),
&short_u16,
|b, s| {
b.iter(|| {
bincode::serialize_into(black_box(&mut ser_buffer.as_mut_slice()), black_box(s))
.unwrap()
})
},
);

group.bench_with_input(
BenchmarkId::new("wincode:deserialize", val),
&serialized,
|b, s| b.iter(|| deserialize::<ShortU16>(black_box(s)).unwrap()),
);

group.bench_with_input(
BenchmarkId::new("bincode:deserialize", val),
&serialized,
|b, s| b.iter(|| bincode::deserialize::<ShortU16>(black_box(s)).unwrap()),
);
}

group.finish();
}

criterion_group!(
benches,
bench_primitives_comparison,
Expand All @@ -297,6 +372,7 @@ criterion_group!(
bench_hashmap_comparison,
bench_hashmap_pod_comparison,
bench_pod_struct_comparison,
bench_short_u16_comparison,
);

criterion_main!(benches);
2 changes: 2 additions & 0 deletions wincode/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub enum ReadError {
InvalidBoolEncoding(u8),
#[error("Sequence length would overflow length encoding scheme: {0}")]
LengthEncodingOverflow(&'static str),
#[error("Invalid value: {0}")]
InvalidValue(&'static str),
#[error("Invalid char lead: {0}")]
InvalidCharLead(u8),
#[error("Custom error: {0}")]
Expand Down
141 changes: 123 additions & 18 deletions wincode/src/len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,31 +66,23 @@ impl<const MAX_SIZE: usize> SeqLen for BincodeLen<MAX_SIZE> {
pub mod short_vec {
use {
super::*,
crate::error::{read_length_encoding_overflow, write_length_encoding_overflow},
crate::error::write_length_encoding_overflow,
core::{
mem::{transmute, MaybeUninit},
ptr,
},
solana_short_vec::{decode_shortu16_len, ShortU16},
solana_short_vec::ShortU16,
};

impl<'de> SchemaRead<'de> for ShortU16 {
type Dst = Self;

#[inline]
fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
let Ok((len, read)) = decode_shortu16_len(reader.fill_buf(3)?) else {
return Err(read_length_encoding_overflow("u16::MAX"));
};

// SAFETY: `read` is the number of bytes visited by `decode_shortu16_len` to decode the length,
// which implies the reader had at least `read` bytes available.
unsafe { reader.consume_unchecked(read) };

let len = decode_short_u16_from_reader(reader)?;
// SAFETY: `dst` is a valid pointer to a `MaybeUninit<ShortU16>`.
let slot = unsafe { &mut *(&raw mut (*dst.as_mut_ptr()).0).cast::<MaybeUninit<u16>>() };
// SAFETY: `len` is always a valid u16. `decode_shortu16_len` casts it to a usize before returning,
// so no risk of overflow.
slot.write(len as u16);
slot.write(len);
Ok(())
}
}
Expand Down Expand Up @@ -167,14 +159,108 @@ pub mod short_vec {
}
}

/// Decodes a ShortU16 from a byte slice, returning the decoded u16 and the number of bytes read.
///
/// This implementation is bit-for-bit compatible with Solana's encoding rules (strict canonical form,
/// max 3 bytes, overflow checks).
///
/// # Examples
///
/// ```
/// use wincode::len::decode_short_u16;
///
/// let bytes = [0x7f];
/// let (len, read) = decode_short_u16(&bytes).unwrap();
/// assert_eq!(len, 127);
/// assert_eq!(read, 1);
/// ```
///
/// ```
/// use wincode::len::decode_short_u16;
///
/// let bytes = [0x80, 0x01];
/// let (len, read) = decode_short_u16(&bytes).unwrap();
/// assert_eq!(len, 128);
/// assert_eq!(read, 2);
/// ```
///
/// ```
/// use wincode::len::decode_short_u16;
///
/// let bytes = [0x80, 0x80, 0x01];
/// let (len, read) = decode_short_u16(&bytes).unwrap();
/// assert_eq!(len, 16384);
/// assert_eq!(read, 3);
/// ```
#[inline]
pub const fn decode_short_u16(bytes: &[u8]) -> ReadResult<(u16, usize)> {
use crate::error::ReadError;

#[cold]
const fn overflow_err() -> ReadError {
ReadError::LengthEncodingOverflow("u16::MAX")
}

#[cold]
const fn non_canonical_err() -> ReadError {
ReadError::InvalidValue("short u16: non-canonical encoding")
}

#[cold]
const fn incomplete_err() -> ReadError {
ReadError::InvalidValue("short u16: unexpected end of input")
}

// Byte 0
if bytes.is_empty() {
return Err(incomplete_err());
}
let b0 = bytes[0];
if b0 < 0x80 {
return Ok((b0 as u16, 1));
}

// Byte 1
if bytes.len() < 2 {
return Err(incomplete_err());
}
let b1 = bytes[1];
if b1 == 0 {
return Err(non_canonical_err());
}
if b1 < 0x80 {
let val = ((b0 & 0x7f) as u16) | ((b1 as u16) << 7);
return Ok((val, 2));
}

// Byte 2
if bytes.len() < 3 {
return Err(incomplete_err());
}
let b2 = bytes[2];
if b2 == 0 {
return Err(non_canonical_err());
}
if b2 > 3 {
return Err(overflow_err());
}
let val = ((b0 & 0x7f) as u16) | (((b1 & 0x7f) as u16) << 7) | ((b2 as u16) << 14);
Ok((val, 3))
}

#[inline]
fn decode_short_u16_from_reader<'de>(reader: &mut impl Reader<'de>) -> ReadResult<u16> {
let (len, read) = decode_short_u16(reader.fill_buf(3)?)?;
// SAFETY: `read` is the number of bytes visited by `decode_shortu16` to decode the length,
// which implies the reader had at least `read` bytes available.
unsafe { reader.consume_unchecked(read) };
Ok(len)
}

impl SeqLen for ShortU16Len {
#[inline(always)]
fn read<'de, T>(reader: &mut impl Reader<'de>) -> ReadResult<usize> {
let Ok((len, read)) = decode_shortu16_len(reader.fill_buf(3)?) else {
return Err(read_length_encoding_overflow("u16::MAX"));
};
unsafe { reader.consume_unchecked(read) };
Ok(len)
Ok(decode_short_u16_from_reader(reader)? as usize)
}

#[inline(always)]
Expand Down Expand Up @@ -264,6 +350,25 @@ pub mod short_vec {
prop_assert_eq!(short_vec_struct, schema_deserialized);
}

#[test]
fn encode_decode_short_u16_roundtrip(len in 0..=u16::MAX) {
let our = our_short_u16_encode(len);
let (decoded_len, read) = decode_short_u16(&our).unwrap();
let (sdk_decoded_len, sdk_read) = solana_short_vec::decode_shortu16_len(&our).unwrap();
let sdk_decoded_len = sdk_decoded_len as u16;
prop_assert_eq!(len, decoded_len);
prop_assert_eq!(len, sdk_decoded_len);
prop_assert_eq!(read, sdk_read);
}

#[test]
fn decode_short_u16_err_equivalence(bytes in prop::collection::vec(any::<u8>(), 0..=3)) {
let our_decode = decode_short_u16(&bytes);
let sdk_decode = solana_short_vec::decode_shortu16_len(&bytes);
prop_assert_eq!(our_decode.is_err(), sdk_decode.is_err());
prop_assert_eq!(our_decode.is_ok(), sdk_decode.is_ok());
}

#[test]
fn test_short_vec_as_schema(sv in any::<u16>()) {
let val = ShortVecAsSchema { short_u16: ShortU16(sv) };
Expand Down