diff --git a/Cargo.toml b/Cargo.toml index d20ff8c..985f3cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ libfuzzer-sys = "0.4" arbitrary = { version = "1", features = ["derive"] } devirt = { path = "crates/core" } devirt-macros = { path = "crates/macros", version = "0.2.0" } -syn = { version = "2", features = ["full", "visit-mut"] } +syn = { version = "2", features = ["full", "visit", "visit-mut"] } quote = "1" proc-macro2 = "1" vstd = { version = "=0.0.0-2026-04-12-0118", default-features = false } diff --git a/crates/core/tests/equivalence.rs b/crates/core/tests/equivalence.rs index 12d325c..b02f56b 100644 --- a/crates/core/tests/equivalence.rs +++ b/crates/core/tests/equivalence.rs @@ -491,3 +491,107 @@ fn attr_dispatch() { assert_eq!(h.val, 99); assert_eq!(c.val, 100); } + +// ── Associated types ────────────────────────────────────────────────────── + +#[cfg(feature = "macros")] +mod attr_assoc_types { + pub struct Circle; + pub struct Rect; + + #[devirt::devirt(Circle)] + pub trait Drawable { + type Color; + fn name(&self) -> &str; + fn draw(&self, color: Self::Color) -> String; + } + + #[devirt::devirt] + impl Drawable for Circle { + type Color = String; + fn name(&self) -> &str { "circle" } + fn draw(&self, color: String) -> String { format!("circle: {color}") } + } + + #[devirt::devirt] + impl Drawable for Rect { + type Color = u32; + fn name(&self) -> &str { "rect" } + fn draw(&self, color: u32) -> String { format!("rect: #{color:06x}") } + } +} + +#[cfg(feature = "macros")] +#[test] +fn attr_assoc_type_dispatch() { + use attr_assoc_types::{Circle, Drawable, Rect}; + + let c = Circle; + assert_eq!((&c as &dyn Drawable).name(), "circle"); + assert_eq!( + (&c as &dyn Drawable).draw("red".into()), + "circle: red" + ); + assert_eq!( + (&c as &(dyn Drawable + Send)).name(), + "circle" + ); + assert_eq!( + (&c as &(dyn Drawable + Send + Sync)).name(), + "circle" + ); + + let r = Rect; + let d: &dyn Drawable = &r; + assert_eq!(d.name(), "rect"); + assert_eq!(d.draw(0x00FF_0000_u32), "rect: #ff0000"); +} + +// ── Generic trait parameters ────────────────────────────────────────────── + +#[cfg(feature = "macros")] +mod attr_generic_trait { + pub struct Handler; + + #[devirt::devirt(Handler)] + pub trait Processor { + fn process(&self, input: T) -> String; + fn name(&self) -> &str; + } + + #[devirt::devirt] + impl Processor for Handler { + fn process(&self, input: String) -> String { format!("str: {input}") } + fn name(&self) -> &str { "handler" } + } + + #[devirt::devirt] + impl Processor for Handler { + fn process(&self, input: u32) -> String { format!("num: {input}") } + fn name(&self) -> &str { "handler" } + } +} + +#[cfg(feature = "macros")] +#[test] +fn attr_generic_trait_dispatch() { + use attr_generic_trait::{Handler, Processor}; + + let h = Handler; + assert_eq!( + (&h as &dyn Processor).process("hello".into()), + "str: hello" + ); + assert_eq!( + (&h as &dyn Processor).process(42), + "num: 42" + ); + assert_eq!( + (&h as &(dyn Processor + Send)).name(), + "handler" + ); + assert_eq!( + (&h as &(dyn Processor + Send + Sync)).name(), + "handler" + ); +} diff --git a/crates/core/tests/ui_attr.rs b/crates/core/tests/ui_attr.rs index 4cf2369..5317188 100644 --- a/crates/core/tests/ui_attr.rs +++ b/crates/core/tests/ui_attr.rs @@ -19,6 +19,9 @@ fn ui_attr() { t.pass("tests/ui_attr/attr_where_clause.rs"); t.pass("tests/ui_attr/attr_generic_impl.rs"); t.pass("tests/ui_attr/attr_where_impl.rs"); + t.pass("tests/ui_attr/attr_assoc_type.rs"); + t.pass("tests/ui_attr/attr_assoc_default.rs"); + t.pass("tests/ui_attr/attr_generic_trait.rs"); t.compile_fail("tests/ui_attr/attr_must_use_unused.rs"); t.compile_fail("tests/ui_attr/attr_missing_args.rs"); t.compile_fail("tests/ui_attr/attr_unsafe_missing_on_impl.rs"); diff --git a/crates/core/tests/ui_attr/attr_assoc_default.rs b/crates/core/tests/ui_attr/attr_assoc_default.rs new file mode 100644 index 0000000..9aa10c0 --- /dev/null +++ b/crates/core/tests/ui_attr/attr_assoc_default.rs @@ -0,0 +1,24 @@ +struct Circle; + +#[devirt::devirt(Circle)] +pub trait Drawable { + type Color; + fn name(&self) -> &str; + fn draw(&self, color: Self::Color) -> String; + fn describe(&self) -> String { + format!("I am {}", self.name()) + } +} + +#[devirt::devirt] +impl Drawable for Circle { + type Color = String; + fn name(&self) -> &str { "circle" } + fn draw(&self, color: String) -> String { format!("circle: {color}") } +} + +fn main() { + let c = Circle; + let d: &dyn Drawable = &c; + assert_eq!(d.describe(), "I am circle"); +} diff --git a/crates/core/tests/ui_attr/attr_assoc_type.rs b/crates/core/tests/ui_attr/attr_assoc_type.rs new file mode 100644 index 0000000..89258ca --- /dev/null +++ b/crates/core/tests/ui_attr/attr_assoc_type.rs @@ -0,0 +1,39 @@ +struct Circle; +struct Rect; + +#[devirt::devirt(Circle)] +pub trait Drawable { + type Color; + fn name(&self) -> &str; + fn draw(&self, color: Self::Color) -> String; +} + +#[devirt::devirt] +impl Drawable for Circle { + type Color = String; + fn name(&self) -> &str { "circle" } + fn draw(&self, color: String) -> String { format!("circle: {color}") } +} + +#[devirt::devirt] +impl Drawable for Rect { + type Color = u32; + fn name(&self) -> &str { "rect" } + fn draw(&self, color: u32) -> String { format!("rect: #{color:06x}") } +} + +fn check_name(d: &dyn Drawable) -> &str { d.name() } +fn check_draw(d: &dyn Drawable, c: String) -> String { d.draw(c) } +fn check_send(d: &(dyn Drawable + Send)) -> &str { d.name() } + +fn main() { + let c = Circle; + assert_eq!(check_name(&c), "circle"); + assert_eq!(check_draw(&c, "red".into()), "circle: red"); + assert_eq!(check_send(&c), "circle"); + + let r = Rect; + let d: &dyn Drawable = &r; + assert_eq!(d.name(), "rect"); + assert_eq!(d.draw(0x00FF_0000), "rect: #ff0000"); +} diff --git a/crates/core/tests/ui_attr/attr_generic_trait.rs b/crates/core/tests/ui_attr/attr_generic_trait.rs new file mode 100644 index 0000000..9f10fe7 --- /dev/null +++ b/crates/core/tests/ui_attr/attr_generic_trait.rs @@ -0,0 +1,30 @@ +struct Handler; + +#[devirt::devirt(Handler)] +pub trait Processor { + fn process(&self, input: T) -> String; + fn name(&self) -> &str; +} + +#[devirt::devirt] +impl Processor for Handler { + fn process(&self, input: String) -> String { format!("str: {input}") } + fn name(&self) -> &str { "handler" } +} + +#[devirt::devirt] +impl Processor for Handler { + fn process(&self, input: u32) -> String { format!("num: {input}") } + fn name(&self) -> &str { "handler" } +} + +fn use_str(p: &dyn Processor) -> String { p.process("hello".into()) } +fn use_u32(p: &dyn Processor) -> String { p.process(42) } +fn use_send(p: &(dyn Processor + Send)) -> &str { p.name() } + +fn main() { + let h = Handler; + assert_eq!(use_str(&h), "str: hello"); + assert_eq!(use_u32(&h), "num: 42"); + assert_eq!(use_send(&h), "handler"); +} diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index 7ed910a..211c393 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -5,11 +5,12 @@ //! is an implementation detail of `devirt` and should not be used //! directly. -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::punctuated::Punctuated; +use syn::visit::Visit; use syn::visit_mut::VisitMut; use syn::{Token, parse_macro_input}; @@ -76,21 +77,8 @@ fn expand_trait(attr: TokenStream, trait_item: &syn::ItemTrait) -> TokenStream { } fn validate_trait(trait_item: &syn::ItemTrait) -> Result<(), syn::Error> { - if !trait_item.generics.params.is_empty() { - return Err(syn::Error::new_spanned( - &trait_item.generics, - "#[devirt] does not support generic traits", - )); - } for item in &trait_item.items { match item { - syn::TraitItem::Type(t) => { - return Err(syn::Error::new_spanned( - t, - "#[devirt] does not yet support associated types — \ - they require type parameterization of the dispatch shim", - )); - } syn::TraitItem::Const(c) => { return Err(syn::Error::new_spanned( c, @@ -146,89 +134,243 @@ fn validate_trait_method(f: &syn::TraitItemFn) -> Result<(), syn::Error> { Ok(()) } -fn emit_trait_expansion( - trait_item: &syn::ItemTrait, - hot_types: &[syn::Type], -) -> TokenStream { - let unsafety = &trait_item.unsafety; - let vis = &trait_item.vis; - let name = &trait_item.ident; - let outer_attrs = &trait_item.attrs; - let supertraits = &trait_item.supertraits; - let where_clause = &trait_item.generics.where_clause; - let inner_name = format_ident!("__{name}Impl"); - - // __spec_* method declarations for the inner trait (with default - // bodies rewritten so `self.method()` → `self.__spec_method()`). - let spec_decls = generate_spec_decls(trait_item); +struct AssocTypeInfo { + names: HashSet, + idents: Vec, + generics: Vec, + decls: Vec, + rewrites: HashMap, +} - // Dispatch methods for the inherent impl on `dyn Trait`. - let dispatch_methods: Vec<_> = trait_item +fn collect_assoc_types(trait_item: &syn::ItemTrait) -> AssocTypeInfo { + let items: Vec<&syn::TraitItemType> = trait_item .items .iter() .filter_map(|item| { - let syn::TraitItem::Fn(m) = item else { - return None; - }; - Some(generate_dispatch_method(m, name, &inner_name, hot_types)) + if let syn::TraitItem::Type(t) = item { + Some(t) + } else { + None + } }) .collect(); + let names = items.iter().map(|t| t.ident.to_string()).collect(); + let idents: Vec = items.iter().map(|t| t.ident.clone()).collect(); + let generics: Vec = + idents.iter().map(|id| format_ident!("__{id}")).collect(); + let decls = items + .iter() + .map(|t| { + let t = *t; + quote! { #t } + }) + .collect(); + let rewrites = idents + .iter() + .zip(generics.iter()) + .map(|(id, gp)| (id.to_string(), gp.clone())) + .collect(); + AssocTypeInfo { names, idents, generics, decls, rewrites } +} + +fn build_trait_dyn_ref( + name: &syn::Ident, + trait_generic_params: &Punctuated, + assoc_type_idents: &[syn::Ident], + assoc_type_generics: &[syn::Ident], +) -> proc_macro2::TokenStream { + let mut type_args: Vec = Vec::new(); + for param in trait_generic_params { + match param { + syn::GenericParam::Type(t) => { + let id = &t.ident; + type_args.push(quote! { #id }); + } + syn::GenericParam::Lifetime(l) => { + let lt = &l.lifetime; + type_args.push(quote! { #lt }); + } + syn::GenericParam::Const(c) => { + let id = &c.ident; + type_args.push(quote! { #id }); + } + } + } + for (id, gp) in assoc_type_idents.iter().zip(assoc_type_generics.iter()) { + type_args.push(quote! { #id = #gp }); + } + if type_args.is_empty() { + quote! { #name } + } else { + quote! { #name<#(#type_args),*> } + } +} + +fn build_fat_ptr_assertion( + trait_item: &syn::ItemTrait, +) -> proc_macro2::TokenStream { + let name = &trait_item.ident; + let params = &trait_item.generics.params; - // Delegating methods for auto-trait inherent impls (Send, Sync, - // Send + Sync). Each method coerces `self` to the base `dyn Trait` - // and calls the dispatch method. - let delegating_methods: Vec<_> = trait_item + let assoc_types: Vec<&syn::TraitItemType> = trait_item .items .iter() .filter_map(|item| { - let syn::TraitItem::Fn(m) = item else { - return None; - }; - Some(generate_delegating_method(m, name)) + if let syn::TraitItem::Type(t) = item { + Some(t) + } else { + None + } }) .collect(); - // Inner trait supertraits: `__FooImpl: Debug + Clone` - let inner_supers = if supertraits.is_empty() { - quote! {} - } else { - quote! { : #supertraits } - }; + if params.is_empty() && assoc_types.is_empty() { + return quote! { + const _: () = assert!( + ::core::mem::size_of::<*const dyn #name>() + == 2 * ::core::mem::size_of::() + ); + }; + } - // Public trait supertraits: `Foo: __FooImpl + Debug + Clone` - // The `+ Debug + Clone` is redundant (implied by `__FooImpl`) but - // makes the bounds visible in rustdoc and compiler diagnostics. - let public_supers = if supertraits.is_empty() { - quote! { #inner_name } + let mut fn_params: Vec = Vec::new(); + let mut dyn_args: Vec = Vec::new(); + + for param in params { + fn_params.push(strip_param_defaults(param)); + match param { + syn::GenericParam::Type(t) => { + let id = &t.ident; + dyn_args.push(quote! { #id }); + } + syn::GenericParam::Lifetime(l) => { + let lt = &l.lifetime; + dyn_args.push(quote! { #lt }); + } + syn::GenericParam::Const(c) => { + let id = &c.ident; + dyn_args.push(quote! { #id }); + } + } + } + + for assoc in &assoc_types { + let id = &assoc.ident; + let assoc_param = format_ident!("__Assoc{}", id); + let bounds = &assoc.bounds; + if bounds.is_empty() { + fn_params.push(quote! { #assoc_param }); + } else { + fn_params.push(quote! { #assoc_param: #bounds }); + } + dyn_args.push(quote! { #id = #assoc_param }); + } + + let where_preds: Vec<_> = trait_item + .generics + .where_clause + .as_ref() + .map(|wc| { + wc.predicates + .iter() + .filter(|pred| !predicate_references_self(pred)) + .collect::>() + }) + .unwrap_or_default(); + + let where_clause = if where_preds.is_empty() { + quote! {} } else { - quote! { #inner_name + #supertraits } + quote! { where #(#where_preds),* } }; + // For generic/associated-type traits the assertion lives inside a + // generic function that is intentionally never called. This means + // the assert is not monomorphised and therefore not evaluated at + // compile time — it only verifies that `*const dyn Trait<…>` is a + // well-formed type. The actual size invariant (fat pointers are + // always two `usize`s) is guaranteed by the Rust ABI and is + // compile-time-checked for every non-generic trait expansion. quote! { - // (1) Hidden inner trait — carries __spec_* methods. - #[doc(hidden)] - #vis #unsafety trait #inner_name #inner_supers #where_clause { - #(#spec_decls)* + const _: () = { + fn __devirt_assert<#(#fn_params),*>() #where_clause { + assert!( + ::core::mem::size_of::<*const dyn #name<#(#dyn_args),*>>() + == 2 * ::core::mem::size_of::() + ); + } + }; + } +} + +fn strip_param_defaults(param: &syn::GenericParam) -> proc_macro2::TokenStream { + match param { + syn::GenericParam::Type(t) => { + let id = &t.ident; + let bounds = &t.bounds; + if bounds.is_empty() { + quote! { #id } + } else { + quote! { #id: #bounds } + } + } + syn::GenericParam::Lifetime(l) => quote! { #l }, + syn::GenericParam::Const(c) => { + let id = &c.ident; + let ty = &c.ty; + quote! { const #id: #ty } } + } +} - // (2) Compile-time fat pointer assertion. - const _: () = assert!( - ::core::mem::size_of::<*const dyn #name>() - == 2 * ::core::mem::size_of::() - ); +fn predicate_references_self(pred: &syn::WherePredicate) -> bool { + struct SelfFinder { + found: bool, + } + impl Visit<'_> for SelfFinder { + fn visit_path(&mut self, i: &syn::Path) { + if i.segments.first().is_some_and(|s| s.ident == "Self") { + self.found = true; + } + syn::visit::visit_path(self, i); + } + } + let mut finder = SelfFinder { found: false }; + syn::visit::visit_where_predicate(&mut finder, pred); + finder.found +} - // (3) Vtable helpers on `dyn Trait`. - impl<'__devirt> dyn #name + '__devirt { - /// Split a fat pointer into `[data, vtable]`. +fn build_vtable_helpers( + can_devirt: bool, + name: &syn::Ident, + inner_name: &syn::Ident, + inherent_impl_generics: &proc_macro2::TokenStream, + trait_dyn_ref: &proc_macro2::TokenStream, + assoc_type_idents: &[syn::Ident], +) -> proc_macro2::TokenStream { + if !can_devirt { + return quote! {}; + } + let vtable_coerce_ty = if assoc_type_idents.is_empty() { + quote! { *const Self } + } else { + quote! { + *const dyn #name< + #(#assoc_type_idents = + <__DevirtT as #inner_name>::#assoc_type_idents),* + > + } + }; + quote! { + impl #inherent_impl_generics dyn #trait_dyn_ref + '__devirt { #[doc(hidden)] #[inline(always)] pub fn __devirt_raw_parts(this: &Self) -> [usize; 2] { - // SAFETY: `&dyn Trait` is a two-`usize` fat pointer - // (verified by the compile-time assertion above). - unsafe { ::core::mem::transmute::<&Self, [usize; 2]>(this) } + unsafe { + ::core::mem::transmute::<&Self, [usize; 2]>(this) + } } - /// Vtable pointer for the `(T, Trait)` pair. #[doc(hidden)] #[inline(always)] pub fn __devirt_vtable_for< @@ -238,44 +380,189 @@ fn emit_trait_expansion( ::core::ptr::without_provenance( ::core::mem::align_of::<__DevirtT>(), ); - let fat: *const Self = fake; - // SAFETY: `*const dyn Trait` is two `usize`s. We read - // only the vtable half; the dangling data half is - // discarded. + let fat: #vtable_coerce_ty = fake; let __parts: [usize; 2] = unsafe { - ::core::mem::transmute::<*const Self, [usize; 2]>(fat) + ::core::mem::transmute::< + #vtable_coerce_ty, [usize; 2] + >(fat) }; __parts[1] } } + } +} - // (4) Inherent dispatch methods. - impl<'__devirt> dyn #name + '__devirt { - #(#dispatch_methods)* +fn build_blanket_impl( + unsafety: Option<&syn::token::Unsafe>, + has_trait_generics: bool, + name: &syn::Ident, + inner_name: &syn::Ident, + trait_generic_params: &Punctuated, + trait_ty_generics: &syn::TypeGenerics<'_>, + trait_where_clause: Option<&syn::WhereClause>, +) -> proc_macro2::TokenStream { + if has_trait_generics { + quote! { + #unsafety impl< + __DevirtT: #inner_name #trait_ty_generics + ?Sized, + #trait_generic_params + > #name #trait_ty_generics for __DevirtT #trait_where_clause {} } - - // (4a) dyn Trait + Send — delegate to base dispatch. - impl<'__devirt> dyn #name + ::core::marker::Send + '__devirt { - #(#delegating_methods)* + } else { + quote! { + #unsafety impl<__DevirtT: #inner_name + ?Sized> #name + for __DevirtT #trait_where_clause {} } + } +} - // (4b) dyn Trait + Sync — delegate to base dispatch. - impl<'__devirt> dyn #name + ::core::marker::Sync + '__devirt { - #(#delegating_methods)* - } +fn build_dispatch_methods( + trait_item: &syn::ItemTrait, + can_devirt: bool, + assoc_info: &AssocTypeInfo, + inner_name: &syn::Ident, + trait_dyn_ref: &proc_macro2::TokenStream, + hot_types: &[syn::Type], +) -> Vec { + trait_item + .items + .iter() + .filter_map(|item| { + let syn::TraitItem::Fn(m) = item else { + return None; + }; + let references_assoc = + method_references_assoc_types(&m.sig, &assoc_info.names); + if !can_devirt || references_assoc { + Some(generate_fallback_method(m, inner_name, &assoc_info.rewrites)) + } else { + Some(generate_dispatch_method( + m, trait_dyn_ref, inner_name, hot_types, &assoc_info.rewrites, + )) + } + }) + .collect() +} + +fn build_delegating_methods( + trait_item: &syn::ItemTrait, + trait_dyn_ref: &proc_macro2::TokenStream, + assoc_rewrites: &HashMap, +) -> Vec { + trait_item + .items + .iter() + .filter_map(|item| { + let syn::TraitItem::Fn(m) = item else { + return None; + }; + Some(generate_delegating_method(m, trait_dyn_ref, assoc_rewrites)) + }) + .collect() +} + +fn emit_trait_expansion( + trait_item: &syn::ItemTrait, + hot_types: &[syn::Type], +) -> TokenStream { + let unsafety = &trait_item.unsafety; + let vis = &trait_item.vis; + let name = &trait_item.ident; + let outer_attrs = &trait_item.attrs; + let supertraits = &trait_item.supertraits; + let inner_name = format_ident!("__{name}Impl"); + + // ── Trait-level generics ────────────────────────────────────── + let has_trait_generics = !trait_item.generics.params.is_empty(); + let trait_generic_params = &trait_item.generics.params; + let trait_where_clause = &trait_item.generics.where_clause; + let (_, trait_ty_generics, _) = trait_item.generics.split_for_impl(); + + let assoc_info = collect_assoc_types(trait_item); + let can_devirt = !has_trait_generics; + + let trait_dyn_ref = build_trait_dyn_ref( + name, + trait_generic_params, + &assoc_info.idents, + &assoc_info.generics, + ); + let spec_decls = generate_spec_decls(trait_item); + let dispatch_methods = build_dispatch_methods( + trait_item, can_devirt, &assoc_info, &inner_name, &trait_dyn_ref, hot_types, + ); + let delegating_methods = build_delegating_methods( + trait_item, &trait_dyn_ref, &assoc_info.rewrites, + ); + + let inner_supers = if supertraits.is_empty() { + quote! {} + } else { + quote! { : #supertraits } + }; + let public_supers = if supertraits.is_empty() { + quote! { #inner_name #trait_ty_generics } + } else { + quote! { #inner_name #trait_ty_generics + #supertraits } + }; + + let mut extra_params: Vec = Vec::new(); + for param in trait_generic_params { + extra_params.push(quote! { #param }); + } + for gp in &assoc_info.generics { + extra_params.push(quote! { #gp }); + } + let inherent_impl_generics = if extra_params.is_empty() { + quote! { <'__devirt> } + } else { + quote! { <'__devirt, #(#extra_params),*> } + }; + let trait_def_generics = if has_trait_generics { + quote! { <#trait_generic_params> } + } else { + quote! {} + }; + + let fat_ptr_assertion = build_fat_ptr_assertion(trait_item); + let vtable_helpers = build_vtable_helpers( + can_devirt, name, &inner_name, &inherent_impl_generics, + &trait_dyn_ref, &assoc_info.idents, + ); + let blanket_impl = build_blanket_impl( + unsafety.as_ref(), has_trait_generics, name, &inner_name, + trait_generic_params, &trait_ty_generics, trait_where_clause.as_ref(), + ); + let assoc_type_decls = &assoc_info.decls; + + quote! { + #[doc(hidden)] + #vis #unsafety trait #inner_name #trait_def_generics + #inner_supers #trait_where_clause + { #(#assoc_type_decls)* #(#spec_decls)* } - // (4c) dyn Trait + Send + Sync — delegate to base dispatch. - impl<'__devirt> dyn #name + ::core::marker::Send + ::core::marker::Sync + '__devirt { - #(#delegating_methods)* + #fat_ptr_assertion + + #vtable_helpers + + impl #inherent_impl_generics dyn #trait_dyn_ref + '__devirt { + #(#dispatch_methods)* } + impl #inherent_impl_generics + dyn #trait_dyn_ref + ::core::marker::Send + '__devirt + { #(#delegating_methods)* } + impl #inherent_impl_generics + dyn #trait_dyn_ref + ::core::marker::Sync + '__devirt + { #(#delegating_methods)* } + impl #inherent_impl_generics + dyn #trait_dyn_ref + ::core::marker::Send + + ::core::marker::Sync + '__devirt + { #(#delegating_methods)* } - // (5) Public marker trait. #(#outer_attrs)* - #vis #unsafety trait #name: #public_supers #where_clause {} - - // (6) Blanket impl. - #unsafety impl<__DevirtT: #inner_name + ?Sized> #name - for __DevirtT #where_clause {} + #vis #unsafety trait #name #trait_def_generics + : #public_supers #trait_where_clause {} + #blanket_impl } .into() } @@ -376,13 +663,109 @@ fn rewrite_sig_with_named_args( (sig, arg_names) } +// ── Associated type helpers ──────────────────────────────────────────────── + +struct AssocTypeFinder<'a> { + assoc_names: &'a HashSet, + found: bool, +} + +impl Visit<'_> for AssocTypeFinder<'_> { + fn visit_path(&mut self, i: &syn::Path) { + if i.segments.len() >= 2 + && i.segments[0].ident == "Self" + && self + .assoc_names + .contains(&i.segments[1].ident.to_string()) + { + self.found = true; + } + syn::visit::visit_path(self, i); + } +} + +fn method_references_assoc_types( + sig: &syn::Signature, + assoc_names: &HashSet, +) -> bool { + if assoc_names.is_empty() { + return false; + } + let mut finder = AssocTypeFinder { assoc_names, found: false }; + syn::visit::visit_signature(&mut finder, sig); + finder.found +} + +struct RewriteSelfAssocTypes { + rewrites: HashMap, +} + +impl VisitMut for RewriteSelfAssocTypes { + fn visit_path_mut(&mut self, i: &mut syn::Path) { + syn::visit_mut::visit_path_mut(self, i); + if i.segments.len() >= 2 + && i.segments[0].ident == "Self" + { + let name = i.segments[1].ident.to_string(); + if let Some(replacement) = self.rewrites.get(&name) { + let remaining: Vec = + i.segments.iter().skip(2).cloned().collect(); + let mut first = syn::PathSegment::from(replacement.clone()); + first.arguments = i.segments[1].arguments.clone(); + let mut new_segments = Punctuated::new(); + new_segments.push(first); + for seg in remaining { + new_segments.push(seg); + } + i.segments = new_segments; + } + } + } +} + +fn generate_fallback_method( + method: &syn::TraitItemFn, + inner_name: &syn::Ident, + assoc_rewrites: &HashMap, +) -> proc_macro2::TokenStream { + let sig = &method.sig; + let attrs = &method.attrs; + let spec_name = format_ident!("__spec_{}", sig.ident); + let is_unsafe = sig.unsafety.is_some(); + + let (mut dispatch_sig, arg_names) = rewrite_sig_with_named_args(sig); + if !assoc_rewrites.is_empty() { + let mut rewriter = RewriteSelfAssocTypes { + rewrites: assoc_rewrites.clone(), + }; + rewriter.visit_signature_mut(&mut dispatch_sig); + } + + let call = quote! { #inner_name::#spec_name(self, #(#arg_names),*) }; + let body = if is_unsafe { + quote! { unsafe { #call } } + } else { + call + }; + + quote! { + #(#attrs)* + #[doc(hidden)] + #[inline] + pub #dispatch_sig { + #body + } + } +} + // ── Dispatch method generation ────────────────────────────────────────────── fn generate_dispatch_method( method: &syn::TraitItemFn, - trait_name: &syn::Ident, + trait_dyn_ref: &proc_macro2::TokenStream, inner_name: &syn::Ident, hot_types: &[syn::Type], + assoc_rewrites: &HashMap, ) -> proc_macro2::TokenStream { let sig = &method.sig; let attrs = &method.attrs; @@ -403,16 +786,22 @@ fn generate_dispatch_method( let is_mut = receiver.mutability.is_some(); let is_unsafe = sig.unsafety.is_some(); - let (dispatch_sig, arg_names) = rewrite_sig_with_named_args(sig); + let (mut dispatch_sig, arg_names) = rewrite_sig_with_named_args(sig); + if !assoc_rewrites.is_empty() { + let mut rewriter = RewriteSelfAssocTypes { + rewrites: assoc_rewrites.clone(), + }; + rewriter.visit_signature_mut(&mut dispatch_sig); + } let raw_parts = if is_mut { - quote! { let __raw = ::__devirt_raw_parts(&*self); } + quote! { let __raw = ::__devirt_raw_parts(&*self); } } else { - quote! { let __raw = ::__devirt_raw_parts(self); } + quote! { let __raw = ::__devirt_raw_parts(self); } }; let hot_checks = gen_hot_checks( - hot_types, trait_name, &spec_name, &arg_names, is_mut, + hot_types, trait_dyn_ref, &spec_name, &arg_names, is_mut, ); let fallback = if is_unsafe { @@ -435,7 +824,7 @@ fn generate_dispatch_method( fn gen_hot_checks( hot_types: &[syn::Type], - trait_name: &syn::Ident, + trait_dyn_ref: &proc_macro2::TokenStream, spec_name: &syn::Ident, arg_names: &[syn::Ident], is_mut: bool, @@ -446,7 +835,7 @@ fn gen_hot_checks( if is_mut { quote! { if __raw[1] - == ::__devirt_vtable_for::<#hot>() + == ::__devirt_vtable_for::<#hot>() { let __p: *mut #hot = __raw[0] as *mut #hot; // SAFETY: vtable identity implies type identity. @@ -461,7 +850,7 @@ fn gen_hot_checks( } else { quote! { if __raw[1] - == ::__devirt_vtable_for::<#hot>() + == ::__devirt_vtable_for::<#hot>() { let __p: *const #hot = __raw[0] as *const #hot; // SAFETY: vtable identity implies type identity. @@ -486,7 +875,8 @@ fn gen_hot_checks( /// with `#[inline(always)]`. fn generate_delegating_method( method: &syn::TraitItemFn, - trait_name: &syn::Ident, + trait_dyn_ref: &proc_macro2::TokenStream, + assoc_rewrites: &HashMap, ) -> proc_macro2::TokenStream { let sig = &method.sig; let method_name = &sig.ident; @@ -506,23 +896,26 @@ fn generate_delegating_method( let is_mut = receiver.mutability.is_some(); let is_unsafe = sig.unsafety.is_some(); - let (dispatch_sig, arg_names) = rewrite_sig_with_named_args(sig); + let (mut dispatch_sig, arg_names) = rewrite_sig_with_named_args(sig); + if !assoc_rewrites.is_empty() { + let mut rewriter = RewriteSelfAssocTypes { + rewrites: assoc_rewrites.clone(), + }; + rewriter.visit_signature_mut(&mut dispatch_sig); + } - // Coerce to the base `dyn Trait` and call the dispatch method. let coerce_and_call = if is_mut { quote! { - let __devirt_base: &mut (dyn #trait_name + '__devirt) = self; + let __devirt_base: &mut (dyn #trait_dyn_ref + '__devirt) = self; __devirt_base.#method_name(#(#arg_names),*) } } else { quote! { - let __devirt_base: &(dyn #trait_name + '__devirt) = self; + let __devirt_base: &(dyn #trait_dyn_ref + '__devirt) = self; __devirt_base.#method_name(#(#arg_names),*) } }; - // When the method is `unsafe fn`, the base dispatch method is also - // `unsafe fn`, so the call must be inside an `unsafe` block. let delegation = if is_unsafe { quote! { unsafe { #coerce_and_call } } } else { @@ -573,16 +966,29 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream { } let unsafety = &impl_item.unsafety; - let trait_name = &trait_path + let trait_segment = trait_path .segments .last() - .expect("validated: path non-empty") - .ident; + .expect("validated: path non-empty"); + let trait_name = &trait_segment.ident; let inner_name = format_ident!("__{trait_name}Impl"); + let trait_args = &trait_segment.arguments; let ty = &impl_item.self_ty; let (impl_generics, _, where_clause) = impl_item.generics.split_for_impl(); + let type_items: Vec<_> = impl_item + .items + .iter() + .filter_map(|item| { + if let syn::ImplItem::Type(t) = item { + Some(quote! { #t }) + } else { + None + } + }) + .collect(); + // Collect method names so sibling calls in impl bodies // (e.g. `self.area()`) are rewritten to `self.__spec_area()`. let method_names: HashSet = impl_item @@ -622,7 +1028,10 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream { .collect(); quote! { - #unsafety impl #impl_generics #inner_name for #ty #where_clause { + #unsafety impl #impl_generics #inner_name #trait_args + for #ty #where_clause + { + #(#type_items)* #(#spec_methods)* } }