diff --git a/examples/megatron/eaft.sh b/examples/megatron/eaft.sh new file mode 100644 index 0000000000..fbbbfc35cb --- /dev/null +++ b/examples/megatron/eaft.sh @@ -0,0 +1,33 @@ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron sft \ + --model Qwen/Qwen3-0.6B \ + --load_safetensors true \ + --save_safetensors true \ + --dataset 'swift_shuf_19k_data.jsonl' \ + --tensor_model_parallel_size 1 \ + --sequence_parallel true \ + --micro_batch_size 4 \ + --global_batch_size 64 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-6 \ + --max_epochs 1 \ + --save megatron_output/Qwen3-0.6B/eaft \ + --save_interval 100 \ + --max_length 16384 \ + --system 'You are a helpful assistant.' \ + --num_workers 4 \ + --no_save_optim true \ + --no_save_rng true \ + --dataset_num_proc 4 \ + --tensorboard_dir /tensorboard/Qwen3-0.6B/eaft \ + --enable_eaft_loss true \ + --eaft_alpha 1.0 + diff --git a/examples/train/eaft.sh b/examples/train/eaft.sh new file mode 100644 index 0000000000..e7dfbcc9d8 --- /dev/null +++ b/examples/train/eaft.sh @@ -0,0 +1,24 @@ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model Qwen/Qwen3-0.6B \ + --train_type full \ + --dataset 'swift_shuf_19k_data.jsonl' \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --learning_rate 1e-5 \ + --gradient_accumulation_steps 16 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 16384 \ + --output_dir swift_output/Qwen3-0.6B/eaft \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --enable_eaft_loss true \ + --eaft_alpha 1.0 \ + --deepspeed zero3 \ + --report_to tensorboard \ + --logging_dir tensorboard/swift_output/Qwen3-0.6B/eaft \ diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 2b12abdd16..252ea8522e 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -346,6 +346,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): llm_architectures: Optional[str] = None max_epochs: Optional[int] = None enable_dft_loss: bool = False + enable_eaft_loss: bool = False + eaft_alpha: float = 1.0 enable_channel_loss: bool = False task_type: Literal['causal_lm', 'seq_cls'] = None num_labels: Optional[int] = None diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 37582f34d6..569e12e0f2 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -272,6 +272,7 @@ def forward( *, inference_params: Optional[BaseInferenceContext] = None, loss_mask: Optional[torch.Tensor] = None, + return_logits: bool = False, **kwargs, ) -> torch.Tensor: """Forward function of the GPT Model This function passes the input tensors @@ -327,6 +328,7 @@ def forward( runtime_gather_output=runtime_gather_output, extra_block_kwargs=extra_block_kwargs, inference_context=inference_context, + return_logits=return_logits, ) def _postprocess( @@ -347,6 +349,7 @@ def _postprocess( runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, + return_logits=False, ): """Postprocesses decoder hidden states to generate logits or compute loss. @@ -473,6 +476,9 @@ def _postprocess( loss = self.compute_language_model_loss(labels, logits) + if return_logits: + return loss, logits + return loss def get_input_tensor(self): diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 484b01c8ff..94a563ad26 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -7,6 +7,7 @@ from megatron.core import mpu from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.training import get_args, get_timers +from megatron.core import parallel_state from torch.distributed.nn import all_reduce from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -52,13 +53,41 @@ def loss_func(self, labels: torch.Tensor, loss_scale: Optional[torch.Tensor] = None, channels: Optional[List[str]] = None, - packed_seq_params=None): + packed_seq_params=None, + logits: Optional[torch.Tensor] = None): args = get_args() losses = output_tensor.float() loss_mask = labels != -100 if args.enable_dft_loss: losses = losses * torch.exp(-losses.detach()) + if args.enable_eaft_loss and logits is not None: + with torch.no_grad(): + logits_float = logits.float() + vocab_size = logits_float.shape[-1] + + batch_size = labels.shape[0] + seq_length = labels.shape[1] + + logits_transposed = logits_float.transpose(0, 1) + logits_reshaped = logits_transposed.view(batch_size * seq_length, vocab_size) + + logits_valid = logits_reshaped[loss_mask.view(-1)] + + topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1) + logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) + log_probs_topk = topk_logits - logsumexp_topk + probs_topk = torch.exp(log_probs_topk) + entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) + normalized_entropy = entropy_approx / 3.0 + eaft_weight_valid = torch.pow(normalized_entropy, args.eaft_alpha) + + eaft_weight = torch.ones(batch_size * seq_length, device=logits_float.device) + eaft_weight[loss_mask.view(-1)] = eaft_weight_valid + eaft_weight = eaft_weight.view(batch_size, seq_length) + + losses = losses * eaft_weight + if loss_scale is not None: losses = losses * loss_scale if args.enable_channel_loss and channels is not None: @@ -146,9 +175,19 @@ def forward_step(self, data_iterator, model): labels = data.get('labels') if self.args.task_type == 'seq_cls': data.pop('labels', None) - with self.stimer: - output_tensor = model(**data) packed_seq_params = data.get('packed_seq_params') + + + with self.stimer: + is_last = parallel_state.is_pipeline_last_stage() + + if is_last: + loss, logits = model(**data, return_logits=True) + output_tensor = loss + else: + output_tensor = model(**data) # only hidden_states tensor + logits = None + if self.args.task_type == 'seq_cls': loss_func = partial( self.seq_cls_loss_func, @@ -161,5 +200,6 @@ def forward_step(self, data_iterator, model): labels=labels, loss_scale=loss_scale, channels=channels, - packed_seq_params=packed_seq_params) + packed_seq_params=packed_seq_params, + logits=logits) return output_tensor, loss_func diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index ec894fbbb5..55e4494216 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -39,6 +39,9 @@ class TrainArgumentsMixin: enable_dft_loss (bool): Whether to enable Diversity-from-Diversity (DFD) loss. See https://arxiv.org/abs/2508.05629. Defaults to False. enable_channel_loss (bool): Whether to enable channel loss. Defaults to False. + enable_eaft_loss (bool): Whether to enable Entropy-Adaptive Fine-Tuning (EAFT) loss. Defaults to False. + eaft_alpha (float): The alpha parameter for EAFT loss. The final loss is calculated as + `(token_entropy / 3.0)^alpha * ce_loss`. Defaults to 1.0. weight_decay (float): The weight decay to apply (if not zero) to all layers except bias and LayerNorm weights. Defaults to 0.1. adam_beta2 (float): The beta2 hyperparameter for the AdamW optimizer. Defaults to 0.95. @@ -106,6 +109,8 @@ class TrainArgumentsMixin: router_aux_loss_coef: float = 0. enable_dft_loss: bool = False # https://arxiv.org/abs/2508.05629 enable_channel_loss: bool = False + enable_eaft_loss: bool = False + eaft_alpha: float = 1.0 weight_decay: float = 0.1 adam_beta2: float = 0.95 diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 00e72559cf..bdf84da0a9 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -288,7 +288,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N channels = inputs.pop('channel', None) if (self.label_smoother is not None or compute_loss_func is not None or loss_scale is not None - or self.args.enable_dft_loss or self.args.enable_channel_loss + or self.args.enable_dft_loss or self.args.enable_eaft_loss or self.args.enable_channel_loss or self.template.sequence_parallel_size > 1) and 'labels' in inputs: if self.args.use_liger_kernel: logger.warning_once('The cross_entropy loss function defined in Liger Kernel will not ' @@ -318,12 +318,20 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0] else: outputs.loss = None - if (self.args.enable_dft_loss or loss_scale is not None or self.args.enable_channel_loss + if (self.args.enable_dft_loss or self.args.enable_eaft_loss or loss_scale is not None or self.args.enable_channel_loss or self.template.sequence_parallel_size > 1): if self.template.sequence_parallel_size > 1: - outputs.loss = per_token_loss_func_sp(outputs, labels, enable_dft_loss=self.args.enable_dft_loss) + outputs.loss = per_token_loss_func_sp( + outputs, labels, + enable_dft_loss=self.args.enable_dft_loss, + enable_eaft_loss=self.args.enable_eaft_loss, + eaft_alpha=self.args.eaft_alpha) else: - outputs.loss = per_token_loss_func(outputs, labels, enable_dft_loss=self.args.enable_dft_loss) + outputs.loss = per_token_loss_func( + outputs, labels, + enable_dft_loss=self.args.enable_dft_loss, + enable_eaft_loss=self.args.enable_eaft_loss, + eaft_alpha=self.args.eaft_alpha) if loss_scale is not None: loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1).view(-1) diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py index bae8b79928..6ca67b7558 100644 --- a/swift/trainers/utils.py +++ b/swift/trainers/utils.py @@ -56,7 +56,7 @@ def is_instance_of_ms_model(model: Module) -> bool: return False -def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> torch.Tensor: +def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, enable_eaft_loss=False, eaft_alpha=1.0, **kwargs) -> torch.Tensor: """Common loss function for sequence parallel training""" if hasattr(outputs, 'logits'): logits = outputs.logits @@ -78,6 +78,24 @@ def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> with torch.no_grad(): target_probs = torch.exp(-loss) loss *= target_probs + if enable_eaft_loss: + with torch.no_grad(): + logits_detach = logits.detach() + valid_mask = labels != -100 + logits_valid = logits_detach[valid_mask] + + topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1) + logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) + log_probs_topk = topk_logits - logsumexp_topk + probs_topk = torch.exp(log_probs_topk) + entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) + normalized_entropy = entropy_approx / 3.0 + eaft_weight_valid = torch.pow(normalized_entropy, eaft_alpha) + + eaft_weight = torch.ones_like(loss) + eaft_weight[valid_mask] = eaft_weight_valid + + loss *= eaft_weight from swift.trainers.sequence_parallel import sequence_parallel position_ids = sequence_parallel.real_position_ids if position_ids is not None: @@ -91,7 +109,8 @@ def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> return loss -def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs): +def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, enable_eaft_loss: bool = False, + eaft_alpha: float = 1.0, **kwargs): logits = outputs.logits # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() @@ -106,4 +125,20 @@ def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs with torch.no_grad(): target_probs = torch.exp(-loss) loss *= target_probs + if enable_eaft_loss: + with torch.no_grad(): + valid_mask = labels != -100 + logits_detach = logits[valid_mask].detach() + topk_logits, topk_indices = torch.topk(logits_detach, k=20, dim=-1) + logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) + log_probs_topk = topk_logits - logsumexp_topk + probs_topk = torch.exp(log_probs_topk) + entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) + normalized_entropy = entropy_approx / 3.0 + eaft_weight = torch.pow(normalized_entropy, eaft_alpha) + + eaft_weight_full = torch.ones_like(loss) + eaft_weight_full[valid_mask] = eaft_weight + + loss *= eaft_weight_full return loss