diff --git a/crates/core/tests/equivalence.rs b/crates/core/tests/equivalence.rs index 2bbc63d..12d325c 100644 --- a/crates/core/tests/equivalence.rs +++ b/crates/core/tests/equivalence.rs @@ -253,6 +253,150 @@ fn attr_default_body_dispatch() { assert_eq!((&h5 as &dyn Defaulted).describe(), "val=7"); } +// ── Where clauses on traits ──────────────────────────────────────────────── + +#[cfg(feature = "macros")] +mod attr_where_clause { + use core::fmt::Debug; + + #[derive(Debug)] + pub struct WcHot { + pub val: u64, + } + + #[derive(Debug)] + pub struct WcCold { + pub val: u64, + } + + // Uses `where Self: Debug` (semantically equivalent to a supertrait). + #[devirt::devirt(WcHot)] + pub trait Inspectable + where + Self: Debug, + { + fn value(&self) -> u64; + fn inspect(&self) -> String { + format!("{:?}={}", self, self.value()) + } + } + + #[devirt::devirt] + impl Inspectable for WcHot { + fn value(&self) -> u64 { + self.val + } + } + + #[devirt::devirt] + impl Inspectable for WcCold { + fn value(&self) -> u64 { + self.val + 1 + } + } +} + +#[cfg(feature = "macros")] +#[test] +fn attr_where_clause_dispatch() { + use attr_where_clause::{Inspectable, WcCold, WcHot}; + + let h = WcHot { val: 42 }; + assert_eq!((&h as &dyn Inspectable).value(), 42); + assert!((&h as &dyn Inspectable).inspect().contains("42")); + + let c = WcCold { val: 10 }; + assert_eq!((&c as &dyn Inspectable).value(), 11); + assert!((&c as &dyn Inspectable).inspect().contains("11")); +} + +// ── Generic impl blocks ──────────────────────────────────────────────────── + +#[cfg(feature = "macros")] +mod attr_generic_impl { + pub struct GHot { + pub val: u64, + } + + #[devirt::devirt(GHot)] + pub trait Scale { + fn area(&self) -> u64; + } + + #[devirt::devirt] + impl Scale for GHot { + fn area(&self) -> u64 { + self.val + } + } + + pub struct Scaled { + pub inner: T, + pub factor: u64, + } + + // Generic impl — always cold. Delegates to inner through &dyn Scale. + #[devirt::devirt] + impl Scale for Scaled { + fn area(&self) -> u64 { + self.factor * (&self.inner as &dyn Scale).area() + } + } + + // Generic impl with where clause. + pub struct Pair + where + A: Scale, + B: Scale, + { + pub a: A, + pub b: B, + } + + #[devirt::devirt] + impl Scale for Pair + where + A: Scale, + B: Scale, + { + fn area(&self) -> u64 { + (&self.a as &dyn Scale).area() + (&self.b as &dyn Scale).area() + } + } +} + +#[cfg(feature = "macros")] +#[test] +fn attr_generic_impl_dispatch() { + use attr_generic_impl::{GHot, Pair, Scale, Scaled}; + + // Hot + let h = GHot { val: 10 }; + assert_eq!((&h as &dyn Scale).area(), 10); + + // Scaled — cold outer, hot inner + let s = Scaled { inner: GHot { val: 5 }, factor: 3 }; + assert_eq!((&s as &dyn Scale).area(), 15); + + // Scaled> — nested generic + let ss = Scaled { + inner: Scaled { inner: GHot { val: 2 }, factor: 4 }, + factor: 5, + }; + assert_eq!((&ss as &dyn Scale).area(), 40); + + // Pair with where clause + let p = Pair { + a: GHot { val: 3 }, + b: Scaled { inner: GHot { val: 2 }, factor: 4 }, + }; + assert_eq!((&p as &dyn Scale).area(), 3 + 8); + + // Via auto-trait flavors + assert_eq!((&s as &(dyn Scale + Send)).area(), 15); + assert_eq!((&s as &(dyn Scale + Send + Sync)).area(), 15); +} + // ── Extended proc-macro tests: supertraits, method lifetimes, #[must_use] ── #[cfg(feature = "macros")] diff --git a/crates/core/tests/ui_attr.rs b/crates/core/tests/ui_attr.rs index 2913e4a..4cf2369 100644 --- a/crates/core/tests/ui_attr.rs +++ b/crates/core/tests/ui_attr.rs @@ -16,6 +16,9 @@ fn ui_attr() { t.pass("tests/ui_attr/attr_default_body.rs"); t.pass("tests/ui_attr/attr_default_override.rs"); t.pass("tests/ui_attr/attr_default_send.rs"); + t.pass("tests/ui_attr/attr_where_clause.rs"); + t.pass("tests/ui_attr/attr_generic_impl.rs"); + t.pass("tests/ui_attr/attr_where_impl.rs"); t.compile_fail("tests/ui_attr/attr_must_use_unused.rs"); t.compile_fail("tests/ui_attr/attr_missing_args.rs"); t.compile_fail("tests/ui_attr/attr_unsafe_missing_on_impl.rs"); diff --git a/crates/core/tests/ui_attr/attr_generic_impl.rs b/crates/core/tests/ui_attr/attr_generic_impl.rs new file mode 100644 index 0000000..031e6ae --- /dev/null +++ b/crates/core/tests/ui_attr/attr_generic_impl.rs @@ -0,0 +1,42 @@ +struct Hot { + val: f64, +} + +#[devirt::devirt(Hot)] +pub trait Shape { + fn area(&self) -> f64; +} + +#[devirt::devirt] +impl Shape for Hot { + fn area(&self) -> f64 { + self.val + } +} + +// Generic impl — cold type, falls back to vtable +struct Scaled { + inner: T, + factor: f64, +} + +#[devirt::devirt] +impl Shape for Scaled { + fn area(&self) -> f64 { + // Coerce to &dyn Shape for devirt dispatch. The rewriter only + // rewrites `self.method()` calls, not `self.inner.method()`. + self.factor * (&self.inner as &dyn Shape).area() + } +} + +fn total(shapes: &[Box]) -> f64 { + shapes.iter().map(|s| s.area()).sum() +} + +fn main() { + let shapes: Vec> = vec![ + Box::new(Hot { val: 10.0 }), + Box::new(Scaled { inner: Hot { val: 5.0 }, factor: 3.0 }), + ]; + assert_eq!(total(&shapes), 25.0); // 10 + 5*3 +} diff --git a/crates/core/tests/ui_attr/attr_where_clause.rs b/crates/core/tests/ui_attr/attr_where_clause.rs new file mode 100644 index 0000000..154b6ff --- /dev/null +++ b/crates/core/tests/ui_attr/attr_where_clause.rs @@ -0,0 +1,48 @@ +use std::fmt::Debug; + +#[derive(Debug)] +struct Hot { + val: u64, +} + +#[derive(Debug)] +struct Cold { + val: u64, +} + +#[devirt::devirt(Hot)] +pub trait Inspectable +where + Self: Debug, +{ + fn value(&self) -> u64; + fn inspect(&self) -> String { + format!("{:?} = {}", self, self.value()) + } +} + +#[devirt::devirt] +impl Inspectable for Hot { + fn value(&self) -> u64 { + self.val + } +} + +#[devirt::devirt] +impl Inspectable for Cold { + fn value(&self) -> u64 { + self.val + 1 + } +} + +fn check(i: &dyn Inspectable) -> String { + i.inspect() +} + +fn main() { + let h = Hot { val: 42 }; + assert!(check(&h).contains("42")); + + let c = Cold { val: 10 }; + assert!(check(&c).contains("11")); +} diff --git a/crates/core/tests/ui_attr/attr_where_impl.rs b/crates/core/tests/ui_attr/attr_where_impl.rs new file mode 100644 index 0000000..389c69f --- /dev/null +++ b/crates/core/tests/ui_attr/attr_where_impl.rs @@ -0,0 +1,34 @@ +use std::fmt::Display; + +struct Hot { + val: f64, +} + +#[devirt::devirt(Hot)] +pub trait Shape { + fn describe(&self) -> String; +} + +#[devirt::devirt] +impl Shape for Hot { + fn describe(&self) -> String { + format!("hot: {}", self.val) + } +} + +struct Named { + name: String, + inner: T, +} + +#[devirt::devirt] +impl Shape for Named +where + T: Shape + Display, +{ + fn describe(&self) -> String { + format!("{}: {}", self.name, self.inner) + } +} + +fn main() {} diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index 175a9ef..7ed910a 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -82,24 +82,20 @@ fn validate_trait(trait_item: &syn::ItemTrait) -> Result<(), syn::Error> { "#[devirt] does not support generic traits", )); } - if let Some(wc) = &trait_item.generics.where_clause { - return Err(syn::Error::new_spanned( - wc, - "#[devirt] does not support where clauses on traits", - )); - } for item in &trait_item.items { match item { syn::TraitItem::Type(t) => { return Err(syn::Error::new_spanned( t, - "#[devirt] does not support associated types", + "#[devirt] does not yet support associated types — \ + they require type parameterization of the dispatch shim", )); } syn::TraitItem::Const(c) => { return Err(syn::Error::new_spanned( c, - "#[devirt] does not support associated constants", + "#[devirt] does not support associated constants — \ + they make a trait not dyn-compatible", )); } syn::TraitItem::Fn(f) => validate_trait_method(f)?, @@ -159,6 +155,7 @@ fn emit_trait_expansion( let name = &trait_item.ident; let outer_attrs = &trait_item.attrs; let supertraits = &trait_item.supertraits; + let where_clause = &trait_item.generics.where_clause; let inner_name = format_ident!("__{name}Impl"); // __spec_* method declarations for the inner trait (with default @@ -210,7 +207,7 @@ fn emit_trait_expansion( quote! { // (1) Hidden inner trait — carries __spec_* methods. #[doc(hidden)] - #vis #unsafety trait #inner_name #inner_supers { + #vis #unsafety trait #inner_name #inner_supers #where_clause { #(#spec_decls)* } @@ -274,11 +271,11 @@ fn emit_trait_expansion( // (5) Public marker trait. #(#outer_attrs)* - #vis #unsafety trait #name: #public_supers {} + #vis #unsafety trait #name: #public_supers #where_clause {} // (6) Blanket impl. #unsafety impl<__DevirtT: #inner_name + ?Sized> #name - for __DevirtT {} + for __DevirtT #where_clause {} } .into() } @@ -563,23 +560,6 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream { .into(); }; - if !impl_item.generics.params.is_empty() { - return syn::Error::new_spanned( - &impl_item.generics, - "#[devirt] does not support generic impl blocks", - ) - .to_compile_error() - .into(); - } - if let Some(wc) = &impl_item.generics.where_clause { - return syn::Error::new_spanned( - wc, - "#[devirt] does not support where clauses on impl blocks", - ) - .to_compile_error() - .into(); - } - // Reject qualified paths — we need a plain ident to construct // the __TraitNameImpl identifier. if trait_path.leading_colon.is_some() || trait_path.segments.len() > 1 { @@ -600,6 +580,8 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream { .ident; let inner_name = format_ident!("__{trait_name}Impl"); let ty = &impl_item.self_ty; + let (impl_generics, _, where_clause) = + impl_item.generics.split_for_impl(); // Collect method names so sibling calls in impl bodies // (e.g. `self.area()`) are rewritten to `self.__spec_area()`. @@ -640,7 +622,7 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream { .collect(); quote! { - #unsafety impl #inner_name for #ty { + #unsafety impl #impl_generics #inner_name for #ty #where_clause { #(#spec_methods)* } }