Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/megatron/eaft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NPROC_PER_NODE=2 \
CUDA_VISIBLE_DEVICES=0,1 \
megatron sft \
--model Qwen/Qwen3-0.6B \
--load_safetensors true \
--save_safetensors true \
--dataset 'swift_shuf_19k_data.jsonl' \
--tensor_model_parallel_size 1 \
--sequence_parallel true \
--micro_batch_size 4 \
--global_batch_size 64 \
--recompute_granularity full \
--recompute_method uniform \
--recompute_num_layers 1 \
--finetune true \
--cross_entropy_loss_fusion true \
--lr 1e-5 \
--lr_warmup_fraction 0.05 \
--min_lr 1e-6 \
--max_epochs 1 \
--save megatron_output/Qwen3-0.6B/eaft \
--save_interval 100 \
--max_length 16384 \
--system 'You are a helpful assistant.' \
--num_workers 4 \
--no_save_optim true \
--no_save_rng true \
--dataset_num_proc 4 \
--tensorboard_dir /tensorboard/Qwen3-0.6B/eaft \
--enable_eaft_loss true \
--eaft_alpha 1.0

24 changes: 24 additions & 0 deletions examples/train/eaft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
NPROC_PER_NODE=2 \
CUDA_VISIBLE_DEVICES=0,1 \
swift sft \
--model Qwen/Qwen3-0.6B \
--train_type full \
--dataset 'swift_shuf_19k_data.jsonl' \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 16 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 16384 \
--output_dir swift_output/Qwen3-0.6B/eaft \
--system 'You are a helpful assistant.' \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--enable_eaft_loss true \
--eaft_alpha 1.0 \
--deepspeed zero3 \
--report_to tensorboard \
--logging_dir tensorboard/swift_output/Qwen3-0.6B/eaft \
2 changes: 2 additions & 0 deletions swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
llm_architectures: Optional[str] = None
max_epochs: Optional[int] = None
enable_dft_loss: bool = False
enable_eaft_loss: bool = False
eaft_alpha: float = 1.0
enable_channel_loss: bool = False
task_type: Literal['causal_lm', 'seq_cls'] = None
num_labels: Optional[int] = None
Expand Down
6 changes: 6 additions & 0 deletions swift/megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def forward(
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[torch.Tensor] = None,
return_logits: bool = False,
**kwargs,
) -> torch.Tensor:
"""Forward function of the GPT Model This function passes the input tensors
Expand Down Expand Up @@ -327,6 +328,7 @@ def forward(
runtime_gather_output=runtime_gather_output,
extra_block_kwargs=extra_block_kwargs,
inference_context=inference_context,
return_logits=return_logits,
)

def _postprocess(
Expand All @@ -347,6 +349,7 @@ def _postprocess(
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
return_logits=False,
):
"""Postprocesses decoder hidden states to generate logits or compute loss.

Expand Down Expand Up @@ -473,6 +476,9 @@ def _postprocess(

loss = self.compute_language_model_loss(labels, logits)

if return_logits:
return loss, logits

return loss

def get_input_tensor(self):
Expand Down
48 changes: 44 additions & 4 deletions swift/megatron/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from megatron.core import mpu
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.training import get_args, get_timers
from megatron.core import parallel_state
from torch.distributed.nn import all_reduce
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

Expand Down Expand Up @@ -52,13 +53,41 @@ def loss_func(self,
labels: torch.Tensor,
loss_scale: Optional[torch.Tensor] = None,
channels: Optional[List[str]] = None,
packed_seq_params=None):
packed_seq_params=None,
logits: Optional[torch.Tensor] = None):
args = get_args()

losses = output_tensor.float()
loss_mask = labels != -100
if args.enable_dft_loss:
losses = losses * torch.exp(-losses.detach())
if args.enable_eaft_loss and logits is not None:
with torch.no_grad():
logits_float = logits.float()
vocab_size = logits_float.shape[-1]

batch_size = labels.shape[0]
seq_length = labels.shape[1]

logits_transposed = logits_float.transpose(0, 1)
logits_reshaped = logits_transposed.view(batch_size * seq_length, vocab_size)

logits_valid = logits_reshaped[loss_mask.view(-1)]

topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value for k in torch.topk is hardcoded to 20. This reduces flexibility for experimentation. It would be better to make this a configurable parameter by adding an eaft_top_k argument to ExtraMegatronArguments (with a default of 20).

logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True)
log_probs_topk = topk_logits - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
normalized_entropy = entropy_approx / 3.0
eaft_weight_valid = torch.pow(normalized_entropy, args.eaft_alpha)

eaft_weight = torch.ones(batch_size * seq_length, device=logits_float.device)
eaft_weight[loss_mask.view(-1)] = eaft_weight_valid
eaft_weight = eaft_weight.view(batch_size, seq_length)

losses = losses * eaft_weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a shape mismatch when applying the EAFT weights. The losses tensor has a shape of [sequence_length, batch_size], while eaft_weight has a shape of [batch_size, sequence_length]. Multiplying them directly will lead to incorrect broadcasting or a runtime error. You need to transpose eaft_weight before multiplication.

Suggested change
losses = losses * eaft_weight
losses = losses * eaft_weight.transpose(0, 1)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. I’ve double-checked the shapes, and the implementation is actually correct. Since the input logits are [Seq, Batch], I explicitly transpose them via transpose(0, 1) to align with the [Batch, Seq] layout of the losses tensor. Given that eaft_weight serves as a weight map for the loss, applying element-wise multiplication here is the correct operation, so no further transposition is required.


if loss_scale is not None:
losses = losses * loss_scale
if args.enable_channel_loss and channels is not None:
Expand Down Expand Up @@ -146,9 +175,19 @@ def forward_step(self, data_iterator, model):
labels = data.get('labels')
if self.args.task_type == 'seq_cls':
data.pop('labels', None)
with self.stimer:
output_tensor = model(**data)
packed_seq_params = data.get('packed_seq_params')


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logits variable is used on line 199 but is only conditionally defined within the if block on line 181. If self.args.task_type is 'seq_cls', this will raise an UnboundLocalError. Initialize logits to None here to ensure it is always defined.

Suggested change
logits = None

with self.stimer:
is_last = parallel_state.is_pipeline_last_stage()

if is_last:
loss, logits = model(**data, return_logits=True)
output_tensor = loss
else:
output_tensor = model(**data) # only hidden_states tensor
logits = None

if self.args.task_type == 'seq_cls':
loss_func = partial(
self.seq_cls_loss_func,
Expand All @@ -161,5 +200,6 @@ def forward_step(self, data_iterator, model):
labels=labels,
loss_scale=loss_scale,
channels=channels,
packed_seq_params=packed_seq_params)
packed_seq_params=packed_seq_params,
logits=logits)
return output_tensor, loss_func
5 changes: 5 additions & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class TrainArgumentsMixin:
enable_dft_loss (bool): Whether to enable Diversity-from-Diversity (DFD) loss.
See https://arxiv.org/abs/2508.05629. Defaults to False.
enable_channel_loss (bool): Whether to enable channel loss. Defaults to False.
enable_eaft_loss (bool): Whether to enable Entropy-Adaptive Fine-Tuning (EAFT) loss. Defaults to False.
eaft_alpha (float): The alpha parameter for EAFT loss. The final loss is calculated as
`(token_entropy / 3.0)^alpha * ce_loss`. Defaults to 1.0.
weight_decay (float): The weight decay to apply (if not zero) to all layers except bias and LayerNorm weights.
Defaults to 0.1.
adam_beta2 (float): The beta2 hyperparameter for the AdamW optimizer. Defaults to 0.95.
Expand Down Expand Up @@ -106,6 +109,8 @@ class TrainArgumentsMixin:
router_aux_loss_coef: float = 0.
enable_dft_loss: bool = False # https://arxiv.org/abs/2508.05629
enable_channel_loss: bool = False
enable_eaft_loss: bool = False
eaft_alpha: float = 1.0

weight_decay: float = 0.1
adam_beta2: float = 0.95
Expand Down
16 changes: 12 additions & 4 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
channels = inputs.pop('channel', None)

if (self.label_smoother is not None or compute_loss_func is not None or loss_scale is not None
or self.args.enable_dft_loss or self.args.enable_channel_loss
or self.args.enable_dft_loss or self.args.enable_eaft_loss or self.args.enable_channel_loss
or self.template.sequence_parallel_size > 1) and 'labels' in inputs:
if self.args.use_liger_kernel:
logger.warning_once('The cross_entropy loss function defined in Liger Kernel will not '
Expand Down Expand Up @@ -318,12 +318,20 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]
else:
outputs.loss = None
if (self.args.enable_dft_loss or loss_scale is not None or self.args.enable_channel_loss
if (self.args.enable_dft_loss or self.args.enable_eaft_loss or loss_scale is not None or self.args.enable_channel_loss
or self.template.sequence_parallel_size > 1):
if self.template.sequence_parallel_size > 1:
outputs.loss = per_token_loss_func_sp(outputs, labels, enable_dft_loss=self.args.enable_dft_loss)
outputs.loss = per_token_loss_func_sp(
outputs, labels,
enable_dft_loss=self.args.enable_dft_loss,
enable_eaft_loss=self.args.enable_eaft_loss,
eaft_alpha=self.args.eaft_alpha)
else:
outputs.loss = per_token_loss_func(outputs, labels, enable_dft_loss=self.args.enable_dft_loss)
outputs.loss = per_token_loss_func(
outputs, labels,
enable_dft_loss=self.args.enable_dft_loss,
enable_eaft_loss=self.args.enable_eaft_loss,
eaft_alpha=self.args.eaft_alpha)

if loss_scale is not None:
loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1).view(-1)
Expand Down
39 changes: 37 additions & 2 deletions swift/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def is_instance_of_ms_model(model: Module) -> bool:
return False


def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> torch.Tensor:
def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, enable_eaft_loss=False, eaft_alpha=1.0, **kwargs) -> torch.Tensor:
"""Common loss function for sequence parallel training"""
if hasattr(outputs, 'logits'):
logits = outputs.logits
Expand All @@ -78,6 +78,24 @@ def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) ->
with torch.no_grad():
target_probs = torch.exp(-loss)
loss *= target_probs
if enable_eaft_loss:
with torch.no_grad():
logits_detach = logits.detach()
valid_mask = labels != -100
logits_valid = logits_detach[valid_mask]

topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value for k in torch.topk is hardcoded to 20. This is also the case in per_token_loss_func. This should be a configurable parameter to allow for easier experimentation. Consider adding an eaft_top_k argument to TrainArgumentsMixin and passing it down to these functions.

logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True)
log_probs_topk = topk_logits - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
normalized_entropy = entropy_approx / 3.0
eaft_weight_valid = torch.pow(normalized_entropy, eaft_alpha)

eaft_weight = torch.ones_like(loss)
eaft_weight[valid_mask] = eaft_weight_valid

loss *= eaft_weight
Comment on lines +81 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating the EAFT loss here is nearly identical to the implementation in per_token_loss_func (lines 128-144). This duplication makes the code harder to maintain. Please refactor this shared logic into a single helper function to improve maintainability and reduce redundancy.

from swift.trainers.sequence_parallel import sequence_parallel
position_ids = sequence_parallel.real_position_ids
if position_ids is not None:
Expand All @@ -91,7 +109,8 @@ def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) ->
return loss


def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs):
def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, enable_eaft_loss: bool = False,
eaft_alpha: float = 1.0, **kwargs):
logits = outputs.logits
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
Expand All @@ -106,4 +125,20 @@ def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs
with torch.no_grad():
target_probs = torch.exp(-loss)
loss *= target_probs
if enable_eaft_loss:
with torch.no_grad():
valid_mask = labels != -100
logits_detach = logits[valid_mask].detach()
topk_logits, topk_indices = torch.topk(logits_detach, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True)
log_probs_topk = topk_logits - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
normalized_entropy = entropy_approx / 3.0
eaft_weight = torch.pow(normalized_entropy, eaft_alpha)

eaft_weight_full = torch.ones_like(loss)
eaft_weight_full[valid_mask] = eaft_weight

loss *= eaft_weight_full
return loss