Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions crates/core/tests/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,150 @@ fn attr_default_body_dispatch() {
assert_eq!((&h5 as &dyn Defaulted).describe(), "val=7");
}

// ── Where clauses on traits ────────────────────────────────────────────────

#[cfg(feature = "macros")]
mod attr_where_clause {
use core::fmt::Debug;

#[derive(Debug)]
pub struct WcHot {
pub val: u64,
}

#[derive(Debug)]
pub struct WcCold {
pub val: u64,
}

// Uses `where Self: Debug` (semantically equivalent to a supertrait).
#[devirt::devirt(WcHot)]
pub trait Inspectable
where
Self: Debug,
{
fn value(&self) -> u64;
fn inspect(&self) -> String {
format!("{:?}={}", self, self.value())
}
}

#[devirt::devirt]
impl Inspectable for WcHot {
fn value(&self) -> u64 {
self.val
}
}

#[devirt::devirt]
impl Inspectable for WcCold {
fn value(&self) -> u64 {
self.val + 1
}
}
}

#[cfg(feature = "macros")]
#[test]
fn attr_where_clause_dispatch() {
use attr_where_clause::{Inspectable, WcCold, WcHot};

let h = WcHot { val: 42 };
assert_eq!((&h as &dyn Inspectable).value(), 42);
assert!((&h as &dyn Inspectable).inspect().contains("42"));

let c = WcCold { val: 10 };
assert_eq!((&c as &dyn Inspectable).value(), 11);
assert!((&c as &dyn Inspectable).inspect().contains("11"));
}

// ── Generic impl blocks ────────────────────────────────────────────────────

#[cfg(feature = "macros")]
mod attr_generic_impl {
pub struct GHot {
pub val: u64,
}

#[devirt::devirt(GHot)]
pub trait Scale {
fn area(&self) -> u64;
}

#[devirt::devirt]
impl Scale for GHot {
fn area(&self) -> u64 {
self.val
}
}

pub struct Scaled<T> {
pub inner: T,
pub factor: u64,
}

// Generic impl — always cold. Delegates to inner through &dyn Scale.
#[devirt::devirt]
impl<T: Scale> Scale for Scaled<T> {
fn area(&self) -> u64 {
self.factor * (&self.inner as &dyn Scale).area()
}
}

// Generic impl with where clause.
pub struct Pair<A, B>
where
A: Scale,
B: Scale,
{
pub a: A,
pub b: B,
}

#[devirt::devirt]
impl<A, B> Scale for Pair<A, B>
where
A: Scale,
B: Scale,
{
fn area(&self) -> u64 {
(&self.a as &dyn Scale).area() + (&self.b as &dyn Scale).area()
}
}
}

#[cfg(feature = "macros")]
#[test]
fn attr_generic_impl_dispatch() {
use attr_generic_impl::{GHot, Pair, Scale, Scaled};

// Hot
let h = GHot { val: 10 };
assert_eq!((&h as &dyn Scale).area(), 10);

// Scaled<GHot> — cold outer, hot inner
let s = Scaled { inner: GHot { val: 5 }, factor: 3 };
assert_eq!((&s as &dyn Scale).area(), 15);

// Scaled<Scaled<GHot>> — nested generic
let ss = Scaled {
inner: Scaled { inner: GHot { val: 2 }, factor: 4 },
factor: 5,
};
assert_eq!((&ss as &dyn Scale).area(), 40);

// Pair with where clause
let p = Pair {
a: GHot { val: 3 },
b: Scaled { inner: GHot { val: 2 }, factor: 4 },
};
assert_eq!((&p as &dyn Scale).area(), 3 + 8);

// Via auto-trait flavors
assert_eq!((&s as &(dyn Scale + Send)).area(), 15);
assert_eq!((&s as &(dyn Scale + Send + Sync)).area(), 15);
}

// ── Extended proc-macro tests: supertraits, method lifetimes, #[must_use] ──

#[cfg(feature = "macros")]
Expand Down
3 changes: 3 additions & 0 deletions crates/core/tests/ui_attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ fn ui_attr() {
t.pass("tests/ui_attr/attr_default_body.rs");
t.pass("tests/ui_attr/attr_default_override.rs");
t.pass("tests/ui_attr/attr_default_send.rs");
t.pass("tests/ui_attr/attr_where_clause.rs");
t.pass("tests/ui_attr/attr_generic_impl.rs");
t.pass("tests/ui_attr/attr_where_impl.rs");
t.compile_fail("tests/ui_attr/attr_must_use_unused.rs");
t.compile_fail("tests/ui_attr/attr_missing_args.rs");
t.compile_fail("tests/ui_attr/attr_unsafe_missing_on_impl.rs");
Expand Down
42 changes: 42 additions & 0 deletions crates/core/tests/ui_attr/attr_generic_impl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
struct Hot {
val: f64,
}

#[devirt::devirt(Hot)]
pub trait Shape {
fn area(&self) -> f64;
}

#[devirt::devirt]
impl Shape for Hot {
fn area(&self) -> f64 {
self.val
}
}

// Generic impl — cold type, falls back to vtable
struct Scaled<T> {
inner: T,
factor: f64,
}

#[devirt::devirt]
impl<T: Shape> Shape for Scaled<T> {
fn area(&self) -> f64 {
// Coerce to &dyn Shape for devirt dispatch. The rewriter only
// rewrites `self.method()` calls, not `self.inner.method()`.
self.factor * (&self.inner as &dyn Shape).area()
}
}

fn total(shapes: &[Box<dyn Shape>]) -> f64 {
shapes.iter().map(|s| s.area()).sum()
}

fn main() {
let shapes: Vec<Box<dyn Shape>> = vec![
Box::new(Hot { val: 10.0 }),
Box::new(Scaled { inner: Hot { val: 5.0 }, factor: 3.0 }),
];
assert_eq!(total(&shapes), 25.0); // 10 + 5*3
}
48 changes: 48 additions & 0 deletions crates/core/tests/ui_attr/attr_where_clause.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use std::fmt::Debug;

#[derive(Debug)]
struct Hot {
val: u64,
}

#[derive(Debug)]
struct Cold {
val: u64,
}

#[devirt::devirt(Hot)]
pub trait Inspectable
where
Self: Debug,
{
fn value(&self) -> u64;
fn inspect(&self) -> String {
format!("{:?} = {}", self, self.value())
}
}

#[devirt::devirt]
impl Inspectable for Hot {
fn value(&self) -> u64 {
self.val
}
}

#[devirt::devirt]
impl Inspectable for Cold {
fn value(&self) -> u64 {
self.val + 1
}
}

fn check(i: &dyn Inspectable) -> String {
i.inspect()
}

fn main() {
let h = Hot { val: 42 };
assert!(check(&h).contains("42"));

let c = Cold { val: 10 };
assert!(check(&c).contains("11"));
}
34 changes: 34 additions & 0 deletions crates/core/tests/ui_attr/attr_where_impl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::fmt::Display;

struct Hot {
val: f64,
}

#[devirt::devirt(Hot)]
pub trait Shape {
fn describe(&self) -> String;
}

#[devirt::devirt]
impl Shape for Hot {
fn describe(&self) -> String {
format!("hot: {}", self.val)
}
}

struct Named<T> {
name: String,
inner: T,
}

#[devirt::devirt]
impl<T> Shape for Named<T>
where
T: Shape + Display,
{
fn describe(&self) -> String {
format!("{}: {}", self.name, self.inner)
}
}

fn main() {}
40 changes: 11 additions & 29 deletions crates/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,20 @@ fn validate_trait(trait_item: &syn::ItemTrait) -> Result<(), syn::Error> {
"#[devirt] does not support generic traits",
));
}
if let Some(wc) = &trait_item.generics.where_clause {
return Err(syn::Error::new_spanned(
wc,
"#[devirt] does not support where clauses on traits",
));
}
for item in &trait_item.items {
match item {
syn::TraitItem::Type(t) => {
return Err(syn::Error::new_spanned(
t,
"#[devirt] does not support associated types",
"#[devirt] does not yet support associated types — \
they require type parameterization of the dispatch shim",
));
}
syn::TraitItem::Const(c) => {
return Err(syn::Error::new_spanned(
c,
"#[devirt] does not support associated constants",
"#[devirt] does not support associated constants — \
they make a trait not dyn-compatible",
));
}
syn::TraitItem::Fn(f) => validate_trait_method(f)?,
Expand Down Expand Up @@ -159,6 +155,7 @@ fn emit_trait_expansion(
let name = &trait_item.ident;
let outer_attrs = &trait_item.attrs;
let supertraits = &trait_item.supertraits;
let where_clause = &trait_item.generics.where_clause;
let inner_name = format_ident!("__{name}Impl");

// __spec_* method declarations for the inner trait (with default
Expand Down Expand Up @@ -210,7 +207,7 @@ fn emit_trait_expansion(
quote! {
// (1) Hidden inner trait — carries __spec_* methods.
#[doc(hidden)]
#vis #unsafety trait #inner_name #inner_supers {
#vis #unsafety trait #inner_name #inner_supers #where_clause {
#(#spec_decls)*
}

Expand Down Expand Up @@ -274,11 +271,11 @@ fn emit_trait_expansion(

// (5) Public marker trait.
#(#outer_attrs)*
#vis #unsafety trait #name: #public_supers {}
#vis #unsafety trait #name: #public_supers #where_clause {}

// (6) Blanket impl.
#unsafety impl<__DevirtT: #inner_name + ?Sized> #name
for __DevirtT {}
for __DevirtT #where_clause {}
}
.into()
}
Expand Down Expand Up @@ -563,23 +560,6 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream {
.into();
};

if !impl_item.generics.params.is_empty() {
return syn::Error::new_spanned(
&impl_item.generics,
"#[devirt] does not support generic impl blocks",
)
.to_compile_error()
.into();
}
if let Some(wc) = &impl_item.generics.where_clause {
return syn::Error::new_spanned(
wc,
"#[devirt] does not support where clauses on impl blocks",
)
.to_compile_error()
.into();
}

// Reject qualified paths — we need a plain ident to construct
// the __TraitNameImpl identifier.
if trait_path.leading_colon.is_some() || trait_path.segments.len() > 1 {
Expand All @@ -600,6 +580,8 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream {
.ident;
let inner_name = format_ident!("__{trait_name}Impl");
let ty = &impl_item.self_ty;
let (impl_generics, _, where_clause) =
impl_item.generics.split_for_impl();

// Collect method names so sibling calls in impl bodies
// (e.g. `self.area()`) are rewritten to `self.__spec_area()`.
Expand Down Expand Up @@ -640,7 +622,7 @@ fn expand_impl(attr: &TokenStream, impl_item: &syn::ItemImpl) -> TokenStream {
.collect();

quote! {
#unsafety impl #inner_name for #ty {
#unsafety impl #impl_generics #inner_name for #ty #where_clause {
#(#spec_methods)*
}
}
Expand Down