diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index 79c007d7aa..4a4b5c5db9 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -159,3 +159,28 @@ class Particle: vel: qd.types.vector(3, qd.f32) mass: qd.f32 ``` + +### Inheritance + +A `@qd.dataclass` may subclass one or more other `@qd.dataclass` types. The subclass inherits the parent's members and `@qd.func` methods, and can add its own or override inherited ones: + +```python +@qd.dataclass +class Body: + pos: qd.types.vector(3, qd.f32) + mass: qd.f32 + + @qd.func + def momentum(self): + return self.mass * self.pos + +@qd.dataclass +class ChargedBody(Body): + charge: qd.f32 # added to inherited `pos` and `mass` + + @qd.func + def charged_mass(self): + return self.mass + self.charge +``` + +`ChargedBody` ends up with members `pos`, `mass`, `charge` (parent members first, in declaration order) and both `momentum` and `charged_mass` methods. A member or method declared on the subclass overrides the inherited one of the same name. diff --git a/python/quadrants/lang/struct.py b/python/quadrants/lang/struct.py index c86efb1a99..5ecfdb61e6 100644 --- a/python/quadrants/lang/struct.py +++ b/python/quadrants/lang/struct.py @@ -617,6 +617,13 @@ def __init__(self, **kwargs): elements.append([dtype, k]) self.dtype = _qd_core.get_type_factory_instance().get_struct_type(elements) + def __mro_entries__(self, bases): + # A ``@qd.dataclass`` is a ``StructType`` *instance*, not a class, so it cannot normally appear in a base-class + # list. PEP 560 lets us return a real placeholder class to stand in for it; that placeholder carries a + # back-reference to this ``StructType`` so ``dataclass`` can merge our members and methods into the subclass. + placeholder = type("_QuadrantsDataclassBase", (object,), {"_quadrants_struct_type": self}) + return (placeholder,) + def __call__(self, *args, **kwargs): """Create an instance of this struct type.""" d = {} @@ -817,18 +824,37 @@ def dataclass(cls): A quadrants struct with the annotations as fields and methods from the class attached. """ + # Merge in members and methods from any @qd.dataclass base classes. A @qd.dataclass parent is a StructType + # instance, so it appears in __bases__ as the placeholder class injected by StructType.__mro_entries__, which + # back-references the parent via _quadrants_struct_type. Bases are visited left-to-right and an already-seen name + # is never overwritten, so on a conflict the leftmost base wins — matching Python's MRO precedence for + # ``class C(A, B)`` — while members keep their natural leftmost-base-first order. + inherited_fields = {} + inherited_methods = {} + for base in getattr(cls, "__bases__", ()): + parent_struct = getattr(base, "_quadrants_struct_type", None) + if parent_struct is not None: + for member_name, member_type in parent_struct.members.items(): + inherited_fields.setdefault(member_name, member_type) + for method_name, method in parent_struct.methods.items(): + inherited_methods.setdefault(method_name, method) + # save the annotation fields for the struct - fields = getattr(cls, "__annotations__", {}) + own_fields = dict(getattr(cls, "__annotations__", {})) # raise error if there are default values - for k in fields.keys(): + for k in own_fields.keys(): if hasattr(cls, k): raise QuadrantsSyntaxError("Default value in @dataclass is not supported.") + fields = {**inherited_fields, **own_fields} # get the class methods to be attached to the struct types - fields["__struct_methods"] = { + own_methods = { attribute: getattr(cls, attribute) for attribute in dir(cls) - if callable(getattr(cls, attribute)) and not attribute.startswith("__") + if callable(getattr(cls, attribute)) + and not attribute.startswith("__") + and attribute != "_quadrants_struct_type" } + fields["__struct_methods"] = {**inherited_methods, **own_methods} return StructType(**fields) diff --git a/tests/python/test_struct.py b/tests/python/test_struct.py index de6d249970..9e178f0f75 100644 --- a/tests/python/test_struct.py +++ b/tests/python/test_struct.py @@ -132,6 +132,106 @@ def foo(x: Foo) -> Foo: assert c.y == 8 +@test_utils.test() +def test_data_class_inheritance(): + @qd.dataclass + class Base: + x: qd.f32 + y: qd.f32 + + @qd.func + def sum_xy(self): + return self.x + self.y + + @qd.dataclass + class Child(Base): + z: qd.f32 + + @qd.func + def sum_xyz(self): + return self.sum_xy() + self.z + + # Inherited members come first, in declaration order, followed by own members. + assert list(Child.members.keys()) == ["x", "y", "z"] + # Both inherited and own methods are attached. + assert "sum_xy" in Child.methods + assert "sum_xyz" in Child.methods + + @qd.kernel + def use_inherited(c: Child) -> qd.f32: + return c.sum_xy() + + @qd.kernel + def use_own(c: Child) -> qd.f32: + return c.sum_xyz() + + assert use_inherited(Child(1, 2, 4)) == 3 + assert use_own(Child(1, 2, 4)) == 7 + + +@test_utils.test() +def test_data_class_inheritance_override(): + @qd.dataclass + class Base: + x: qd.f32 + + @qd.func + def value(self): + return self.x + + @qd.dataclass + class Child(Base): + x: qd.f32 + y: qd.f32 + + @qd.func + def value(self): + return self.x + self.y + + assert list(Child.members.keys()) == ["x", "y"] + + @qd.kernel + def k(c: Child) -> qd.f32: + return c.value() + + assert k(Child(1, 2)) == 3 + + +@test_utils.test() +def test_data_class_multiple_inheritance_leftmost_wins(): + @qd.dataclass + class A: + x: qd.i32 + + @qd.func + def who(self): + return 1 + + @qd.dataclass + class B: + x: qd.f32 + y: qd.f32 + + @qd.func + def who(self): + return 2 + + @qd.dataclass + class C(A, B): + pass + + # On a conflict the leftmost base (A) wins, matching Python's MRO; distinct members keep + # leftmost-base-first order (A's `x`, then B's new `y`). + assert list(C.members.keys()) == ["x", "y"] + assert C.members["x"] == qd.i32 + + @qd.kernel + def k(c: C) -> qd.i32: + return c.who() + + assert k(C(x=0, y=0)) == 1 + + @test_utils.test() def test_nested_data_class_func(): @qd.dataclass