HyperParallel provides DFunction, a base class for writing custom distributed
functions that integrate seamlessly with the DTensor dispatch system. Users
subclass DFunction (which inherits from the platform's autograd Function)
and implement forward / backward as plain local-tensor operations. When the
inputs are DTensors, the dispatch system transparently handles layout inference,
local-tensor extraction, and output wrapping — no changes to user code are needed
for the multi-card path.
Base class for user-defined distributed autograd functions.
class DFunction(platform.Function):
_op_name: str = None # Set this to the registered DistributedOp name
@staticmethod
def forward(ctx, *args, **kwargs) -> Tensor: ...
@staticmethod
def backward(ctx, *grad_outputs) -> ...: ...
@classmethod
def apply(cls, *args, **kwargs) -> Tensor | DTensor: ...Dispatch behaviour:
| Input type | Route |
|---|---|
Plain Tensor |
super().apply() — platform autograd, single-device path |
At least one DTensor |
_OP_DISPATCHER.dispatch() — layout inference + DTensor wrapping |
_op_name must be set when DTensor inputs are expected. It must match the
op_name passed to a registered DistributedOp subclass. Omitting _op_name
while passing a DTensor raises ValueError.
forward(ctx, *args) operates on local tensors. When the distributed
path is taken, the dispatcher extracts the local shard from each input DTensor
and passes it here. Non-tensor positional arguments are forwarded unchanged via
the preprocess → _dispatch_new path (see DistributedOp.preprocess).
backward(ctx, *grad_outputs) likewise operates on local tensors. It is
identical to a standard torch.autograd.Function.backward or MindSpore's
_Function.backward.
Base class for layout inference and optional pre/post-processing.
class DistributedOp:
def __init__(self, op_name: str): ...
def preprocess(self, args: tuple, kwargs: dict) -> None | tuple: ...
def infer_layout(self, layouts_or_cache, extra_args=None) -> Layout | tuple: ...
def get_expand_impl(self, func, infer_result, layouts, extra_args=None) -> None | Callable: ...Instantiating a DistributedOp subclass automatically registers it under
op_name. Exactly one instance must exist per op_name before the first call
to the corresponding DFunction.apply.
Optional. Called once before layout inference on the first dispatch (cached thereafter). Returns either:
None— fall through to the legacy dispatch path.(local_args, local_kwargs, cache_values)— take the new dispatch path.
local_args / local_kwargs are the local-tensor positional / keyword arguments
to pass to the user's forward. cache_values is an ordered list of values used
as the layout cache key (typically Layout objects plus scalars such as bool,
int).
Use preprocess when:
- Non-tensor positional arguments must be forwarded to
forward. - Finer control over what ends up in
local_argsis needed.
Computes the output layout(s).
| Dispatch path | Signature |
|---|---|
Legacy (no preprocess) |
infer_layout(input_layouts: tuple, extra_args: list) -> Layout |
New (preprocess returned a tuple) |
infer_layout(cache_values: list) -> (out_layouts_tuple, None) |
For multi-output ops, return a tuple of Layout objects.
Optional. Returns None (default) or a callable that replaces the default
func(*local_args) call. Use this to modify how local arguments are combined
before the computation — the canonical example is bias scaling in row-parallel
linear where each rank needs bias / tp_size instead of the full bias.
from hyper_parallel import init_device_mesh, DFunction
from hyper_parallel.core.dtensor.dtensor import distribute_tensor
from hyper_parallel.core.dtensor.placement_types import Shard, Replicate
from hyper_parallel.core.shard.ops.parallel_ops import DistributedOp
# ── Step 1: Register a DistributedOp ─────────────────────────────────────────
class MyAddDistOp(DistributedOp):
def __init__(self):
super().__init__("MyAdd")
def infer_layout(self, layouts, extra_args=None):
return layouts[0] # element-wise: output layout = first input layout
MyAddDistOp() # instantiation registers the op
# ── Step 2: Implement DFunction ───────────────────────────────────────────
class MyAdd(DFunction):
_op_name = "MyAdd"
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x + y
@staticmethod
def backward(ctx, grad):
return grad, grad
# ── Step 3: Call ─────────────────────────────────────────────────────────────
# Single-device (plain tensors)
result = MyAdd.apply(x_local, y_local)
# Multi-device (DTensors — dispatched automatically)
mesh = init_device_mesh("npu", (2, 4), mesh_dim_names=("dp", "tp"))
x_dist = distribute_tensor(x, mesh, (Shard(0), Replicate()))
y_dist = distribute_tensor(y, mesh, (Shard(0), Replicate()))
result_dist = MyAdd.apply(x_dist, y_dist) # returns DTensorNo preprocess override — simplest case. infer_layout receives the list of
input Layout objects and returns the output Layout.
class _ScaleDistOp(DistributedOp):
def __init__(self):
super().__init__("Scale")
def infer_layout(self, layouts, extra_args=None):
return layouts[0] # scale is element-wise
_ScaleDistOp()
class ScaleFunc(DFunction):
_op_name = "Scale"
@staticmethod
def forward(ctx, x, scale):
ctx.save_for_backward(x)
ctx.scale = scale
return x * scale
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None # None for the non-tensor scaleNote: Non-tensor positional arguments (e.g.
scale) are not forwarded toforwardon the legacy dispatch path. Either pass them askwargs, or use the new dispatch path by implementingpreprocess.
Implement preprocess to extract local tensors and build cache_values.
infer_layout then receives cache_values and returns (out_layouts_tuple, None).
class LinColDistOp(DistributedOp):
def __init__(self):
super().__init__("LinCol")
def preprocess(self, args, kwargs):
x, w = args[0], args[1]
local_args = (x.to_local(), w.to_local())
cache_values = [x.layout, w.layout]
return local_args, {}, cache_values
def infer_layout(self, cache_values):
x_layout, w_layout = cache_values[0], cache_values[1]
# derive output layout from x[.., in] and w[out, in] ...
out_layout = _linear_output_layout(x_layout, w_layout)
return ((out_layout,), None)
def get_expand_impl(self, func, infer_result, cache_values):
return None # no bias → no scaling needed
LinColDistOp()
class LinColFunc(DFunction):
_op_name = "LinCol"
@staticmethod
def forward(ctx, x, w):
ctx.save_for_backward(x, w)
return F.linear(x, w)
@staticmethod
def backward(ctx, grad):
x, w = ctx.saved_tensors
return grad @ w, grad.t() @ xWhen the contracting dimension is sharded across TP ranks, each rank computes a
partial sum. Adding the full bias on every rank would over-count it.
get_expand_impl returns a replacement callable that pre-scales bias by
1 / tp_size:
class LinRowDistOp(DistributedOp):
def __init__(self):
super().__init__("LinRow")
def preprocess(self, args, kwargs):
x, w, bias = args[0], args[1], args[2]
local_args = (x.to_local(), w.to_local(), bias)
cache_values = [x.layout, w.layout, bias is not None]
return local_args, {}, cache_values
def infer_layout(self, cache_values):
x_layout, w_layout = cache_values[0], cache_values[1]
out_layout = _linear_output_layout(x_layout, w_layout)
return ((out_layout,), None)
def get_expand_impl(self, func, infer_result, cache_values):
x_layout = cache_values[0]
bias_present = cache_values[2]
contract = x_layout.alias_tensor_map[-1]
if contract == "None" or not bias_present:
return None
# tp_size: number of TP ranks sharing the contracting dimension
tp_size = x_layout.mesh.get_device_num_along_axis(contract)
def expand_impl(x, w, bias):
return func(x, w, bias / tp_size) # scale bias down
return expand_impl
LinRowDistOp()
class LinRowFunc(DFunction):
_op_name = "LinRow"
@staticmethod
def forward(ctx, x, w, bias):
ctx.save_for_backward(x, w, bias)
return F.linear(x, w, bias)
@staticmethod
def backward(ctx, grad):
x, w, bias = ctx.saved_tensors
return grad @ w, grad.t() @ x, grad.sum(0)
# Usage — contracting (in) dimension sharded across TP
mesh = init_device_mesh("npu", (2, 4), mesh_dim_names=("dp", "tp"))
x_dist = distribute_tensor(x, mesh, (Replicate(), Shard(1)))
w_dist = distribute_tensor(w, mesh, (Replicate(), Shard(1)))
result = LinRowFunc.apply(x_dist, w_dist, bias)
# result is a partial DTensor; reduce before redistribution
output = result.reduce_partial()Inside distributed tests, accessing gradients for non-leaf local tensors requires
retain_grad() before the forward pass:
x_dist = distribute_tensor(x_leaf, mesh, placements)
x_local = x_dist.to_local() # returns the internal local tensor (non-leaf)
x_local.retain_grad() # allow grad accumulation on non-leaf
result = MyFunc.apply(x_dist)
result.to_local().sum().backward()
print(x_local.grad) # gradient available herex_dist.to_local() returns the same _local_tensor object every call, so
calling retain_grad() once before the forward is sufficient.
DFunction.apply(dtensor1, ...)
│ has_dtensor=True
▼
_OP_DISPATCHER.dispatch(local_callable, args, kwargs)
│
├─ preprocess() defined?
│ Yes → _dispatch_new: local_args, cache_values → infer_layout(cache_values)
│ No → _with_layout_infer: to_local() each DTensor → infer_layout(layouts)
│
├─ get_expand_impl() → op_impl (None → use local_callable directly)
│
└─ py_output = op_impl(*local_args)
│
└─ local_callable(*local_args)
│ has_dtensor=False
└─ super().apply(*local_args) ← platform autograd, no recursion
└─ MyFunc.forward(ctx, local_tensor1, ...)
_op_namemust be set on theDFunctionsubclass and match the registeredDistributedOpexactly.- Each
op_namecan have only one registeredDistributedOpinstance at a time. - Non-tensor positional arguments are not forwarded to
forwardon the legacy dispatch path. Usekwargsor implementpreprocessinstead. - When using
get_expand_impland the output has partial status, callresult.reduce_partial()before redistribution. - Both
forwardandbackwardmust operate on local tensors — do not callDFunction.applyrecursively from within them.
DFunction is platform-agnostic. It inherits from platform.Function which
resolves to the correct base class at import time.
| Platform | platform.Function |
forward / backward tensor type |
|---|---|---|
| PyTorch (GPU / NPU) | torch.autograd.Function |
torch.Tensor |
| MindSpore (Ascend NPU) | mindspore._Function |
mindspore.Tensor |
Cross-platform code runs identically on both backends. The ctx.save_for_backward
/ ctx.saved_tensors contract is uniform across platforms.