diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 1de7af33c7..7cee5de738 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -257,6 +257,9 @@ gradient_checkpointing: true - 🔥neftune_noise_alpha: neftune添加的噪声系数。默认为0,通常可以设置为5、10、15。 - 🔥use_liger_kernel: 是否启用[Liger](https://github.com/linkedin/Liger-Kernel)内核加速训练并减少显存消耗。默认为False。示例shell参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/liger)。 - 注意:liger_kernel不支持device_map,请使用DDP/DeepSpeed进行多卡训练。liger_kernel目前只支持`task_type='causal_lm'`。 +- use_cce: 是否启用[cut-cross-entropy](https://github.com/apple/ml-cross-entropy)融合算子降低显存并加速训练。默认为False。示例shell参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/cce)。 +- use_tiled_mlp: 是否启用Tiled MLP进行内存高效的长序列训练。启用后,MLP层会被替换为分块实现,将序列分成多个shard进行计算以减少显存占用。默认为False。 +- tiled_mlp_num_shards: Tiled MLP计算时将序列分成的shard数量。默认为None,即设置为4。较大的值可以减少显存但可能增加计算时间。 - average_tokens_across_devices: 是否在设备之间进行token数平均。如果设置为True,将使用all_reduce同步`num_tokens_in_batch`以进行精确的损失计算。默认为False。 - max_grad_norm: 梯度裁剪。默认为1.。 - 注意:日志中的grad_norm记录的是裁剪前的值。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 485825ba98..217e197b23 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -260,6 +260,9 @@ Other important parameters: - 🔥neftune_noise_alpha: Noise magnitude for NEFTune. Default is 0. Common values: 5, 10, 15. - 🔥use_liger_kernel: Whether to enable the [Liger](https://github.com/linkedin/Liger-Kernel) kernel to accelerate training and reduce GPU memory consumption. Defaults to False. Example shell script can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/liger). - Note: Liger kernel does not support `device_map`. Use DDP or DeepSpeed for multi-GPU training. Currently, liger_kernel only supports `task_type='causal_lm'`. +- use_cce: Whether to enable the [cut-cross-entropy](https://github.com/apple/ml-cross-entropy) fused operator to reduce GPU memory usage and accelerate training. Defaults to `False`. Example shell script can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/cce). +- use_tiled_mlp: Whether to enable Tiled MLP for memory-efficient long sequence training. When enabled, MLP layers are replaced with a tiled implementation that processes sequences in chunks to reduce memory usage. Defaults to False. +- tiled_mlp_num_shards: Number of shards to split the sequence for tiled MLP computation. Defaults to None, which sets it to 4. Larger values reduce memory but may increase computation time. - average_tokens_across_devices: Whether to average token counts across devices. If `True`, `num_tokens_in_batch` is synchronized via `all_reduce` for accurate loss computation. Default is `False`. - max_grad_norm: Gradient clipping. Default is 1. - Note: The logged `grad_norm` reflects the value **before** clipping. diff --git a/examples/train/activation_cpu_offload/fsdp2.json b/examples/train/activation_cpu_offload/fsdp2.json new file mode 100644 index 0000000000..73d856389a --- /dev/null +++ b/examples/train/activation_cpu_offload/fsdp2.json @@ -0,0 +1,26 @@ +{ + "_description": "FSDP2 configuration for distributed training (PyTorch native FSDP v2)", + "_requires": "torch>=2.4.0", + "_note": "This is the recommended configuration for multi-GPU training without CPU offloading. NOTE: When using FSDP2, do NOT use --gradient_checkpointing, use activation_checkpointing in fsdp_config instead.", + + "_param_docs": { + "fsdp": "FSDP strategy string. Options: 'full_shard' (ZeRO-3 style, shards params+grads+optimizer), 'shard_grad_op' (ZeRO-2 style, shards grads+optimizer only). Add 'auto_wrap' to enable automatic layer wrapping. Add 'offload' to enable CPU offloading.", + "fsdp_version": "FSDP version. Use 2 for PyTorch native FSDP2 (recommended). FSDP2 uses DTensor for per-parameter sharding, supports LoRA/QLoRA natively.", + "auto_wrap_policy": "How to wrap model layers. 'TRANSFORMER_BASED_WRAP' wraps transformer decoder layers (from model._no_split_modules). 'SIZE_BASED_WRAP' wraps modules exceeding min_num_params.", + "cpu_ram_efficient_loading": "If true, only rank 0 loads full model weights, then broadcasts to other ranks. Reduces CPU RAM usage during initialization.", + "state_dict_type": "'SHARDED_STATE_DICT' (recommended): each rank saves its own shard without extra communication. 'FULL_STATE_DICT': gathers full model on rank 0 (higher memory, slower).", + "reshard_after_forward": "true = FULL_SHARD (ZeRO-3), reshards params after forward pass. false = SHARD_GRAD_OP (ZeRO-2), keeps params gathered during forward/backward.", + "activation_checkpointing": "Use FSDP's native activation checkpointing instead of gradient_checkpointing. This is the correct way to save memory with FSDP.", + "activation_cpu_offload": "true = offload activations to CPU. false = keep activations on GPU,can enable when using activation_checkpointing." + }, + "fsdp": "full_shard auto_wrap", + "fsdp_config": { + "fsdp_version": 2, + "reshard_after_forward": true, + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "cpu_ram_efficient_loading": true, + "state_dict_type": "SHARDED_STATE_DICT", + "activation_checkpointing": false, + "activation_cpu_offload": true + } +} diff --git a/examples/train/activation_cpu_offload/train.sh b/examples/train/activation_cpu_offload/train.sh new file mode 100644 index 0000000000..e5fee8e54c --- /dev/null +++ b/examples/train/activation_cpu_offload/train.sh @@ -0,0 +1,27 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model 'Qwen/Qwen3-0.6B' \ + --dataset 'swift/self-cognition#1000' \ \ + --load_from_cache_file true \ + --split_dataset_ratio 0.01 \ + --train_type lora \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --freeze_vit true \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --fsdp './examples/train/activation_cpu_offload/fsdp2.json' diff --git a/examples/train/cce/sft.sh b/examples/train/cce/sft.sh new file mode 100644 index 0000000000..5c34b74db3 --- /dev/null +++ b/examples/train/cce/sft.sh @@ -0,0 +1,17 @@ +# test env: 1 * A10 +# Using use_cce: 2.62GB +# Not using use_cce: 16.24G + +# Install CCE dependency +pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88" + +# Run ms-swift (example) +swift sft \ + --model Qwen/Qwen2.5-0.5B-Instruct \ + --dataset gsm8k#1024 \ + --train_type lora \ + --per_device_train_batch_size 64 \ + --per_device_eval_batch_size 64 \ + --use_hf true \ + --use_cce true \ + "$@" diff --git a/examples/train/tiled_mlp/fsdp2.json b/examples/train/tiled_mlp/fsdp2.json new file mode 100644 index 0000000000..18cce13780 --- /dev/null +++ b/examples/train/tiled_mlp/fsdp2.json @@ -0,0 +1,25 @@ +{ + "compute_environment": "LOCAL_MACHINE", + "debug": false, + "distributed_type": "FSDP", + "downcast_bf16": "no", + "fsdp_config": { + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_cpu_ram_efficient_loading": true, + "fsdp_reshard_after_forward": true, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_activation_checkpointing": true, + "fsdp_version": 2 + }, + "machine_rank": 0, + "main_training_function": "main", + "mixed_precision": "bf16", + "num_machines": 1, + "num_processes": 2, + "rdzv_backend": "static", + "same_network": true, + "tpu_env": [], + "tpu_use_cluster": false, + "tpu_use_sudo": false, + "use_cpu": false +} diff --git a/examples/train/tiled_mlp/train_deepspeed.sh b/examples/train/tiled_mlp/train_deepspeed.sh new file mode 100644 index 0000000000..244677b3ac --- /dev/null +++ b/examples/train/tiled_mlp/train_deepspeed.sh @@ -0,0 +1,24 @@ +CUDA_VISIBLE_DEVICES=0,1 \ +NPROC_PER_NODE=2 \ +swift sft \ + --model Qwen/Qwen3-4B \ + --dataset swift/self-cognition#200 \ + --train_type full \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --learning_rate 1e-5 \ + --weight_decay 0.1 \ + --gradient_accumulation_steps 1 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 1 \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --use_tiled_mlp true \ + --tiled_mlp_num_shards 4 \ + --deepspeed zero3 diff --git a/examples/train/tiled_mlp/train_fsdp2.sh b/examples/train/tiled_mlp/train_fsdp2.sh new file mode 100644 index 0000000000..5d2372602d --- /dev/null +++ b/examples/train/tiled_mlp/train_fsdp2.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# FSDP2 training with tiled MLP +# Requires accelerate config with fsdp_version: 2 + +# First, create the accelerate config (fsdp2.json) or use the one in examples/train/multi-gpu/fsdp2_lora/ + +# FSDP2 with tiled MLP +accelerate launch --config_file fsdp2.json \ + -m swift sft \ + --model Qwen/Qwen3-4B \ + --dataset swift/self-cognition#200 \ + --train_type full \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --learning_rate 1e-5 \ + --gradient_checkpointing false \ + --weight_decay 0.1 \ + --gradient_accumulation_steps 1 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 1 \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --use_tiled_mlp true \ + --tiled_mlp_num_shards 4 diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 841bdb9ffa..5673826376 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -51,6 +51,10 @@ def _prepare_generation_config(self): @RayHelper.function(group='default') def _prepare_model_tokenizer(self, **kwargs): args = self.args + # Apply tiled MLP before model instantiation + if getattr(args, 'use_tiled_mlp', False): + from swift.plugin.tiled_mlp import apply_tiled_mlp + apply_tiled_mlp(args.model_type, num_shards=getattr(args, 'tiled_mlp_num_shards', None)) self.model, self.processor = args.get_model_processor(**kwargs) if args.sequence_parallel_size > 1: from swift.trainers.sequence_parallel import sequence_parallel @@ -265,6 +269,7 @@ def train(self, trainer): @RayHelper.function(group='default') def _prepare_callbacks(self): from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback + from swift.plugin import ActivationCpuOffloadCallBack args = self.args callbacks = [] if args.lisa_activated_layers > 0: @@ -275,6 +280,10 @@ def _prepare_callbacks(self): model=self.model) lisa_callback.switch_active_layers() # Make trainable parameters printing a correct value callbacks.append(lisa_callback) + # Check activation_cpu_offload from fsdp_config + fsdp_config = getattr(self.args, 'fsdp_config', {}) + if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False): + callbacks.append(ActivationCpuOffloadCallBack()) if args.is_adapter and args.train_type == 'adalora': callbacks.append(TrainerAdapterCallback(args)) diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py index 286a9f4b04..6cc75e496e 100644 --- a/swift/llm/train/tuner.py +++ b/swift/llm/train/tuner.py @@ -86,6 +86,73 @@ def apply_liger(model_type: str): 'by running `pip install -U liger-kernel`') +def apply_cce(model_type: str): + try: + from cut_cross_entropy.transformers import cce_patch + from swift.llm import ModelType + except ImportError: + raise ImportError('Please upgrade cut-cross-entropy to apply cce kernels to this model ' + 'by running `pip install "cut-cross-entropy[transformers] @ ' + 'git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`') + + model_type_map = { + # llama family + ModelType.llama: 'llama', + ModelType.llama3: 'llama', + ModelType.llama3_1: 'llama', + ModelType.llama3_2: 'llama', + ModelType.llama4: 'llama4', + ModelType.llama3_2_vision: 'mllama', + # mistral & mixtral family + ModelType.mistral: 'mistral', + ModelType.mixtral: 'mixtral', + # phi + ModelType.phi3: 'phi3', + # gemma family + ModelType.gemma: 'gemma', + ModelType.gemma2: 'gemma2', + ModelType.gemma3_text: 'gemma3_text', + ModelType.gemma3_vision: 'gemma3', + ModelType.gemma3n: 'gemma3n', + # glm4 family + ModelType.glm4: 'glm4', + ModelType.glm4_0414: 'glm4', + ModelType.glm4_5: 'glm4_moe', + ModelType.glm4_z1_rumination: 'glm4_moe', + ModelType.glm4v: 'glm4v', + ModelType.glm4_1v: 'glm4v', + ModelType.glm4_5v: 'glm4v_moe', + # llava + ModelType.llava1_5_hf: 'llava', + ModelType.llava_llama3_hf: 'llava', + # qwen2 family + ModelType.qwen2: 'qwen2', + ModelType.qwen2_5: 'qwen2', + ModelType.qwen2_vl: 'qwen2_vl', + ModelType.qwen2_5_vl: 'qwen2_5_vl', + # qwen3 family + ModelType.qwen3: 'qwen3', + ModelType.qwen3_guard: 'qwen3', + ModelType.qwen3_thinking: 'qwen3', + ModelType.qwen3_nothinking: 'qwen3', + ModelType.qwen3_coder: 'qwen3', + ModelType.qwen3_moe: 'qwen3_moe', + ModelType.qwen3_moe_thinking: 'qwen3_moe', + ModelType.qwen3_next: 'qwen3_next', + ModelType.qwen3_next_thinking: 'qwen3_next', + ModelType.qwen3_vl: 'qwen3_vl', + ModelType.qwen3_moe_vl: 'qwen3_vl_moe', + } + + cce_model_type = model_type_map.get(model_type) + if cce_model_type: + cce_patch(cce_model_type) + return + + supported_models = ', '.join(sorted(set(model_type_map.values()))) + raise ValueError(f'Unsupported cce model_type: {model_type}. Supported types: {supported_models}') + + def get_multimodal_target_regex( model, *, @@ -375,6 +442,9 @@ def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_t # Apply liger apply_liger(args.model_type) + if args.use_cce and 'use_cce' not in inspect.signature(TrainingArguments).parameters: + apply_cce(args.model_type) + if args.is_adapter: if args.tuner_backend != 'unsloth' and args.train_type not in extra_tuners: # Fix the name of the layer in xcomposer that contains Plora. diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index 870ece61cd..0b36bf27d6 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -17,6 +17,8 @@ from .rm_plugin import rm_plugins from .env import envs, Env from .context_manager import context_managers, ContextManager + from .tiled_mlp import (TiledSwiGLUMLP, apply_tiled_mlp, is_fsdp2_enabled, is_fsdp1_enabled, get_tiled_mlp_mode) + from swift.plugin.activation_cpu_offload import ActivationCpuOffloadCallBack else: _import_structure = { diff --git a/swift/plugin/activation_cpu_offload.py b/swift/plugin/activation_cpu_offload.py new file mode 100644 index 0000000000..dc348805dc --- /dev/null +++ b/swift/plugin/activation_cpu_offload.py @@ -0,0 +1,612 @@ +"""Functionality for CPU offloading of tensors saved for backward pass.""" +from __future__ import annotations +import functools +import logging +import os +from contextlib import nullcontext +from typing import Any, Dict, Optional + +import torch +from torch.distributed.fsdp import FSDPModule as FSDP2 +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import TrainerCallback +from transformers.trainer_callback import TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from swift.utils import get_logger + +logger = get_logger() +logger.setLevel(logging.WARNING) + + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + if hasattr(torch, 'npu') and callable(getattr(torch.npu, 'is_available', None)): + return torch.npu.is_available() + return False + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def _get_unique_tensor_key(tensor): + key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype) + return key + + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = 'cuda' + elif is_npu_available: + device = 'npu' + else: + device = 'cpu' + return device + + +class FSDPParameterFilter: + + def __init__(self): + self.model_parameters_storage = set() + + def __call__(self, tensor): + return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage + + def update_model_parameters(self, model): + new_storage = set() + for p in model.parameters(): + new_storage.add(p.data.untyped_storage().data_ptr()) + self.model_parameters_storage = new_storage + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda + + +class CpuOffloadHookWithOffloadHandler: + """Context-manager that offloads/recovers tensors through an offload hander. + + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs + self.inside_context = False + + def __enter__(self): + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) + + def __exit__(self, *args: Any): + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + '`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your ' + 'custom tensor_push.') + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + '`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your ' + 'custom tensor_pop.') + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device='cpu', + pin_memory=pin_memory, + ) + cpu_backup.copy_(src_tensor, non_blocking=True) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + self.group_offload_mapping = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + # allocate streams and events for synchronization + self.d2h_stream = get_torch_device().Stream() + self.h2d_stream = get_torch_device().Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + torch_stray_tensor = isinstance( + tensor, + torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor, + ) + need_offload = not torch_stray_tensor + need_offload = need_offload and self.tensor_need_offloading_checker(tensor) + + if need_offload: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + + assert tensor_tag not in self.tensor_tag_to_state + self.tensor_tag_to_state[tensor_tag] = tensor + + if self.current_group < self.num_offload_group: + self.tensor_tag_to_buf[tensor_tag] = tensor + else: + tensor_tag = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + if isinstance(tensor_tag, torch.Tensor): + return tensor_tag + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + offload_mapping = {} + offload_size = 0 + with get_torch_device().stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + key = _get_unique_tensor_key(state) + if key not in offload_mapping: + offload_mapping[key] = state + # if offload, return the reference to cpu copy + self.tensor_tag_to_state[tensor_tag] = (key, state.shape) + for key, tensor in offload_mapping.items(): + state = SynchronizedGroupOffloadHandler.offload(tensor) + offload_size += tensor.numel() * tensor.element_size() + offload_mapping[key] = state + + self.group_offload_mapping[group_to_offload] = offload_mapping + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + # Stream synchronization both ways + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + @torch.no_grad + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + with get_torch_device().stream(self.h2d_stream): + # move back tensors + offload_mapping = self.group_offload_mapping.pop(group_to_reload) + assert offload_mapping is not None + for key, state in offload_mapping.items(): + offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state) + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload and not isinstance(state, torch.Tensor): + assert isinstance(state, tuple), f'{group_id} {state}' + key, shape = state + recovered_tensor = offload_mapping[key].view(shape) + self.tensor_tag_to_state[tensor_label] = recovered_tensor + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + # Stream synchronization both ways + self.h2d_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + get_torch_device().current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_activation_offload_context(num_layers: int = 1, + model_layers: int = 1, + tensor_need_offloading_checker=(lambda t: True)): + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + + +class ActivationHandler: + + def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt): + self._offload_ctx = offload_ctx + self._sync_func = sync_func + self._enable_ckpt = enable_ckpt + self._tensor_filter = tensor_filter + if enable_ckpt: + self.checkpoint_fn = functools.partial( + torch.utils.checkpoint.checkpoint, + use_reentrant=True, + ) + + def pre_forward(self, module): + if module.training: + self._offload_ctx.__enter__() + self._tensor_filter.update_model_parameters(module) + + def post_forward(self, module): + if module.training: + self._offload_ctx.__exit__(None, None, None) + + def _pack_kwargs(self, *args, **kwargs): + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + def _unpack_kwargs(self, flat_args, kwarg_keys): + assert len(kwarg_keys) <= len(flat_args), f'too many keys {len(kwarg_keys)} vs. {len(flat_args)}' + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[:-len(kwarg_keys)] + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys):], strict=True)) + return args, kwargs + + def _ckpt_forward(self, forward_method, *args, **kwargs): + flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs) + + def my_function(*inputs): + # unpack back into args and kwargs + nonlocal forward_method, kwarg_keys + unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys) + # run original module + return forward_method(*unpacked_args, **unpacked_kwargs) + + return self.checkpoint_fn( + my_function, + *flat_args, + ) + + def forward(self, module, forward_method, *args, **kwargs): + if not module.training: + return forward_method(*args, **kwargs) + if not self._enable_ckpt: + ret = forward_method(*args, **kwargs) + else: + ret = self._ckpt_forward(forward_method, *args, **kwargs) + binded_tensor = ret + if isinstance(ret, tuple): + binded_tensor = ret[0] + binded_tensor = self._sync_func(binded_tensor) + final_ret = binded_tensor + if isinstance(ret, tuple): + final_ret = (final_ret, ) + ret[1:] + return final_ret + + def wrap_module_forward_method(self, module): + orig_method = module.forward + handler = self + + @functools.wraps(orig_method) + def wrapped_method(model_self, *args, **kwargs): + nonlocal handler + handler.pre_forward(model_self) + out = handler.forward(model_self, orig_method, *args, **kwargs) + handler.post_forward(model_self) + return out + + module.forward = wrapped_method.__get__(module, type(module)) + + +def enable_activation_offloading(model, strategy, enable_ckpt=False): + """ + Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation + groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th + activation group happen at the same time, and there are at most two activation groups in GPU memory. + + Args: + model: the model to enable activation offloading + strategy: the training strategy of the model, such as "fsdp" + enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model + + Note: + For best efficiency, activation offloading is usually combined with activation checkpointing. However, this + implementation of activation offloading is conflicted with the implementation of activation checkpointing in + some training strategies. This function resolves this conflict, and therefore requires the "strategy" and + "enable_ckpt" arguments. + + Returns: + + """ + + assert strategy == 'fsdp' or strategy == 'fsdp2', 'activation offloading only supports fsdp strategy' + layers = [] + + def get_layers(module): + for name, child in module.named_children(): + if not isinstance(child, FSDP | FSDP2): + get_layers(child) + else: + wrapped_module = child + if isinstance(child, FSDP): + wrapped_module = child._fsdp_wrapped_module + # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation + # size of torch.nn.Embedding is small, so it's not necessary to offload it. + if not isinstance(wrapped_module, torch.nn.Embedding): + layers.append(child) + + get_layers(model) + if len(layers) < 3: + logger.warning(f'Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading') + return + + tensor_filter = FSDPParameterFilter() + context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter) + if enable_ckpt: + # The implementation of activation checkpointing in transformers library is incompatible with + # activation offloading, + # so it will be disabled, but this implementation supports another version of activation checkpointing, so that + # these two features can be enabled at the same time. + for module in model.modules(): + if hasattr(module, 'gradient_checkpointing_disable'): + module.gradient_checkpointing_disable() + + handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt) + for layer in layers: + module = layer + if isinstance(layer, FSDP): + module = module._fsdp_wrapped_module + handler.wrap_module_forward_method(module) + + +class ActivationCpuOffloadCallBack(TrainerCallback): + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of training. + """ + model = kwargs['model'] + + # Check if model is wrapped with FSDP + if isinstance(model, FSDP) or isinstance(model, FSDP2): + if args is not None and hasattr(args, 'fsdp_config'): + fsdp_config = args.fsdp_config + # Check if fsdp_config is a dictionary and has activation_cpu_offload enabled + if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False): + # Get FSDP version from fsdp_config + strategy = fsdp_config.get('version', None) + if strategy is not None: + fsdp_version = 'fsdp' if strategy == 1 else 'fsdp2' + # Get activation checkpointing setting from fsdp_config + enable_ckpt = fsdp_config.get('activation_checkpointing', False) + if enable_ckpt and hasattr(model, 'enable_input_require_grads'): + model.enable_input_require_grads() + enable_activation_offloading(model, strategy=fsdp_version, enable_ckpt=enable_ckpt) diff --git a/swift/plugin/tiled_mlp.py b/swift/plugin/tiled_mlp.py new file mode 100644 index 0000000000..f061e7c385 --- /dev/null +++ b/swift/plugin/tiled_mlp.py @@ -0,0 +1,409 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Tiled MLP implementation for memory-efficient training. + +This module provides a tiled MLP implementation that is compatible with FSDP2. +- FSDP2: Uses custom TiledMLP implementation (this file) +- DeepSpeed/Single GPU: Uses liger_kernel's LigerTiledSwiGLUMLP +- FSDP1: Raises error (not compatible) +""" +import os +import threading +from typing import List, Optional + +import torch +import torch.nn as nn + +from swift.utils import get_logger + +logger = get_logger() + +# ============================================================================ +# FSDP2 Compatible TiledMLP Implementation +# ============================================================================ + + +class GradientAccumulator: + """Gradient accumulator for TiledMLP (FSDP2 compatible)""" + + def __init__(self, params: List[torch.nn.Parameter], total_shards: int, dtype: torch.dtype = None): + self.params = params + self.total_shards = total_shards + self.grad_accumulation_dtype = dtype or torch.float32 + self.accumulated_grads = {} + self.hooks = [] + self.lock = threading.Lock() + + for param in self.params: + if param.grad is not None: + self.accumulated_grads[param] = param.grad.to(self.grad_accumulation_dtype) + param.grad = None + else: + self.accumulated_grads[param] = torch.zeros_like(param, dtype=self.grad_accumulation_dtype) + + def install_hooks(self, is_last_shard: bool): + self._remove_hooks() + + def create_hook(param): + + def hook(grad): + with self.lock: + grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype) + self.accumulated_grads[param] += grad_to_accum_dtype + + if is_last_shard: + param.grad = None # Critical: prevent double accumulation + final_grad = self.accumulated_grads[param].to(param.dtype) + return final_grad + return None + + return hook + + for param in self.params: + if param.requires_grad: + hook = param.register_hook(create_hook(param)) + self.hooks.append(hook) + + def _remove_hooks(self): + for hook in self.hooks: + hook.remove() + self.hooks.clear() + + def cleanup(self): + self._remove_hooks() + + +class TiledMLPFunction(torch.autograd.Function): + """TiledMLP autograd function for FSDP2 compatibility""" + + @staticmethod + def forward(ctx, fn, self, x, shards, compute_params): + ctx.fn = fn + ctx.self = self + ctx.shards = shards + ctx.compute_params = [p for p in compute_params if p.requires_grad] + ctx.save_for_backward(x) + + # Split on dim=-2 (seqlen dimension) + x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) + with torch.no_grad(): + output_shards = [fn(self, x_shard) for x_shard in x_shards] + output_unsharded = torch.cat(output_shards, dim=-2) + return output_unsharded + + @staticmethod + def backward(ctx, *grads): + fn = ctx.fn + (x, ) = ctx.saved_tensors + self = ctx.self + shards = ctx.shards + compute_params = ctx.compute_params + + x_requires_grad = x.requires_grad + x = x.detach() + x.requires_grad_(x_requires_grad) + + # Flatten to [bs*seqlen, hidden_size] + hidden_size = x.shape[-1] + x_shape_orig = x.shape + x = x.view(-1, hidden_size) + incoming_grad = grads[0].view(-1, hidden_size) + + # Pre-allocate input gradient + x_grad = torch.zeros_like(x) + + # Split on dim=0 + x_shards = list(torch.chunk(x, chunks=shards, dim=0)) + + grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype) + + for i, x_shard in enumerate(x_shards): + x_shard.requires_grad_(x_requires_grad) + + shard_step = x_shards[i].shape[0] + shard_offset = i * x_shards[0].shape[0] + + # narrow(0, ...) creates a view that can correctly receive gradients + x_shard.grad = x_grad.narrow(0, shard_offset, shard_step) + incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step) + + is_last_shard = i + 1 == shards + grad_accumulator.install_hooks(is_last_shard) + + with torch.enable_grad(): + output = fn(self, x_shard) + torch.autograd.backward(output, incoming_grad_shard) + + grad_accumulator.cleanup() + del grad_accumulator + + # Restore original shape + x_grad = x_grad.view(x_shape_orig) if x_requires_grad else None + return (None, None, x_grad, None, None) + + +class TiledSwiGLUMLP(nn.Module): + """ + Memory-efficient SwiGLU MLP using tiled computation for FSDP2. + + This module combines SwiGLU activation with tiled processing to handle + very long sequences efficiently. The forward pass is recomputed during + backward to save memory. + + Args: + config: Model configuration with hidden_size and intermediate_size attributes + num_shards: Number of shards to split the sequence. If None, automatically + calculated as ceil(seqlen / hidden_size) + """ + + def __init__(self, config, num_shards: Optional[int] = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_shards = num_shards or 4 # Default to 4 shards + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act = nn.SiLU() + + def _mlp_forward(self, module, x): + """Internal MLP forward function for tiled computation.""" + gate = module.gate_proj(x) + up = module.up_proj(x) + return module.down_proj(module.act(gate) * up) + + def forward(self, x): + """ + Forward pass with tiled computation. + + Args: + x: Input tensor of shape [batch_size, seq_len, hidden_size] + or [seq_len, hidden_size] + Returns: + Output tensor of the same shape as input + """ + compute_params = [ + self.gate_proj.weight, + self.up_proj.weight, + self.down_proj.weight, + ] + return TiledMLPFunction.apply( + self._mlp_forward, + self, + x, + self.num_shards, + compute_params, + ) + + +# ============================================================================ +# Environment Detection Functions +# ============================================================================ + + +def is_fsdp2_enabled() -> bool: + """Check if FSDP2 is enabled via accelerate config.""" + # Check environment variable set by accelerate + if os.environ.get('ACCELERATE_USE_FSDP', 'false').lower() == 'true': + # Check fsdp_version from accelerate config + # FSDP_VERSION is set by accelerate when fsdp_version is specified in config + fsdp_version = os.environ.get('FSDP_VERSION', '1') + if fsdp_version == '2': + return True + # Also check accelerate state if available + try: + from accelerate import PartialState + state = PartialState() + if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: + # Check if fsdp_version is 2 in the plugin + if hasattr(state.fsdp_plugin, 'fsdp_version'): + return state.fsdp_plugin.fsdp_version == 2 + except Exception: + pass + return False + + +def is_fsdp1_enabled() -> bool: + """Check if FSDP1 is enabled via accelerate config.""" + if os.environ.get('ACCELERATE_USE_FSDP', 'false').lower() == 'true': + fsdp_version = os.environ.get('FSDP_VERSION', '1') + if fsdp_version == '2': + return False + # Also check accelerate state if available + try: + from accelerate import PartialState + state = PartialState() + if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: + if hasattr(state.fsdp_plugin, 'fsdp_version'): + return state.fsdp_plugin.fsdp_version != 2 + except Exception: + pass + return True + return False + + +def is_deepspeed_enabled() -> bool: + """Check if DeepSpeed is enabled.""" + from swift.utils import is_deepspeed_enabled as _is_deepspeed_enabled + return _is_deepspeed_enabled() + + +def get_tiled_mlp_mode() -> str: + """ + Determine which tiled MLP implementation to use. + + Returns: + 'fsdp2': Use custom TiledSwiGLUMLP implementation + 'liger': Use liger_kernel's LigerTiledSwiGLUMLP + 'error': FSDP1 detected, should raise error + """ + if is_fsdp2_enabled(): + return 'fsdp2' + elif is_fsdp1_enabled(): + return 'error' + else: + # DeepSpeed, Single GPU, or DDP - use liger kernel + return 'liger' + + +# ============================================================================ +# MLP Replacement Functions +# ============================================================================ + +# Supported model types for tiled MLP +SUPPORTED_MODEL_TYPES = { + 'qwen2', + 'qwen2_5', + 'qwen3', + 'qwen3_vl', +} + + +def _get_mlp_class_for_model(model_type: str) -> str: + """Get the MLP class name for different model architectures.""" + # Map model types to their MLP class names + mlp_class_mapping = { + 'qwen2': 'Qwen2MLP', + 'qwen2_5': 'Qwen2MLP', + 'qwen3': 'Qwen3MLP', + 'qwen3_vl': 'Qwen3VLTextMLP', + } + + if model_type in mlp_class_mapping: + return mlp_class_mapping[model_type] + + # Fallback: capitalize model_type and append 'MLP' + # e.g., 'mistral' -> 'MistralMLP' + return model_type.capitalize() + 'MLP' + + +def apply_tiled_mlp(model_type: str, num_shards: Optional[int] = None): + """ + Apply tiled MLP replacement before model instantiation. + + This function should be called BEFORE loading the model to replace + the MLP class in the transformers module. + + Args: + model_type: The model type (e.g., 'llama', 'qwen2') + num_shards: Number of shards for tiled computation + + Raises: + ValueError: If FSDP1 is detected (not compatible) + """ + mode = get_tiled_mlp_mode() + + if mode == 'error': + raise ValueError('Tiled MLP is not compatible with FSDP1. ' + 'Please use FSDP2 (set fsdp_version: 2 in accelerate config) or DeepSpeed.') + + if mode == 'fsdp2': + _apply_custom_tiled_mlp(model_type, num_shards) + elif mode == 'liger': + _apply_liger_tiled_mlp(model_type, num_shards) + + +def _apply_custom_tiled_mlp(model_type: str, num_shards: Optional[int] = None): + """Apply custom FSDP2-compatible tiled MLP.""" + num_shards = num_shards or 4 + mlp_class_name = _get_mlp_class_for_model(model_type) + + # Get the transformers module for this model + model_module = _get_transformers_module(model_type) + if model_module is None: + raise ValueError(f'Tiled MLP: Could not find transformers module for model_type={model_type}. ' + f'Supported model types: {SUPPORTED_MODEL_TYPES}') + + # Check if MLP class exists in the module + original_mlp_class = getattr(model_module, mlp_class_name, None) + if original_mlp_class is None: + raise ValueError(f'Tiled MLP: Could not find {mlp_class_name} in {model_module.__name__}. ' + f'model_type={model_type} may not be supported.') + + # Create a wrapper class that uses TiledSwiGLUMLP + class TiledMLPWrapper(TiledSwiGLUMLP): + + def __init__(self, config, **kwargs): + super().__init__(config, num_shards=num_shards) + + # Replace the MLP class + setattr(model_module, mlp_class_name, TiledMLPWrapper) + logger.info(f'Tiled MLP: Replaced {mlp_class_name} with TiledSwiGLUMLP (FSDP2 mode, num_shards={num_shards})') + + +def _apply_liger_tiled_mlp(model_type: str, num_shards: Optional[int] = None): + """Apply liger_kernel's tiled MLP implementation.""" + try: + from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP + except ImportError: + raise ImportError('Tiled MLP: liger_kernel not installed or LigerTiledSwiGLUMLP not available. ' + 'Please install liger-kernel: pip install liger-kernel') + + num_shards = num_shards or 4 + mlp_class_name = _get_mlp_class_for_model(model_type) + + model_module = _get_transformers_module(model_type) + if model_module is None: + raise ValueError(f'Tiled MLP: Could not find transformers module for model_type={model_type}. ' + f'Supported model types: {SUPPORTED_MODEL_TYPES}') + + # Check if MLP class exists in the module + original_mlp_class = getattr(model_module, mlp_class_name, None) + if original_mlp_class is None: + raise ValueError(f'Tiled MLP: Could not find {mlp_class_name} in {model_module.__name__}. ' + f'model_type={model_type} may not be supported.') + + # Create a wrapper class + class LigerTiledMLPWrapper(LigerTiledSwiGLUMLP): + + def __init__(self, config, **kwargs): + super().__init__(config, num_shards=num_shards) + + setattr(model_module, mlp_class_name, LigerTiledMLPWrapper) + logger.info(f'Tiled MLP: Replaced {mlp_class_name} with LigerTiledSwiGLUMLP (liger mode, num_shards={num_shards})') + + +def _get_transformers_module(model_type: str): + """Get the transformers modeling module for a given model type.""" + import importlib + + module_mapping = { + 'qwen2': 'transformers.models.qwen2.modeling_qwen2', + 'qwen2_5': 'transformers.models.qwen2.modeling_qwen2', + 'qwen3': 'transformers.models.qwen3.modeling_qwen3', + 'qwen3_vl': 'transformers.models.qwen3_vl.modeling_qwen3_vl', + } + + module_name = module_mapping.get(model_type) + + # Fallback: try to construct module name from model_type + if module_name is None: + base_type = model_type + module_name = f'transformers.models.{base_type}.modeling_{base_type}' + + try: + return importlib.import_module(module_name) + except ImportError: + return None diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 5640bc6a6c..53e7bfb928 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -9,7 +9,7 @@ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments from swift.plugin import loss_mapping -from swift.utils import get_dist_setting, get_logger, is_liger_available, is_mp, json_parse_to_dict +from swift.utils import get_dist_setting, get_logger, is_cce_available, is_liger_available, is_mp, json_parse_to_dict from .optimizers.galore import GaLoreConfig logger = get_logger() @@ -53,6 +53,15 @@ class TrainArgumentsMixin: dataloader_prefetch_factor (Optional[int]): The number of batches loaded in advance by each worker. Defaults to None. use_liger_kernel (bool): Whether to use the Liger kernel for optimization. Defaults to False. + use_tiled_mlp (bool): Whether to use tiled MLP for memory-efficient training. When enabled, the MLP layers + are replaced with a tiled implementation that processes sequences in chunks to reduce memory usage. + - FSDP2: Uses custom TiledSwiGLUMLP implementation (compatible) + - DeepSpeed/Single GPU: Uses liger_kernel's LigerTiledSwiGLUMLP + - FSDP1: Raises error (not compatible) + Defaults to False. + tiled_mlp_num_shards (Optional[int]): Number of shards to split the sequence for tiled MLP computation. + If None, defaults to 4. Larger values reduce memory but may increase computation time. Defaults to None. + use_cce (bool): Whether to use ml-cross-entropy fused kernels for optimization. Defaults to False. check_model (bool): If True, checks local model files for corruption or modification and provides a warning. Should be set to False in an offline environment. Defaults to True. acc_strategy (Literal['token', 'seq']): The strategy for calculating accuracy during training and validation. @@ -165,11 +174,20 @@ def _init_liger(self): except Exception: pass + def _init_cce(self): + if self.use_cce: + assert is_cce_available(), ('use_cce requires cut-cross-entropy, try ' + '`pip install "cut-cross-entropy[transformers] @ ' + 'git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`') + def __post_init__(self): if is_mp() and self.use_liger_kernel: raise ValueError('liger_kernel does not support device_map. ' 'Please use DDP/DeepSpeed for multi-GPU training.') + if self.use_cce and self.use_liger_kernel: + logger.warning('Enabling both use_cce and use_liger_kernel may lead to duplicated kernel patches.') + if self.optimizer is None and (self.vit_lr is not None or self.aligner_lr is not None): self.optimizer = 'multimodal' if self.gradient_accumulation_steps is None: @@ -183,6 +201,7 @@ def __post_init__(self): if self.gradient_checkpointing_kwargs: self.gradient_checkpointing_kwargs = json_parse_to_dict(self.gradient_checkpointing_kwargs) self._init_liger() + self._init_cce() if self.dataloader_num_workers is None: if platform.system() == 'Windows': self.dataloader_num_workers = 0 diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index d1b383e582..b352051710 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -346,9 +346,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if (self.label_smoother is not None or compute_loss_func is not None or loss_scale is not None or self.args.enable_dft_loss or self.args.enable_channel_loss or self.template.sequence_parallel_size > 1) and 'labels' in inputs: - if self.args.use_liger_kernel: - logger.warning_once('The cross_entropy loss function defined in Liger Kernel will not ' - 'take effect, potentially leading to increased GPU memory consumption.') + if self.args.use_liger_kernel or getattr(self.args, 'use_cce', False): + logger.warning_once('The cross_entropy loss function defined in Liger Kernel or ml-cross-entropy will ' + 'not take effect, potentially leading to increased GPU memory consumption.') labels = inputs.pop('labels') outputs = model(**inputs) if getattr(outputs, 'aux_loss', None) is not None: diff --git a/swift/ui/llm_grpo/llm_grpo.py b/swift/ui/llm_grpo/llm_grpo.py index 0cd13462fe..ff6256e4ee 100644 --- a/swift/ui/llm_grpo/llm_grpo.py +++ b/swift/ui/llm_grpo/llm_grpo.py @@ -157,6 +157,16 @@ class LLMGRPO(LLMTrain): 'en': 'Liger kernel can reduce memory usage' } }, + 'use_cce': { + 'label': { + 'zh': '使用CCE加速', + 'en': 'Use CCE acceleration' + }, + 'info': { + 'zh': 'CCE (ml-cross-entropy) 提供融合的交叉熵算子', + 'en': 'CCE (ml-cross-entropy) provides fused cross-entropy kernels' + } + }, 'sequence_parallel_size': { 'label': { 'zh': '序列并行大小', @@ -233,6 +243,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): gr.Textbox(elem_id='seed', scale=4) gr.Dropdown(elem_id='torch_dtype', scale=4) gr.Checkbox(elem_id='use_liger_kernel', scale=4) + gr.Checkbox(elem_id='use_cce', scale=4) gr.Textbox(elem_id='sequence_parallel_size', lines=1, scale=4) with gr.Row(): gr.Dropdown( diff --git a/swift/ui/llm_rlhf/llm_rlhf.py b/swift/ui/llm_rlhf/llm_rlhf.py index d7f1f740c7..1b8c028fce 100644 --- a/swift/ui/llm_rlhf/llm_rlhf.py +++ b/swift/ui/llm_rlhf/llm_rlhf.py @@ -170,6 +170,16 @@ class LLMRLHF(LLMTrain): 'en': 'Liger kernel can reduce memory usage' } }, + 'use_cce': { + 'label': { + 'zh': '使用CCE加速', + 'en': 'Use CCE acceleration' + }, + 'info': { + 'zh': 'CCE (ml-cross-entropy) 提供融合的交叉熵算子', + 'en': 'CCE (ml-cross-entropy) provides fused cross-entropy kernels' + } + }, 'sequence_parallel_size': { 'label': { 'zh': '序列并行大小', @@ -246,6 +256,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): gr.Textbox(elem_id='seed', scale=2) gr.Dropdown(elem_id='torch_dtype', scale=2) gr.Checkbox(elem_id='use_liger_kernel', scale=2) + gr.Checkbox(elem_id='use_cce', scale=2) with gr.Row(): gr.Dropdown( elem_id='gpu_id', diff --git a/swift/ui/llm_train/llm_train.py b/swift/ui/llm_train/llm_train.py index b1b2b94f6d..a6c1c7dcce 100644 --- a/swift/ui/llm_train/llm_train.py +++ b/swift/ui/llm_train/llm_train.py @@ -177,6 +177,16 @@ class LLMTrain(BaseUI): 'en': 'Liger kernel can reduce memory usage' } }, + 'use_cce': { + 'label': { + 'zh': '使用CCE加速', + 'en': 'Use CCE acceleration' + }, + 'info': { + 'zh': 'CCE (ml-cross-entropy) 提供融合的交叉熵算子', + 'en': 'CCE (ml-cross-entropy) provides fused cross-entropy kernels' + } + }, 'sequence_parallel_size': { 'label': { 'zh': '序列并行大小', @@ -257,6 +267,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): gr.Textbox(elem_id='seed', scale=4) gr.Dropdown(elem_id='torch_dtype', scale=4) gr.Checkbox(elem_id='use_liger_kernel', scale=4) + gr.Checkbox(elem_id='use_cce', scale=4) with gr.Row(): gr.Dropdown( elem_id='gpu_id', @@ -390,6 +401,9 @@ def train(cls, *args): use_liger_kernel = kwargs.get('use_liger_kernel', None) if use_liger_kernel: kwargs.pop('use_liger_kernel') + use_cce = kwargs.get('use_cce', None) + if use_cce: + kwargs.pop('use_cce') if other_kwargs.get('use_muon'): kwargs['use_muon'] = other_kwargs.pop('use_muon') @@ -428,6 +442,9 @@ def train(cls, *args): if use_liger_kernel: params += f'--use_liger_kernel {cls.quote}{use_liger_kernel}{cls.quote} ' command.extend(['--use_liger_kernel', f'{use_liger_kernel}']) + if use_cce: + params += f'--use_cce {cls.quote}{use_cce}{cls.quote} ' + command.extend(['--use_cce', f'{use_cce}']) if use_muon: params += f'--optimizer {cls.quote}muon{cls.quote} ' command.extend(['--optimizer', 'muon']) diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index dccf48f0be..2e389209e0 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -2,7 +2,7 @@ from .env import (get_dist_setting, get_hf_endpoint, get_node_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist, is_last_rank, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, use_hf_hub) -from .import_utils import (is_flash_attn_2_available, is_flash_attn_3_available, is_liger_available, +from .import_utils import (is_cce_available, is_flash_attn_2_available, is_flash_attn_3_available, is_liger_available, is_lmdeploy_available, is_megatron_available, is_swanlab_available, is_trl_available, is_unsloth_available, is_vllm_ascend_available, is_vllm_available, is_wandb_available) from .io_utils import JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, read_from_jsonl, write_to_jsonl diff --git a/swift/utils/import_utils.py b/swift/utils/import_utils.py index 95e9ba1f47..fc46160877 100644 --- a/swift/utils/import_utils.py +++ b/swift/utils/import_utils.py @@ -28,6 +28,10 @@ def is_liger_available(): return importlib.util.find_spec('liger_kernel') is not None +def is_cce_available(): + return importlib.util.find_spec('cut_cross_entropy') is not None + + def is_swanlab_available(): return importlib.util.find_spec('swanlab') is not None diff --git a/tests/train/test_cce.py b/tests/train/test_cce.py new file mode 100644 index 0000000000..728aa2f76c --- /dev/null +++ b/tests/train/test_cce.py @@ -0,0 +1,29 @@ +import os + +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' +# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + +kwargs = { + 'per_device_train_batch_size': 64, + 'save_steps': 30, + 'gradient_accumulation_steps': 2, + 'num_train_epochs': 1, +} + + +def test_sft(): + from swift.llm import sft_main, TrainArguments, infer_main, InferArguments + result = sft_main( + TrainArguments( + model='Qwen/Qwen2.5-0.5B-Instruct', + dataset=['gsm8k#1024'], + split_dataset_ratio=0.01, + use_cce=True, + # use_liger_kernel=True, + **kwargs)) + last_model_checkpoint = result['last_model_checkpoint'] + infer_main(InferArguments(adapters=last_model_checkpoint, load_data_args=True)) + + +if __name__ == '__main__': + test_sft()