From ec7b9fb2200e855f24cd6abdbd43ff6e12acc7e3 Mon Sep 17 00:00:00 2001 From: ymxyll <2313418841@qq.com> Date: Mon, 12 Jan 2026 15:40:51 +0800 Subject: [PATCH 1/3] [FEATURE] eaft --- examples/megatron/eaft.sh | 33 ++++++++++++++++++++ examples/train/eaft.sh | 24 +++++++++++++++ swift/megatron/argument/megatron_args.py | 2 ++ swift/megatron/model/gpt_model.py | 6 ++++ swift/megatron/trainers/trainer.py | 38 +++++++++++++++++++++--- swift/trainers/arguments.py | 5 ++++ swift/trainers/trainers.py | 16 +++++++--- swift/trainers/utils.py | 32 ++++++++++++++++++-- 8 files changed, 146 insertions(+), 10 deletions(-) create mode 100644 examples/megatron/eaft.sh create mode 100644 examples/train/eaft.sh diff --git a/examples/megatron/eaft.sh b/examples/megatron/eaft.sh new file mode 100644 index 0000000000..fbbbfc35cb --- /dev/null +++ b/examples/megatron/eaft.sh @@ -0,0 +1,33 @@ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron sft \ + --model Qwen/Qwen3-0.6B \ + --load_safetensors true \ + --save_safetensors true \ + --dataset 'swift_shuf_19k_data.jsonl' \ + --tensor_model_parallel_size 1 \ + --sequence_parallel true \ + --micro_batch_size 4 \ + --global_batch_size 64 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-6 \ + --max_epochs 1 \ + --save megatron_output/Qwen3-0.6B/eaft \ + --save_interval 100 \ + --max_length 16384 \ + --system 'You are a helpful assistant.' \ + --num_workers 4 \ + --no_save_optim true \ + --no_save_rng true \ + --dataset_num_proc 4 \ + --tensorboard_dir /tensorboard/Qwen3-0.6B/eaft \ + --enable_eaft_loss true \ + --eaft_alpha 1.0 + diff --git a/examples/train/eaft.sh b/examples/train/eaft.sh new file mode 100644 index 0000000000..e7dfbcc9d8 --- /dev/null +++ b/examples/train/eaft.sh @@ -0,0 +1,24 @@ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model Qwen/Qwen3-0.6B \ + --train_type full \ + --dataset 'swift_shuf_19k_data.jsonl' \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --learning_rate 1e-5 \ + --gradient_accumulation_steps 16 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 16384 \ + --output_dir swift_output/Qwen3-0.6B/eaft \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --enable_eaft_loss true \ + --eaft_alpha 1.0 \ + --deepspeed zero3 \ + --report_to tensorboard \ + --logging_dir tensorboard/swift_output/Qwen3-0.6B/eaft \ diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 2b12abdd16..252ea8522e 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -346,6 +346,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): llm_architectures: Optional[str] = None max_epochs: Optional[int] = None enable_dft_loss: bool = False + enable_eaft_loss: bool = False + eaft_alpha: float = 1.0 enable_channel_loss: bool = False task_type: Literal['causal_lm', 'seq_cls'] = None num_labels: Optional[int] = None diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 37582f34d6..569e12e0f2 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -272,6 +272,7 @@ def forward( *, inference_params: Optional[BaseInferenceContext] = None, loss_mask: Optional[torch.Tensor] = None, + return_logits: bool = False, **kwargs, ) -> torch.Tensor: """Forward function of the GPT Model This function passes the input tensors @@ -327,6 +328,7 @@ def forward( runtime_gather_output=runtime_gather_output, extra_block_kwargs=extra_block_kwargs, inference_context=inference_context, + return_logits=return_logits, ) def _postprocess( @@ -347,6 +349,7 @@ def _postprocess( runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, + return_logits=False, ): """Postprocesses decoder hidden states to generate logits or compute loss. @@ -473,6 +476,9 @@ def _postprocess( loss = self.compute_language_model_loss(labels, logits) + if return_logits: + return loss, logits + return loss def get_input_tensor(self): diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 484b01c8ff..db9ebadcdc 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -52,13 +52,36 @@ def loss_func(self, labels: torch.Tensor, loss_scale: Optional[torch.Tensor] = None, channels: Optional[List[str]] = None, - packed_seq_params=None): + packed_seq_params=None, + logits: Optional[torch.Tensor] = None): args = get_args() losses = output_tensor.float() loss_mask = labels != -100 if args.enable_dft_loss: losses = losses * torch.exp(-losses.detach()) + if args.enable_eaft_loss and logits is not None: + with torch.no_grad(): + logits_float = logits.float() + vocab_size = logits_float.shape[-1] + + batch_size = labels.shape[0] + seq_length = labels.shape[1] + logits_reshaped = logits_float.view(batch_size * seq_length, vocab_size) + + topk_logits, topk_indices = torch.topk(logits_reshaped, k=20, dim=-1) + logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) + log_probs_topk = topk_logits - logsumexp_topk + probs_topk = torch.exp(log_probs_topk) + entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) + normalized_entropy = entropy_approx / 3.0 + eaft_weight = torch.pow(normalized_entropy, args.eaft_alpha) + eaft_weight = eaft_weight.view(batch_size, seq_length) + eaft_weight = torch.where(loss_mask, eaft_weight, torch.ones_like(eaft_weight)) + + + losses = losses * eaft_weight + if loss_scale is not None: losses = losses * loss_scale if args.enable_channel_loss and channels is not None: @@ -146,9 +169,15 @@ def forward_step(self, data_iterator, model): labels = data.get('labels') if self.args.task_type == 'seq_cls': data.pop('labels', None) - with self.stimer: - output_tensor = model(**data) packed_seq_params = data.get('packed_seq_params') + + + with self.stimer: + if self.args.task_type != 'seq_cls': + output_tensor, logits = model(**data, return_logits=True) + else: + output_tensor = model(**data) + if self.args.task_type == 'seq_cls': loss_func = partial( self.seq_cls_loss_func, @@ -161,5 +190,6 @@ def forward_step(self, data_iterator, model): labels=labels, loss_scale=loss_scale, channels=channels, - packed_seq_params=packed_seq_params) + packed_seq_params=packed_seq_params, + logits=logits) return output_tensor, loss_func diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index ec894fbbb5..55e4494216 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -39,6 +39,9 @@ class TrainArgumentsMixin: enable_dft_loss (bool): Whether to enable Diversity-from-Diversity (DFD) loss. See https://arxiv.org/abs/2508.05629. Defaults to False. enable_channel_loss (bool): Whether to enable channel loss. Defaults to False. + enable_eaft_loss (bool): Whether to enable Entropy-Adaptive Fine-Tuning (EAFT) loss. Defaults to False. + eaft_alpha (float): The alpha parameter for EAFT loss. The final loss is calculated as + `(token_entropy / 3.0)^alpha * ce_loss`. Defaults to 1.0. weight_decay (float): The weight decay to apply (if not zero) to all layers except bias and LayerNorm weights. Defaults to 0.1. adam_beta2 (float): The beta2 hyperparameter for the AdamW optimizer. Defaults to 0.95. @@ -106,6 +109,8 @@ class TrainArgumentsMixin: router_aux_loss_coef: float = 0. enable_dft_loss: bool = False # https://arxiv.org/abs/2508.05629 enable_channel_loss: bool = False + enable_eaft_loss: bool = False + eaft_alpha: float = 1.0 weight_decay: float = 0.1 adam_beta2: float = 0.95 diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 00e72559cf..bdf84da0a9 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -288,7 +288,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N channels = inputs.pop('channel', None) if (self.label_smoother is not None or compute_loss_func is not None or loss_scale is not None - or self.args.enable_dft_loss or self.args.enable_channel_loss + or self.args.enable_dft_loss or self.args.enable_eaft_loss or self.args.enable_channel_loss or self.template.sequence_parallel_size > 1) and 'labels' in inputs: if self.args.use_liger_kernel: logger.warning_once('The cross_entropy loss function defined in Liger Kernel will not ' @@ -318,12 +318,20 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0] else: outputs.loss = None - if (self.args.enable_dft_loss or loss_scale is not None or self.args.enable_channel_loss + if (self.args.enable_dft_loss or self.args.enable_eaft_loss or loss_scale is not None or self.args.enable_channel_loss or self.template.sequence_parallel_size > 1): if self.template.sequence_parallel_size > 1: - outputs.loss = per_token_loss_func_sp(outputs, labels, enable_dft_loss=self.args.enable_dft_loss) + outputs.loss = per_token_loss_func_sp( + outputs, labels, + enable_dft_loss=self.args.enable_dft_loss, + enable_eaft_loss=self.args.enable_eaft_loss, + eaft_alpha=self.args.eaft_alpha) else: - outputs.loss = per_token_loss_func(outputs, labels, enable_dft_loss=self.args.enable_dft_loss) + outputs.loss = per_token_loss_func( + outputs, labels, + enable_dft_loss=self.args.enable_dft_loss, + enable_eaft_loss=self.args.enable_eaft_loss, + eaft_alpha=self.args.eaft_alpha) if loss_scale is not None: loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1).view(-1) diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py index bae8b79928..02f6f0df18 100644 --- a/swift/trainers/utils.py +++ b/swift/trainers/utils.py @@ -56,7 +56,7 @@ def is_instance_of_ms_model(model: Module) -> bool: return False -def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> torch.Tensor: +def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, enable_eaft_loss=False, eaft_alpha=1.0, **kwargs) -> torch.Tensor: """Common loss function for sequence parallel training""" if hasattr(outputs, 'logits'): logits = outputs.logits @@ -78,6 +78,17 @@ def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> with torch.no_grad(): target_probs = torch.exp(-loss) loss *= target_probs + if enable_eaft_loss: + with torch.no_grad(): + topk_logits, topk_indices = torch.topk(logits_detach, k=20, dim=-1) + logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) + log_probs_topk = topk_logits - logsumexp_topk + probs_topk = torch.exp(log_probs_topk) + entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) + normalized_entropy = entropy_approx / 3.0 + eaft_weight = torch.pow(normalized_entropy, eaft_alpha) + + loss *= eaft_weight from swift.trainers.sequence_parallel import sequence_parallel position_ids = sequence_parallel.real_position_ids if position_ids is not None: @@ -91,7 +102,8 @@ def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> return loss -def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs): +def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, enable_eaft_loss: bool = False, + eaft_alpha: float = 1.0, **kwargs): logits = outputs.logits # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() @@ -106,4 +118,20 @@ def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs with torch.no_grad(): target_probs = torch.exp(-loss) loss *= target_probs + if enable_eaft_loss: + with torch.no_grad(): + valid_mask = labels != -100 + logits_detach = logits[valid_mask].detach() + topk_logits, topk_indices = torch.topk(logits_detach, k=20, dim=-1) + logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) + log_probs_topk = topk_logits - logsumexp_topk + probs_topk = torch.exp(log_probs_topk) + entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) + normalized_entropy = entropy_approx / 3.0 + eaft_weight = torch.pow(normalized_entropy, eaft_alpha) + + eaft_weight_full = torch.ones_like(loss) + eaft_weight_full[valid_mask] = eaft_weight + + loss *= eaft_weight_full return loss From eca0b27ef172d1cd1121d05c528d2a57fd26d41a Mon Sep 17 00:00:00 2001 From: ymxyll <2313418841@qq.com> Date: Mon, 12 Jan 2026 16:13:22 +0800 Subject: [PATCH 2/3] fix bugs --- swift/megatron/trainers/trainer.py | 17 +++++++++++------ swift/trainers/utils.py | 11 +++++++++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index db9ebadcdc..24d6081c9f 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -67,18 +67,23 @@ def loss_func(self, batch_size = labels.shape[0] seq_length = labels.shape[1] - logits_reshaped = logits_float.view(batch_size * seq_length, vocab_size) - - topk_logits, topk_indices = torch.topk(logits_reshaped, k=20, dim=-1) + + logits_transposed = logits_float.transpose(0, 1) + logits_reshaped = logits_transposed.view(batch_size * seq_length, vocab_size) + + logits_valid = logits_reshaped[loss_mask.view(-1)] + + topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1) logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) log_probs_topk = topk_logits - logsumexp_topk probs_topk = torch.exp(log_probs_topk) entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) normalized_entropy = entropy_approx / 3.0 - eaft_weight = torch.pow(normalized_entropy, args.eaft_alpha) - eaft_weight = eaft_weight.view(batch_size, seq_length) - eaft_weight = torch.where(loss_mask, eaft_weight, torch.ones_like(eaft_weight)) + eaft_weight_valid = torch.pow(normalized_entropy, args.eaft_alpha) + eaft_weight = torch.ones(batch_size * seq_length, device=logits_float.device) + eaft_weight[loss_mask.view(-1)] = eaft_weight_valid + eaft_weight = eaft_weight.view(batch_size, seq_length) losses = losses * eaft_weight diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py index 02f6f0df18..6ca67b7558 100644 --- a/swift/trainers/utils.py +++ b/swift/trainers/utils.py @@ -80,13 +80,20 @@ def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, enable_eaft_l loss *= target_probs if enable_eaft_loss: with torch.no_grad(): - topk_logits, topk_indices = torch.topk(logits_detach, k=20, dim=-1) + logits_detach = logits.detach() + valid_mask = labels != -100 + logits_valid = logits_detach[valid_mask] + + topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1) logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) log_probs_topk = topk_logits - logsumexp_topk probs_topk = torch.exp(log_probs_topk) entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) normalized_entropy = entropy_approx / 3.0 - eaft_weight = torch.pow(normalized_entropy, eaft_alpha) + eaft_weight_valid = torch.pow(normalized_entropy, eaft_alpha) + + eaft_weight = torch.ones_like(loss) + eaft_weight[valid_mask] = eaft_weight_valid loss *= eaft_weight from swift.trainers.sequence_parallel import sequence_parallel From 7683b1e3b6dd509b95d77d3ee557aba59b9c7a34 Mon Sep 17 00:00:00 2001 From: ymxyll <2313418841@qq.com> Date: Thu, 15 Jan 2026 15:16:26 +0800 Subject: [PATCH 3/3] fix bug when pp>1 --- swift/megatron/trainers/trainer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 24d6081c9f..94a563ad26 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -7,6 +7,7 @@ from megatron.core import mpu from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.training import get_args, get_timers +from megatron.core import parallel_state from torch.distributed.nn import all_reduce from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -178,10 +179,14 @@ def forward_step(self, data_iterator, model): with self.stimer: - if self.args.task_type != 'seq_cls': - output_tensor, logits = model(**data, return_logits=True) + is_last = parallel_state.is_pipeline_last_stage() + + if is_last: + loss, logits = model(**data, return_logits=True) + output_tensor = loss else: - output_tensor = model(**data) + output_tensor = model(**data) # only hidden_states tensor + logits = None if self.args.task_type == 'seq_cls': loss_func = partial(