Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions swift/megatron/trainers/rlhf_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,20 @@ def null_ref_context(self):
if has_ref_adapter:
for m in self.peft_models:
m.set_adapter('ref_adapter')
yield ref_models
if has_ref_adapter:
for m in self.peft_models:
m.set_adapter('default')
# Temporarily set ref_models to eval mode to disable checkpointing/recompute.
# This is important when using TE + checkpointing, as checkpointing in no_grad context
training_states = [m.training for m in ref_models]
for m in ref_models:
m.eval()
try:
yield ref_models
finally:
for m, was_training in zip(ref_models, training_states):
if was_training:
m.train()
if has_ref_adapter:
for m in self.peft_models:
m.set_adapter('default')
Comment on lines +55 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of try...finally is a good practice for ensuring resource cleanup. To enhance robustness, I recommend moving the state modification logic (setting models to eval() mode) inside the try block. This change ensures that if an exception occurs during the m.eval() call on one of the models, the finally block will still be executed, guaranteeing that the training states of all models are correctly restored. This prevents potential state corruption that could affect subsequent training steps.

Suggested change
training_states = [m.training for m in ref_models]
for m in ref_models:
m.eval()
try:
yield ref_models
finally:
for m, was_training in zip(ref_models, training_states):
if was_training:
m.train()
if has_ref_adapter:
for m in self.peft_models:
m.set_adapter('default')
training_states = [m.training for m in ref_models]
try:
for m in ref_models:
m.eval()
yield ref_models
finally:
for m, was_training in zip(ref_models, training_states):
if was_training:
m.train()
if has_ref_adapter:
for m in self.peft_models:
m.set_adapter('default')


def get_logps(self, output_tensor, labels, packed_seq_params, num_samples, per_token=False):
args = get_args()
Expand Down