diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index edc5207839..0ea5c4f031 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -50,10 +50,20 @@ def null_ref_context(self): if has_ref_adapter: for m in self.peft_models: m.set_adapter('ref_adapter') - yield ref_models - if has_ref_adapter: - for m in self.peft_models: - m.set_adapter('default') + # Temporarily set ref_models to eval mode to disable checkpointing/recompute. + # This is important when using TE + checkpointing, as checkpointing in no_grad context + training_states = [m.training for m in ref_models] + for m in ref_models: + m.eval() + try: + yield ref_models + finally: + for m, was_training in zip(ref_models, training_states): + if was_training: + m.train() + if has_ref_adapter: + for m in self.peft_models: + m.set_adapter('default') def get_logps(self, output_tensor, labels, packed_seq_params, num_samples, per_token=False): args = get_args()