Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/source/user_guide/compound_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,28 @@ class Particle:
vel: qd.types.vector(3, qd.f32)
mass: qd.f32
```

### Inheritance

A `@qd.dataclass` may subclass one or more other `@qd.dataclass` types. The subclass inherits the parent's members and `@qd.func` methods, and can add its own or override inherited ones:

```python
@qd.dataclass
class Body:
pos: qd.types.vector(3, qd.f32)
mass: qd.f32

@qd.func
def momentum(self):
return self.mass * self.pos

@qd.dataclass
class ChargedBody(Body):
charge: qd.f32 # added to inherited `pos` and `mass`

@qd.func
def charged_mass(self):
return self.mass + self.charge
```

`ChargedBody` ends up with members `pos`, `mass`, `charge` (parent members first, in declaration order) and both `momentum` and `charged_mass` methods. A member or method declared on the subclass overrides the inherited one of the same name.
34 changes: 30 additions & 4 deletions python/quadrants/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,13 @@ def __init__(self, **kwargs):
elements.append([dtype, k])
self.dtype = _qd_core.get_type_factory_instance().get_struct_type(elements)

def __mro_entries__(self, bases):
# A ``@qd.dataclass`` is a ``StructType`` *instance*, not a class, so it cannot normally appear in a base-class
# list. PEP 560 lets us return a real placeholder class to stand in for it; that placeholder carries a
# back-reference to this ``StructType`` so ``dataclass`` can merge our members and methods into the subclass.
placeholder = type("_QuadrantsDataclassBase", (object,), {"_quadrants_struct_type": self})
return (placeholder,)

def __call__(self, *args, **kwargs):
"""Create an instance of this struct type."""
d = {}
Expand Down Expand Up @@ -817,18 +824,37 @@ def dataclass(cls):
A quadrants struct with the annotations as fields
and methods from the class attached.
"""
# Merge in members and methods from any @qd.dataclass base classes. A @qd.dataclass parent is a StructType
# instance, so it appears in __bases__ as the placeholder class injected by StructType.__mro_entries__, which
# back-references the parent via _quadrants_struct_type. Bases are visited left-to-right and an already-seen name
# is never overwritten, so on a conflict the leftmost base wins — matching Python's MRO precedence for
# ``class C(A, B)`` — while members keep their natural leftmost-base-first order.
inherited_fields = {}
inherited_methods = {}
for base in getattr(cls, "__bases__", ()):
parent_struct = getattr(base, "_quadrants_struct_type", None)
if parent_struct is not None:
for member_name, member_type in parent_struct.members.items():
inherited_fields.setdefault(member_name, member_type)
for method_name, method in parent_struct.methods.items():
inherited_methods.setdefault(method_name, method)

# save the annotation fields for the struct
fields = getattr(cls, "__annotations__", {})
own_fields = dict(getattr(cls, "__annotations__", {}))
# raise error if there are default values
for k in fields.keys():
for k in own_fields.keys():
if hasattr(cls, k):
raise QuadrantsSyntaxError("Default value in @dataclass is not supported.")
fields = {**inherited_fields, **own_fields}
# get the class methods to be attached to the struct types
fields["__struct_methods"] = {
own_methods = {
attribute: getattr(cls, attribute)
for attribute in dir(cls)
if callable(getattr(cls, attribute)) and not attribute.startswith("__")
if callable(getattr(cls, attribute))
and not attribute.startswith("__")
and attribute != "_quadrants_struct_type"
}
fields["__struct_methods"] = {**inherited_methods, **own_methods}
return StructType(**fields)


Expand Down
100 changes: 100 additions & 0 deletions tests/python/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,106 @@ def foo(x: Foo) -> Foo:
assert c.y == 8


@test_utils.test()
def test_data_class_inheritance():
@qd.dataclass
class Base:
x: qd.f32
y: qd.f32

@qd.func
def sum_xy(self):
return self.x + self.y

@qd.dataclass
class Child(Base):
z: qd.f32

@qd.func
def sum_xyz(self):
return self.sum_xy() + self.z

# Inherited members come first, in declaration order, followed by own members.
assert list(Child.members.keys()) == ["x", "y", "z"]
# Both inherited and own methods are attached.
assert "sum_xy" in Child.methods
assert "sum_xyz" in Child.methods

@qd.kernel
def use_inherited(c: Child) -> qd.f32:
return c.sum_xy()

@qd.kernel
def use_own(c: Child) -> qd.f32:
return c.sum_xyz()

assert use_inherited(Child(1, 2, 4)) == 3
assert use_own(Child(1, 2, 4)) == 7


@test_utils.test()
def test_data_class_inheritance_override():
@qd.dataclass
class Base:
x: qd.f32

@qd.func
def value(self):
return self.x

@qd.dataclass
class Child(Base):
x: qd.f32
y: qd.f32

@qd.func
def value(self):
return self.x + self.y

assert list(Child.members.keys()) == ["x", "y"]

@qd.kernel
def k(c: Child) -> qd.f32:
return c.value()

assert k(Child(1, 2)) == 3


@test_utils.test()
def test_data_class_multiple_inheritance_leftmost_wins():
@qd.dataclass
class A:
x: qd.i32

@qd.func
def who(self):
return 1

@qd.dataclass
class B:
x: qd.f32
y: qd.f32

@qd.func
def who(self):
return 2

@qd.dataclass
class C(A, B):
pass

# On a conflict the leftmost base (A) wins, matching Python's MRO; distinct members keep
# leftmost-base-first order (A's `x`, then B's new `y`).
assert list(C.members.keys()) == ["x", "y"]
assert C.members["x"] == qd.i32

@qd.kernel
def k(c: C) -> qd.i32:
return c.who()

assert k(C(x=0, y=0)) == 1


@test_utils.test()
def test_nested_data_class_func():
@qd.dataclass
Expand Down
Loading