diff --git a/mostlyai/engine/_language/training.py b/mostlyai/engine/_language/training.py index b86c813e..57fef09a 100644 --- a/mostlyai/engine/_language/training.py +++ b/mostlyai/engine/_language/training.py @@ -31,7 +31,7 @@ from huggingface_hub import get_safetensors_metadata from opacus import GradSampleModule, PrivacyEngine from opacus.accountants import GaussianAccountant, PRVAccountant, RDPAccountant -from opacus.grad_sample import register_grad_sampler +from opacus.grad_sample import GradSampleHooks, register_grad_sampler from opacus.utils.batch_memory_manager import wrap_data_loader from peft import LoraConfig, PeftModel from torch import nn @@ -626,6 +626,7 @@ def concat_prompt_and_response(x): # this can help accelerate GPU compute torch.backends.cudnn.benchmark = True + dp_grad_sample_hooks: GradSampleHooks | None = None if with_dp: if isinstance(differential_privacy, DifferentialPrivacyConfig): dp_config = differential_privacy.model_dump() @@ -650,18 +651,21 @@ def concat_prompt_and_response(x): privacy_engine.accountant.load_state_dict( torch.load(workspace.model_dp_accountant_path, map_location=device, weights_only=True), ) - # Opacus will return the modified objects - # - model: wrapped in GradSampleModule and contains additional hooks for computing per-sample gradients - # - optimizer: wrapped in DPOptimizer and will do different operations during virtual steps and logical steps - # - dataloader: the dataloader with batch_sampler=UniformWithReplacementSampler (for Poisson sampling) - model, optimizer, trn_dataloader = privacy_engine.make_private( + # Opacus returns GradSampleHooks when wrap_model=False: hooks attach to the original module so HF / + # Transformers sees an unwrapped PreTrainedModel (requires Opacus >= 1.6). + # - dp_grad_sample_hooks: must call .cleanup() after training to remove backward hooks and param attrs + # - optimizer: wrapped in DPOptimizer (virtual vs logical steps) + # - dataloader: UniformWithReplacementSampler when poisson_sampling=True + dp_grad_sample_hooks, optimizer, trn_dataloader = privacy_engine.make_private( module=model, optimizer=optimizer, data_loader=trn_dataloader, noise_multiplier=dp_config.get("noise_multiplier"), max_grad_norm=dp_config.get("max_grad_norm"), poisson_sampling=True, + wrap_model=False, ) + model = dp_grad_sample_hooks._module # this further wraps the dataloader with batch_sampler=BatchSplittingSampler to achieve gradient accumulation # it will split the sampled logical batches into smaller sub-batches with batch_size trn_dataloader = wrap_data_loader( @@ -835,6 +839,9 @@ def concat_prompt_and_response(x): if total_training_time > max_training_time: do_stop = True + if dp_grad_sample_hooks is not None: + dp_grad_sample_hooks.cleanup() + # no checkpoint is saved yet because the training stopped before the first epoch ended if not model_checkpoint.has_saved_once(): _LOG.info("saving model weights, as none were saved so far")