Conversation
Summary of ChangesHello @ymxyll, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Entropy-Adaptive Fine-Tuning (EAFT) loss into the project, offering a new method for fine-tuning foundation models. By dynamically adjusting the loss based on token entropy, EAFT aims to improve model performance and generalization. The changes involve adding new command-line arguments for configuration, modifying the core loss calculation logic to incorporate entropy-based weighting, and updating model outputs to provide the necessary logits for this calculation. Example scripts are also included to facilitate immediate adoption and testing of this new feature. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for Entropy-Adaptive Fine-Tuning (EAFT) loss, a new feature. The changes are well-structured, introducing new arguments, modifying the model to provide necessary outputs (logits), and updating the trainers to compute and apply the EAFT loss. The implementation seems to correctly follow the principles of EAFT.
However, I've identified a critical bug in the Megatron trainer concerning a tensor shape mismatch which will lead to incorrect loss calculation. There is also a high-severity bug that could cause an UnboundLocalError. Additionally, I've provided suggestions to improve maintainability by refactoring duplicated code and making hardcoded values configurable. Addressing these points will enhance the robustness and quality of the implementation.
| eaft_weight[loss_mask.view(-1)] = eaft_weight_valid | ||
| eaft_weight = eaft_weight.view(batch_size, seq_length) | ||
|
|
||
| losses = losses * eaft_weight |
There was a problem hiding this comment.
There's a shape mismatch when applying the EAFT weights. The losses tensor has a shape of [sequence_length, batch_size], while eaft_weight has a shape of [batch_size, sequence_length]. Multiplying them directly will lead to incorrect broadcasting or a runtime error. You need to transpose eaft_weight before multiplication.
| losses = losses * eaft_weight | |
| losses = losses * eaft_weight.transpose(0, 1) |
There was a problem hiding this comment.
Thanks for the review. I’ve double-checked the shapes, and the implementation is actually correct. Since the input logits are [Seq, Batch], I explicitly transpose them via transpose(0, 1) to align with the [Batch, Seq] layout of the losses tensor. Given that eaft_weight serves as a weight map for the loss, applying element-wise multiplication here is the correct operation, so no further transposition is required.
| output_tensor = model(**data) | ||
| packed_seq_params = data.get('packed_seq_params') | ||
|
|
||
|
|
There was a problem hiding this comment.
|
|
||
| logits_valid = logits_reshaped[loss_mask.view(-1)] | ||
|
|
||
| topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1) |
| if enable_eaft_loss: | ||
| with torch.no_grad(): | ||
| logits_detach = logits.detach() | ||
| valid_mask = labels != -100 | ||
| logits_valid = logits_detach[valid_mask] | ||
|
|
||
| topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1) | ||
| logsumexp_topk = torch.logsumexp(topk_logits, dim=-1, keepdim=True) | ||
| log_probs_topk = topk_logits - logsumexp_topk | ||
| probs_topk = torch.exp(log_probs_topk) | ||
| entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) | ||
| normalized_entropy = entropy_approx / 3.0 | ||
| eaft_weight_valid = torch.pow(normalized_entropy, eaft_alpha) | ||
|
|
||
| eaft_weight = torch.ones_like(loss) | ||
| eaft_weight[valid_mask] = eaft_weight_valid | ||
|
|
||
| loss *= eaft_weight |
There was a problem hiding this comment.
| valid_mask = labels != -100 | ||
| logits_valid = logits_detach[valid_mask] | ||
|
|
||
| topk_logits, topk_indices = torch.topk(logits_valid, k=20, dim=-1) |
There was a problem hiding this comment.
PR type
PR information
Dear all,
Thanks for maintaining such a fantastic project that enables the community to explore various fine-tuning approaches on foundation models.
I would like to contribute by implementing Entropy-Adaptive Fine-Tuning(EAFT) support to this repository.
Thank you so much, and I welcome any feedback for improving this PR.
Best regards,