Skip to content

[feature] add support for EAFT loss#7361

Open
ymxyll wants to merge 3 commits intomodelscope:mainfrom
ymxyll:feature/eaft
Open

[feature] add support for EAFT loss#7361
ymxyll wants to merge 3 commits intomodelscope:mainfrom
ymxyll:feature/eaft

Conversation

@ymxyll
Copy link

@ymxyll ymxyll commented Jan 12, 2026

PR type

  • Bug Fix
  • [✅ ] New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Dear all,
Thanks for maintaining such a fantastic project that enables the community to explore various fine-tuning approaches on foundation models.

I would like to contribute by implementing Entropy-Adaptive Fine-Tuning(EAFT) support to this repository.

Thank you so much, and I welcome any feedback for improving this PR.

Best regards,

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ymxyll, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Entropy-Adaptive Fine-Tuning (EAFT) loss into the project, offering a new method for fine-tuning foundation models. By dynamically adjusting the loss based on token entropy, EAFT aims to improve model performance and generalization. The changes involve adding new command-line arguments for configuration, modifying the core loss calculation logic to incorporate entropy-based weighting, and updating model outputs to provide the necessary logits for this calculation. Example scripts are also included to facilitate immediate adoption and testing of this new feature.

Highlights

  • New Feature: EAFT Loss Support: Introduced Entropy-Adaptive Fine-Tuning (EAFT) loss, a novel approach to dynamically weight cross-entropy loss based on token entropy, enhancing model fine-tuning capabilities.
  • Configuration Arguments: Added new arguments, --enable_eaft_loss (boolean) and --eaft_alpha (float), to both Megatron and Swift training configurations, allowing users to enable and control the EAFT loss.
  • Loss Calculation Logic: Modified the loss computation functions in swift/megatron/trainers/trainer.py and swift/trainers/utils.py to incorporate the EAFT weighting mechanism, which involves calculating approximate entropy from model logits.
  • Model Output Modification: Updated swift/megatron/model/gpt_model.py to optionally return logits from the forward and _postprocess methods, providing the necessary input for EAFT loss calculation.
  • Example Scripts: Provided new example shell scripts (examples/megatron/eaft.sh and examples/train/eaft.sh) to demonstrate how to enable and use the EAFT loss with both Megatron and Swift training frameworks.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Entropy-Adaptive Fine-Tuning (EAFT) loss, a new feature. The changes are well-structured, introducing new arguments, modifying the model to provide necessary outputs (logits), and updating the trainers to compute and apply the EAFT loss. The implementation seems to correctly follow the principles of EAFT.

However, I've identified a critical bug in the Megatron trainer concerning a tensor shape mismatch which will lead to incorrect loss calculation. There is also a high-severity bug that could cause an UnboundLocalError. Additionally, I've provided suggestions to improve maintainability by refactoring duplicated code and making hardcoded values configurable. Addressing these points will enhance the robustness and quality of the implementation.

eaft_weight[loss_mask.view(-1)] = eaft_weight_valid
eaft_weight = eaft_weight.view(batch_size, seq_length)

losses = losses * eaft_weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a shape mismatch when applying the EAFT weights. The losses tensor has a shape of [sequence_length, batch_size], while eaft_weight has a shape of [batch_size, sequence_length]. Multiplying them directly will lead to incorrect broadcasting or a runtime error. You need to transpose eaft_weight before multiplication.

Suggested change
losses = losses * eaft_weight
losses = losses * eaft_weight.transpose(0, 1)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. I’ve double-checked the shapes, and the implementation is actually correct. Since the input logits are [Seq, Batch], I explicitly transpose them via transpose(0, 1) to align with the [Batch, Seq] layout of the losses tensor. Given that eaft_weight serves as a weight map for the loss, applying element-wise multiplication here is the correct operation, so no further transposition is required.

output_tensor = model(**data)
packed_seq_params = data.get('packed_seq_params')


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logits variable is used on line 199 but is only conditionally defined within the if block on line 181. If self.args.task_type is 'seq_cls', this will raise an UnboundLocalError. Initialize logits to None here to ensure it is always defined.

Suggested change
logits = None


logits_valid = logits_reshaped[loss_mask.view(-1)]

topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value for k in torch.topk is hardcoded to 20. This reduces flexibility for experimentation. It would be better to make this a configurable parameter by adding an eaft_top_k argument to ExtraMegatronArguments (with a default of 20).

Comment on lines +81 to +98
if enable_eaft_loss:
with torch.no_grad():
logits_detach = logits.detach()
valid_mask = labels != -100
logits_valid = logits_detach[valid_mask]

topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True)
log_probs_topk = topk_logits - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
normalized_entropy = entropy_approx / 3.0
eaft_weight_valid = torch.pow(normalized_entropy, eaft_alpha)

eaft_weight = torch.ones_like(loss)
eaft_weight[valid_mask] = eaft_weight_valid

loss *= eaft_weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating the EAFT loss here is nearly identical to the implementation in per_token_loss_func (lines 128-144). This duplication makes the code harder to maintain. Please refactor this shared logic into a single helper function to improve maintainability and reduce redundancy.

valid_mask = labels != -100
logits_valid = logits_detach[valid_mask]

topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value for k in torch.topk is hardcoded to 20. This is also the case in per_token_loss_func. This should be a configurable parameter to allow for easier experimentation. Consider adding an eaft_top_k argument to TrainArgumentsMixin and passing it down to these functions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant