From 9dee07f16d883d343773d89de826e9ef435420ba Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 14 Apr 2026 23:50:36 +0000 Subject: [PATCH] Support where clauses on traits and generic impl blocks Remove rejections in expand_trait and expand_impl, forwarding the generics and where clauses to the generated inner trait, public marker trait, blanket impl, and impl block. Generic impls are always cold (can't be listed as hot types) but fall back to vtable dispatch correctly via the existing shim. Improve rejection messages for associated types and associated constants to explain *why* they're rejected (type parameterization requirement and dyn-compatibility respectively). Add trybuild tests (attr_where_clause, attr_generic_impl, attr_where_impl) and equivalence tests (attr_where_clause_dispatch, attr_generic_impl_dispatch) covering nested generic impls and where-clause-constrained impls through &dyn, &(dyn + Send), and &(dyn + Send + Sync). https://claude.ai/code/session_01BJen3TbKCFohcjNFeYwBiA --- crates/core/tests/equivalence.rs | 144 ++++++++++++++++++ crates/core/tests/ui_attr.rs | 3 + .../core/tests/ui_attr/attr_generic_impl.rs | 42 +++++ .../core/tests/ui_attr/attr_where_clause.rs | 48 ++++++ crates/core/tests/ui_attr/attr_where_impl.rs | 34 +++++ crates/macros/src/lib.rs | 40 ++--- 6 files changed, 282 insertions(+), 29 deletions(-) create mode 100644 crates/core/tests/ui_attr/attr_generic_impl.rs create mode 100644 crates/core/tests/ui_attr/attr_where_clause.rs create mode 100644 crates/core/tests/ui_attr/attr_where_impl.rs diff --git a/crates/core/tests/equivalence.rs b/crates/core/tests/equivalence.rs index 2bbc63d..12d325c 100644 --- a/crates/core/tests/equivalence.rs +++ b/crates/core/tests/equivalence.rs @@ -253,6 +253,150 @@ fn attr_default_body_dispatch() { assert_eq!((&h5 as &dyn Defaulted).describe(), "val=7"); } +// ── Where clauses on traits ──────────────────────────────────────────────── + +#[cfg(feature = "macros")] +mod attr_where_clause { + use core::fmt::Debug; + + #[derive(Debug)] + pub struct WcHot { + pub val: u64, + } + + #[derive(Debug)] + pub struct WcCold { + pub val: u64, + } + + // Uses `where Self: Debug` (semantically equivalent to a supertrait). + #[devirt::devirt(WcHot)] + pub trait Inspectable + where + Self: Debug, + { + fn value(&self) -> u64; + fn inspect(&self) -> String { + format!("{:?}={}", self, self.value()) + } + } + + #[devirt::devirt] + impl Inspectable for WcHot { + fn value(&self) -> u64 { + self.val + } + } + + #[devirt::devirt] + impl Inspectable for WcCold { + fn value(&self) -> u64 { + self.val + 1 + } + } +} + +#[cfg(feature = "macros")] +#[test] +fn attr_where_clause_dispatch() { + use attr_where_clause::{Inspectable, WcCold, WcHot}; + + let h = WcHot { val: 42 }; + assert_eq!((&h as &dyn Inspectable).value(), 42); + assert!((&h as &dyn Inspectable).inspect().contains("42")); + + let c = WcCold { val: 10 }; + assert_eq!((&c as &dyn Inspectable).value(), 11); + assert!((&c as &dyn Inspectable).inspect().contains("11")); +} + +// ── Generic impl blocks ──────────────────────────────────────────────────── + +#[cfg(feature = "macros")] +mod attr_generic_impl { + pub struct GHot { + pub val: u64, + } + + #[devirt::devirt(GHot)] + pub trait Scale { + fn area(&self) -> u64; + } + + #[devirt::devirt] + impl Scale for GHot { + fn area(&self) -> u64 { + self.val + } + } + + pub struct Scaled { + pub inner: T, + pub factor: u64, + } + + // Generic impl — always cold. Delegates to inner through &dyn Scale. + #[devirt::devirt] + impl Scale for Scaled { + fn area(&self) -> u64 { + self.factor * (&self.inner as &dyn Scale).area() + } + } + + // Generic impl with where clause. + pub struct Pair + where + A: Scale, + B: Scale, + { + pub a: A, + pub b: B, + } + + #[devirt::devirt] + impl Scale for Pair + where + A: Scale, + B: Scale, + { + fn area(&self) -> u64 { + (&self.a as &dyn Scale).area() + (&self.b as &dyn Scale).area() + } + } +} + +#[cfg(feature = "macros")] +#[test] +fn attr_generic_impl_dispatch() { + use attr_generic_impl::{GHot, Pair, Scale, Scaled}; + + // Hot + let h = GHot { val: 10 }; + assert_eq!((&h as &dyn Scale).area(), 10); + + // Scaled — cold outer, hot inner + let s = Scaled { inner: GHot { val: 5 }, factor: 3 }; + assert_eq!((&s as &dyn Scale).area(), 15); + + // Scaled> — nested generic + let ss = Scaled { + inner: Scaled { inner: GHot { val: 2 }, factor: 4 }, + factor: 5, + }; + assert_eq!((&ss as &dyn Scale).area(), 40); + + // Pair with where clause + let p = Pair { + a: GHot { val: 3 }, + b: Scaled { inner: GHot { val: 2 }, factor: 4 }, + }; + assert_eq!((&p as &dyn Scale).area(), 3 + 8); + + // Via auto-trait flavors + assert_eq!((&s as &(dyn Scale + Send)).area(), 15); + assert_eq!((&s as &(dyn Scale + Send + Sync)).area(), 15); +} + // ── Extended proc-macro tests: supertraits, method lifetimes, #[must_use] ── #[cfg(feature = "macros")] diff --git a/crates/core/tests/ui_attr.rs b/crates/core/tests/ui_attr.rs index 2913e4a..4cf2369 100644 --- a/crates/core/tests/ui_attr.rs +++ b/crates/core/tests/ui_attr.rs @@ -16,6 +16,9 @@ fn ui_attr() { t.pass("tests/ui_attr/attr_default_body.rs"); t.pass("tests/ui_attr/attr_default_override.rs"); t.pass("tests/ui_attr/attr_default_send.rs"); + t.pass("tests/ui_attr/attr_where_clause.rs"); + t.pass("tests/ui_attr/attr_generic_impl.rs"); + t.pass("tests/ui_attr/attr_where_impl.rs"); t.compile_fail("tests/ui_attr/attr_must_use_unused.rs"); t.compile_fail("tests/ui_attr/attr_missing_args.rs"); t.compile_fail("tests/ui_attr/attr_unsafe_missing_on_impl.rs"); diff --git a/crates/core/tests/ui_attr/attr_generic_impl.rs b/crates/core/tests/ui_attr/attr_generic_impl.rs new file mode 100644 index 0000000..031e6ae --- /dev/null +++ b/crates/core/tests/ui_attr/attr_generic_impl.rs @@ -0,0 +1,42 @@ +struct Hot { + val: f64, +} + +#[devirt::devirt(Hot)] +pub trait Shape { + fn area(&self) -> f64; +} + +#[devirt::devirt] +impl Shape for Hot { + fn area(&self) -> f64 { + self.val + } +} + +// Generic impl — cold type, falls back to vtable +struct Scaled { + inner: T, + factor: f64, +} + +#[devirt::devirt] +impl Shape for Scaled { + fn area(&self) -> f64 { + // Coerce to &dyn Shape for devirt dispatch. The rewriter only + // rewrites `self.method()` calls, not `self.inner.method()`. + self.factor * (&self.inner as &dyn Shape).area() + } +} + +fn total(shapes: &[Box]) -> f64 { + shapes.iter().map(|s| s.area()).sum() +} + +fn main() { + let shapes: Vec> = vec![ + Box::new(Hot { val: 10.0 }), + Box::new(Scaled { inner: Hot { val: 5.0 }, factor: 3.0 }), + ]; + assert_eq!(total(&shapes), 25.0); // 10 + 5*3 +} diff --git a/crates/core/tests/ui_attr/attr_where_clause.rs b/crates/core/tests/ui_attr/attr_where_clause.rs new file mode 100644 index 0000000..154b6ff --- /dev/null +++ b/crates/core/tests/ui_attr/attr_where_clause.rs @@ -0,0 +1,48 @@ +use std::fmt::Debug; + +#[derive(Debug)] +struct Hot { + val: u64, +} + +#[derive(Debug)] +struct Cold { + val: u64, +} + +#[devirt::devirt(Hot)] +pub trait Inspectable +where + Self: Debug, +{ + fn value(&self) -> u64; + fn inspect(&self) -> String { + format!("{:?} = {}", self, self.value()) + } +} + +#[devirt::devirt] +impl Inspectable for Hot { + fn value(&self) -> u64 { + self.val + } +} + +#[devirt::devirt] +impl Inspectable for Cold { + fn value(&self) -> u64 { + self.val + 1 + } +} + +fn check(i: &dyn Inspectable) -> String { + i.inspect() +} + +fn main() { + let h = Hot { val: 42 }; + assert!(check(&h).contains("42")); + + let c = Cold { val: 10 }; + assert!(check(&c).contains("11")); +} diff --git a/crates/core/tests/ui_attr/attr_where_impl.rs b/crates/core/tests/ui_attr/attr_where_impl.rs new file mode 100644 index 0000000..389c69f --- /dev/null +++ b/crates/core/tests/ui_attr/attr_where_impl.rs @@ -0,0 +1,34 @@ +use std::fmt::Display; + +struct Hot { + val: f64, +} + +#[devirt::devirt(Hot)] +pub trait Shape { + fn describe(&self) -> String; +} + +#[devirt::devirt] +impl Shape for Hot { + fn describe(&self) -> String { + format!("hot: {}", self.val) + } +} + +struct Named { + name: String, + inner: T, +} + +#[devirt::devirt] +impl Shape for Named +where + T: Shape + Display, +{ + fn describe(&self) -> String { + format!("{}: {}", self.name, self.inner) + } +} + +fn main() {} diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index 175a9ef..7ed910a 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -82,24 +82,20 @@ fn validate_trait(trait_item: &syn::ItemTrait) -> Result<(), syn::Error> { "#[devirt] does not support generic traits", )); } - if let Some(wc) = &trait_item.generics.where_clause { - return Err(syn::Error::new_spanned( - wc, - "#[devirt] does not support where clauses on traits", - )); - } for item in &trait_item.items { match item { syn::TraitItem::Type(t) => { return Err(syn::Error::new_spanned( t, - "#[devirt] does not support associated types", + "#[devirt] does not yet support associated types — \ + they require type parameterization of the dispatch shim", )); } syn::TraitItem::Const(c) => { return Err(syn::Error::new_spanned( c, - "#[devirt] does not support associated constants", + "#[devirt] does not support associated constants — \ + they make a trait not dyn-compatible", )); } syn::TraitItem::Fn(f) => validate_trait_method(f)?, @@ -159,6 +155,7 @@ fn emit_trait_expansion( let name = &trait_item.ident; let outer_attrs = &trait_item.attrs; let supertraits = &trait_item.supertraits; + let where_clause = &trait_item.generics.where_clause; let inner_name = format_ident!("__{name}Impl"); // __spec_* method declarations for the inner trait (with default @@ -210,7 +207,7 @@ fn emit_trait_expansion( quote! { // (1) Hidden inner trait — carries __spec_* methods. #[doc(hidden)] - #vis #unsafety trait #inner_name #inner_supers { + #vis #unsafety trait #inner_name #inner_supers #where_clause { #(#spec_decls)* } @@ -274,11 +271,11 @@ fn emit_trait_expansion( // (5) Public marker trait. #(#outer_attrs)* - #vis #unsafety trait #name: #public_supers {} + #vis #unsafety trait #name: #public_supers #where_clause {} // (6) Blanket impl. #unsafety impl<__DevirtT: #inner_name + ?Sized> #name - for __DevirtT {} + for __DevirtT #where_clause {} } .into() } @@ -563,23 +560,6 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream { .into(); }; - if !impl_item.generics.params.is_empty() { - return syn::Error::new_spanned( - &impl_item.generics, - "#[devirt] does not support generic impl blocks", - ) - .to_compile_error() - .into(); - } - if let Some(wc) = &impl_item.generics.where_clause { - return syn::Error::new_spanned( - wc, - "#[devirt] does not support where clauses on impl blocks", - ) - .to_compile_error() - .into(); - } - // Reject qualified paths — we need a plain ident to construct // the __TraitNameImpl identifier. if trait_path.leading_colon.is_some() || trait_path.segments.len() > 1 { @@ -600,6 +580,8 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream { .ident; let inner_name = format_ident!("__{trait_name}Impl"); let ty = &impl_item.self_ty; + let (impl_generics, _, where_clause) = + impl_item.generics.split_for_impl(); // Collect method names so sibling calls in impl bodies // (e.g. `self.area()`) are rewritten to `self.__spec_area()`. @@ -640,7 +622,7 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream { .collect(); quote! { - #unsafety impl #inner_name for #ty { + #unsafety impl #impl_generics #inner_name for #ty #where_clause { #(#spec_methods)* } }