Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions mostlyai/engine/_language/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from huggingface_hub import get_safetensors_metadata
from opacus import GradSampleModule, PrivacyEngine
from opacus.accountants import GaussianAccountant, PRVAccountant, RDPAccountant
from opacus.grad_sample import register_grad_sampler
from opacus.grad_sample import GradSampleHooks, register_grad_sampler
from opacus.utils.batch_memory_manager import wrap_data_loader
from peft import LoraConfig, PeftModel
from torch import nn
Expand Down Expand Up @@ -626,6 +626,7 @@ def concat_prompt_and_response(x):
# this can help accelerate GPU compute
torch.backends.cudnn.benchmark = True

dp_grad_sample_hooks: GradSampleHooks | None = None
if with_dp:
if isinstance(differential_privacy, DifferentialPrivacyConfig):
dp_config = differential_privacy.model_dump()
Expand All @@ -650,18 +651,21 @@ def concat_prompt_and_response(x):
privacy_engine.accountant.load_state_dict(
torch.load(workspace.model_dp_accountant_path, map_location=device, weights_only=True),
)
# Opacus will return the modified objects
# - model: wrapped in GradSampleModule and contains additional hooks for computing per-sample gradients
# - optimizer: wrapped in DPOptimizer and will do different operations during virtual steps and logical steps
# - dataloader: the dataloader with batch_sampler=UniformWithReplacementSampler (for Poisson sampling)
model, optimizer, trn_dataloader = privacy_engine.make_private(
# Opacus returns GradSampleHooks when wrap_model=False: hooks attach to the original module so HF /
# Transformers sees an unwrapped PreTrainedModel (requires Opacus >= 1.6).
# - dp_grad_sample_hooks: must call .cleanup() after training to remove backward hooks and param attrs
# - optimizer: wrapped in DPOptimizer (virtual vs logical steps)
# - dataloader: UniformWithReplacementSampler when poisson_sampling=True
dp_grad_sample_hooks, optimizer, trn_dataloader = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=trn_dataloader,
noise_multiplier=dp_config.get("noise_multiplier"),
max_grad_norm=dp_config.get("max_grad_norm"),
poisson_sampling=True,
wrap_model=False,
)
model = dp_grad_sample_hooks._module
# this further wraps the dataloader with batch_sampler=BatchSplittingSampler to achieve gradient accumulation
# it will split the sampled logical batches into smaller sub-batches with batch_size
trn_dataloader = wrap_data_loader(
Expand Down Expand Up @@ -835,6 +839,9 @@ def concat_prompt_and_response(x):
if total_training_time > max_training_time:
do_stop = True

if dp_grad_sample_hooks is not None:
dp_grad_sample_hooks.cleanup()

# no checkpoint is saved yet because the training stopped before the first epoch ended
if not model_checkpoint.has_saved_once():
_LOG.info("saving model weights, as none were saved so far")
Expand Down