diff --git a/Cargo.lock b/Cargo.lock index 97732ea..be0c08d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -110,13 +110,13 @@ dependencies = [ name = "dispatch-bundle" version = "0.1.0" dependencies = [ - "embedded-command-macros 0.3.0", + "embedded-command-macros 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "serac", ] [[package]] name = "embedded-command-macros" -version = "0.3.0" +version = "0.4.0" dependencies = [ "Inflector", "proc-macro2", @@ -126,9 +126,9 @@ dependencies = [ [[package]] name = "embedded-command-macros" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f424c060861bb940cf4960cb69acd819a5c9748d6eb21b5adee2d2931355badd" +checksum = "36092ddd7791b6ed78443892d78fdeb878ce9c5fa3f47ae72a724a2254c1b01c" dependencies = [ "Inflector", "proc-macro2", @@ -285,12 +285,12 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serac" -version = "0.3.0" +version = "0.4.1" dependencies = [ "cortex-m", "cortex-m-rt", "defmt", - "embedded-command-macros 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "embedded-command-macros 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "fill-array", "panic-halt", ] diff --git a/dispatch-bundle/Cargo.toml b/dispatch-bundle/Cargo.toml index 7fb6804..52c16ac 100644 --- a/dispatch-bundle/Cargo.toml +++ b/dispatch-bundle/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" packit = [] [dependencies] -macros = { package = "embedded-command-macros", path = "../macros" } +macros = { package = "embedded-command-macros", version = "0.4.0" } [dev-dependencies] serac = { path = "../serac" } diff --git a/dispatch-bundle/src/lib.rs b/dispatch-bundle/src/lib.rs index 706e2fa..bd86322 100644 --- a/dispatch-bundle/src/lib.rs +++ b/dispatch-bundle/src/lib.rs @@ -81,17 +81,17 @@ mod tests { #[test] fn cookie_cutter() { - #[derive(vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive(vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf)] struct A { val: u8, } - #[derive(vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive(vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf)] struct B { val: u16, } - #[derive(vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive(vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf)] struct C { val: u8, other: A, @@ -118,7 +118,7 @@ mod tests { const TEN: u8 = 10; #[bundle(Foo)] - #[derive(vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive(vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf)] #[repr(u8)] enum MyBundle { A, diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 5fc6850..e81c8f7 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "embedded-command-macros" -version = "0.3.0" +version = "0.4.0" edition = "2024" description = "Macros for the embedded command crate family." license = "CC-BY-NC-SA-4.0" @@ -16,4 +16,4 @@ packit = [] Inflector = "0.11.4" proc-macro2 = "1.0.78" quote = "1.0.35" -syn = { version = "2.0.50", features = ["full"] } +syn = { version = "2.0.50", features = ["full", "extra-traits"] } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8368c72..177c91a 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -26,7 +26,17 @@ pub fn impl_serialize_iter_vanilla(item: TokenStream) -> TokenStream { serac::vanilla::serialize_iter(item) } -/// Generates the implementation block for conforming to `SerializeBuf` of the "vanilla" flavor. +/// Generates the implementation block for conforming to `Size` of the "vanilla" flavor. +/// +/// # Note +/// +/// Requires `serac` to be in scope with that name. +#[proc_macro_derive(Size)] +pub fn impl_size_vanilla(item: TokenStream) -> TokenStream { + serac::vanilla::impl_size(item) +} + +/// Generates the implementation block for conforming to `SerializeBuf`. /// /// As of now, generic types *cannot* implement `SerializeBuf` on stable. /// @@ -34,6 +44,18 @@ pub fn impl_serialize_iter_vanilla(item: TokenStream) -> TokenStream { /// /// Requires `serac` to be in scope with that name. #[proc_macro_derive(SerializeBuf)] -pub fn impl_serialize_buf_vanilla(item: TokenStream) -> TokenStream { - serac::vanilla::impl_serialize_buf(item) +pub fn impl_serialize_buf(item: TokenStream) -> TokenStream { + serac::impl_serialize_buf(item) +} + +/// Generates the implementation block for conforming to `SerializeBuf` for a type +/// alias. This is used for implementing `SerializeBuf` for concretely specified +/// generic types. +/// +/// # Note +/// +/// Requires `serac` to be in scope with that name. +#[proc_macro_attribute] +pub fn impl_serialize_buf_alias(attrs: TokenStream, item: TokenStream) -> TokenStream { + serac::impl_serialize_buf_alias(attrs, item) } diff --git a/macros/src/serac.rs b/macros/src/serac.rs index 38050c2..8468e99 100644 --- a/macros/src/serac.rs +++ b/macros/src/serac.rs @@ -1 +1,54 @@ pub(crate) mod vanilla; + +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::{DeriveInput, Generics, Ident, ItemType, Path}; + +#[derive(Clone)] +struct BodyInfo { + ident: Ident, + generics: Generics, + path: Path, +} + +pub fn impl_serialize_buf(item: TokenStream) -> TokenStream { + let item: DeriveInput = syn::parse2(item.into()).unwrap(); + + if !item.generics.params.is_empty() { + panic!("SerializeBuf is incompatible with generic types. You may still use SerializeIter."); + } + + let info = BodyInfo { + ident: item.ident, + generics: item.generics, + path: syn::parse2(quote! { serac }).unwrap(), + }; + + let path = info.path; + let ident = info.ident; + + quote! { + unsafe impl #path::SerializeBuf<{ <#ident as #path::Size>::SIZE }> for #ident {} + } + .into() +} + +pub fn impl_serialize_buf_alias(attrs: TokenStream, item: TokenStream) -> TokenStream { + let item = item.into(); + let attrs: TokenStream2 = attrs.into(); + + let original = quote! { #attrs #item }; + + let item: ItemType = syn::parse2(item).expect("a"); + + let path = syn::parse2::(quote! { serac }).expect("b"); + let ident = item.ident; + + quote! { + #original + + unsafe impl #path::SerializeBuf<{ <#ident as #path::Size>::SIZE }> for #ident {} + } + .into() +} diff --git a/macros/src/serac/vanilla.rs b/macros/src/serac/vanilla.rs index 317655a..321afe9 100644 --- a/macros/src/serac/vanilla.rs +++ b/macros/src/serac/vanilla.rs @@ -1,17 +1,11 @@ +use std::collections::HashSet; + use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote}; -use syn::{ - Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Generics, Ident, Index, Path, Type, - Variant, -}; - -#[derive(Clone)] -struct BodyInfo { - ident: Ident, - generics: Generics, - path: Path, -} +use syn::{Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Index, Type, Variant}; + +use crate::serac::BodyInfo; fn get_repr<'a>(mut attrs: impl Iterator) -> Type { attrs @@ -118,6 +112,19 @@ fn serialize_struct(s: DataStruct, info: &BodyInfo) -> TokenStream2 { } }; + let (.., types) = size_of_struct(s, info); + + let where_clause = { + let constraints = types.iter().map(|ty| { + quote! { #ty: #path::SerializeIter } + }); + + match where_clause { + Some(w) => quote! { #w #(#constraints,)* }, + None => quote! { where #(#constraints,)* }, + } + }; + quote! { impl #impl_generics #path::SerializeIter for #implementer #ty_generics #where_clause { fn serialize_iter<'a>(&self, dst: impl IntoIterator::Word>) -> Result @@ -137,15 +144,18 @@ fn serialize_struct(s: DataStruct, info: &BodyInfo) -> TokenStream2 { } } -fn size_of_struct(s: DataStruct, info: &BodyInfo) -> TokenStream2 { - let types: Vec<_> = s.fields.iter().map(|field| &field.ty).collect(); +fn size_of_struct(s: DataStruct, info: &BodyInfo) -> (TokenStream2, HashSet) { + let types: Vec<_> = s.fields.iter().map(|field| field.ty.clone()).collect(); let path = &info.path; - if types.is_empty() { - quote! { 0 } - } else { - quote! { #( <#types as #path::Size>::SIZE )+* } - } + ( + if types.is_empty() { + quote! { 0 } + } else { + quote! { #( <#types as #path::Size>::SIZE )+* } + }, + HashSet::from_iter(types), + ) } fn serialize_enum(e: DataEnum, info: &BodyInfo, repr: Type) -> TokenStream2 { @@ -260,6 +270,19 @@ fn serialize_enum(e: DataEnum, info: &BodyInfo, repr: Type) -> TokenStream2 { }) .collect(); + let (.., types) = size_of_enum(e, info, repr.clone()); + + let where_clause = { + let constraints = types.iter().map(|ty| { + quote! { #ty: #path::SerializeIter } + }); + + match where_clause { + Some(w) => quote! { #w #(#constraints,)* }, + None => quote! { where #(#constraints,)* }, + } + }; + quote! { impl #impl_generics #path::SerializeIter for #implementer #ty_generics #where_clause { fn serialize_iter<'a>(&self, dst: impl IntoIterator::Word>) -> Result @@ -303,33 +326,44 @@ fn serialize_enum(e: DataEnum, info: &BodyInfo, repr: Type) -> TokenStream2 { } } -fn size_of_enum(e: DataEnum, info: &BodyInfo, repr: Type) -> TokenStream2 { +fn size_of_enum(e: DataEnum, info: &BodyInfo, repr: Type) -> (TokenStream2, HashSet) { + let mut types = HashSet::new(); + let path = &info.path; let sizes: Vec<_> = e .variants .iter() .filter_map(|variant| { if !variant.fields.is_empty() { - let types: Vec<_> = variant.fields.iter().map(|field| &field.ty).collect(); + let variant_types: Vec<_> = variant + .fields + .iter() + .map(|field| field.ty.clone()) + .collect(); - Some(quote! { #(<#types as #path::Size>::SIZE)+* }) + types.extend(variant_types.iter().cloned()); + + Some(quote! { #(<#variant_types as #path::Size>::SIZE)+* }) } else { None } }) .collect(); - quote! {{ - let mut max = 0; + ( + quote! {{ + let mut max = 0; - #( - if #sizes > max { - max = #sizes; - } - )* + #( + if #sizes > max { + max = #sizes; + } + )* - max + <#repr as #path::Size>::SIZE - }} + max + <#repr as #path::Size>::SIZE + }}, + types, + ) } pub fn serialize_iter(item: TokenStream) -> TokenStream { @@ -350,34 +384,41 @@ pub fn serialize_iter(item: TokenStream) -> TokenStream { implementation.into() } -pub fn impl_serialize_buf(item: TokenStream) -> TokenStream { +pub fn impl_size(item: TokenStream) -> TokenStream { let item: DeriveInput = syn::parse2(item.into()).unwrap(); - if !item.generics.params.is_empty() { - panic!("SerializeBuf is incompatible with generic types. You may still use SerializeIter."); - } - let info = BodyInfo { ident: item.ident, generics: item.generics, path: syn::parse2(quote! { serac }).unwrap(), }; - let size = match item.data { + let (size, types) = match item.data { Data::Struct(s) => size_of_struct(s, &info), Data::Enum(e) => size_of_enum(e, &info, get_repr(item.attrs.iter())), _ => panic!("Vanilla serializer is only implemented for structs and enums."), }; + let (impl_generics, ty_generics, where_clause) = info.generics.split_for_impl(); + let path = info.path; let ident = info.ident; + let where_clause = { + let constraints = types.iter().map(|ty| { + quote! { #ty: #path::Size } + }); + + match where_clause { + Some(w) => quote! { #w #(#constraints,)* }, + None => quote! { where #(#constraints,)* }, + } + }; + quote! { - unsafe impl #path::Size for #ident { + unsafe impl #impl_generics #path::Size for #ident #ty_generics #where_clause { const SIZE: usize = #size; } - - impl #path::SerializeBuf<{ <#ident as #path::Size>::SIZE }> for #ident {} } .into() } diff --git a/serac/Cargo.toml b/serac/Cargo.toml index a917ec9..a2ac23a 100644 --- a/serac/Cargo.toml +++ b/serac/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "serac" -version = "0.3.0" +version = "0.4.1" edition = "2024" description = "A static, modular, and light-weight serialization framework." license = "CC-BY-NC-SA-4.0" repository = "https://github.com/adinack/embedded-command" [dependencies] -macros = { package = "embedded-command-macros", version = "0.3.0" } +macros = { package = "embedded-command-macros", version = "0.4.0" } fill-array = "0.2.1" # for binary diff --git a/serac/README.md b/serac/README.md index a2c9435..d7e34b2 100644 --- a/serac/README.md +++ b/serac/README.md @@ -59,7 +59,7 @@ use serac::{buf, encoding::vanilla, SerializeBuf}; const BE: u8 = 0xbe; #[repr(u8)] -#[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::SerializeBuf)] +#[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::Size, SerializeBuf)] enum Foo { A, B(u8, i16) = 0xde, @@ -78,3 +78,50 @@ assert_eq(foo, readback); This example shows a crazy enum with lots of fancy things going on, which is able to be serialized by serac. + +### Serialize a custom generic type + +Mostly, the serialization of generic types is the same: + +```rust +use serac::{buf, encoding::vanilla, SerializeIter, SerializeBuf}; + +const BE: u8 = 0xbe; + +#[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::Size)] +#[repr(u16)] +enum Foo { + A(u8, T), + B { woah: U } = BE as u16, +} + +let foo = Foo::B { woah: 42i16 }; + +let mut buf = buf!(Foo); +foo.serialize_iter(&mut buf).unwrap(); + +let readback = SerializeIter::deserialize_iter(&buf).unwrap(); +assert_eq(foo, readback); + +// ... +``` + +But `SerializeBuf` is not derivable on generic types. + +You can, however, implement `SerializeBuf` for concretely specified aliases of +generic types: + +```rust +// ... + +#[serac::serialize_buf] +type ConcreteFoo = Foo; + +let foo = Foo::B { woah: 42i16 }; + +let mut buf = buf!(ConcreteFoo); +foo.serialize_buf(&mut buf); + +let readback = SerializeBuf::deserialize_buf(&buf).unwrap(); +assert_eq(foo, readback); +``` diff --git a/serac/src/encoding/vanilla.rs b/serac/src/encoding/vanilla.rs index e63ecce..c01f295 100644 --- a/serac/src/encoding/vanilla.rs +++ b/serac/src/encoding/vanilla.rs @@ -7,7 +7,7 @@ use super::Encoding; use crate::{Medium, SerializeBuf, SerializeIter, Size, error}; // reexport proc macros -pub use macros::{SerializeBuf, SerializeIter}; +pub use macros::{SerializeIter, Size}; pub struct Vanilla; @@ -66,7 +66,7 @@ macro_rules! impl_number { const SIZE: usize = $SIZE; } - impl SerializeBuf<{ <$TYPE as Size>::SIZE }> for $TYPE {} + unsafe impl SerializeBuf<{ <$TYPE as Size>::SIZE }> for $TYPE {} }; } @@ -250,6 +250,12 @@ impl SerializeIter for PhantomData { } } +unsafe impl Size for PhantomData { + const SIZE: usize = 0; +} + +unsafe impl SerializeBuf<0> for PhantomData {} + #[cfg(test)] mod tests { mod primitives { @@ -319,19 +325,27 @@ mod tests { mod structs { use super::*; - #[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive( + Debug, PartialEq, vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf, + )] struct Foo { a: u8, b: i16, } - #[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive( + Debug, PartialEq, vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf, + )] struct Nothing; - #[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive( + Debug, PartialEq, vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf, + )] struct Bar(u8, Nothing, i16); - #[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive( + Debug, PartialEq, vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf, + )] struct Baz { numbers: [f32; 16], flags: (bool, u8), @@ -415,7 +429,9 @@ mod tests { const BE: u8 = 0xbe; - #[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive( + Debug, PartialEq, vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf, + )] #[repr(u8)] enum Foo { A, @@ -469,27 +485,23 @@ mod tests { fn generics() { const BE: u8 = 0xbe; - #[derive(Debug, PartialEq, vanilla::SerializeIter)] + #[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::Size)] #[repr(u16)] - enum FooGen - where - T: SerializeIter, - U: SerializeIter, - { + enum FooGen { A(u8, T), B { woah: U } = BE as u16, // arbitrary expression in discriminant! } - #[derive(Debug, PartialEq, vanilla::SerializeIter)] - struct BarGen - where - T: SerializeIter, - { + #[derive(Debug, PartialEq, vanilla::SerializeIter, vanilla::Size)] + struct BarGen { a: T, b: FooGen, c: PhantomData, } + #[serac::serialize_buf] + type ConcreteFoo = FooGen; + let mut buf = [0; 4]; let test_bar = BarGen { @@ -501,13 +513,23 @@ mod tests { // buf is too small assert!(test_bar.serialize_iter(&mut buf).is_err()); - let mut buf = [0; 8]; + let mut buf = buf!(BarGen); assert_eq!(6, test_bar.serialize_iter(&mut buf).unwrap()); let read_bar = SerializeIter::deserialize_iter(&buf).unwrap(); assert_eq!(test_bar, read_bar); // comparison provides type inference for deserialization! + + let mut buf = buf!(ConcreteFoo); + + let test_foo = ConcreteFoo::B { woah: -42 }; + + assert_eq!(4, test_foo.serialize_buf(&mut buf)); + + let read_foo = SerializeBuf::deserialize_buf(&buf).unwrap(); + + assert_eq!(test_foo, read_foo); // comparison provides type inference for deserialization! } } } diff --git a/serac/src/lib.rs b/serac/src/lib.rs index aa15279..9200e91 100644 --- a/serac/src/lib.rs +++ b/serac/src/lib.rs @@ -7,6 +7,7 @@ pub mod medium; pub use encoding::Encoding; use encoding::vanilla::Vanilla; +pub use macros::{SerializeBuf, impl_serialize_buf_alias as serialize_buf}; pub use medium::Medium; pub mod error { @@ -27,7 +28,7 @@ pub mod error { /// Deserialization failed. #[repr(u8)] - #[derive(Debug, Clone, Copy, vanilla::SerializeIter, vanilla::SerializeBuf)] + #[derive(Debug, Clone, Copy, vanilla::SerializeIter, vanilla::Size, serac::SerializeBuf)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Error { /// The encoder reached the end of the input before deserialization was @@ -77,7 +78,14 @@ pub trait SerializeIter: Sized { /// implementer type. /// /// To implement this trait, the type must already implement [`SerializeIter`] and [`Size`]. -pub trait SerializeBuf: SerializeIter + Size { +/// +/// # Safety +/// +/// This trait must only be implemented for all `T` where `T: Size` and +/// `N == T::SIZE`. +pub unsafe trait SerializeBuf: + SerializeIter + Size +{ fn serialize_buf<'a>(&self, buf: &'a mut E::Serialized) -> usize where &'a mut E::Serialized: IntoIterator,