Skip to content

refactor: extract shared metric aggregation logic in Eagle3Trainer #11

@cicirori

Description

@cicirori

Problem

_aggregate_metrics (line 368-415) and _aggregate_eval_metrics (line 285-321) in eagle3_trainer.py share ~30 lines of nearly identical logic:

  • torch.stack + mean + all_reduce for plosses and acces
  • simulated_acc_len cumulative calculation (cumulative *= acces[i])
  • 0.8**i weighted loss computation
  • Per-position metric extraction

Bug fixes must be applied in two places, which is error-prone.

Proposed Solution

Extract a shared helper method:

def _compute_weighted_loss_and_acc(self, avg_plosses, avg_acces, prefix="train"):
    """Shared logic: simulated_acc_len, weighted loss, per-position metrics."""
    ...

Both _aggregate_metrics and _aggregate_eval_metrics would call this helper, adding only their unique fields (e.g., grad_norm, lr for training).

Files

  • torchspec/training/eagle3_trainer.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions