Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions wincode-derive/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub(crate) trait TypeExt {
/// ```ignore
/// &'a str -> &'de str
/// ```
fn with_lifetime(&self, ident: &'static str) -> Type;
fn with_lifetime(&self, ident: &str) -> Type;

/// Replace any inference tokens on this type with the fully qualified generic arguments
/// of the given `infer` type.
Expand All @@ -62,10 +62,13 @@ pub(crate) trait TypeExt {
/// assert_eq!(target.with_infer(actual), parse_quote!(Pod<[u8; u64]>));
/// ```
fn with_infer(&self, infer: &Type) -> Type;

/// Gather all the lifetimes on this type.
fn lifetimes(&self) -> Vec<&Lifetime>;
}

impl TypeExt for Type {
fn with_lifetime(&self, ident: &'static str) -> Type {
fn with_lifetime(&self, ident: &str) -> Type {
let mut this = self.clone();
ReplaceLifetimes(ident).visit_type_mut(&mut this);
this
Expand All @@ -86,6 +89,12 @@ impl TypeExt for Type {
infer.visit_type_mut(&mut this);
this
}

fn lifetimes(&self) -> Vec<&Lifetime> {
let mut lifetimes = Vec::new();
GatherLifetimes(&mut lifetimes).visit_type(self);
lifetimes
}
}

#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -680,9 +689,9 @@ impl<'ast> VisitMut for InferGeneric<'ast> {
}

/// Visitor to recursively replace a given type's lifetimes with the given lifetime name.
struct ReplaceLifetimes(&'static str);
struct ReplaceLifetimes<'a>(&'a str);

impl ReplaceLifetimes {
impl ReplaceLifetimes<'_> {
/// Replace the lifetime with `'de`, preserving the span.
fn replace(&self, t: &mut Lifetime) {
t.ident = Ident::new(self.0, t.ident.span());
Expand All @@ -696,7 +705,7 @@ impl ReplaceLifetimes {
}
}

impl VisitMut for ReplaceLifetimes {
impl VisitMut for ReplaceLifetimes<'_> {
fn visit_type_reference_mut(&mut self, t: &mut TypeReference) {
match &mut t.lifetime {
Some(l) => self.replace(l),
Expand Down Expand Up @@ -735,6 +744,14 @@ impl VisitMut for ReplaceLifetimes {
}
}

struct GatherLifetimes<'a, 'ast>(&'a mut Vec<&'ast Lifetime>);

impl<'ast> Visit<'ast> for GatherLifetimes<'_, 'ast> {
fn visit_lifetime(&mut self, l: &'ast Lifetime) {
self.0.push(l);
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -812,4 +829,15 @@ mod tests {
assert_eq!(iter.next().unwrap().to_string(), "b0");
assert_eq!(iter.nth(24).unwrap().to_string(), "a1");
}

#[test]
fn test_gather_lifetimes() {
let ty: Type = parse_quote!(&'a Foo);
let lt: Lifetime = parse_quote!('a);
assert_eq!(ty.lifetimes(), vec![&lt]);

let ty: Type = parse_quote!(&'a Foo<'b, 'c>);
let (a, b, c) = (parse_quote!('a), parse_quote!('b), parse_quote!('c));
assert_eq!(ty.lifetimes(), vec![&a, &b, &c]);
}
}
32 changes: 23 additions & 9 deletions wincode-derive/src/schema_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,10 @@ fn impl_struct_extensions(args: &SchemaArgs, crate_name: &Path) -> Result<TokenS
/// Calling this when the content is not yet fully initialized causes undefined behavior: it is up to the caller
/// to guarantee that the `MaybeUninit<T>` really is in an initialized state.
#[inline]
#vis const unsafe fn into_assume_init_mut(mut self) -> &'_wincode_inner mut #builder_dst {
#vis unsafe fn into_assume_init_mut(mut self) -> &'_wincode_inner mut #builder_dst {
let mut this = ManuallyDrop::new(self);
// SAFETY: reference lives beyond the scope of the builder, and builder is forgotten.
let inner = unsafe { ptr::read(&mut self.inner) };
mem::forget(self);
let inner = unsafe { ptr::read(&mut this.inner) };
// SAFETY: Caller asserts the `MaybeUninit<T>` is in an initialized state.
unsafe {
inner.assume_init_mut()
Expand All @@ -377,15 +377,29 @@ fn impl_struct_extensions(args: &SchemaArgs, crate_name: &Path) -> Result<TokenS

// Generate the helper methods for the builder.
let builder_helpers = fields.iter().enumerate().map(|(i, field)| {
let target = field.target_resolved();
let target_reader_bound = target.with_lifetime("de");
let ty = &field.ty;
let target_reader_bound = field.target_resolved().with_lifetime("de");
let ident = field.struct_member_ident(i);
let ident_string = field.struct_member_ident_to_string(i);
let uninit_mut_ident = format_ident!("uninit_{ident_string}_mut");
let read_field_ident = format_ident!("read_{ident_string}");
let write_uninit_field_ident = format_ident!("write_{ident_string}");
let assume_init_field_ident = format_ident!("assume_init_{ident_string}");
let init_with_field_ident = format_ident!("init_{ident_string}_with");
let lifetimes = ty.lifetimes();
// We must always extract the `Dst` from the type because `SchemaRead` implementations need
// not necessarily write to `Self` -- they write to `Self::Dst`, which isn't necessarily `Self`
// (e.g., in the case of container types).
let field_projection_type = if lifetimes.is_empty() {
quote!(<#ty as SchemaRead<'_>>::Dst)
} else {
let lt = lifetimes[0];
// Even though a type may have multiple distinct lifetimes, we force them to be uniform
// for a `SchemaRead` cast because an implementation of `SchemaRead` must bind all lifetimes
// to the lifetime of the reader (and will not be implemented over multiple distinct lifetimes).
let ty = ty.with_lifetime(&lt.ident.to_string());
quote!(<#ty as SchemaRead<#lt>>::Dst)
};

// The bit index for the field.
let index_bit = LitInt::new(&(1u128 << i).to_string(), Span::call_site());
Expand All @@ -396,7 +410,7 @@ fn impl_struct_extensions(args: &SchemaArgs, crate_name: &Path) -> Result<TokenS
quote! {
/// Get a mutable reference to the maybe uninitialized field.
#[inline]
#vis const fn #uninit_mut_ident(&mut self) -> &mut MaybeUninit<#target> {
#vis const fn #uninit_mut_ident(&mut self) -> &mut MaybeUninit<#field_projection_type> {
// SAFETY:
// - `self.inner` is a valid reference to a `MaybeUninit<#builder_dst>`.
// - We return the field as `&mut MaybeUninit<#target>`, so
Expand All @@ -406,7 +420,7 @@ fn impl_struct_extensions(args: &SchemaArgs, crate_name: &Path) -> Result<TokenS

/// Write a value to the maybe uninitialized field.
#[inline]
#vis const fn #write_uninit_field_ident(&mut self, val: #target) -> &mut Self {
#vis const fn #write_uninit_field_ident(&mut self, val: #field_projection_type) -> &mut Self {
self.#uninit_mut_ident().write(val);
#set_index_bit
self
Expand All @@ -431,7 +445,7 @@ fn impl_struct_extensions(args: &SchemaArgs, crate_name: &Path) -> Result<TokenS
///
/// The caller must guarantee that the initializer function fully initializes the field.
#[inline]
#vis unsafe fn #init_with_field_ident(&mut self, mut initializer: impl FnMut(&mut MaybeUninit<#target>) -> ReadResult<()>) -> ReadResult<&mut Self> {
#vis unsafe fn #init_with_field_ident(&mut self, mut initializer: impl FnMut(&mut MaybeUninit<#field_projection_type>) -> ReadResult<()>) -> ReadResult<&mut Self> {
initializer(self.#uninit_mut_ident())?;
#set_index_bit
Ok(self)
Expand All @@ -453,7 +467,7 @@ fn impl_struct_extensions(args: &SchemaArgs, crate_name: &Path) -> Result<TokenS
Ok(quote! {
const _: () = {
use {
core::{mem::{MaybeUninit, self}, ptr, marker::PhantomData},
core::{mem::{MaybeUninit, ManuallyDrop, self}, ptr, marker::PhantomData},
#crate_name::{SchemaRead, ReadResult, TypeMeta, io::Reader, error,},
};
impl #impl_generics #struct_ident #ty_generics #where_clause {
Expand Down
90 changes: 90 additions & 0 deletions wincode/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,96 @@ mod tests {
});
}

#[test]
fn test_struct_extensions_with_container() {
#[derive(SchemaWrite, SchemaRead, Debug, PartialEq, Eq, proptest_derive::Arbitrary)]
#[wincode(internal, struct_extensions)]
struct Test {
Copy link
Collaborator

@kskalski kskalski Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since there is a non-trivial SchemaRead lifetime handling, could you add a test that works on struct with lifetimes, say

struct TestRef<'a> {
    a: &'a [u8],
    b: Option<'a u8>
}

it should be possible to initialize it with a (locally scoped) slice, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, 5ed3c6b

#[wincode(with = "containers::Vec<Pod<_>>")]
a: Vec<u8>,
#[wincode(with = "containers::Pod<_>")]
b: [u8; 32],
c: u64,
}

proptest!(proptest_cfg(), |(test: Test)| {
let mut uninit = MaybeUninit::<Test>::uninit();
let mut builder = TestUninitBuilder::from_maybe_uninit_mut(&mut uninit);
builder
.write_a(test.a.clone())
.write_b(test.b)
.write_c(test.c);
prop_assert!(builder.is_init());
let init_mut = unsafe { builder.into_assume_init_mut() };
prop_assert_eq!(&test, init_mut);
// Ensure `uninit` is marked initialized so fields are dropped.
let init = unsafe { uninit.assume_init() };
prop_assert_eq!(test, init);
});
}

#[test]
fn test_struct_extensions_with_reference() {
#[derive(SchemaWrite, SchemaRead, Debug, PartialEq, Eq, proptest_derive::Arbitrary)]
#[wincode(internal)]
struct Test {
a: Vec<u8>,
b: Option<String>,
}

#[derive(SchemaWrite, SchemaRead, Debug, PartialEq, Eq)]
#[wincode(internal, struct_extensions)]
struct TestRef<'a> {
a: &'a [u8],
b: Option<&'a str>,
}

proptest!(proptest_cfg(), |(test: Test)| {
let mut uninit = MaybeUninit::<TestRef>::uninit();
let mut builder = TestRefUninitBuilder::from_maybe_uninit_mut(&mut uninit);
builder
.write_a(test.a.as_slice())
.write_b(test.b.as_deref());
prop_assert!(builder.is_init());
builder.finish();
let init = unsafe { uninit.assume_init() };
prop_assert_eq!(test.a.as_slice(), init.a);
prop_assert_eq!(test.b.as_deref(), init.b);
});
}

#[test]
fn test_struct_extensions_with_mapped_type() {
#[derive(SchemaWrite, SchemaRead, Debug, PartialEq, Eq, proptest_derive::Arbitrary)]
#[wincode(internal)]
struct Test {
a: Vec<u8>,
b: [u8; 32],
c: u64,
}

#[derive(SchemaWrite, SchemaRead)]
#[wincode(internal, from = "Test", struct_extensions)]
struct TestMapped {
a: containers::Vec<containers::Pod<u8>>,
b: containers::Pod<[u8; 32]>,
c: u64,
}

proptest!(proptest_cfg(), |(test: Test)| {
let mut uninit = MaybeUninit::<Test>::uninit();
let mut builder = TestMappedUninitBuilder::from_maybe_uninit_mut(&mut uninit);
builder
.write_a(test.a.clone())
.write_b(test.b)
.write_c(test.c);
prop_assert!(builder.is_init());
builder.finish();
let init = unsafe { uninit.assume_init() };
prop_assert_eq!(test, init);
});
}

#[test]
fn test_struct_extensions_builder_fully_initialized() {
#[derive(SchemaWrite, SchemaRead, Debug, PartialEq, Eq, proptest_derive::Arbitrary)]
Expand Down