refactor(nyz): audio language model RL pipeline#58
Open
PaParaZz1 wants to merge 13 commits into
Open
Conversation
puyuan1996
requested changes
Apr 24, 2026
|
|
||
| # Build question template (matches R1-AQA) | ||
| choice_str = f"Please choose the answer from the following options: {multi_choice}." | ||
| # There should be a space between <answer> and </answer> |
Collaborator
There was a problem hiding this comment.
这个对结果影响很大吗?注释里面讲加不加空格对性能的影响是否更好呢
Member
Author
There was a problem hiding this comment.
对 qwen2-audio 的初始收敛速度影响很大,should be 就是说明要加这个空格的意思
| parser = argparse.ArgumentParser( | ||
| description="Preprocess AVQA dataset (R1-AQA format) for LightRFT training" | ||
| ) | ||
| parser = argparse.ArgumentParser(description="Preprocess AVQA dataset (R1-AQA format) for LightRFT training") |
| Each item returns ``(prompt_text, audio_payload, reference, label)`` where: | ||
| - ``prompt_text`` is rendered through the Qwen2-Audio chat template | ||
| - ``audio_payload`` is kept as raw waveform + sampling rate for rollout-side processing | ||
| - ``reference`` and ``label`` are passed through to reward computation |
Collaborator
There was a problem hiding this comment.
为什么删掉:param dataset:这些呢
| all_videos: Optional[List] = None, | ||
| videos_num: Optional[List[int]] = None, | ||
| ) -> EasyDict: | ||
| @staticmethod |
Collaborator
There was a problem hiding this comment.
AudioMultimodalProcessor是没有用到了吗
| "Output the final answer in <answer></answer>." | ||
| ) | ||
| question_template = (f"{obj_dict['question']} {choice_str} " | ||
| "Output the final answer in <answer></answer>.") |
| log_probs: Optional[torch.Tensor] = None, | ||
| old_log_probs: Optional[torch.Tensor] = None, | ||
| ratio: Optional[torch.Tensor] = None, | ||
| ) -> None: |
| experience.sequences[0].unsqueeze(0), skip_special_tokens=True | ||
| ) | ||
| self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) | ||
| self.strategy.print( |
Collaborator
There was a problem hiding this comment.
启动脚本中加个debug的option 如果打开就print这里的信息吧?方便debug分析
Member
Author
There was a problem hiding this comment.
这不是这个 PR 的功能,应该在其他 polish PR 弄
| else: | ||
| sequences = experience.sequences | ||
|
|
||
| pixel_values = experience.pixel_values |
Collaborator
There was a problem hiding this comment.
为什么 pixel_values相关都删除了呀
Member
Author
There was a problem hiding this comment.
是放在 _build_model_kwargs 方法里统一实现了,这样对各种模态都能处理
|
|
||
| # [Protection measure 2] Per-token KL Clamping | ||
| # NOTE: Adding this causes svkng training to not converge | ||
| # kl = torch.clamp(kl, min=0.0, max=20.0) |
| # Use wandb_log_counter to ensure eval has a unique system step | ||
| # This prevents eval metrics from being overwritten by train metrics | ||
| # The plots will still use eval/global_step as X-axis due to define_metric | ||
| self.wandb_log_counter += 1 |
puyuan1996
reviewed
Apr 24, 2026
| Audio RL now uses a dedicated rollout path in core LightRFT code: | ||
| - raw audio payloads stay on the generation side and are passed to SGLang as `audio_data` | ||
| - processed mel features are stored explicitly as `audio_values` | ||
| - Qwen2-Audio feature masking is stored explicitly as `feature_attention_mask` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.