-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[feature] add support for EAFT loss #7361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ | ||
| NPROC_PER_NODE=2 \ | ||
| CUDA_VISIBLE_DEVICES=0,1 \ | ||
| megatron sft \ | ||
| --model Qwen/Qwen3-0.6B \ | ||
| --load_safetensors true \ | ||
| --save_safetensors true \ | ||
| --dataset 'swift_shuf_19k_data.jsonl' \ | ||
| --tensor_model_parallel_size 1 \ | ||
| --sequence_parallel true \ | ||
| --micro_batch_size 4 \ | ||
| --global_batch_size 64 \ | ||
| --recompute_granularity full \ | ||
| --recompute_method uniform \ | ||
| --recompute_num_layers 1 \ | ||
| --finetune true \ | ||
| --cross_entropy_loss_fusion true \ | ||
| --lr 1e-5 \ | ||
| --lr_warmup_fraction 0.05 \ | ||
| --min_lr 1e-6 \ | ||
| --max_epochs 1 \ | ||
| --save megatron_output/Qwen3-0.6B/eaft \ | ||
| --save_interval 100 \ | ||
| --max_length 16384 \ | ||
| --system 'You are a helpful assistant.' \ | ||
| --num_workers 4 \ | ||
| --no_save_optim true \ | ||
| --no_save_rng true \ | ||
| --dataset_num_proc 4 \ | ||
| --tensorboard_dir /tensorboard/Qwen3-0.6B/eaft \ | ||
| --enable_eaft_loss true \ | ||
| --eaft_alpha 1.0 | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| NPROC_PER_NODE=2 \ | ||
| CUDA_VISIBLE_DEVICES=0,1 \ | ||
| swift sft \ | ||
| --model Qwen/Qwen3-0.6B \ | ||
| --train_type full \ | ||
| --dataset 'swift_shuf_19k_data.jsonl' \ | ||
| --torch_dtype bfloat16 \ | ||
| --num_train_epochs 1 \ | ||
| --per_device_train_batch_size 2 \ | ||
| --learning_rate 1e-5 \ | ||
| --gradient_accumulation_steps 16 \ | ||
| --save_steps 100 \ | ||
| --save_total_limit 2 \ | ||
| --logging_steps 5 \ | ||
| --max_length 16384 \ | ||
| --output_dir swift_output/Qwen3-0.6B/eaft \ | ||
| --system 'You are a helpful assistant.' \ | ||
| --warmup_ratio 0.05 \ | ||
| --dataloader_num_workers 4 \ | ||
| --enable_eaft_loss true \ | ||
| --eaft_alpha 1.0 \ | ||
| --deepspeed zero3 \ | ||
| --report_to tensorboard \ | ||
| --logging_dir tensorboard/swift_output/Qwen3-0.6B/eaft \ |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||
| from megatron.core import mpu | ||||||
| from megatron.core.rerun_state_machine import get_rerun_state_machine | ||||||
| from megatron.training import get_args, get_timers | ||||||
| from megatron.core import parallel_state | ||||||
| from torch.distributed.nn import all_reduce | ||||||
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||||||
|
|
||||||
|
|
@@ -52,13 +53,41 @@ def loss_func(self, | |||||
| labels: torch.Tensor, | ||||||
| loss_scale: Optional[torch.Tensor] = None, | ||||||
| channels: Optional[List[str]] = None, | ||||||
| packed_seq_params=None): | ||||||
| packed_seq_params=None, | ||||||
| logits: Optional[torch.Tensor] = None): | ||||||
| args = get_args() | ||||||
|
|
||||||
| losses = output_tensor.float() | ||||||
| loss_mask = labels != -100 | ||||||
| if args.enable_dft_loss: | ||||||
| losses = losses * torch.exp(-losses.detach()) | ||||||
| if args.enable_eaft_loss and logits is not None: | ||||||
| with torch.no_grad(): | ||||||
| logits_float = logits.float() | ||||||
| vocab_size = logits_float.shape[-1] | ||||||
|
|
||||||
| batch_size = labels.shape[0] | ||||||
| seq_length = labels.shape[1] | ||||||
|
|
||||||
| logits_transposed = logits_float.transpose(0, 1) | ||||||
| logits_reshaped = logits_transposed.view(batch_size * seq_length, vocab_size) | ||||||
|
|
||||||
| logits_valid = logits_reshaped[loss_mask.view(-1)] | ||||||
|
|
||||||
| topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1) | ||||||
| logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) | ||||||
| log_probs_topk = topk_logits - logsumexp_topk | ||||||
| probs_topk = torch.exp(log_probs_topk) | ||||||
| entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) | ||||||
| normalized_entropy = entropy_approx / 3.0 | ||||||
| eaft_weight_valid = torch.pow(normalized_entropy, args.eaft_alpha) | ||||||
|
|
||||||
| eaft_weight = torch.ones(batch_size * seq_length, device=logits_float.device) | ||||||
| eaft_weight[loss_mask.view(-1)] = eaft_weight_valid | ||||||
| eaft_weight = eaft_weight.view(batch_size, seq_length) | ||||||
|
|
||||||
| losses = losses * eaft_weight | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a shape mismatch when applying the EAFT weights. The
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review. I’ve double-checked the shapes, and the implementation is actually correct. Since the input logits are [Seq, Batch], I explicitly transpose them via transpose(0, 1) to align with the [Batch, Seq] layout of the losses tensor. Given that eaft_weight serves as a weight map for the loss, applying element-wise multiplication here is the correct operation, so no further transposition is required. |
||||||
|
|
||||||
| if loss_scale is not None: | ||||||
| losses = losses * loss_scale | ||||||
| if args.enable_channel_loss and channels is not None: | ||||||
|
|
@@ -146,9 +175,19 @@ def forward_step(self, data_iterator, model): | |||||
| labels = data.get('labels') | ||||||
| if self.args.task_type == 'seq_cls': | ||||||
| data.pop('labels', None) | ||||||
| with self.stimer: | ||||||
| output_tensor = model(**data) | ||||||
| packed_seq_params = data.get('packed_seq_params') | ||||||
|
|
||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| with self.stimer: | ||||||
| is_last = parallel_state.is_pipeline_last_stage() | ||||||
|
|
||||||
| if is_last: | ||||||
| loss, logits = model(**data, return_logits=True) | ||||||
| output_tensor = loss | ||||||
| else: | ||||||
| output_tensor = model(**data) # only hidden_states tensor | ||||||
| logits = None | ||||||
|
|
||||||
| if self.args.task_type == 'seq_cls': | ||||||
| loss_func = partial( | ||||||
| self.seq_cls_loss_func, | ||||||
|
|
@@ -161,5 +200,6 @@ def forward_step(self, data_iterator, model): | |||||
| labels=labels, | ||||||
| loss_scale=loss_scale, | ||||||
| channels=channels, | ||||||
| packed_seq_params=packed_seq_params) | ||||||
| packed_seq_params=packed_seq_params, | ||||||
| logits=logits) | ||||||
| return output_tensor, loss_func | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,7 +56,7 @@ def is_instance_of_ms_model(model: Module) -> bool: | |
| return False | ||
|
|
||
|
|
||
| def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> torch.Tensor: | ||
| def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, enable_eaft_loss=False, eaft_alpha=1.0, **kwargs) -> torch.Tensor: | ||
| """Common loss function for sequence parallel training""" | ||
| if hasattr(outputs, 'logits'): | ||
| logits = outputs.logits | ||
|
|
@@ -78,6 +78,24 @@ def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> | |
| with torch.no_grad(): | ||
| target_probs = torch.exp(-loss) | ||
| loss *= target_probs | ||
| if enable_eaft_loss: | ||
| with torch.no_grad(): | ||
| logits_detach = logits.detach() | ||
| valid_mask = labels != -100 | ||
| logits_valid = logits_detach[valid_mask] | ||
|
|
||
| topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) | ||
| log_probs_topk = topk_logits - logsumexp_topk | ||
| probs_topk = torch.exp(log_probs_topk) | ||
| entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) | ||
| normalized_entropy = entropy_approx / 3.0 | ||
| eaft_weight_valid = torch.pow(normalized_entropy, eaft_alpha) | ||
|
|
||
| eaft_weight = torch.ones_like(loss) | ||
| eaft_weight[valid_mask] = eaft_weight_valid | ||
|
|
||
| loss *= eaft_weight | ||
|
Comment on lines
+81
to
+98
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| from swift.trainers.sequence_parallel import sequence_parallel | ||
| position_ids = sequence_parallel.real_position_ids | ||
| if position_ids is not None: | ||
|
|
@@ -91,7 +109,8 @@ def per_token_loss_func_sp(outputs, labels, enable_dft_loss=False, **kwargs) -> | |
| return loss | ||
|
|
||
|
|
||
| def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs): | ||
| def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, enable_eaft_loss: bool = False, | ||
| eaft_alpha: float = 1.0, **kwargs): | ||
| logits = outputs.logits | ||
| # Upcast to float if we need to compute the loss to avoid potential precision issues | ||
| logits = logits.float() | ||
|
|
@@ -106,4 +125,20 @@ def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs | |
| with torch.no_grad(): | ||
| target_probs = torch.exp(-loss) | ||
| loss *= target_probs | ||
| if enable_eaft_loss: | ||
| with torch.no_grad(): | ||
| valid_mask = labels != -100 | ||
| logits_detach = logits[valid_mask].detach() | ||
| topk_logits, topk_indices = torch.topk(logits_detach, k=20, dim=-1) | ||
| logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) | ||
| log_probs_topk = topk_logits - logsumexp_topk | ||
| probs_topk = torch.exp(log_probs_topk) | ||
| entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) | ||
| normalized_entropy = entropy_approx / 3.0 | ||
| eaft_weight = torch.pow(normalized_entropy, eaft_alpha) | ||
|
|
||
| eaft_weight_full = torch.ones_like(loss) | ||
| eaft_weight_full[valid_mask] = eaft_weight | ||
|
|
||
| loss *= eaft_weight_full | ||
| return loss | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value for
kintorch.topkis hardcoded to 20. This reduces flexibility for experimentation. It would be better to make this a configurable parameter by adding aneaft_top_kargument toExtraMegatronArguments(with a default of 20).