From a695267a5a4c529e5f3df9fd436cb73d49116979 Mon Sep 17 00:00:00 2001 From: Courtney Golden Date: Fri, 12 Jun 2026 12:08:14 -0700 Subject: [PATCH] arbitrary map/reduce operators --- accelforge/frontend/arch/__init__.py | 1 + accelforge/frontend/arch/components.py | 47 ++++- accelforge/frontend/workload.py | 36 ++++ accelforge/model/_looptree/energy.py | 26 ++- accelforge/model/_looptree/latency/latency.py | 9 +- accelforge/model/_looptree/latency/memory.py | 23 ++- .../model/_looptree/reuse/symbolic/_stats.py | 60 ++++-- .../_looptree/reuse/symbolic/_symbolic.py | 97 ++++++++-- examples/arches/simple_add_max.yaml | 31 +++ examples/workloads/tropical_gemm.yaml | 33 ++++ tests/test_compute_fusion.py | 176 ++++++++++++++++++ 11 files changed, 492 insertions(+), 47 deletions(-) create mode 100644 examples/arches/simple_add_max.yaml create mode 100644 examples/workloads/tropical_gemm.yaml create mode 100644 tests/test_compute_fusion.py diff --git a/accelforge/frontend/arch/__init__.py b/accelforge/frontend/arch/__init__.py index 7c0c81ca..30cd4b31 100644 --- a/accelforge/frontend/arch/__init__.py +++ b/accelforge/frontend/arch/__init__.py @@ -27,6 +27,7 @@ "Comparison", "Component", "Compute", + "ComputeAction", "Container", "Fork", "Array", diff --git a/accelforge/frontend/arch/components.py b/accelforge/frontend/arch/components.py index 554d4d37..16004e27 100644 --- a/accelforge/frontend/arch/components.py +++ b/accelforge/frontend/arch/components.py @@ -225,6 +225,30 @@ def __call__(self, field, value, evaluated, symbol_table): return super()._eval_expressions(*args, **kwargs, post_calls=(MyPostCall(),)) +class ComputeAction(Action): + op_kind: str = "mac" + """ The semantic category of operation this action models (e.g., "mul", "add", + "mac", "max"). Einsums declare `map_op`/`reduce_op`; the analysis derives an + op_profile from those and binds each op_kind to the Compute action that declares + it. The default "mac" preserves legacy single-action behavior. """ + + fuses: EvalableList[str] = [] + """ op_kinds this action coalesces into a single fire. When an einsum's op_profile + contains all listed op_kinds with EQUAL counts, those entries collapse into one + entry keyed under this action's `op_kind` with the shared count -- e.g. a fused MAC + pairs (mul, add) into one charge per iteration. Left unset, an `op_kind="mac"` + action defaults to fusing `[mul, add]` (so legacy single-MAC arches need no + change); any other op_kind defaults to no fusion. See `effective_fuses`. """ + + @property + def effective_fuses(self) -> list[str]: + """`fuses` if set, else the op_kind-derived default: a bare `mac` action + fuses `[mul, add]` (legacy fused-MAC); every other op_kind fuses nothing.""" + if self.fuses: + return list(self.fuses) + return ["mul", "add"] if self.op_kind == "mac" else [] + + _COMPONENT_MODEL_CACHE: dict[tuple, "Component"] = {} @@ -888,7 +912,7 @@ def _copy_for_component_modeling(self) -> Self: COMPUTE_ACTIONS = EvalableList( [ - Action(name="compute"), + ComputeAction(name="compute", op_kind="mac"), ] ) @@ -1279,8 +1303,9 @@ def _render_node_color(self) -> str: class Compute(Component, Leaf): - actions: EvalableList[Action] = COMPUTE_ACTIONS - """ The actions that this `Compute` can perform. """ + actions: EvalableList[ComputeAction] = COMPUTE_ACTIONS + """ The actions that this `Compute` can perform. Each `ComputeAction` declares an + `op_kind` that einsums bind to via their `map_op`/`reduce_op`. """ skip_initial_output_write: bool = True """ @@ -1288,8 +1313,20 @@ class Compute(Component, Leaf): initalize outputs. If True, this initial fetch and fill is skipped. """ - def model_post_init(self, __context__=None) -> None: - self._update_actions(COMPUTE_ACTIONS) + def action_for_op_kind(self, op_kind: str) -> ComputeAction: + """Return the `ComputeAction` on this Compute whose `op_kind` matches. + + Raises EvaluationError if no action declares this op_kind. + """ + for action in self.actions: + if getattr(action, "op_kind", None) == op_kind: + return action + declared = sorted({getattr(a, "op_kind", None) for a in self.actions}) + raise EvaluationError( + f"Compute component {self.name!r} has no action with op_kind " + f"{op_kind!r}. Declared op_kinds: {declared}.", + source_field=f"{self.name}.actions", + ) def _render_node_shape(self) -> str: return "ellipse" diff --git a/accelforge/frontend/workload.py b/accelforge/frontend/workload.py index 5d5a3fb8..325ecea9 100755 --- a/accelforge/frontend/workload.py +++ b/accelforge/frontend/workload.py @@ -456,6 +456,18 @@ class Einsum(EvalableModel): and directly place them at the location of the output tensor(s) without any computation. If the destination tensor is at the same location, then this is a no-op.""" + map_op: str = "mul" + """ Binary operator applied to paired input-tensor values at each iteration-space + point (e.g. "mul", "add", "max", "square"). Combined with `reduce_op` to derive the + einsum's op_profile: the map and reduce ops are each charged once per + iteration-space point. The default "mul" + "add" describes a standard + sum-of-products; on an arch that declares a fused-MAC compute action this pair + collapses into one MAC (see `ComputeAction.fuses`), so legacy arches stay + bit-identical. """ + reduce_op: str = "add" + """ Operator that folds mapped values into the output tensor across the reduction + ranks (e.g. "add", "max"). Charged once per iteration-space point alongside + `map_op`. Ignored for copy operations. """ renames: RenameList[Rename] = RenameList() """ Renames of the Einsum. Renames here can be used to rename rank variables or tensors. When this Einsum is executed on an architecture, the architecture can use @@ -582,6 +594,30 @@ def tensor2irrelevant_rank_variables( for t in self.tensor_accesses } + def effective_op_profile(self) -> dict[str, int]: + """Per-iteration-space-point op counts keyed by op_kind, derived from + `map_op` and `reduce_op`. + + Copy operations have no ops. Every other einsum is charged one map and + one reduce per point (matching the legacy uniform "1 op per iter" + attribution); if `map_op` and `reduce_op` are the same op_kind the two + entries collapse into one with count 2. These are the *raw* ops -- a + fused-MAC arch coalesces a `{mul: N, add: N}` profile back into + `{mac: N}` downstream via `ComputeAction.fuses`, so the default mul+add + einsum stays bit-identical on legacy single-MAC arches. + + `square` is treated as `mul` here (x*x runs on any multiplier), so a + square+add reduction fuses into a MAC like an ordinary product. An arch + declaring a dedicated `op_kind="square"` action would therefore not + bind -- the substitution erases the distinction. + """ + if self.is_copy_operation: + return {} + map_op = "mul" if self.map_op == "square" else self.map_op + if map_op == self.reduce_op: + return {map_op: 2} + return {map_op: 1, self.reduce_op: 1} + def _to_formatted_string(self, compress: bool = False) -> str: """ Returns a string representation of this Einsum for use in a Pydot graph. diff --git a/accelforge/model/_looptree/energy.py b/accelforge/model/_looptree/energy.py index 1728c9a9..8a4b0044 100755 --- a/accelforge/model/_looptree/energy.py +++ b/accelforge/model/_looptree/energy.py @@ -52,12 +52,22 @@ def gather_actions( actions[key].total += accesses.net_total_write_actions() actions[key].max_per_unit += accesses.net_max_per_unit_write_actions() + # `ops.total_ops` is a per-op-kind dict ({op_kind: count}). Emit one action + # key per (level, op_kind), where the action *name* is resolved from the + # Compute's ComputeAction whose op_kind matches. This is what lets the + # downstream `compute_energy_from_actions` look up energy via + # `component.actions[key.action].energy`. With the legacy single-action + # arch ({op_kind: "mac"}, name: "compute") and the default einsum profile + # ({"mac": 1}), this collapses to exactly one ("compute") key per level, + # bit-identical to prior behavior. for compute, ops in looptree_results.compute_stats.items(): - key = compute_keyer(compute, "compute") - if key not in actions: - actions[key] = ActionCount.default() - actions[key].total += ops.total_ops - actions[key].max_per_unit += ops.max_per_unit_ops + for op_kind, total in ops.total_ops.items(): + action_name = _resolve_compute_action_name(spec, compute.level, op_kind) + key = compute_keyer(compute, action_name) + if key not in actions: + actions[key] = ActionCount.default() + actions[key].total += total + actions[key].max_per_unit += ops.max_per_unit_ops.get(op_kind, 0) for network, stats in looptree_results.network_stats.items(): key = network_keyer(network, "hops") @@ -70,7 +80,6 @@ def gather_actions( return actions - def _apply_actions_scale(actions, spec): components = {} for key, count in actions.items(): @@ -80,6 +89,11 @@ def _apply_actions_scale(actions, spec): count.total *= scale count.max_per_unit *= scale +def _resolve_compute_action_name(spec: Spec, level: str, op_kind: str) -> str: + """Map (compute level, op_kind) to the matching ComputeAction's name. + """ + component = spec.arch.find(level) + return component.action_for_op_kind(op_kind).name def _get_buffet_keyer(verbose, use_name, bindings): if not verbose: diff --git a/accelforge/model/_looptree/latency/latency.py b/accelforge/model/_looptree/latency/latency.py index ab1810bd..c87fd909 100755 --- a/accelforge/model/_looptree/latency/latency.py +++ b/accelforge/model/_looptree/latency/latency.py @@ -47,10 +47,17 @@ def calculate_compute_latency(reuse_analysis_results, mapping, workload): def compute_summarized_latency(compute_stats, mapping, workload): # TODO: this is only for single-Einsum!!! + # `stats.max_latency` is a dict[op_kind, cycles]. Sum across op_kinds + # within a ComputeStats entry (matching the Compute's default + # total_latency = sum(*action2latency.values())), then take the max + # across entries -- i.e., sum-then-max. The cross-stats max here mirrors + # max_nonzero(comp_latency, ...) in get_latency(), keeping this code path + # consistent with the per-component path in latency/memory.py. longest_compute_latency = 0 for stats in compute_stats.values(): + per_iter_latency = sum(stats.max_latency.values(), 0) longest_compute_latency = max_nonzero( - longest_compute_latency, stats.max_latency + longest_compute_latency, per_iter_latency ) return longest_compute_latency diff --git a/accelforge/model/_looptree/latency/memory.py b/accelforge/model/_looptree/latency/memory.py index c96ec9b1..c642e4af 100755 --- a/accelforge/model/_looptree/latency/memory.py +++ b/accelforge/model/_looptree/latency/memory.py @@ -103,10 +103,25 @@ def component_latency( f"Component {component} is not a TensorHolder or Compute" ) - longest_compute_latency = Max( - 0, *[s.max_latency for s in looptree_results.compute_stats.values()] - ) - component_to_actions[compute_obj.name]["compute"] = longest_compute_latency + # `max_latency` is now a per-op-kind dict ({op_kind: cycles-per-iter-of-worst-iter}). + # For each op_kind, take the max across compute_stats entries (different + # (einsum, compute-level) keys) and inject it under the action's name, + # where action_name is the ComputeAction on this Compute whose op_kind matches. + # The Compute's `total_latency` expression (default sum(*action2latency.values())) + # then turns the per-kind counts into per-kind latency contributions and combines + # them. This implements sum-then-max: sum across op_kinds within a Compute + # (via total_latency), max across compute levels (via the per-kind max here + # and the Max(...) over component latencies at the get_latency layer). + per_kind_max_latency: dict[str, float] = {} + for s in looptree_results.compute_stats.values(): + for op_kind, val in s.max_latency.items(): + if op_kind in per_kind_max_latency: + per_kind_max_latency[op_kind] = Max(per_kind_max_latency[op_kind], val) + else: + per_kind_max_latency[op_kind] = val + for op_kind, count in per_kind_max_latency.items(): + action = compute_obj.action_for_op_kind(op_kind) + component_to_actions[compute_obj.name][action.name] = count new_component_to_actions: dict[str, list] = {} for component, action_counts in component_to_actions.items(): diff --git a/accelforge/model/_looptree/reuse/symbolic/_stats.py b/accelforge/model/_looptree/reuse/symbolic/_stats.py index 307ad69a..c88a0a1e 100644 --- a/accelforge/model/_looptree/reuse/symbolic/_stats.py +++ b/accelforge/model/_looptree/reuse/symbolic/_stats.py @@ -175,13 +175,39 @@ def blank(cls): stats.n_loops_above = None # Inherit from whoever is added to this return stats +def _scale_op_dict(d: dict[str, Any], factor: Any) -> dict[str, Any]: + """Multiply every per-op-kind value by `factor`. Identity at factor==1.""" + if factor == 1: + return dict(d) + if isinstance(factor, float) and factor == int(factor): + factor = int(factor) + return {k: v * factor for k, v in d.items()} + + +def _sum_op_dicts(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: + """Per-op-kind sum. Keys present in only one operand are kept as-is.""" + out = dict(a) + for k, v in b.items(): + out[k] = out[k] + v if k in out else v + return out + + +def _max_op_dicts(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: + """Per-op-kind max_nonzero. Keys present in only one operand are kept as-is.""" + out = dict(a) + for k, v in b.items(): + out[k] = max_nonzero(out[k], v) if k in out else v + return out @dataclass class ComputeStats: - total_ops: Any = field(default=0) - max_per_unit_ops: Any = field(default=0) + # Per-op-kind counts. Keys are op_kind strings (e.g. "mul", "add", "mac") + # matching the ComputeAction.op_kind values declared on the arch's Compute + # component. An empty dict means no contribution. + total_ops: dict[str, Any] = field(default_factory=dict) + max_per_unit_ops: dict[str, Any] = field(default_factory=dict) # "max" below refers to the longest latency of any iteration - max_latency: Any = field(default=0) + max_latency: dict[str, Any] = field(default_factory=dict) # Mapping from the loop-index (0 at top) to the latency of the first # iteration of that loop. "Max" because we may have loops above that and we # will take the maximum of the firsts. @@ -193,9 +219,9 @@ def repeat_temporal(self, factor: int) -> "ComputeStats": return new if type(factor) is float and factor == int(factor): factor = int(factor) - new.total_ops = new.total_ops * factor - new.max_per_unit_ops = new.max_per_unit_ops * factor - new.max_latency = new.max_latency * factor + new.total_ops = _scale_op_dict(new.total_ops, factor) + new.max_per_unit_ops = _scale_op_dict(new.max_per_unit_ops, factor) + new.max_latency = _scale_op_dict(new.max_latency, factor) # NOTE: max_first_latency does not change return new @@ -205,14 +231,14 @@ def repeat_spatial(self, factor: int) -> "ComputeStats": return new if type(factor) is float and factor == int(factor): factor = int(factor) - new.total_ops = new.total_ops * factor + new.total_ops = _scale_op_dict(new.total_ops, factor) return new def __add__(self, other: "ComputeStats") -> "ComputeStats": new = copy.copy(self) - new.total_ops += other.total_ops - new.max_per_unit_ops += other.max_per_unit_ops - new.max_latency += other.max_latency + new.total_ops = _sum_op_dicts(new.total_ops, other.total_ops) + new.max_per_unit_ops = _sum_op_dicts(new.max_per_unit_ops, other.max_per_unit_ops) + new.max_latency = _sum_op_dicts(new.max_latency, other.max_latency) # max_first_latency is only ever updated across loops ABOVE the loop # for which we calculated that first latency, so we should MAX new.max_first_latency = max_dict( @@ -221,9 +247,9 @@ def __add__(self, other: "ComputeStats") -> "ComputeStats": return new def combine_temporal(self, other: "ComputeStats"): - self.total_ops += other.total_ops - self.max_per_unit_ops += other.max_per_unit_ops - self.max_latency += other.max_latency + self.total_ops = _sum_op_dicts(self.total_ops, other.total_ops) + self.max_per_unit_ops = _sum_op_dicts(self.max_per_unit_ops, other.max_per_unit_ops) + self.max_latency = _sum_op_dicts(self.max_latency, other.max_latency) # max_first_latency is only ever updated across loops ABOVE the loop # for which we calculated that first latency, so we should MAX self.max_first_latency = max_dict( @@ -231,11 +257,9 @@ def combine_temporal(self, other: "ComputeStats"): ) # FIRST LATENCY def combine_spatial(self, other: "ComputeStats"): - self.total_ops += other.total_ops - self.max_per_unit_ops = max_nonzero( - self.max_per_unit_ops, other.max_per_unit_ops - ) - self.max_latency = max_nonzero(self.max_latency, other.max_latency) + self.total_ops = _sum_op_dicts(self.total_ops, other.total_ops) + self.max_per_unit_ops = _max_op_dicts(self.max_per_unit_ops, other.max_per_unit_ops) + self.max_latency = _max_op_dicts(self.max_latency, other.max_latency) # max_first_latency is only ever updated across loops ABOVE the loop # for which we calculated that first latency, so we should MAX self.max_first_latency = max_dict( diff --git a/accelforge/model/_looptree/reuse/symbolic/_symbolic.py b/accelforge/model/_looptree/reuse/symbolic/_symbolic.py index 3f1aab71..aa207b5b 100755 --- a/accelforge/model/_looptree/reuse/symbolic/_symbolic.py +++ b/accelforge/model/_looptree/reuse/symbolic/_symbolic.py @@ -64,7 +64,6 @@ PRINT_FORMULAS = False - def quick_insert_reservation_nodes( job: Job, mapping: Mapping | None = None, tensors: oset[TensorName] | None = None ) -> Mapping: @@ -550,7 +549,11 @@ def handle_repeated_value(repeated_shape): for key in child_result.compute_stats: if first_latency is None: - first_latency = child_result.compute_stats[key].max_latency + # max_latency is now a per-op-kind dict; collapse to scalar + # for the FIRST LATENCY column (one number per loop_idx). + first_latency = sum( + child_result.compute_stats[key].max_latency.values(), 0 + ) compute_stats = result_accumulator.compute_stats.setdefault( key, ComputeStats() @@ -925,27 +928,95 @@ def analyze_reservation(node_idx, current_shape, info: AnalysisInfo): return child_result +def _apply_compute_fusion(profile: dict[str, Any], compute_obj: Any) -> dict[str, Any]: + """Coalesce op_profile entries via the Compute's `fuses` declarations. + + For each ComputeAction whose `effective_fuses` lists more than one op_kind: if + the profile contains all of them with EQUAL counts, remove those entries and add + the shared count under the action's primary `op_kind`. Actions that fuse nothing + (the common case) are no-ops. Actions are processed in declaration order, and the + first match wins on overlap. + + This is how a fused-MAC arch charges one MAC per (mul, add) pair instead of two + separate ops: the standard sum-of-products einsum emits `{mul: N, add: N}`, and a + `mac` action (which fuses `[mul, add]` by default) collapses that to `{mac: N}` + before downstream binding. + """ + if not profile: + return profile + profile = dict(profile) + for action in getattr(compute_obj, "actions", []) or []: + fuses = getattr(action, "effective_fuses", None) + if not fuses or len(fuses) <= 1: + continue + if any(ok not in profile for ok in fuses): + continue + counts = [profile[ok] for ok in fuses] + # Require all counts equal -- otherwise we cannot safely collapse. + first = counts[0] + if not all(c is first or c == first for c in counts[1:]): + continue + for ok in fuses: + del profile[ok] + primary = action.op_kind + if primary in profile: + profile[primary] = profile[primary] + first + else: + profile[primary] = first + return profile + + def analyze_compute( node_idx, current_shape, info: AnalysisInfo ) -> SymbolicAnalysisOutput: - einsum = info.mapping[-1].einsum + einsum_name = info.mapping[-1].einsum + einsum = info.workload.einsums[einsum_name] node = info.mapping[node_idx] - - computes = 0 if info.is_copy_operation else 1 - component_object = info.job.flattened_arch[node.component] skip_initial = ( component_object.skip_initial_output_write and not info.is_copy_operation ) + # Seed per-op-kind counts from the einsum's map_op/reduce_op (one + # iteration's worth of work per op_kind), after coalescing any fused-MAC + # pairs the arch declares. Loops above scale these via repeat_temporal + # / repeat_spatial so the final value is `(#iteration-space points) * + # ops_per_point[kind]`. + # + # NOTE (structural vs. semantic accounting): this counts ops structurally + # -- the adder fires on every iteration-space point of a reducing einsum, + # including the first iteration along the reduction dim where the value + # is being added to zero. The "semantic" count would subtract one add per + # unique output cell (the `+= 0` that isn't really accumulating). This + # over-counts by `#output_cells` adds per reducing einsum; for GEMM + # MNK=4,4,8 that's 16 over 128 = 12.5%, shrinking with K. If/when an arch + # that physically skips initial writes (see Compute.skip_initial_output_write) + # needs exact accounting, split max_latency into reduction-loop vs. + # map-loop products inside ComputeStats and consult `skip_initial` here. + op_profile = einsum.effective_op_profile() + op_profile = _apply_compute_fusion(op_profile, component_object) + if info.is_copy_operation: + seed_total_ops = {"mac": 0} + seed_max_per_unit_ops = {"mac": 0} + seed_max_latency = {"mac": 0} + else: + seed_total_ops = dict(op_profile) + seed_max_per_unit_ops = dict(op_profile) + seed_max_latency = dict(op_profile) + + # `temporal_steps` is a single scalar count of iterations per einsum used + # by other code paths; keep it equal to (1 if non-copy else 0) to match + # legacy semantics. + temporal_step_seed = 0 if info.is_copy_operation else 1 + result_accumulator = SymbolicAnalysisOutput() - compute_key = Compute(einsum, node.component) - result_accumulator.temporal_steps[einsum] = computes + compute_key = Compute(einsum_name, node.component) + result_accumulator.temporal_steps[einsum_name] = temporal_step_seed result_accumulator.compute_stats[compute_key] = ComputeStats( - computes, - computes, - computes, + total_ops=seed_total_ops, + max_per_unit_ops=seed_max_per_unit_ops, + max_latency=seed_max_latency, ) if info.is_copy_operation: @@ -953,11 +1024,11 @@ def analyze_compute( tensors = info.all_tensors if info.current_tensor is None else [info.current_tensor] for tensor in tensors: - buffet = Buffet(tensor, einsum, node.component) + buffet = Buffet(tensor, einsum_name, node.component) stats = BuffetStats() stats.total_reads_to_parent = 1 stats.max_per_parent_reads_to_parent = 1 - if tensor in info.workload.einsums[einsum].output_tensor_names: + if tensor in einsum.output_tensor_names: stats.total_writes_to_parent = 1 stats.max_per_parent_writes_to_parent = 1 if skip_initial: diff --git a/examples/arches/simple_add_max.yaml b/examples/arches/simple_add_max.yaml new file mode 100644 index 00000000..4b1842ac --- /dev/null +++ b/examples/arches/simple_add_max.yaml @@ -0,0 +1,31 @@ +# Tropical-GEMM-targeting arch: separate adder and max unit. + +arch: + nodes: + - !Memory + name: MainMemory + size: inf + leak_power: 0 + area: 0 + tensors: {keep: ~Intermediates, may_keep: All} + actions: + - {name: read, energy: 100, throughput: .inf} + - {name: write, energy: 100, throughput: .inf} + + - !Memory + name: Buffer + size: 1024 + leak_power: 0 + area: 0 + tensors: {keep: ~MainMemory, may_keep: All} + actions: + - {name: read, energy: 1, throughput: .inf} + - {name: write, energy: 1, throughput: .inf} + + - !Compute + name: TropicalALU + leak_power: 0 + area: 0 + actions: + - {name: add, op_kind: add, energy: 0.3, throughput: 1} + - {name: max, op_kind: max, energy: 0.2, throughput: 1} diff --git a/examples/workloads/tropical_gemm.yaml b/examples/workloads/tropical_gemm.yaml new file mode 100644 index 00000000..e1b07de5 --- /dev/null +++ b/examples/workloads/tropical_gemm.yaml @@ -0,0 +1,33 @@ +# Tropical (max-plus) GEMM: A[m,n] = max over k of (B[m,k] + C[k,n]). +# +# Same iteration space and tensor shapes as standard GEMM, but the inner op is +# an add followed by a max-reduction instead of a multiply followed by a +# sum-reduction. Set via map_op: add / reduce_op: max, so effective_op_profile() +# emits {add: 1, max: 1}. The "add"+"max" pair is not a fused-MAC pattern, so no +# arch fuses it; arch_add_max.yaml charges add and max separately. +# +# Useful for shortest-path-style algorithms (Floyd-Warshall inner loop), +# HMM forward/Viterbi, and other dynamic-programming kernels that share GEMM's +# data movement but use a different semiring. + +workload: + rank_sizes: + M: 64 + N: 64 + K: 64 + + iteration_space_shape: + m: 0 <= m < 64 + n: 0 <= n < 64 + k: 0 <= k < 64 + + bits_per_value: {All: 8} + + einsums: + - name: TropicalMatmul + tensor_accesses: + - {name: B, projection: [m, k]} + - {name: C, projection: [k, n]} + - {name: A, projection: [m, n], output: True} + map_op: add + reduce_op: max diff --git a/tests/test_compute_fusion.py b/tests/test_compute_fusion.py new file mode 100644 index 00000000..55b5f356 --- /dev/null +++ b/tests/test_compute_fusion.py @@ -0,0 +1,176 @@ +"""Unit tests for map_op/reduce_op op-profiles and arch-side op-kind fusion. + +An einsum declares `map_op`/`reduce_op`; `effective_op_profile()` turns those +into per-iteration op_kind counts. A Compute component may declare a +ComputeAction with `fuses=[...]` that coalesces several op_kinds into a single +entry keyed under the action's primary `op_kind`. The canonical case is a +fused-MAC unit that pairs (mul, add) into one charge per iteration -- and an +`op_kind="mac"` action fuses [mul, add] by default, so legacy single-MAC arches +keep working unchanged. +""" + +import pytest + +from accelforge import Spec, examples +from accelforge.frontend.arch.components import ComputeAction +from accelforge.frontend.workload import Einsum, TensorAccess +from accelforge.model._looptree.reuse.symbolic._symbolic import _apply_compute_fusion + + +def _compute(actions): + """Minimal object exposing an `.actions` list of real ComputeActions + (so `effective_fuses`, including the op_kind-aware default, is exercised).""" + + class _C: + pass + + c = _C() + c.actions = [ + ComputeAction(name=f"a{i}", op_kind=ok, fuses=list(fz)) + for i, (ok, fz) in enumerate(actions) + ] + return c + + +# ---------- _apply_compute_fusion ---------- + + +def test_apply_fusion_collapses_mul_add_to_mac(): + compute = _compute([("mac", ["mul", "add"])]) + assert _apply_compute_fusion({"mul": 5, "add": 5}, compute) == {"mac": 5} + + +def test_apply_fusion_requires_equal_counts(): + compute = _compute([("mac", ["mul", "add"])]) + assert _apply_compute_fusion({"mul": 3, "add": 5}, compute) == {"mul": 3, "add": 5} + + +def test_apply_fusion_missing_op_kind_no_op(): + compute = _compute([("mac", ["mul", "add"])]) + assert _apply_compute_fusion({"add": 1}, compute) == {"add": 1} + + +def test_apply_fusion_no_fuses_declared(): + # Separate mul/add units (no mac) -> nothing fuses. + compute = _compute([("mul", []), ("add", [])]) + assert _apply_compute_fusion({"mul": 7, "add": 7}, compute) == {"mul": 7, "add": 7} + + +def test_apply_fusion_singleton_fuses_is_noop(): + compute = _compute([("mul", ["mul"])]) + assert _apply_compute_fusion({"mul": 4, "add": 4}, compute) == {"mul": 4, "add": 4} + + +def test_apply_fusion_preserves_unrelated_keys(): + compute = _compute([("mac", ["mul", "add"])]) + assert _apply_compute_fusion({"mul": 2, "add": 2, "max": 1}, compute) == { + "mac": 2, + "max": 1, + } + + +def test_apply_fusion_multiple_actions(): + compute = _compute([("mac", ["mul", "add"]), ("logical", ["and", "or"])]) + assert _apply_compute_fusion( + {"mul": 3, "add": 3, "and": 2, "or": 2}, compute + ) == {"mac": 3, "logical": 2} + + +def test_apply_fusion_empty_profile_returns_empty(): + compute = _compute([("mac", ["mul", "add"])]) + assert _apply_compute_fusion({}, compute) == {} + + +def test_mac_action_fuses_mul_add_by_default(): + """An `op_kind="mac"` action with no explicit `fuses` still coalesces + [mul, add] -- this is what keeps legacy single-MAC arches unchanged.""" + compute = _compute([("mac", [])]) # fuses unset + assert _apply_compute_fusion({"mul": 6, "add": 6}, compute) == {"mac": 6} + + +def test_non_mac_action_does_not_fuse_by_default(): + compute = _compute([("mul", []), ("add", [])]) # fuses unset, not mac + assert _apply_compute_fusion({"mul": 1, "add": 1}, compute) == {"mul": 1, "add": 1} + + +# ---------- Einsum.effective_op_profile ---------- + + +def test_default_ops_are_mul_add(): + e = Einsum( + name="GEMM", + tensor_accesses=[ + TensorAccess(name="A", projection=["m", "k"]), + TensorAccess(name="B", projection=["k", "n"]), + TensorAccess(name="C", projection=["m", "n"], output=True), + ], + ) + assert e.effective_op_profile() == {"mul": 1, "add": 1} + + +def test_equal_map_reduce_collapses_to_count_two(): + e = Einsum( + name="AddReduce", + tensor_accesses=[ + TensorAccess(name="X", projection=["m", "k"]), + TensorAccess(name="Y", projection=["m"], output=True), + ], + map_op="add", + reduce_op="add", + ) + assert e.effective_op_profile() == {"add": 2} + + +def test_square_map_op_rewritten_to_mul(): + e = Einsum( + name="SumSq", + tensor_accesses=[ + TensorAccess(name="X", projection=["m"]), + TensorAccess(name="Y", projection=[], output=True), + ], + map_op="square", + ) + assert e.effective_op_profile() == {"mul": 1, "add": 1} + + +def test_copy_operation_has_no_ops(): + e = Einsum( + name="Copy", + tensor_accesses=[ + TensorAccess(name="X", projection=["m"]), + TensorAccess(name="Y", projection=["m"], output=True), + ], + is_copy_operation=True, + ) + assert e.effective_op_profile() == {} + + +def test_square_plus_add_fuses_to_mac(): + """RMSNorm-style sum-of-squares: square->mul, then mul+add fuses to one MAC.""" + e = Einsum( + name="SumSq", + tensor_accesses=[ + TensorAccess(name="X", projection=["m"]), + TensorAccess(name="Y", projection=[], output=True), + ], + map_op="square", + ) + compute = _compute([("mac", ["mul", "add"])]) + assert _apply_compute_fusion(e.effective_op_profile(), compute) == {"mac": 1} + + +# ---------- end-to-end via Spec ---------- + + +def test_simple_arch_mac_fuses_mul_add_by_default(): + """examples/arches/simple.yaml declares a single op_kind="mac" compute + action with no explicit fuses; its effective_fuses defaults to [mul, add].""" + spec = Spec.from_yaml( + examples.arches.simple, + examples.workloads.matmuls, + jinja_parse_data={"N_EINSUMS": 1, "M": 8, "KN": 8}, + ) + mac = spec.arch.find("MAC") + assert len(mac.actions) == 1 + assert mac.actions[0].op_kind == "mac" + assert mac.actions[0].effective_fuses == ["mul", "add"]